123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- package ingress
- import (
- "context"
- "fmt"
- "io"
- "net"
- "net/http"
- "net/http/httptest"
- "net/url"
- "testing"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "github.com/cloudflare/cloudflared/carrier"
- "github.com/cloudflare/cloudflared/websocket"
- )
- func TestRawTCPServiceEstablishConnection(t *testing.T) {
- originListener, err := net.Listen("tcp", "127.0.0.1:0")
- require.NoError(t, err)
- listenerClosed := make(chan struct{})
- tcpListenRoutine(originListener, listenerClosed)
- rawTCPService := &rawTCPService{name: ServiceWarpRouting}
- req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
- require.NoError(t, err)
- originListener.Close()
- <-listenerClosed
- req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
- require.NoError(t, err)
- // Origin not listening for new connection, should return an error
- _, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger)
- require.Error(t, err)
- }
- func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
- originListener, err := net.Listen("tcp", "127.0.0.1:0")
- require.NoError(t, err)
- listenerClosed := make(chan struct{})
- tcpListenRoutine(originListener, listenerClosed)
- originURL := &url.URL{
- Scheme: "tcp",
- Host: originListener.Addr().String(),
- }
- baseReq, err := http.NewRequest(http.MethodGet, "https://place-holder", nil)
- require.NoError(t, err)
- baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
- bastionReq := baseReq.Clone(context.Background())
- carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
- tests := []struct {
- testCase string
- service *tcpOverWSService
- req *http.Request
- expectErr bool
- }{
- {
- testCase: "specific TCP service",
- service: newTCPOverWSService(originURL),
- req: baseReq,
- },
- {
- testCase: "bastion service",
- service: newBastionService(),
- req: bastionReq,
- },
- {
- testCase: "invalid bastion request",
- service: newBastionService(),
- req: baseReq,
- expectErr: true,
- },
- }
- for _, test := range tests {
- t.Run(test.testCase, func(t *testing.T) {
- if test.expectErr {
- bastionHost, _ := carrier.ResolveBastionDest(test.req)
- _, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
- assert.Error(t, err)
- }
- })
- }
- originListener.Close()
- <-listenerClosed
- for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
- // Origin not listening for new connection, should return an error
- bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
- _, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
- assert.Error(t, err)
- }
- }
- func TestHTTPServiceHostHeaderOverride(t *testing.T) {
- cfg := OriginRequestConfig{
- HTTPHostHeader: t.Name(),
- }
- handler := func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, r.Host, t.Name())
- if websocket.IsWebSocketUpgrade(r) {
- respHeaders := websocket.NewResponseHeader(r)
- for k, v := range respHeaders {
- w.Header().Set(k, v[0])
- }
- w.WriteHeader(http.StatusSwitchingProtocols)
- return
- }
- // return the X-Forwarded-Host header for assertions
- // as the httptest Server URL isn't available here yet
- w.Write([]byte(r.Header.Get("X-Forwarded-Host")))
- }
- origin := httptest.NewServer(http.HandlerFunc(handler))
- defer origin.Close()
- originURL, err := url.Parse(origin.URL)
- require.NoError(t, err)
- httpService := &httpService{
- url: originURL,
- }
- shutdownC := make(chan struct{})
- require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
- req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
- require.NoError(t, err)
- resp, err := httpService.RoundTrip(req)
- require.NoError(t, err)
- require.Equal(t, http.StatusOK, resp.StatusCode)
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
- require.Equal(t, respBody, []byte(originURL.Host))
- }
- // TestHTTPServiceUsesIngressRuleScheme makes sure httpService uses scheme defined in ingress rule and not by eyeball request
- func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
- handler := func(w http.ResponseWriter, r *http.Request) {
- require.NotNil(t, r.TLS)
- // Echo the X-Forwarded-Proto header for assertions
- w.Write([]byte(r.Header.Get("X-Forwarded-Proto")))
- }
- origin := httptest.NewTLSServer(http.HandlerFunc(handler))
- defer origin.Close()
- originURL, err := url.Parse(origin.URL)
- require.NoError(t, err)
- require.Equal(t, "https", originURL.Scheme)
- cfg := OriginRequestConfig{
- NoTLSVerify: true,
- }
- httpService := &httpService{
- url: originURL,
- }
- shutdownC := make(chan struct{})
- require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
- // Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
- protos := []string{"https", "http", "dne"}
- for _, p := range protos {
- req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
- require.NoError(t, err)
- req.Header.Add("X-Forwarded-Proto", p)
- resp, err := httpService.RoundTrip(req)
- require.NoError(t, err)
- require.Equal(t, http.StatusOK, resp.StatusCode)
- respBody, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
- require.Equal(t, respBody, []byte(p))
- }
- }
- func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
- go func() {
- for {
- conn, err := listener.Accept()
- if err != nil {
- close(closeChan)
- return
- }
- // Close immediately, this test is not about testing read/write on connection
- conn.Close()
- }
- }()
- }
|