connection_test.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. package connection
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "fmt"
  6. "io"
  7. "math/big"
  8. "net/http"
  9. "time"
  10. pkgerrors "github.com/pkg/errors"
  11. "github.com/rs/zerolog"
  12. cfdflow "github.com/cloudflare/cloudflared/flow"
  13. "github.com/cloudflare/cloudflared/stream"
  14. "github.com/cloudflare/cloudflared/tracing"
  15. tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  16. "github.com/cloudflare/cloudflared/websocket"
  17. )
  18. const (
  19. largeFileSize = 2 * 1024 * 1024
  20. testGracePeriod = time.Millisecond * 100
  21. )
  22. var (
  23. testOrchestrator = &mockOrchestrator{
  24. originProxy: &mockOriginProxy{},
  25. }
  26. log = zerolog.Nop()
  27. testLargeResp = make([]byte, largeFileSize)
  28. )
  29. var _ ReadWriteAcker = (*HTTPResponseReadWriteAcker)(nil)
  30. type testRequest struct {
  31. name string
  32. endpoint string
  33. expectedStatus int
  34. expectedBody []byte
  35. isProxyError bool
  36. }
  37. type mockOrchestrator struct {
  38. originProxy OriginProxy
  39. }
  40. func (mcr *mockOrchestrator) GetConfigJSON() ([]byte, error) {
  41. return nil, fmt.Errorf("not implemented")
  42. }
  43. func (*mockOrchestrator) UpdateConfig(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
  44. return &tunnelpogs.UpdateConfigurationResponse{
  45. LastAppliedVersion: version,
  46. }
  47. }
  48. func (mcr *mockOrchestrator) GetOriginProxy() (OriginProxy, error) {
  49. return mcr.originProxy, nil
  50. }
  51. func (mcr *mockOrchestrator) WarpRoutingEnabled() (enabled bool) {
  52. return true
  53. }
  54. type mockOriginProxy struct{}
  55. func (moc *mockOriginProxy) ProxyHTTP(
  56. w ResponseWriter,
  57. tr *tracing.TracedHTTPRequest,
  58. isWebsocket bool,
  59. ) error {
  60. req := tr.Request
  61. if isWebsocket {
  62. switch req.URL.Path {
  63. case "/ws/echo":
  64. return wsEchoEndpoint(w, req)
  65. case "/ws/flaky":
  66. return wsFlakyEndpoint(w, req)
  67. default:
  68. originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
  69. return fmt.Errorf("unknown websocket endpoint %s", req.URL.Path)
  70. }
  71. }
  72. switch req.URL.Path {
  73. case "/ok":
  74. originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
  75. case "/large_file":
  76. originRespEndpoint(w, http.StatusOK, testLargeResp)
  77. case "/400":
  78. originRespEndpoint(w, http.StatusBadRequest, []byte(http.StatusText(http.StatusBadRequest)))
  79. case "/500":
  80. originRespEndpoint(w, http.StatusInternalServerError, []byte(http.StatusText(http.StatusInternalServerError)))
  81. case "/error":
  82. return fmt.Errorf("Failed to proxy to origin")
  83. default:
  84. originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
  85. }
  86. return nil
  87. }
  88. func (moc *mockOriginProxy) ProxyTCP(
  89. ctx context.Context,
  90. rwa ReadWriteAcker,
  91. r *TCPRequest,
  92. ) error {
  93. if r.CfTraceID == "flow-rate-limited" {
  94. return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "tcp flow rate limited")
  95. }
  96. return nil
  97. }
  98. type echoPipe struct {
  99. reader *io.PipeReader
  100. writer *io.PipeWriter
  101. }
  102. func (ep *echoPipe) Read(p []byte) (int, error) {
  103. return ep.reader.Read(p)
  104. }
  105. func (ep *echoPipe) Write(p []byte) (int, error) {
  106. return ep.writer.Write(p)
  107. }
  108. // A mock origin that echos data by streaming like a tcpOverWSConnection
  109. // https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go
  110. func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
  111. resp := &http.Response{
  112. StatusCode: http.StatusSwitchingProtocols,
  113. }
  114. if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
  115. return err
  116. }
  117. wsCtx, cancel := context.WithCancel(r.Context())
  118. readPipe, writePipe := io.Pipe()
  119. wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
  120. go func() {
  121. select {
  122. case <-wsCtx.Done():
  123. case <-r.Context().Done():
  124. }
  125. readPipe.Close()
  126. writePipe.Close()
  127. }()
  128. originConn := &echoPipe{reader: readPipe, writer: writePipe}
  129. stream.Pipe(wsConn, originConn, &log)
  130. cancel()
  131. wsConn.Close()
  132. return nil
  133. }
  134. type flakyConn struct {
  135. closeAt time.Time
  136. }
  137. func (fc *flakyConn) Read(p []byte) (int, error) {
  138. if time.Now().After(fc.closeAt) {
  139. return 0, io.EOF
  140. }
  141. n := copy(p, "Read from flaky connection")
  142. return n, nil
  143. }
  144. func (fc *flakyConn) Write(p []byte) (int, error) {
  145. if time.Now().After(fc.closeAt) {
  146. return 0, fmt.Errorf("flaky connection closed")
  147. }
  148. return len(p), nil
  149. }
  150. func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
  151. resp := &http.Response{
  152. StatusCode: http.StatusSwitchingProtocols,
  153. }
  154. if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
  155. return err
  156. }
  157. wsCtx, cancel := context.WithCancel(r.Context())
  158. wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
  159. rInt, _ := rand.Int(rand.Reader, big.NewInt(50))
  160. closedAfter := time.Millisecond * time.Duration(rInt.Int64())
  161. originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
  162. stream.Pipe(wsConn, originConn, &log)
  163. cancel()
  164. wsConn.Close()
  165. return nil
  166. }
  167. func originRespEndpoint(w ResponseWriter, status int, data []byte) {
  168. resp := &http.Response{
  169. StatusCode: status,
  170. }
  171. _ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
  172. _, _ = w.Write(data)
  173. }
  174. type mockConnectedFuse struct{}
  175. func (mcf mockConnectedFuse) Connected() {}
  176. func (mcf mockConnectedFuse) IsConnected() bool {
  177. return true
  178. }