conn_test.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. package ldap
  2. import (
  3. "bytes"
  4. "errors"
  5. "io"
  6. "net"
  7. "net/http"
  8. "net/http/httptest"
  9. "runtime"
  10. "sync"
  11. "testing"
  12. "time"
  13. "gopkg.in/asn1-ber.v1"
  14. )
  15. func TestUnresponsiveConnection(t *testing.T) {
  16. // The do-nothing server that accepts requests and does nothing
  17. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  18. }))
  19. defer ts.Close()
  20. c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
  21. if err != nil {
  22. t.Fatalf("error connecting to localhost tcp: %v", err)
  23. }
  24. // Create an Ldap connection
  25. conn := NewConn(c, false)
  26. conn.SetTimeout(time.Millisecond)
  27. conn.Start()
  28. defer conn.Close()
  29. // Mock a packet
  30. packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
  31. packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, conn.nextMessageID(), "MessageID"))
  32. bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
  33. bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
  34. packet.AppendChild(bindRequest)
  35. // Send packet and test response
  36. msgCtx, err := conn.sendMessage(packet)
  37. if err != nil {
  38. t.Fatalf("error sending message: %v", err)
  39. }
  40. defer conn.finishMessage(msgCtx)
  41. packetResponse, ok := <-msgCtx.responses
  42. if !ok {
  43. t.Fatalf("no PacketResponse in response channel")
  44. }
  45. packet, err = packetResponse.ReadPacket()
  46. if err == nil {
  47. t.Fatalf("expected timeout error")
  48. }
  49. if err.Error() != "ldap: connection timed out" {
  50. t.Fatalf("unexpected error: %v", err)
  51. }
  52. }
  53. // TestFinishMessage tests that we do not enter deadlock when a goroutine makes
  54. // a request but does not handle all responses from the server.
  55. func TestFinishMessage(t *testing.T) {
  56. ptc := newPacketTranslatorConn()
  57. defer ptc.Close()
  58. conn := NewConn(ptc, false)
  59. conn.Start()
  60. // Test sending 5 different requests in series. Ensure that we can
  61. // get a response packet from the underlying connection and also
  62. // ensure that we can gracefully ignore unhandled responses.
  63. for i := 0; i < 5; i++ {
  64. t.Logf("serial request %d", i)
  65. // Create a message and make sure we can receive responses.
  66. msgCtx := testSendRequest(t, ptc, conn)
  67. testReceiveResponse(t, ptc, msgCtx)
  68. // Send a few unhandled responses and finish the message.
  69. testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
  70. t.Logf("serial request %d done", i)
  71. }
  72. // Test sending 5 different requests in parallel.
  73. var wg sync.WaitGroup
  74. for i := 0; i < 5; i++ {
  75. wg.Add(1)
  76. go func(i int) {
  77. defer wg.Done()
  78. t.Logf("parallel request %d", i)
  79. // Create a message and make sure we can receive responses.
  80. msgCtx := testSendRequest(t, ptc, conn)
  81. testReceiveResponse(t, ptc, msgCtx)
  82. // Send a few unhandled responses and finish the message.
  83. testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
  84. t.Logf("parallel request %d done", i)
  85. }(i)
  86. }
  87. wg.Wait()
  88. // We cannot run Close() in a defer because t.FailNow() will run it and
  89. // it will block if the processMessage Loop is in a deadlock.
  90. conn.Close()
  91. }
  92. func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) {
  93. var msgID int64
  94. runWithTimeout(t, time.Second, func() {
  95. msgID = conn.nextMessageID()
  96. })
  97. requestPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
  98. requestPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgID, "MessageID"))
  99. var err error
  100. runWithTimeout(t, time.Second, func() {
  101. msgCtx, err = conn.sendMessage(requestPacket)
  102. if err != nil {
  103. t.Fatalf("unable to send request message: %s", err)
  104. }
  105. })
  106. // We should now be able to get this request packet out from the other
  107. // side.
  108. runWithTimeout(t, time.Second, func() {
  109. if _, err = ptc.ReceiveRequest(); err != nil {
  110. t.Fatalf("unable to receive request packet: %s", err)
  111. }
  112. })
  113. return msgCtx
  114. }
  115. func testReceiveResponse(t *testing.T, ptc *packetTranslatorConn, msgCtx *messageContext) {
  116. // Send a mock response packet.
  117. responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
  118. responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
  119. runWithTimeout(t, time.Second, func() {
  120. if err := ptc.SendResponse(responsePacket); err != nil {
  121. t.Fatalf("unable to send response packet: %s", err)
  122. }
  123. })
  124. // We should be able to receive the packet from the connection.
  125. runWithTimeout(t, time.Second, func() {
  126. if _, ok := <-msgCtx.responses; !ok {
  127. t.Fatal("response channel closed")
  128. }
  129. })
  130. }
  131. func testSendUnhandledResponsesAndFinish(t *testing.T, ptc *packetTranslatorConn, conn *Conn, msgCtx *messageContext, numResponses int) {
  132. // Send a mock response packet.
  133. responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
  134. responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
  135. // Send extra responses but do not attempt to receive them on the
  136. // client side.
  137. for i := 0; i < numResponses; i++ {
  138. runWithTimeout(t, time.Second, func() {
  139. if err := ptc.SendResponse(responsePacket); err != nil {
  140. t.Fatalf("unable to send response packet: %s", err)
  141. }
  142. })
  143. }
  144. // Finally, attempt to finish this message.
  145. runWithTimeout(t, time.Second, func() {
  146. conn.finishMessage(msgCtx)
  147. })
  148. }
  149. func runWithTimeout(t *testing.T, timeout time.Duration, f func()) {
  150. done := make(chan struct{})
  151. go func() {
  152. f()
  153. close(done)
  154. }()
  155. select {
  156. case <-done: // Success!
  157. case <-time.After(timeout):
  158. _, file, line, _ := runtime.Caller(1)
  159. t.Fatalf("%s:%d timed out", file, line)
  160. }
  161. }
  162. // packetTranslatorConn is a helpful type which can be used with various tests
  163. // in this package. It implements the net.Conn interface to be used as an
  164. // underlying connection for a *ldap.Conn. Most methods are no-ops but the
  165. // Read() and Write() methods are able to translate ber-encoded packets for
  166. // testing LDAP requests and responses.
  167. //
  168. // Test cases can simulate an LDAP server sending a response by calling the
  169. // SendResponse() method with a ber-encoded LDAP response packet. Test cases
  170. // can simulate an LDAP server receiving a request from a client by calling the
  171. // ReceiveRequest() method which returns a ber-encoded LDAP request packet.
  172. type packetTranslatorConn struct {
  173. lock sync.Mutex
  174. isClosed bool
  175. responseCond sync.Cond
  176. requestCond sync.Cond
  177. responseBuf bytes.Buffer
  178. requestBuf bytes.Buffer
  179. }
  180. var errPacketTranslatorConnClosed = errors.New("connection closed")
  181. func newPacketTranslatorConn() *packetTranslatorConn {
  182. conn := &packetTranslatorConn{}
  183. conn.responseCond = sync.Cond{L: &conn.lock}
  184. conn.requestCond = sync.Cond{L: &conn.lock}
  185. return conn
  186. }
  187. // Read is called by the reader() loop to receive response packets. It will
  188. // block until there are more packet bytes available or this connection is
  189. // closed.
  190. func (c *packetTranslatorConn) Read(b []byte) (n int, err error) {
  191. c.lock.Lock()
  192. defer c.lock.Unlock()
  193. for !c.isClosed {
  194. // Attempt to read data from the response buffer. If it fails
  195. // with an EOF, wait and try again.
  196. n, err = c.responseBuf.Read(b)
  197. if err != io.EOF {
  198. return n, err
  199. }
  200. c.responseCond.Wait()
  201. }
  202. return 0, errPacketTranslatorConnClosed
  203. }
  204. // SendResponse writes the given response packet to the response buffer for
  205. // this connection, signalling any goroutine waiting to read a response.
  206. func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error {
  207. c.lock.Lock()
  208. defer c.lock.Unlock()
  209. if c.isClosed {
  210. return errPacketTranslatorConnClosed
  211. }
  212. // Signal any goroutine waiting to read a response.
  213. defer c.responseCond.Broadcast()
  214. // Writes to the buffer should always succeed.
  215. c.responseBuf.Write(packet.Bytes())
  216. return nil
  217. }
  218. // Write is called by the processMessages() loop to send request packets.
  219. func (c *packetTranslatorConn) Write(b []byte) (n int, err error) {
  220. c.lock.Lock()
  221. defer c.lock.Unlock()
  222. if c.isClosed {
  223. return 0, errPacketTranslatorConnClosed
  224. }
  225. // Signal any goroutine waiting to read a request.
  226. defer c.requestCond.Broadcast()
  227. // Writes to the buffer should always succeed.
  228. return c.requestBuf.Write(b)
  229. }
  230. // ReceiveRequest attempts to read a request packet from this connection. It
  231. // will block until it is able to read a full request packet or until this
  232. // connection is closed.
  233. func (c *packetTranslatorConn) ReceiveRequest() (*ber.Packet, error) {
  234. c.lock.Lock()
  235. defer c.lock.Unlock()
  236. for !c.isClosed {
  237. // Attempt to parse a request packet from the request buffer.
  238. // If it fails with an unexpected EOF, wait and try again.
  239. requestReader := bytes.NewReader(c.requestBuf.Bytes())
  240. packet, err := ber.ReadPacket(requestReader)
  241. switch err {
  242. case io.EOF, io.ErrUnexpectedEOF:
  243. c.requestCond.Wait()
  244. case nil:
  245. // Advance the request buffer by the number of bytes
  246. // read to decode the request packet.
  247. c.requestBuf.Next(c.requestBuf.Len() - requestReader.Len())
  248. return packet, nil
  249. default:
  250. return nil, err
  251. }
  252. }
  253. return nil, errPacketTranslatorConnClosed
  254. }
  255. // Close closes this connection causing Read() and Write() calls to fail.
  256. func (c *packetTranslatorConn) Close() error {
  257. c.lock.Lock()
  258. defer c.lock.Unlock()
  259. c.isClosed = true
  260. c.responseCond.Broadcast()
  261. c.requestCond.Broadcast()
  262. return nil
  263. }
  264. func (c *packetTranslatorConn) LocalAddr() net.Addr {
  265. return (*net.TCPAddr)(nil)
  266. }
  267. func (c *packetTranslatorConn) RemoteAddr() net.Addr {
  268. return (*net.TCPAddr)(nil)
  269. }
  270. func (c *packetTranslatorConn) SetDeadline(t time.Time) error {
  271. return nil
  272. }
  273. func (c *packetTranslatorConn) SetReadDeadline(t time.Time) error {
  274. return nil
  275. }
  276. func (c *packetTranslatorConn) SetWriteDeadline(t time.Time) error {
  277. return nil
  278. }