origin_proxy_test.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package ingress
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net"
  7. "net/http"
  8. "net/http/httptest"
  9. "net/url"
  10. "testing"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. "github.com/cloudflare/cloudflared/carrier"
  14. "github.com/cloudflare/cloudflared/websocket"
  15. )
  16. func TestRawTCPServiceEstablishConnection(t *testing.T) {
  17. originListener, err := net.Listen("tcp", "127.0.0.1:0")
  18. require.NoError(t, err)
  19. listenerClosed := make(chan struct{})
  20. tcpListenRoutine(originListener, listenerClosed)
  21. rawTCPService := &rawTCPService{name: ServiceWarpRouting}
  22. req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
  23. require.NoError(t, err)
  24. originListener.Close()
  25. <-listenerClosed
  26. req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
  27. require.NoError(t, err)
  28. // Origin not listening for new connection, should return an error
  29. _, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger)
  30. require.Error(t, err)
  31. }
  32. func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
  33. originListener, err := net.Listen("tcp", "127.0.0.1:0")
  34. require.NoError(t, err)
  35. listenerClosed := make(chan struct{})
  36. tcpListenRoutine(originListener, listenerClosed)
  37. originURL := &url.URL{
  38. Scheme: "tcp",
  39. Host: originListener.Addr().String(),
  40. }
  41. baseReq, err := http.NewRequest(http.MethodGet, "https://place-holder", nil)
  42. require.NoError(t, err)
  43. baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
  44. bastionReq := baseReq.Clone(context.Background())
  45. carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
  46. tests := []struct {
  47. testCase string
  48. service *tcpOverWSService
  49. req *http.Request
  50. expectErr bool
  51. }{
  52. {
  53. testCase: "specific TCP service",
  54. service: newTCPOverWSService(originURL),
  55. req: baseReq,
  56. },
  57. {
  58. testCase: "bastion service",
  59. service: newBastionService(),
  60. req: bastionReq,
  61. },
  62. {
  63. testCase: "invalid bastion request",
  64. service: newBastionService(),
  65. req: baseReq,
  66. expectErr: true,
  67. },
  68. }
  69. for _, test := range tests {
  70. t.Run(test.testCase, func(t *testing.T) {
  71. if test.expectErr {
  72. bastionHost, _ := carrier.ResolveBastionDest(test.req)
  73. _, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
  74. assert.Error(t, err)
  75. }
  76. })
  77. }
  78. originListener.Close()
  79. <-listenerClosed
  80. for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
  81. // Origin not listening for new connection, should return an error
  82. bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
  83. _, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
  84. assert.Error(t, err)
  85. }
  86. }
  87. func TestHTTPServiceHostHeaderOverride(t *testing.T) {
  88. cfg := OriginRequestConfig{
  89. HTTPHostHeader: t.Name(),
  90. }
  91. handler := func(w http.ResponseWriter, r *http.Request) {
  92. require.Equal(t, r.Host, t.Name())
  93. if websocket.IsWebSocketUpgrade(r) {
  94. respHeaders := websocket.NewResponseHeader(r)
  95. for k, v := range respHeaders {
  96. w.Header().Set(k, v[0])
  97. }
  98. w.WriteHeader(http.StatusSwitchingProtocols)
  99. return
  100. }
  101. // return the X-Forwarded-Host header for assertions
  102. // as the httptest Server URL isn't available here yet
  103. w.Write([]byte(r.Header.Get("X-Forwarded-Host")))
  104. }
  105. origin := httptest.NewServer(http.HandlerFunc(handler))
  106. defer origin.Close()
  107. originURL, err := url.Parse(origin.URL)
  108. require.NoError(t, err)
  109. httpService := &httpService{
  110. url: originURL,
  111. }
  112. shutdownC := make(chan struct{})
  113. require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
  114. req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
  115. require.NoError(t, err)
  116. resp, err := httpService.RoundTrip(req)
  117. require.NoError(t, err)
  118. require.Equal(t, http.StatusOK, resp.StatusCode)
  119. respBody, err := io.ReadAll(resp.Body)
  120. require.NoError(t, err)
  121. require.Equal(t, respBody, []byte(originURL.Host))
  122. }
  123. // TestHTTPServiceUsesIngressRuleScheme makes sure httpService uses scheme defined in ingress rule and not by eyeball request
  124. func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
  125. handler := func(w http.ResponseWriter, r *http.Request) {
  126. require.NotNil(t, r.TLS)
  127. // Echo the X-Forwarded-Proto header for assertions
  128. w.Write([]byte(r.Header.Get("X-Forwarded-Proto")))
  129. }
  130. origin := httptest.NewTLSServer(http.HandlerFunc(handler))
  131. defer origin.Close()
  132. originURL, err := url.Parse(origin.URL)
  133. require.NoError(t, err)
  134. require.Equal(t, "https", originURL.Scheme)
  135. cfg := OriginRequestConfig{
  136. NoTLSVerify: true,
  137. }
  138. httpService := &httpService{
  139. url: originURL,
  140. }
  141. shutdownC := make(chan struct{})
  142. require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
  143. // Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
  144. protos := []string{"https", "http", "dne"}
  145. for _, p := range protos {
  146. req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
  147. require.NoError(t, err)
  148. req.Header.Add("X-Forwarded-Proto", p)
  149. resp, err := httpService.RoundTrip(req)
  150. require.NoError(t, err)
  151. require.Equal(t, http.StatusOK, resp.StatusCode)
  152. respBody, err := io.ReadAll(resp.Body)
  153. require.NoError(t, err)
  154. require.Equal(t, respBody, []byte(p))
  155. }
  156. }
  157. func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
  158. go func() {
  159. for {
  160. conn, err := listener.Accept()
  161. if err != nil {
  162. close(closeChan)
  163. return
  164. }
  165. // Close immediately, this test is not about testing read/write on connection
  166. conn.Close()
  167. }
  168. }()
  169. }