carrier_test.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. package carrier
  2. import (
  3. "bytes"
  4. "io"
  5. "net"
  6. "net/http"
  7. "net/http/httptest"
  8. "sync"
  9. "testing"
  10. ws "github.com/gorilla/websocket"
  11. "github.com/rs/zerolog"
  12. "github.com/stretchr/testify/assert"
  13. )
  14. const (
  15. // example in Sec-Websocket-Key in rfc6455
  16. testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ=="
  17. )
  18. type testStreamer struct {
  19. buf *bytes.Buffer
  20. l sync.RWMutex
  21. }
  22. func newTestStream() *testStreamer {
  23. return &testStreamer{buf: new(bytes.Buffer)}
  24. }
  25. func (s *testStreamer) Read(p []byte) (int, error) {
  26. s.l.RLock()
  27. defer s.l.RUnlock()
  28. return s.buf.Read(p)
  29. }
  30. func (s *testStreamer) Write(p []byte) (int, error) {
  31. s.l.Lock()
  32. defer s.l.Unlock()
  33. return s.buf.Write(p)
  34. }
  35. func TestStartClient(t *testing.T) {
  36. message := "Good morning Austin! Time for another sunny day in the great state of Texas."
  37. log := zerolog.Nop()
  38. wsConn := NewWSConnection(&log)
  39. ts := newTestWebSocketServer()
  40. defer ts.Close()
  41. buf := newTestStream()
  42. options := &StartOptions{
  43. OriginURL: "http://" + ts.Listener.Addr().String(),
  44. Headers: nil,
  45. }
  46. err := StartClient(wsConn, buf, options)
  47. assert.NoError(t, err)
  48. _, _ = buf.Write([]byte(message))
  49. readBuffer := make([]byte, len(message))
  50. _, _ = buf.Read(readBuffer)
  51. assert.Equal(t, message, string(readBuffer))
  52. }
  53. func TestStartServer(t *testing.T) {
  54. listener, err := net.Listen("tcp", "localhost:")
  55. if err != nil {
  56. t.Fatalf("Error starting listener: %v", err)
  57. }
  58. message := "Good morning Austin! Time for another sunny day in the great state of Texas."
  59. log := zerolog.Nop()
  60. shutdownC := make(chan struct{})
  61. wsConn := NewWSConnection(&log)
  62. ts := newTestWebSocketServer()
  63. defer ts.Close()
  64. options := &StartOptions{
  65. OriginURL: "http://" + ts.Listener.Addr().String(),
  66. Headers: nil,
  67. }
  68. go func() {
  69. err := Serve(wsConn, listener, shutdownC, options)
  70. if err != nil {
  71. t.Errorf("Error running server: %v", err)
  72. return
  73. }
  74. }()
  75. conn, err := net.Dial("tcp", listener.Addr().String())
  76. _, _ = conn.Write([]byte(message))
  77. readBuffer := make([]byte, len(message))
  78. _, _ = conn.Read(readBuffer)
  79. assert.Equal(t, string(readBuffer), message)
  80. }
  81. func TestIsAccessResponse(t *testing.T) {
  82. validLocationHeader := http.Header{}
  83. validLocationHeader.Add("location", "https://test.cloudflareaccess.com/cdn-cgi/access/login/blahblah")
  84. invalidLocationHeader := http.Header{}
  85. invalidLocationHeader.Add("location", "https://google.com")
  86. testCases := []struct {
  87. Description string
  88. In *http.Response
  89. ExpectedOut bool
  90. }{
  91. {"nil response", nil, false},
  92. {"redirect with no location", &http.Response{StatusCode: http.StatusFound}, false},
  93. {"200 ok", &http.Response{StatusCode: http.StatusOK}, false},
  94. {"redirect with location", &http.Response{StatusCode: http.StatusFound, Header: validLocationHeader}, true},
  95. {"redirect with invalid location", &http.Response{StatusCode: http.StatusFound, Header: invalidLocationHeader}, false},
  96. }
  97. for i, tc := range testCases {
  98. if IsAccessResponse(tc.In) != tc.ExpectedOut {
  99. t.Fatalf("Failed case %d -- %s", i, tc.Description)
  100. }
  101. }
  102. }
  103. func newTestWebSocketServer() *httptest.Server {
  104. upgrader := ws.Upgrader{
  105. ReadBufferSize: 1024,
  106. WriteBufferSize: 1024,
  107. }
  108. return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  109. conn, _ := upgrader.Upgrade(w, r, nil)
  110. defer conn.Close()
  111. for {
  112. mt, message, err := conn.ReadMessage()
  113. if err != nil {
  114. break
  115. }
  116. if err := conn.WriteMessage(mt, []byte(message)); err != nil {
  117. break
  118. }
  119. }
  120. }))
  121. }
  122. func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
  123. req, err := http.NewRequest("GET", url, stream)
  124. if err != nil {
  125. t.Fatalf("testRequestHeader error")
  126. }
  127. req.Header.Add("Connection", "Upgrade")
  128. req.Header.Add("Upgrade", "WebSocket")
  129. req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
  130. req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
  131. req.Header.Add("Sec-Websocket-Version", "13")
  132. req.Header.Add("User-Agent", "curl/7.59.0")
  133. return req
  134. }
  135. func TestBastionDestination(t *testing.T) {
  136. tests := []struct {
  137. name string
  138. header http.Header
  139. expectedDest string
  140. wantErr bool
  141. }{
  142. {
  143. name: "hostname destination",
  144. header: http.Header{
  145. cfJumpDestinationHeader: []string{"localhost"},
  146. },
  147. expectedDest: "localhost",
  148. },
  149. {
  150. name: "hostname destination with port",
  151. header: http.Header{
  152. cfJumpDestinationHeader: []string{"localhost:9000"},
  153. },
  154. expectedDest: "localhost:9000",
  155. },
  156. {
  157. name: "hostname destination with scheme and port",
  158. header: http.Header{
  159. cfJumpDestinationHeader: []string{"ssh://localhost:9000"},
  160. },
  161. expectedDest: "localhost:9000",
  162. },
  163. {
  164. name: "full hostname url",
  165. header: http.Header{
  166. cfJumpDestinationHeader: []string{"ssh://localhost:9000/metrics"},
  167. },
  168. expectedDest: "localhost:9000",
  169. },
  170. {
  171. name: "hostname destination with port and path",
  172. header: http.Header{
  173. cfJumpDestinationHeader: []string{"localhost:9000/metrics"},
  174. },
  175. expectedDest: "localhost:9000",
  176. },
  177. {
  178. name: "ip destination",
  179. header: http.Header{
  180. cfJumpDestinationHeader: []string{"127.0.0.1"},
  181. },
  182. expectedDest: "127.0.0.1",
  183. },
  184. {
  185. name: "ip destination with port",
  186. header: http.Header{
  187. cfJumpDestinationHeader: []string{"127.0.0.1:9000"},
  188. },
  189. expectedDest: "127.0.0.1:9000",
  190. },
  191. {
  192. name: "ip destination with port and path",
  193. header: http.Header{
  194. cfJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"},
  195. },
  196. expectedDest: "127.0.0.1:9000",
  197. },
  198. {
  199. name: "ip destination with schem and port",
  200. header: http.Header{
  201. cfJumpDestinationHeader: []string{"tcp://127.0.0.1:9000"},
  202. },
  203. expectedDest: "127.0.0.1:9000",
  204. },
  205. {
  206. name: "full ip url",
  207. header: http.Header{
  208. cfJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"},
  209. },
  210. expectedDest: "127.0.0.1:9000",
  211. },
  212. {
  213. name: "no destination",
  214. wantErr: true,
  215. },
  216. }
  217. for _, test := range tests {
  218. r := &http.Request{
  219. Header: test.header,
  220. }
  221. dest, err := ResolveBastionDest(r)
  222. if test.wantErr {
  223. assert.Error(t, err, "Test %s expects error", test.name)
  224. } else {
  225. assert.NoError(t, err, "Test %s expects no error, got error %v", test.name, err)
  226. assert.Equal(t, test.expectedDest, dest, "Test %s expect dest %s, got %s", test.name, test.expectedDest, dest)
  227. }
  228. }
  229. }