connection_test.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. package connection
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "math/rand"
  7. "net/http"
  8. "time"
  9. "github.com/rs/zerolog"
  10. "github.com/cloudflare/cloudflared/stream"
  11. "github.com/cloudflare/cloudflared/tracing"
  12. tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  13. "github.com/cloudflare/cloudflared/websocket"
  14. )
  15. const (
  16. largeFileSize = 2 * 1024 * 1024
  17. testGracePeriod = time.Millisecond * 100
  18. )
  19. var (
  20. testOrchestrator = &mockOrchestrator{
  21. originProxy: &mockOriginProxy{},
  22. }
  23. log = zerolog.Nop()
  24. testLargeResp = make([]byte, largeFileSize)
  25. )
  26. var _ ReadWriteAcker = (*HTTPResponseReadWriteAcker)(nil)
  27. type testRequest struct {
  28. name string
  29. endpoint string
  30. expectedStatus int
  31. expectedBody []byte
  32. isProxyError bool
  33. }
  34. type mockOrchestrator struct {
  35. originProxy OriginProxy
  36. }
  37. func (mcr *mockOrchestrator) GetConfigJSON() ([]byte, error) {
  38. return nil, fmt.Errorf("not implemented")
  39. }
  40. func (*mockOrchestrator) UpdateConfig(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
  41. return &tunnelpogs.UpdateConfigurationResponse{
  42. LastAppliedVersion: version,
  43. }
  44. }
  45. func (mcr *mockOrchestrator) GetOriginProxy() (OriginProxy, error) {
  46. return mcr.originProxy, nil
  47. }
  48. func (mcr *mockOrchestrator) WarpRoutingEnabled() (enabled bool) {
  49. return true
  50. }
  51. type mockOriginProxy struct{}
  52. func (moc *mockOriginProxy) ProxyHTTP(
  53. w ResponseWriter,
  54. tr *tracing.TracedHTTPRequest,
  55. isWebsocket bool,
  56. ) error {
  57. req := tr.Request
  58. if isWebsocket {
  59. switch req.URL.Path {
  60. case "/ws/echo":
  61. return wsEchoEndpoint(w, req)
  62. case "/ws/flaky":
  63. return wsFlakyEndpoint(w, req)
  64. default:
  65. originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
  66. return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
  67. }
  68. }
  69. switch req.URL.Path {
  70. case "/ok":
  71. originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
  72. case "/large_file":
  73. originRespEndpoint(w, http.StatusOK, testLargeResp)
  74. case "/400":
  75. originRespEndpoint(w, http.StatusBadRequest, []byte(http.StatusText(http.StatusBadRequest)))
  76. case "/500":
  77. originRespEndpoint(w, http.StatusInternalServerError, []byte(http.StatusText(http.StatusInternalServerError)))
  78. case "/error":
  79. return fmt.Errorf("Failed to proxy to origin")
  80. default:
  81. originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
  82. }
  83. return nil
  84. }
  85. func (moc *mockOriginProxy) ProxyTCP(
  86. ctx context.Context,
  87. rwa ReadWriteAcker,
  88. r *TCPRequest,
  89. ) error {
  90. return nil
  91. }
  92. type echoPipe struct {
  93. reader *io.PipeReader
  94. writer *io.PipeWriter
  95. }
  96. func (ep *echoPipe) Read(p []byte) (int, error) {
  97. return ep.reader.Read(p)
  98. }
  99. func (ep *echoPipe) Write(p []byte) (int, error) {
  100. return ep.writer.Write(p)
  101. }
  102. // A mock origin that echos data by streaming like a tcpOverWSConnection
  103. // https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go
  104. func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
  105. resp := &http.Response{
  106. StatusCode: http.StatusSwitchingProtocols,
  107. }
  108. if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
  109. return err
  110. }
  111. wsCtx, cancel := context.WithCancel(r.Context())
  112. readPipe, writePipe := io.Pipe()
  113. wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
  114. go func() {
  115. select {
  116. case <-wsCtx.Done():
  117. case <-r.Context().Done():
  118. }
  119. readPipe.Close()
  120. writePipe.Close()
  121. }()
  122. originConn := &echoPipe{reader: readPipe, writer: writePipe}
  123. stream.Pipe(wsConn, originConn, &log)
  124. cancel()
  125. wsConn.Close()
  126. return nil
  127. }
  128. type flakyConn struct {
  129. closeAt time.Time
  130. }
  131. func (fc *flakyConn) Read(p []byte) (int, error) {
  132. if time.Now().After(fc.closeAt) {
  133. return 0, io.EOF
  134. }
  135. n := copy(p, "Read from flaky connection")
  136. return n, nil
  137. }
  138. func (fc *flakyConn) Write(p []byte) (int, error) {
  139. if time.Now().After(fc.closeAt) {
  140. return 0, fmt.Errorf("flaky connection closed")
  141. }
  142. return len(p), nil
  143. }
  144. func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
  145. resp := &http.Response{
  146. StatusCode: http.StatusSwitchingProtocols,
  147. }
  148. if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
  149. return err
  150. }
  151. wsCtx, cancel := context.WithCancel(r.Context())
  152. wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
  153. closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
  154. originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
  155. stream.Pipe(wsConn, originConn, &log)
  156. cancel()
  157. wsConn.Close()
  158. return nil
  159. }
  160. func originRespEndpoint(w ResponseWriter, status int, data []byte) {
  161. resp := &http.Response{
  162. StatusCode: status,
  163. }
  164. _ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
  165. _, _ = w.Write(data)
  166. }
  167. type mockConnectedFuse struct{}
  168. func (mcf mockConnectedFuse) Connected() {}
  169. func (mcf mockConnectedFuse) IsConnected() bool {
  170. return true
  171. }