validation_test.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package validation
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "testing"
  7. "context"
  8. "crypto/tls"
  9. "crypto/x509"
  10. "net"
  11. "net/http"
  12. "net/http/httptest"
  13. "net/url"
  14. "strings"
  15. "github.com/stretchr/testify/assert"
  16. )
  17. func TestValidateHostname(t *testing.T) {
  18. var inputHostname string
  19. hostname, err := ValidateHostname(inputHostname)
  20. assert.Equal(t, err, nil)
  21. assert.Empty(t, hostname)
  22. inputHostname = "hello.example.com"
  23. hostname, err = ValidateHostname(inputHostname)
  24. assert.Nil(t, err)
  25. assert.Equal(t, "hello.example.com", hostname)
  26. inputHostname = "http://hello.example.com"
  27. hostname, err = ValidateHostname(inputHostname)
  28. assert.Nil(t, err)
  29. assert.Equal(t, "hello.example.com", hostname)
  30. inputHostname = "bücher.example.com"
  31. hostname, err = ValidateHostname(inputHostname)
  32. assert.Nil(t, err)
  33. assert.Equal(t, "xn--bcher-kva.example.com", hostname)
  34. inputHostname = "http://bücher.example.com"
  35. hostname, err = ValidateHostname(inputHostname)
  36. assert.Nil(t, err)
  37. assert.Equal(t, "xn--bcher-kva.example.com", hostname)
  38. inputHostname = "http%3A%2F%2Fhello.example.com"
  39. hostname, err = ValidateHostname(inputHostname)
  40. assert.Nil(t, err)
  41. assert.Equal(t, "hello.example.com", hostname)
  42. }
  43. func TestValidateUrl(t *testing.T) {
  44. type testCase struct {
  45. input string
  46. expectedOutput string
  47. }
  48. testCases := []testCase{
  49. {"http://localhost", "http://localhost"},
  50. {"http://localhost/", "http://localhost"},
  51. {"http://localhost/api", "http://localhost"},
  52. {"http://localhost/api/", "http://localhost"},
  53. {"https://localhost", "https://localhost"},
  54. {"https://localhost/", "https://localhost"},
  55. {"https://localhost/api", "https://localhost"},
  56. {"https://localhost/api/", "https://localhost"},
  57. {"https://localhost:8080", "https://localhost:8080"},
  58. {"https://localhost:8080/", "https://localhost:8080"},
  59. {"https://localhost:8080/api", "https://localhost:8080"},
  60. {"https://localhost:8080/api/", "https://localhost:8080"},
  61. {"localhost", "http://localhost"},
  62. {"localhost/", "http://localhost/"},
  63. {"localhost/api", "http://localhost/api"},
  64. {"localhost/api/", "http://localhost/api/"},
  65. {"localhost:8080", "http://localhost:8080"},
  66. {"localhost:8080/", "http://localhost:8080/"},
  67. {"localhost:8080/api", "http://localhost:8080/api"},
  68. {"localhost:8080/api/", "http://localhost:8080/api/"},
  69. {"localhost:8080/api/?asdf", "http://localhost:8080/api/?asdf"},
  70. {"http://127.0.0.1:8080", "http://127.0.0.1:8080"},
  71. {"127.0.0.1:8080", "http://127.0.0.1:8080"},
  72. {"127.0.0.1", "http://127.0.0.1"},
  73. {"https://127.0.0.1:8080", "https://127.0.0.1:8080"},
  74. {"[::1]:8080", "http://[::1]:8080"},
  75. {"http://[::1]", "http://[::1]"},
  76. {"http://[::1]:8080", "http://[::1]:8080"},
  77. {"[::1]", "http://[::1]"},
  78. {"https://example.com", "https://example.com"},
  79. {"example.com", "http://example.com"},
  80. {"http://hello.example.com", "http://hello.example.com"},
  81. {"hello.example.com", "http://hello.example.com"},
  82. {"hello.example.com:8080", "http://hello.example.com:8080"},
  83. {"https://hello.example.com:8080", "https://hello.example.com:8080"},
  84. {"https://bücher.example.com", "https://xn--bcher-kva.example.com"},
  85. {"bücher.example.com", "http://xn--bcher-kva.example.com"},
  86. {"https%3A%2F%2Fhello.example.com", "https://hello.example.com"},
  87. {"https://alex:12345@hello.example.com:8080", "https://hello.example.com:8080"},
  88. }
  89. for i, testCase := range testCases {
  90. validUrl, err := ValidateUrl(testCase.input)
  91. assert.NoError(t, err, "test case %v", i)
  92. assert.Equal(t, testCase.expectedOutput, validUrl.String(), "test case %v", i)
  93. }
  94. validUrl, err := ValidateUrl("")
  95. assert.Equal(t, fmt.Errorf("URL should not be empty"), err)
  96. assert.Empty(t, validUrl)
  97. validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
  98. assert.Equal(t, "Currently Cloudflare Tunnel does not support ftp protocol.", err.Error())
  99. assert.Empty(t, validUrl)
  100. }
  101. func TestNewAccessValidatorOk(t *testing.T) {
  102. ctx := context.Background()
  103. url := "test.cloudflareaccess.com"
  104. access, err := NewAccessValidator(ctx, url, url, "")
  105. assert.NoError(t, err)
  106. assert.NotNil(t, access)
  107. assert.Error(t, access.Validate(ctx, ""))
  108. assert.Error(t, access.Validate(ctx, "invalid"))
  109. req := httptest.NewRequest("GET", "https://test.cloudflareaccess.com", nil)
  110. req.Header.Set(accessJwtHeader, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c")
  111. assert.Error(t, access.ValidateRequest(ctx, req))
  112. }
  113. func TestNewAccessValidatorErr(t *testing.T) {
  114. ctx := context.Background()
  115. urls := []string{
  116. "",
  117. "ftp://test.cloudflareaccess.com",
  118. "wss://cloudflarenone.com",
  119. }
  120. for _, url := range urls {
  121. access, err := NewAccessValidator(ctx, url, url, "")
  122. assert.Error(t, err, url)
  123. assert.Nil(t, access)
  124. }
  125. }
  126. type testRoundTripper func(req *http.Request) (*http.Response, error)
  127. func (f testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  128. return f(req)
  129. }
  130. func emptyResponse(statusCode int) *http.Response {
  131. return &http.Response{
  132. StatusCode: statusCode,
  133. Body: io.NopCloser(bytes.NewReader(nil)),
  134. Header: make(http.Header),
  135. }
  136. }
  137. func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
  138. client := http.DefaultClient
  139. server := httptest.NewServer(handler)
  140. client.Transport = &http.Transport{
  141. Proxy: func(req *http.Request) (*url.URL, error) {
  142. return url.Parse(server.URL)
  143. },
  144. }
  145. return server, client, nil
  146. }
  147. func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
  148. client := http.DefaultClient
  149. server := httptest.NewTLSServer(handler)
  150. cert, err := x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0])
  151. if err != nil {
  152. server.Close()
  153. return nil, nil, err
  154. }
  155. certpool := x509.NewCertPool()
  156. certpool.AddCert(cert)
  157. client.Transport = &http.Transport{
  158. DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
  159. return net.Dial("tcp", server.URL[strings.LastIndex(server.URL, "/")+1:])
  160. },
  161. TLSClientConfig: &tls.Config{
  162. RootCAs: certpool,
  163. },
  164. }
  165. return server, client, nil
  166. }
  167. func FuzzNewAccessValidator(f *testing.F) {
  168. f.Fuzz(func(t *testing.T, domain string, issuer string, applicationAUD string) {
  169. ctx := context.Background()
  170. _, _ = NewAccessValidator(ctx, domain, issuer, applicationAUD)
  171. })
  172. }