123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374 |
- package validation
- import (
- "bytes"
- "fmt"
- "io/ioutil"
- "testing"
- "time"
- "context"
- "crypto/tls"
- "crypto/x509"
- "net"
- "net/http"
- "net/http/httptest"
- "net/url"
- "strings"
- "github.com/stretchr/testify/assert"
- )
- func TestValidateHostname(t *testing.T) {
- var inputHostname string
- hostname, err := ValidateHostname(inputHostname)
- assert.Equal(t, err, nil)
- assert.Empty(t, hostname)
- inputHostname = "hello.example.com"
- hostname, err = ValidateHostname(inputHostname)
- assert.Nil(t, err)
- assert.Equal(t, "hello.example.com", hostname)
- inputHostname = "http://hello.example.com"
- hostname, err = ValidateHostname(inputHostname)
- assert.Nil(t, err)
- assert.Equal(t, "hello.example.com", hostname)
- inputHostname = "bücher.example.com"
- hostname, err = ValidateHostname(inputHostname)
- assert.Nil(t, err)
- assert.Equal(t, "xn--bcher-kva.example.com", hostname)
- inputHostname = "http://bücher.example.com"
- hostname, err = ValidateHostname(inputHostname)
- assert.Nil(t, err)
- assert.Equal(t, "xn--bcher-kva.example.com", hostname)
- inputHostname = "http%3A%2F%2Fhello.example.com"
- hostname, err = ValidateHostname(inputHostname)
- assert.Nil(t, err)
- assert.Equal(t, "hello.example.com", hostname)
- }
- func TestValidateUrl(t *testing.T) {
- type testCase struct {
- input string
- expectedOutput string
- }
- testCases := []testCase{
- {"http://localhost", "http://localhost"},
- {"http://localhost/", "http://localhost"},
- {"http://localhost/api", "http://localhost"},
- {"http://localhost/api/", "http://localhost"},
- {"https://localhost", "https://localhost"},
- {"https://localhost/", "https://localhost"},
- {"https://localhost/api", "https://localhost"},
- {"https://localhost/api/", "https://localhost"},
- {"https://localhost:8080", "https://localhost:8080"},
- {"https://localhost:8080/", "https://localhost:8080"},
- {"https://localhost:8080/api", "https://localhost:8080"},
- {"https://localhost:8080/api/", "https://localhost:8080"},
- {"localhost", "http://localhost"},
- {"localhost/", "http://localhost/"},
- {"localhost/api", "http://localhost/api"},
- {"localhost/api/", "http://localhost/api/"},
- {"localhost:8080", "http://localhost:8080"},
- {"localhost:8080/", "http://localhost:8080/"},
- {"localhost:8080/api", "http://localhost:8080/api"},
- {"localhost:8080/api/", "http://localhost:8080/api/"},
- {"localhost:8080/api/?asdf", "http://localhost:8080/api/?asdf"},
- {"http://127.0.0.1:8080", "http://127.0.0.1:8080"},
- {"127.0.0.1:8080", "http://127.0.0.1:8080"},
- {"127.0.0.1", "http://127.0.0.1"},
- {"https://127.0.0.1:8080", "https://127.0.0.1:8080"},
- {"[::1]:8080", "http://[::1]:8080"},
- {"http://[::1]", "http://[::1]"},
- {"http://[::1]:8080", "http://[::1]:8080"},
- {"[::1]", "http://[::1]"},
- {"https://example.com", "https://example.com"},
- {"example.com", "http://example.com"},
- {"http://hello.example.com", "http://hello.example.com"},
- {"hello.example.com", "http://hello.example.com"},
- {"hello.example.com:8080", "http://hello.example.com:8080"},
- {"https://hello.example.com:8080", "https://hello.example.com:8080"},
- {"https://bücher.example.com", "https://xn--bcher-kva.example.com"},
- {"bücher.example.com", "http://xn--bcher-kva.example.com"},
- {"https%3A%2F%2Fhello.example.com", "https://hello.example.com"},
- {"https://alex:12345@hello.example.com:8080", "https://hello.example.com:8080"},
- }
- for i, testCase := range testCases {
- validUrl, err := ValidateUrl(testCase.input)
- assert.NoError(t, err, "test case %v", i)
- assert.Equal(t, testCase.expectedOutput, validUrl.String(), "test case %v", i)
- }
- validUrl, err := ValidateUrl("")
- assert.Equal(t, fmt.Errorf("URL should not be empty"), err)
- assert.Empty(t, validUrl)
- validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
- assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error())
- assert.Empty(t, validUrl)
- }
- func TestToggleProtocol(t *testing.T) {
- assert.Equal(t, "https", toggleProtocol("http"))
- assert.Equal(t, "http", toggleProtocol("https"))
- assert.Equal(t, "random", toggleProtocol("random"))
- assert.Equal(t, "", toggleProtocol(""))
- }
- // Happy path 1: originURL is HTTP, and HTTP connections work
- func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
- originURL := "http://127.0.0.1/"
- hostname := "example.com"
- assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
- assert.Equal(t, req.Host, hostname)
- if req.URL.Scheme == "http" {
- return emptyResponse(200), nil
- }
- if req.URL.Scheme == "https" {
- t.Fatal("http works, shouldn't have tried with https")
- }
- panic("Shouldn't reach here")
- })))
- assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
- assert.Equal(t, req.Host, hostname)
- if req.URL.Scheme == "http" {
- return emptyResponse(503), nil
- }
- if req.URL.Scheme == "https" {
- t.Fatal("http works, shouldn't have tried with https")
- }
- panic("Shouldn't reach here")
- })))
- }
- // Happy path 2: originURL is HTTPS, and HTTPS connections work
- func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
- originURL := "https://127.0.0.1:1234/"
- hostname := "example.com"
- assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
- assert.Equal(t, req.Host, hostname)
- if req.URL.Scheme == "http" {
- t.Fatal("https works, shouldn't have tried with http")
- }
- if req.URL.Scheme == "https" {
- return emptyResponse(200), nil
- }
- panic("Shouldn't reach here")
- })))
- assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
- assert.Equal(t, req.Host, hostname)
- if req.URL.Scheme == "http" {
- t.Fatal("https works, shouldn't have tried with http")
- }
- if req.URL.Scheme == "https" {
- return emptyResponse(503), nil
- }
- panic("Shouldn't reach here")
- })))
- }
- // Error path 1: originURL is HTTPS, but HTTP connections work
- func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
- originURL := "https://127.0.0.1:1234/"
- hostname := "example.com"
- assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
- assert.Equal(t, req.Host, hostname)
- if req.URL.Scheme == "http" {
- return emptyResponse(200), nil
- }
- if req.URL.Scheme == "https" {
- return nil, assert.AnError
- }
- panic("Shouldn't reach here")
- })))
- assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
- assert.Equal(t, req.Host, hostname)
- if req.URL.Scheme == "http" {
- return emptyResponse(503), nil
- }
- if req.URL.Scheme == "https" {
- return nil, assert.AnError
- }
- panic("Shouldn't reach here")
- })))
- }
- // Error path 2: originURL is HTTP, but HTTPS connections work
- func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
- originURL := "http://127.0.0.1:1234/"
- hostname := "example.com"
- assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
- assert.Equal(t, req.Host, hostname)
- if req.URL.Scheme == "http" {
- return nil, assert.AnError
- }
- if req.URL.Scheme == "https" {
- return emptyResponse(200), nil
- }
- panic("Shouldn't reach here")
- })))
- assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
- assert.Equal(t, req.Host, hostname)
- if req.URL.Scheme == "http" {
- return nil, assert.AnError
- }
- if req.URL.Scheme == "https" {
- return emptyResponse(503), nil
- }
- panic("Shouldn't reach here")
- })))
- }
- // Ensure the client does not follow 302 responses
- func TestValidateHTTPService_NoFollowRedirects(t *testing.T) {
- hostname := "example.com"
- redirectServer, redirectClient, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path == "/followedRedirect" {
- t.Fatal("shouldn't have followed the 302")
- }
- if r.Method == "CONNECT" {
- assert.Equal(t, "127.0.0.1:443", r.Host)
- } else {
- assert.Equal(t, hostname, r.Host)
- }
- w.Header().Set("Location", "/followedRedirect")
- w.WriteHeader(302)
- }))
- assert.NoError(t, err)
- defer redirectServer.Close()
- assert.NoError(t, ValidateHTTPService(redirectServer.URL, hostname, redirectClient.Transport))
- }
- // Ensure validation times out when origin URL is nonresponsive
- func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) {
- originURL := "http://127.0.0.1/"
- hostname := "example.com"
- oldValidationTimeout := validationTimeout
- defer func() {
- validationTimeout = oldValidationTimeout
- }()
- validationTimeout = 500 * time.Millisecond
- // Use createMockServerAndClient, not createSecureMockServerAndClient.
- // The latter will bail with HTTP 400 immediately on an http:// request,
- // which defeats the purpose of a 'nonresponsive origin' test.
- server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.Method == "CONNECT" {
- assert.Equal(t, "127.0.0.1:443", r.Host)
- } else {
- assert.Equal(t, hostname, r.Host)
- }
- time.Sleep(1 * time.Second)
- w.WriteHeader(200)
- }))
- if !assert.NoError(t, err) {
- t.FailNow()
- }
- defer server.Close()
- err = ValidateHTTPService(originURL, hostname, client.Transport)
- fmt.Println(err)
- if err, ok := err.(net.Error); assert.True(t, ok) {
- assert.True(t, err.Timeout())
- }
- }
- func TestNewAccessValidatorOk(t *testing.T) {
- ctx := context.Background()
- url := "test.cloudflareaccess.com"
- access, err := NewAccessValidator(ctx, url, url, "")
- assert.NoError(t, err)
- assert.NotNil(t, access)
- assert.Error(t, access.Validate(ctx, ""))
- assert.Error(t, access.Validate(ctx, "invalid"))
- req := httptest.NewRequest("GET", "https://test.cloudflareaccess.com", nil)
- req.Header.Set(accessJwtHeader, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c")
- assert.Error(t, access.ValidateRequest(ctx, req))
- }
- func TestNewAccessValidatorErr(t *testing.T) {
- ctx := context.Background()
- urls := []string{
- "",
- "ftp://test.cloudflareaccess.com",
- "wss://cloudflarenone.com",
- }
- for _, url := range urls {
- access, err := NewAccessValidator(ctx, url, url, "")
- assert.Error(t, err, url)
- assert.Nil(t, access)
- }
- }
- type testRoundTripper func(req *http.Request) (*http.Response, error)
- func (f testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
- return f(req)
- }
- func emptyResponse(statusCode int) *http.Response {
- return &http.Response{
- StatusCode: statusCode,
- Body: ioutil.NopCloser(bytes.NewReader(nil)),
- Header: make(http.Header),
- }
- }
- func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
- client := http.DefaultClient
- server := httptest.NewServer(handler)
- client.Transport = &http.Transport{
- Proxy: func(req *http.Request) (*url.URL, error) {
- return url.Parse(server.URL)
- },
- }
- return server, client, nil
- }
- func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
- client := http.DefaultClient
- server := httptest.NewTLSServer(handler)
- cert, err := x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0])
- if err != nil {
- server.Close()
- return nil, nil, err
- }
- certpool := x509.NewCertPool()
- certpool.AddCert(cert)
- client.Transport = &http.Transport{
- DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- return net.Dial("tcp", server.URL[strings.LastIndex(server.URL, "/")+1:])
- },
- TLSClientConfig: &tls.Config{
- RootCAs: certpool,
- },
- }
- return server, client, nil
- }
|