wstest.go 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. package test
  2. // copied from https://github.com/nhooyr/websocket/blob/master/internal/test/wstest/pipe.go
  3. import (
  4. "bufio"
  5. "context"
  6. "net"
  7. "net/http"
  8. "net/http/httptest"
  9. "nhooyr.io/websocket"
  10. )
  11. // Pipe is used to create an in memory connection
  12. // between two websockets analogous to net.Pipe.
  13. func WSPipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (clientConn, serverConn *websocket.Conn) {
  14. tt := fakeTransport{
  15. h: func(w http.ResponseWriter, r *http.Request) {
  16. serverConn, _ = websocket.Accept(w, r, acceptOpts)
  17. },
  18. }
  19. if dialOpts == nil {
  20. dialOpts = &websocket.DialOptions{}
  21. }
  22. dialOpts = &*dialOpts
  23. dialOpts.HTTPClient = &http.Client{
  24. Transport: tt,
  25. }
  26. clientConn, _, _ = websocket.Dial(context.Background(), "ws://example.com", dialOpts)
  27. return clientConn, serverConn
  28. }
  29. type fakeTransport struct {
  30. h http.HandlerFunc
  31. }
  32. func (t fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  33. clientConn, serverConn := net.Pipe()
  34. hj := testHijacker{
  35. ResponseRecorder: httptest.NewRecorder(),
  36. serverConn: serverConn,
  37. }
  38. t.h.ServeHTTP(hj, r)
  39. resp := hj.ResponseRecorder.Result()
  40. if resp.StatusCode == http.StatusSwitchingProtocols {
  41. resp.Body = clientConn
  42. }
  43. return resp, nil
  44. }
  45. type testHijacker struct {
  46. *httptest.ResponseRecorder
  47. serverConn net.Conn
  48. }
  49. var _ http.Hijacker = testHijacker{}
  50. func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  51. return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil
  52. }