diff options
Diffstat (limited to 'vendor/github.com/go-ldap/ldap/conn.go')
-rw-r--r-- | vendor/github.com/go-ldap/ldap/conn.go | 147 |
1 files changed, 95 insertions, 52 deletions
diff --git a/vendor/github.com/go-ldap/ldap/conn.go b/vendor/github.com/go-ldap/ldap/conn.go index 6aad628be..b5bd99adb 100644 --- a/vendor/github.com/go-ldap/ldap/conn.go +++ b/vendor/github.com/go-ldap/ldap/conn.go @@ -17,18 +17,27 @@ import ( ) const ( - MessageQuit = 0 - MessageRequest = 1 + // MessageQuit causes the processMessages loop to exit + MessageQuit = 0 + // MessageRequest sends a request to the server + MessageRequest = 1 + // MessageResponse receives a response from the server MessageResponse = 2 - MessageFinish = 3 - MessageTimeout = 4 + // MessageFinish indicates the client considers a particular message ID to be finished + MessageFinish = 3 + // MessageTimeout indicates the client-specified timeout for a particular message ID has been reached + MessageTimeout = 4 ) +// PacketResponse contains the packet or error encountered reading a response type PacketResponse struct { + // Packet is the packet read from the server Packet *ber.Packet - Error error + // Error is an error encountered while reading + Error error } +// ReadPacket returns the packet or an error func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) { if (pr == nil) || (pr.Packet == nil && pr.Error == nil) { return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response")) @@ -36,11 +45,31 @@ func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) { return pr.Packet, pr.Error } +type messageContext struct { + id int64 + // close(done) should only be called from finishMessage() + done chan struct{} + // close(responses) should only be called from processMessages(), and only sent to from sendResponse() + responses chan *PacketResponse +} + +// sendResponse should only be called within the processMessages() loop which +// is also responsible for closing the responses channel. +func (msgCtx *messageContext) sendResponse(packet *PacketResponse) { + select { + case msgCtx.responses <- packet: + // Successfully sent packet to message handler. + case <-msgCtx.done: + // The request handler is done and will not receive more + // packets. + } +} + type messagePacket struct { Op int MessageID int64 Packet *ber.Packet - Channel chan *PacketResponse + Context *messageContext } type sendMessageFlags uint @@ -54,10 +83,11 @@ type Conn struct { conn net.Conn isTLS bool isClosing bool + closeErr error isStartingTLS bool Debug debugging chanConfirm chan bool - chanResults map[int64]chan *PacketResponse + messageContexts map[int64]*messageContext chanMessage chan *messagePacket chanMessageID chan int64 wgSender sync.WaitGroup @@ -111,16 +141,17 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { // NewConn returns a new Conn using conn for network I/O. func NewConn(conn net.Conn, isTLS bool) *Conn { return &Conn{ - conn: conn, - chanConfirm: make(chan bool), - chanMessageID: make(chan int64), - chanMessage: make(chan *messagePacket, 10), - chanResults: map[int64]chan *PacketResponse{}, - requestTimeout: 0, - isTLS: isTLS, + conn: conn, + chanConfirm: make(chan bool), + chanMessageID: make(chan int64), + chanMessage: make(chan *messagePacket, 10), + messageContexts: map[int64]*messageContext{}, + requestTimeout: 0, + isTLS: isTLS, } } +// Start initializes goroutines to read responses and process messages func (l *Conn) Start() { go l.reader() go l.processMessages() @@ -148,7 +179,7 @@ func (l *Conn) Close() { l.wgClose.Wait() } -// Sets the time after a request is sent that a MessageTimeout triggers +// SetTimeout sets the time after a request is sent that a MessageTimeout triggers func (l *Conn) SetTimeout(timeout time.Duration) { if timeout > 0 { l.requestTimeout = timeout @@ -167,35 +198,31 @@ func (l *Conn) nextMessageID() int64 { // StartTLS sends the command to start a TLS session and then creates a new TLS Client func (l *Conn) StartTLS(config *tls.Config) error { - messageID := l.nextMessageID() - if l.isTLS { return NewError(ErrorNetwork, errors.New("ldap: already encrypted")) } packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) packet.AppendChild(request) l.Debug.PrintPacket(packet) - channel, err := l.sendMessageWithFlags(packet, startTLS) + msgCtx, err := l.sendMessageWithFlags(packet, startTLS) if err != nil { return err } - if channel == nil { - return NewError(ErrorNetwork, errors.New("ldap: could not send message")) - } + defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", messageID) - defer l.finishMessage(messageID) - packetResponse, ok := <-channel + l.Debug.Printf("%d: waiting for response", msgCtx.id) + + packetResponse, ok := <-msgCtx.responses if !ok { - return NewError(ErrorNetwork, errors.New("ldap: channel closed")) + return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) } packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", messageID, packet) + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) if err != nil { return err } @@ -226,11 +253,11 @@ func (l *Conn) StartTLS(config *tls.Config) error { return nil } -func (l *Conn) sendMessage(packet *ber.Packet) (chan *PacketResponse, error) { +func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) { return l.sendMessageWithFlags(packet, 0) } -func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (chan *PacketResponse, error) { +func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { if l.isClosing { return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) } @@ -238,32 +265,38 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) l.Debug.Printf("flags&startTLS = %d", flags&startTLS) if l.isStartingTLS { l.messageMutex.Unlock() - return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase.")) + return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase")) } if flags&startTLS != 0 { if l.outstandingRequests != 0 { l.messageMutex.Unlock() return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests")) - } else { - l.isStartingTLS = true } + l.isStartingTLS = true } l.outstandingRequests++ l.messageMutex.Unlock() - out := make(chan *PacketResponse) + responses := make(chan *PacketResponse) + messageID := packet.Children[0].Value.(int64) message := &messagePacket{ Op: MessageRequest, - MessageID: packet.Children[0].Value.(int64), + MessageID: messageID, Packet: packet, - Channel: out, + Context: &messageContext{ + id: messageID, + done: make(chan struct{}), + responses: responses, + }, } l.sendProcessMessage(message) - return out, nil + return message.Context, nil } -func (l *Conn) finishMessage(messageID int64) { +func (l *Conn) finishMessage(msgCtx *messageContext) { + close(msgCtx.done) + if l.isClosing { return } @@ -277,7 +310,7 @@ func (l *Conn) finishMessage(messageID int64) { message := &messagePacket{ Op: MessageFinish, - MessageID: messageID, + MessageID: msgCtx.id, } l.sendProcessMessage(message) } @@ -297,10 +330,15 @@ func (l *Conn) processMessages() { if err := recover(); err != nil { log.Printf("ldap: recovered panic in processMessages: %v", err) } - for messageID, channel := range l.chanResults { + for messageID, msgCtx := range l.messageContexts { + // If we are closing due to an error, inform anyone who + // is waiting about the error. + if l.isClosing && l.closeErr != nil { + msgCtx.sendResponse(&PacketResponse{Error: l.closeErr}) + } l.Debug.Printf("Closing channel for MessageID %d", messageID) - close(channel) - delete(l.chanResults, messageID) + close(msgCtx.responses) + delete(l.messageContexts, messageID) } close(l.chanMessageID) l.chanConfirm <- true @@ -324,15 +362,20 @@ func (l *Conn) processMessages() { case MessageRequest: // Add to message list and write to network l.Debug.Printf("Sending message %d", message.MessageID) - l.chanResults[message.MessageID] = message.Channel buf := message.Packet.Bytes() _, err := l.conn.Write(buf) if err != nil { l.Debug.Printf("Error Sending Message: %s", err.Error()) + message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}) + close(message.Context.responses) break } + // Only add to messageContexts if we were able to + // successfully write the message. + l.messageContexts[message.MessageID] = message.Context + // Add timeout if defined if l.requestTimeout > 0 { go func() { @@ -351,8 +394,8 @@ func (l *Conn) processMessages() { } case MessageResponse: l.Debug.Printf("Receiving message %d", message.MessageID) - if chanResult, ok := l.chanResults[message.MessageID]; ok { - chanResult <- &PacketResponse{message.Packet, nil} + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { + msgCtx.sendResponse(&PacketResponse{message.Packet, nil}) } else { log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing) ber.PrintPacket(message.Packet) @@ -360,17 +403,17 @@ func (l *Conn) processMessages() { case MessageTimeout: // Handle the timeout by closing the channel // All reads will return immediately - if chanResult, ok := l.chanResults[message.MessageID]; ok { - chanResult <- &PacketResponse{message.Packet, errors.New("ldap: connection timed out")} + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { l.Debug.Printf("Receiving message timeout for %d", message.MessageID) - delete(l.chanResults, message.MessageID) - close(chanResult) + msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")}) + delete(l.messageContexts, message.MessageID) + close(msgCtx.responses) } case MessageFinish: l.Debug.Printf("Finished message %d", message.MessageID) - if chanResult, ok := l.chanResults[message.MessageID]; ok { - close(chanResult) - delete(l.chanResults, message.MessageID) + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { + delete(l.messageContexts, message.MessageID) + close(msgCtx.responses) } } } @@ -397,6 +440,7 @@ func (l *Conn) reader() { if err != nil { // A read error is expected here if we are closing the connection... if !l.isClosing { + l.closeErr = fmt.Errorf("unable to read LDAP response packet: %s", err) l.Debug.Printf("reader error: %s", err.Error()) } return @@ -419,6 +463,5 @@ func (l *Conn) reader() { if !l.sendProcessMessage(message) { return } - } } |