123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- package validation
- import (
- "bytes"
- "fmt"
- "io"
- "testing"
- "context"
- "crypto/tls"
- "crypto/x509"
- "net"
- "net/http"
- "net/http/httptest"
- "net/url"
- "strings"
- "github.com/stretchr/testify/assert"
- )
- func TestValidateHostname(t *testing.T) {
- var inputHostname string
- hostname, err := ValidateHostname(inputHostname)
- assert.Equal(t, err, nil)
- assert.Empty(t, hostname)
- inputHostname = "hello.example.com"
- hostname, err = ValidateHostname(inputHostname)
- assert.Nil(t, err)
- assert.Equal(t, "hello.example.com", hostname)
- inputHostname = "http://hello.example.com"
- hostname, err = ValidateHostname(inputHostname)
- assert.Nil(t, err)
- assert.Equal(t, "hello.example.com", hostname)
- inputHostname = "bücher.example.com"
- hostname, err = ValidateHostname(inputHostname)
- assert.Nil(t, err)
- assert.Equal(t, "xn--bcher-kva.example.com", hostname)
- inputHostname = "http://bücher.example.com"
- hostname, err = ValidateHostname(inputHostname)
- assert.Nil(t, err)
- assert.Equal(t, "xn--bcher-kva.example.com", hostname)
- inputHostname = "http%3A%2F%2Fhello.example.com"
- hostname, err = ValidateHostname(inputHostname)
- assert.Nil(t, err)
- assert.Equal(t, "hello.example.com", hostname)
- }
- func TestValidateUrl(t *testing.T) {
- type testCase struct {
- input string
- expectedOutput string
- }
- testCases := []testCase{
- {"http://localhost", "http://localhost"},
- {"http://localhost/", "http://localhost"},
- {"http://localhost/api", "http://localhost"},
- {"http://localhost/api/", "http://localhost"},
- {"https://localhost", "https://localhost"},
- {"https://localhost/", "https://localhost"},
- {"https://localhost/api", "https://localhost"},
- {"https://localhost/api/", "https://localhost"},
- {"https://localhost:8080", "https://localhost:8080"},
- {"https://localhost:8080/", "https://localhost:8080"},
- {"https://localhost:8080/api", "https://localhost:8080"},
- {"https://localhost:8080/api/", "https://localhost:8080"},
- {"localhost", "http://localhost"},
- {"localhost/", "http://localhost/"},
- {"localhost/api", "http://localhost/api"},
- {"localhost/api/", "http://localhost/api/"},
- {"localhost:8080", "http://localhost:8080"},
- {"localhost:8080/", "http://localhost:8080/"},
- {"localhost:8080/api", "http://localhost:8080/api"},
- {"localhost:8080/api/", "http://localhost:8080/api/"},
- {"localhost:8080/api/?asdf", "http://localhost:8080/api/?asdf"},
- {"http://127.0.0.1:8080", "http://127.0.0.1:8080"},
- {"127.0.0.1:8080", "http://127.0.0.1:8080"},
- {"127.0.0.1", "http://127.0.0.1"},
- {"https://127.0.0.1:8080", "https://127.0.0.1:8080"},
- {"[::1]:8080", "http://[::1]:8080"},
- {"http://[::1]", "http://[::1]"},
- {"http://[::1]:8080", "http://[::1]:8080"},
- {"[::1]", "http://[::1]"},
- {"https://example.com", "https://example.com"},
- {"example.com", "http://example.com"},
- {"http://hello.example.com", "http://hello.example.com"},
- {"hello.example.com", "http://hello.example.com"},
- {"hello.example.com:8080", "http://hello.example.com:8080"},
- {"https://hello.example.com:8080", "https://hello.example.com:8080"},
- {"https://bücher.example.com", "https://xn--bcher-kva.example.com"},
- {"bücher.example.com", "http://xn--bcher-kva.example.com"},
- {"https%3A%2F%2Fhello.example.com", "https://hello.example.com"},
- {"https://alex:12345@hello.example.com:8080", "https://hello.example.com:8080"},
- }
- for i, testCase := range testCases {
- validUrl, err := ValidateUrl(testCase.input)
- assert.NoError(t, err, "test case %v", i)
- assert.Equal(t, testCase.expectedOutput, validUrl.String(), "test case %v", i)
- }
- validUrl, err := ValidateUrl("")
- assert.Equal(t, fmt.Errorf("URL should not be empty"), err)
- assert.Empty(t, validUrl)
- validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
- assert.Equal(t, "Currently Cloudflare Tunnel does not support ftp protocol.", err.Error())
- assert.Empty(t, validUrl)
- }
- func TestNewAccessValidatorOk(t *testing.T) {
- ctx := context.Background()
- url := "test.cloudflareaccess.com"
- access, err := NewAccessValidator(ctx, url, url, "")
- assert.NoError(t, err)
- assert.NotNil(t, access)
- assert.Error(t, access.Validate(ctx, ""))
- assert.Error(t, access.Validate(ctx, "invalid"))
- req := httptest.NewRequest("GET", "https://test.cloudflareaccess.com", nil)
- req.Header.Set(accessJwtHeader, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c")
- assert.Error(t, access.ValidateRequest(ctx, req))
- }
- func TestNewAccessValidatorErr(t *testing.T) {
- ctx := context.Background()
- urls := []string{
- "",
- "ftp://test.cloudflareaccess.com",
- "wss://cloudflarenone.com",
- }
- for _, url := range urls {
- access, err := NewAccessValidator(ctx, url, url, "")
- assert.Error(t, err, url)
- assert.Nil(t, access)
- }
- }
- type testRoundTripper func(req *http.Request) (*http.Response, error)
- func (f testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
- return f(req)
- }
- func emptyResponse(statusCode int) *http.Response {
- return &http.Response{
- StatusCode: statusCode,
- Body: io.NopCloser(bytes.NewReader(nil)),
- Header: make(http.Header),
- }
- }
- func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
- client := http.DefaultClient
- server := httptest.NewServer(handler)
- client.Transport = &http.Transport{
- Proxy: func(req *http.Request) (*url.URL, error) {
- return url.Parse(server.URL)
- },
- }
- return server, client, nil
- }
- func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
- client := http.DefaultClient
- server := httptest.NewTLSServer(handler)
- cert, err := x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0])
- if err != nil {
- server.Close()
- return nil, nil, err
- }
- certpool := x509.NewCertPool()
- certpool.AddCert(cert)
- client.Transport = &http.Transport{
- DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- return net.Dial("tcp", server.URL[strings.LastIndex(server.URL, "/")+1:])
- },
- TLSClientConfig: &tls.Config{
- RootCAs: certpool,
- },
- }
- return server, client, nil
- }
- func FuzzNewAccessValidator(f *testing.F) {
- f.Fuzz(func(t *testing.T, domain string, issuer string, applicationAUD string) {
- ctx := context.Background()
- _, _ = NewAccessValidator(ctx, domain, issuer, applicationAUD)
- })
- }
|