conn.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. package ldap
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "gopkg.in/asn1-ber.v1"
  12. )
  13. const (
  14. // MessageQuit causes the processMessages loop to exit
  15. MessageQuit = 0
  16. // MessageRequest sends a request to the server
  17. MessageRequest = 1
  18. // MessageResponse receives a response from the server
  19. MessageResponse = 2
  20. // MessageFinish indicates the client considers a particular message ID to be finished
  21. MessageFinish = 3
  22. // MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
  23. MessageTimeout = 4
  24. )
  25. // PacketResponse contains the packet or error encountered reading a response
  26. type PacketResponse struct {
  27. // Packet is the packet read from the server
  28. Packet *ber.Packet
  29. // Error is an error encountered while reading
  30. Error error
  31. }
  32. // ReadPacket returns the packet or an error
  33. func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
  34. if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
  35. return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
  36. }
  37. return pr.Packet, pr.Error
  38. }
  39. type messageContext struct {
  40. id int64
  41. // close(done) should only be called from finishMessage()
  42. done chan struct{}
  43. // close(responses) should only be called from processMessages(), and only sent to from sendResponse()
  44. responses chan *PacketResponse
  45. }
  46. // sendResponse should only be called within the processMessages() loop which
  47. // is also responsible for closing the responses channel.
  48. func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
  49. select {
  50. case msgCtx.responses <- packet:
  51. // Successfully sent packet to message handler.
  52. case <-msgCtx.done:
  53. // The request handler is done and will not receive more
  54. // packets.
  55. }
  56. }
  57. type messagePacket struct {
  58. Op int
  59. MessageID int64
  60. Packet *ber.Packet
  61. Context *messageContext
  62. }
  63. type sendMessageFlags uint
  64. const (
  65. startTLS sendMessageFlags = 1 << iota
  66. )
  67. // Conn represents an LDAP Connection
  68. type Conn struct {
  69. conn net.Conn
  70. isTLS bool
  71. closing uint32
  72. closeErr atomicValue
  73. isStartingTLS bool
  74. Debug debugging
  75. chanConfirm chan struct{}
  76. messageContexts map[int64]*messageContext
  77. chanMessage chan *messagePacket
  78. chanMessageID chan int64
  79. wgClose sync.WaitGroup
  80. outstandingRequests uint
  81. messageMutex sync.Mutex
  82. requestTimeout int64
  83. }
  84. var _ Client = &Conn{}
  85. // DefaultTimeout is a package-level variable that sets the timeout value
  86. // used for the Dial and DialTLS methods.
  87. //
  88. // WARNING: since this is a package-level variable, setting this value from
  89. // multiple places will probably result in undesired behaviour.
  90. var DefaultTimeout = 60 * time.Second
  91. // Dial connects to the given address on the given network using net.Dial
  92. // and then returns a new Conn for the connection.
  93. func Dial(network, addr string) (*Conn, error) {
  94. c, err := net.DialTimeout(network, addr, DefaultTimeout)
  95. if err != nil {
  96. return nil, NewError(ErrorNetwork, err)
  97. }
  98. conn := NewConn(c, false)
  99. conn.Start()
  100. return conn, nil
  101. }
  102. // DialTLS connects to the given address on the given network using tls.Dial
  103. // and then returns a new Conn for the connection.
  104. func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
  105. dc, err := net.DialTimeout(network, addr, DefaultTimeout)
  106. if err != nil {
  107. return nil, NewError(ErrorNetwork, err)
  108. }
  109. c := tls.Client(dc, config)
  110. err = c.Handshake()
  111. if err != nil {
  112. // Handshake error, close the established connection before we return an error
  113. dc.Close()
  114. return nil, NewError(ErrorNetwork, err)
  115. }
  116. conn := NewConn(c, true)
  117. conn.Start()
  118. return conn, nil
  119. }
  120. // NewConn returns a new Conn using conn for network I/O.
  121. func NewConn(conn net.Conn, isTLS bool) *Conn {
  122. return &Conn{
  123. conn: conn,
  124. chanConfirm: make(chan struct{}),
  125. chanMessageID: make(chan int64),
  126. chanMessage: make(chan *messagePacket, 10),
  127. messageContexts: map[int64]*messageContext{},
  128. requestTimeout: 0,
  129. isTLS: isTLS,
  130. }
  131. }
  132. // Start initializes goroutines to read responses and process messages
  133. func (l *Conn) Start() {
  134. go l.reader()
  135. go l.processMessages()
  136. l.wgClose.Add(1)
  137. }
  138. // isClosing returns whether or not we're currently closing.
  139. func (l *Conn) isClosing() bool {
  140. return atomic.LoadUint32(&l.closing) == 1
  141. }
  142. // setClosing sets the closing value to true
  143. func (l *Conn) setClosing() bool {
  144. return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
  145. }
  146. // Close closes the connection.
  147. func (l *Conn) Close() {
  148. l.messageMutex.Lock()
  149. defer l.messageMutex.Unlock()
  150. if l.setClosing() {
  151. l.Debug.Printf("Sending quit message and waiting for confirmation")
  152. l.chanMessage <- &messagePacket{Op: MessageQuit}
  153. <-l.chanConfirm
  154. close(l.chanMessage)
  155. l.Debug.Printf("Closing network connection")
  156. if err := l.conn.Close(); err != nil {
  157. log.Println(err)
  158. }
  159. l.wgClose.Done()
  160. }
  161. l.wgClose.Wait()
  162. }
  163. // SetTimeout sets the time after a request is sent that a MessageTimeout triggers
  164. func (l *Conn) SetTimeout(timeout time.Duration) {
  165. if timeout > 0 {
  166. atomic.StoreInt64(&l.requestTimeout, int64(timeout))
  167. }
  168. }
  169. // Returns the next available messageID
  170. func (l *Conn) nextMessageID() int64 {
  171. if messageID, ok := <-l.chanMessageID; ok {
  172. return messageID
  173. }
  174. return 0
  175. }
  176. // StartTLS sends the command to start a TLS session and then creates a new TLS Client
  177. func (l *Conn) StartTLS(config *tls.Config) error {
  178. if l.isTLS {
  179. return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
  180. }
  181. packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
  182. packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
  183. request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
  184. request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
  185. packet.AppendChild(request)
  186. l.Debug.PrintPacket(packet)
  187. msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
  188. if err != nil {
  189. return err
  190. }
  191. defer l.finishMessage(msgCtx)
  192. l.Debug.Printf("%d: waiting for response", msgCtx.id)
  193. packetResponse, ok := <-msgCtx.responses
  194. if !ok {
  195. return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
  196. }
  197. packet, err = packetResponse.ReadPacket()
  198. l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
  199. if err != nil {
  200. return err
  201. }
  202. if l.Debug {
  203. if err := addLDAPDescriptions(packet); err != nil {
  204. l.Close()
  205. return err
  206. }
  207. ber.PrintPacket(packet)
  208. }
  209. if resultCode, message := getLDAPResultCode(packet); resultCode == LDAPResultSuccess {
  210. conn := tls.Client(l.conn, config)
  211. if err := conn.Handshake(); err != nil {
  212. l.Close()
  213. return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", err))
  214. }
  215. l.isTLS = true
  216. l.conn = conn
  217. } else {
  218. return NewError(resultCode, fmt.Errorf("ldap: cannot StartTLS (%s)", message))
  219. }
  220. go l.reader()
  221. return nil
  222. }
  223. func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
  224. return l.sendMessageWithFlags(packet, 0)
  225. }
  226. func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
  227. if l.isClosing() {
  228. return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
  229. }
  230. l.messageMutex.Lock()
  231. l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
  232. if l.isStartingTLS {
  233. l.messageMutex.Unlock()
  234. return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
  235. }
  236. if flags&startTLS != 0 {
  237. if l.outstandingRequests != 0 {
  238. l.messageMutex.Unlock()
  239. return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
  240. }
  241. l.isStartingTLS = true
  242. }
  243. l.outstandingRequests++
  244. l.messageMutex.Unlock()
  245. responses := make(chan *PacketResponse)
  246. messageID := packet.Children[0].Value.(int64)
  247. message := &messagePacket{
  248. Op: MessageRequest,
  249. MessageID: messageID,
  250. Packet: packet,
  251. Context: &messageContext{
  252. id: messageID,
  253. done: make(chan struct{}),
  254. responses: responses,
  255. },
  256. }
  257. l.sendProcessMessage(message)
  258. return message.Context, nil
  259. }
  260. func (l *Conn) finishMessage(msgCtx *messageContext) {
  261. close(msgCtx.done)
  262. if l.isClosing() {
  263. return
  264. }
  265. l.messageMutex.Lock()
  266. l.outstandingRequests--
  267. if l.isStartingTLS {
  268. l.isStartingTLS = false
  269. }
  270. l.messageMutex.Unlock()
  271. message := &messagePacket{
  272. Op: MessageFinish,
  273. MessageID: msgCtx.id,
  274. }
  275. l.sendProcessMessage(message)
  276. }
  277. func (l *Conn) sendProcessMessage(message *messagePacket) bool {
  278. l.messageMutex.Lock()
  279. defer l.messageMutex.Unlock()
  280. if l.isClosing() {
  281. return false
  282. }
  283. l.chanMessage <- message
  284. return true
  285. }
  286. func (l *Conn) processMessages() {
  287. defer func() {
  288. if err := recover(); err != nil {
  289. log.Printf("ldap: recovered panic in processMessages: %v", err)
  290. }
  291. for messageID, msgCtx := range l.messageContexts {
  292. // If we are closing due to an error, inform anyone who
  293. // is waiting about the error.
  294. if l.isClosing() && l.closeErr.Load() != nil {
  295. msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
  296. }
  297. l.Debug.Printf("Closing channel for MessageID %d", messageID)
  298. close(msgCtx.responses)
  299. delete(l.messageContexts, messageID)
  300. }
  301. close(l.chanMessageID)
  302. close(l.chanConfirm)
  303. }()
  304. var messageID int64 = 1
  305. for {
  306. select {
  307. case l.chanMessageID <- messageID:
  308. messageID++
  309. case message := <-l.chanMessage:
  310. switch message.Op {
  311. case MessageQuit:
  312. l.Debug.Printf("Shutting down - quit message received")
  313. return
  314. case MessageRequest:
  315. // Add to message list and write to network
  316. l.Debug.Printf("Sending message %d", message.MessageID)
  317. buf := message.Packet.Bytes()
  318. _, err := l.conn.Write(buf)
  319. if err != nil {
  320. l.Debug.Printf("Error Sending Message: %s", err.Error())
  321. message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
  322. close(message.Context.responses)
  323. break
  324. }
  325. // Only add to messageContexts if we were able to
  326. // successfully write the message.
  327. l.messageContexts[message.MessageID] = message.Context
  328. // Add timeout if defined
  329. requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
  330. if requestTimeout > 0 {
  331. go func() {
  332. defer func() {
  333. if err := recover(); err != nil {
  334. log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
  335. }
  336. }()
  337. time.Sleep(requestTimeout)
  338. timeoutMessage := &messagePacket{
  339. Op: MessageTimeout,
  340. MessageID: message.MessageID,
  341. }
  342. l.sendProcessMessage(timeoutMessage)
  343. }()
  344. }
  345. case MessageResponse:
  346. l.Debug.Printf("Receiving message %d", message.MessageID)
  347. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  348. msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
  349. } else {
  350. log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing())
  351. ber.PrintPacket(message.Packet)
  352. }
  353. case MessageTimeout:
  354. // Handle the timeout by closing the channel
  355. // All reads will return immediately
  356. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  357. l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
  358. msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
  359. delete(l.messageContexts, message.MessageID)
  360. close(msgCtx.responses)
  361. }
  362. case MessageFinish:
  363. l.Debug.Printf("Finished message %d", message.MessageID)
  364. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  365. delete(l.messageContexts, message.MessageID)
  366. close(msgCtx.responses)
  367. }
  368. }
  369. }
  370. }
  371. }
  372. func (l *Conn) reader() {
  373. cleanstop := false
  374. defer func() {
  375. if err := recover(); err != nil {
  376. log.Printf("ldap: recovered panic in reader: %v", err)
  377. }
  378. if !cleanstop {
  379. l.Close()
  380. }
  381. }()
  382. for {
  383. if cleanstop {
  384. l.Debug.Printf("reader clean stopping (without closing the connection)")
  385. return
  386. }
  387. packet, err := ber.ReadPacket(l.conn)
  388. if err != nil {
  389. // A read error is expected here if we are closing the connection...
  390. if !l.isClosing() {
  391. l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
  392. l.Debug.Printf("reader error: %s", err.Error())
  393. }
  394. return
  395. }
  396. addLDAPDescriptions(packet)
  397. if len(packet.Children) == 0 {
  398. l.Debug.Printf("Received bad ldap packet")
  399. continue
  400. }
  401. l.messageMutex.Lock()
  402. if l.isStartingTLS {
  403. cleanstop = true
  404. }
  405. l.messageMutex.Unlock()
  406. message := &messagePacket{
  407. Op: MessageResponse,
  408. MessageID: packet.Children[0].Value.(int64),
  409. Packet: packet,
  410. }
  411. if !l.sendProcessMessage(message) {
  412. return
  413. }
  414. }
  415. }