carrier_test.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. package carrier
  2. import (
  3. "bytes"
  4. "io"
  5. "net"
  6. "net/http"
  7. "net/http/httptest"
  8. "sync"
  9. "testing"
  10. ws "github.com/gorilla/websocket"
  11. "github.com/rs/zerolog"
  12. "github.com/stretchr/testify/assert"
  13. )
  14. const (
  15. // example in Sec-Websocket-Key in rfc6455
  16. testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ=="
  17. )
  18. type testStreamer struct {
  19. buf *bytes.Buffer
  20. l sync.RWMutex
  21. }
  22. func newTestStream() *testStreamer {
  23. return &testStreamer{buf: new(bytes.Buffer)}
  24. }
  25. func (s *testStreamer) Read(p []byte) (int, error) {
  26. s.l.RLock()
  27. defer s.l.RUnlock()
  28. return s.buf.Read(p)
  29. }
  30. func (s *testStreamer) Write(p []byte) (int, error) {
  31. s.l.Lock()
  32. defer s.l.Unlock()
  33. return s.buf.Write(p)
  34. }
  35. func TestStartClient(t *testing.T) {
  36. message := "Good morning Austin! Time for another sunny day in the great state of Texas."
  37. log := zerolog.Nop()
  38. wsConn := NewWSConnection(&log)
  39. ts := newTestWebSocketServer()
  40. defer ts.Close()
  41. buf := newTestStream()
  42. options := &StartOptions{
  43. OriginURL: "http://" + ts.Listener.Addr().String(),
  44. Headers: nil,
  45. }
  46. err := StartClient(wsConn, buf, options)
  47. assert.NoError(t, err)
  48. _, _ = buf.Write([]byte(message))
  49. readBuffer := make([]byte, len(message))
  50. _, _ = buf.Read(readBuffer)
  51. assert.Equal(t, message, string(readBuffer))
  52. }
  53. func TestStartServer(t *testing.T) {
  54. listener, err := net.Listen("tcp", "localhost:")
  55. if err != nil {
  56. t.Fatalf("Error starting listener: %v", err)
  57. }
  58. message := "Good morning Austin! Time for another sunny day in the great state of Texas."
  59. log := zerolog.Nop()
  60. shutdownC := make(chan struct{})
  61. wsConn := NewWSConnection(&log)
  62. ts := newTestWebSocketServer()
  63. defer ts.Close()
  64. options := &StartOptions{
  65. OriginURL: "http://" + ts.Listener.Addr().String(),
  66. Headers: nil,
  67. }
  68. go func() {
  69. err := Serve(wsConn, listener, shutdownC, options)
  70. if err != nil {
  71. t.Fatalf("Error running server: %v", err)
  72. }
  73. }()
  74. conn, err := net.Dial("tcp", listener.Addr().String())
  75. _, _ = conn.Write([]byte(message))
  76. readBuffer := make([]byte, len(message))
  77. _, _ = conn.Read(readBuffer)
  78. assert.Equal(t, string(readBuffer), message)
  79. }
  80. func TestIsAccessResponse(t *testing.T) {
  81. validLocationHeader := http.Header{}
  82. validLocationHeader.Add("location", "https://test.cloudflareaccess.com/cdn-cgi/access/login/blahblah")
  83. invalidLocationHeader := http.Header{}
  84. invalidLocationHeader.Add("location", "https://google.com")
  85. testCases := []struct {
  86. Description string
  87. In *http.Response
  88. ExpectedOut bool
  89. }{
  90. {"nil response", nil, false},
  91. {"redirect with no location", &http.Response{StatusCode: http.StatusFound}, false},
  92. {"200 ok", &http.Response{StatusCode: http.StatusOK}, false},
  93. {"redirect with location", &http.Response{StatusCode: http.StatusFound, Header: validLocationHeader}, true},
  94. {"redirect with invalid location", &http.Response{StatusCode: http.StatusFound, Header: invalidLocationHeader}, false},
  95. }
  96. for i, tc := range testCases {
  97. if IsAccessResponse(tc.In) != tc.ExpectedOut {
  98. t.Fatalf("Failed case %d -- %s", i, tc.Description)
  99. }
  100. }
  101. }
  102. func newTestWebSocketServer() *httptest.Server {
  103. upgrader := ws.Upgrader{
  104. ReadBufferSize: 1024,
  105. WriteBufferSize: 1024,
  106. }
  107. return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  108. conn, _ := upgrader.Upgrade(w, r, nil)
  109. defer conn.Close()
  110. for {
  111. mt, message, err := conn.ReadMessage()
  112. if err != nil {
  113. break
  114. }
  115. if err := conn.WriteMessage(mt, []byte(message)); err != nil {
  116. break
  117. }
  118. }
  119. }))
  120. }
  121. func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
  122. req, err := http.NewRequest("GET", url, stream)
  123. if err != nil {
  124. t.Fatalf("testRequestHeader error")
  125. }
  126. req.Header.Add("Connection", "Upgrade")
  127. req.Header.Add("Upgrade", "WebSocket")
  128. req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
  129. req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
  130. req.Header.Add("Sec-Websocket-Version", "13")
  131. req.Header.Add("User-Agent", "curl/7.59.0")
  132. return req
  133. }