validation_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. package validation
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io/ioutil"
  6. "testing"
  7. "time"
  8. "context"
  9. "crypto/tls"
  10. "crypto/x509"
  11. "net"
  12. "net/http"
  13. "net/http/httptest"
  14. "net/url"
  15. "strings"
  16. "github.com/stretchr/testify/assert"
  17. )
  18. func TestValidateHostname(t *testing.T) {
  19. var inputHostname string
  20. hostname, err := ValidateHostname(inputHostname)
  21. assert.Equal(t, err, nil)
  22. assert.Empty(t, hostname)
  23. inputHostname = "hello.example.com"
  24. hostname, err = ValidateHostname(inputHostname)
  25. assert.Nil(t, err)
  26. assert.Equal(t, "hello.example.com", hostname)
  27. inputHostname = "http://hello.example.com"
  28. hostname, err = ValidateHostname(inputHostname)
  29. assert.Nil(t, err)
  30. assert.Equal(t, "hello.example.com", hostname)
  31. inputHostname = "bücher.example.com"
  32. hostname, err = ValidateHostname(inputHostname)
  33. assert.Nil(t, err)
  34. assert.Equal(t, "xn--bcher-kva.example.com", hostname)
  35. inputHostname = "http://bücher.example.com"
  36. hostname, err = ValidateHostname(inputHostname)
  37. assert.Nil(t, err)
  38. assert.Equal(t, "xn--bcher-kva.example.com", hostname)
  39. inputHostname = "http%3A%2F%2Fhello.example.com"
  40. hostname, err = ValidateHostname(inputHostname)
  41. assert.Nil(t, err)
  42. assert.Equal(t, "hello.example.com", hostname)
  43. }
  44. func TestValidateUrl(t *testing.T) {
  45. type testCase struct {
  46. input string
  47. expectedOutput string
  48. }
  49. testCases := []testCase{
  50. {"http://localhost", "http://localhost"},
  51. {"http://localhost/", "http://localhost"},
  52. {"http://localhost/api", "http://localhost"},
  53. {"http://localhost/api/", "http://localhost"},
  54. {"https://localhost", "https://localhost"},
  55. {"https://localhost/", "https://localhost"},
  56. {"https://localhost/api", "https://localhost"},
  57. {"https://localhost/api/", "https://localhost"},
  58. {"https://localhost:8080", "https://localhost:8080"},
  59. {"https://localhost:8080/", "https://localhost:8080"},
  60. {"https://localhost:8080/api", "https://localhost:8080"},
  61. {"https://localhost:8080/api/", "https://localhost:8080"},
  62. {"localhost", "http://localhost"},
  63. {"localhost/", "http://localhost/"},
  64. {"localhost/api", "http://localhost/api"},
  65. {"localhost/api/", "http://localhost/api/"},
  66. {"localhost:8080", "http://localhost:8080"},
  67. {"localhost:8080/", "http://localhost:8080/"},
  68. {"localhost:8080/api", "http://localhost:8080/api"},
  69. {"localhost:8080/api/", "http://localhost:8080/api/"},
  70. {"localhost:8080/api/?asdf", "http://localhost:8080/api/?asdf"},
  71. {"http://127.0.0.1:8080", "http://127.0.0.1:8080"},
  72. {"127.0.0.1:8080", "http://127.0.0.1:8080"},
  73. {"127.0.0.1", "http://127.0.0.1"},
  74. {"https://127.0.0.1:8080", "https://127.0.0.1:8080"},
  75. {"[::1]:8080", "http://[::1]:8080"},
  76. {"http://[::1]", "http://[::1]"},
  77. {"http://[::1]:8080", "http://[::1]:8080"},
  78. {"[::1]", "http://[::1]"},
  79. {"https://example.com", "https://example.com"},
  80. {"example.com", "http://example.com"},
  81. {"http://hello.example.com", "http://hello.example.com"},
  82. {"hello.example.com", "http://hello.example.com"},
  83. {"hello.example.com:8080", "http://hello.example.com:8080"},
  84. {"https://hello.example.com:8080", "https://hello.example.com:8080"},
  85. {"https://bücher.example.com", "https://xn--bcher-kva.example.com"},
  86. {"bücher.example.com", "http://xn--bcher-kva.example.com"},
  87. {"https%3A%2F%2Fhello.example.com", "https://hello.example.com"},
  88. {"https://alex:12345@hello.example.com:8080", "https://hello.example.com:8080"},
  89. }
  90. for i, testCase := range testCases {
  91. validUrl, err := ValidateUrl(testCase.input)
  92. assert.NoError(t, err, "test case %v", i)
  93. assert.Equal(t, testCase.expectedOutput, validUrl.String(), "test case %v", i)
  94. }
  95. validUrl, err := ValidateUrl("")
  96. assert.Equal(t, fmt.Errorf("URL should not be empty"), err)
  97. assert.Empty(t, validUrl)
  98. validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
  99. assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error())
  100. assert.Empty(t, validUrl)
  101. }
  102. func TestToggleProtocol(t *testing.T) {
  103. assert.Equal(t, "https", toggleProtocol("http"))
  104. assert.Equal(t, "http", toggleProtocol("https"))
  105. assert.Equal(t, "random", toggleProtocol("random"))
  106. assert.Equal(t, "", toggleProtocol(""))
  107. }
  108. // Happy path 1: originURL is HTTP, and HTTP connections work
  109. func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
  110. originURL := "http://127.0.0.1/"
  111. hostname := "example.com"
  112. assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
  113. assert.Equal(t, req.Host, hostname)
  114. if req.URL.Scheme == "http" {
  115. return emptyResponse(200), nil
  116. }
  117. if req.URL.Scheme == "https" {
  118. t.Fatal("http works, shouldn't have tried with https")
  119. }
  120. panic("Shouldn't reach here")
  121. })))
  122. assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
  123. assert.Equal(t, req.Host, hostname)
  124. if req.URL.Scheme == "http" {
  125. return emptyResponse(503), nil
  126. }
  127. if req.URL.Scheme == "https" {
  128. t.Fatal("http works, shouldn't have tried with https")
  129. }
  130. panic("Shouldn't reach here")
  131. })))
  132. }
  133. // Happy path 2: originURL is HTTPS, and HTTPS connections work
  134. func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
  135. originURL := "https://127.0.0.1:1234/"
  136. hostname := "example.com"
  137. assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
  138. assert.Equal(t, req.Host, hostname)
  139. if req.URL.Scheme == "http" {
  140. t.Fatal("https works, shouldn't have tried with http")
  141. }
  142. if req.URL.Scheme == "https" {
  143. return emptyResponse(200), nil
  144. }
  145. panic("Shouldn't reach here")
  146. })))
  147. assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
  148. assert.Equal(t, req.Host, hostname)
  149. if req.URL.Scheme == "http" {
  150. t.Fatal("https works, shouldn't have tried with http")
  151. }
  152. if req.URL.Scheme == "https" {
  153. return emptyResponse(503), nil
  154. }
  155. panic("Shouldn't reach here")
  156. })))
  157. }
  158. // Error path 1: originURL is HTTPS, but HTTP connections work
  159. func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
  160. originURL := "https://127.0.0.1:1234/"
  161. hostname := "example.com"
  162. assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
  163. assert.Equal(t, req.Host, hostname)
  164. if req.URL.Scheme == "http" {
  165. return emptyResponse(200), nil
  166. }
  167. if req.URL.Scheme == "https" {
  168. return nil, assert.AnError
  169. }
  170. panic("Shouldn't reach here")
  171. })))
  172. assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
  173. assert.Equal(t, req.Host, hostname)
  174. if req.URL.Scheme == "http" {
  175. return emptyResponse(503), nil
  176. }
  177. if req.URL.Scheme == "https" {
  178. return nil, assert.AnError
  179. }
  180. panic("Shouldn't reach here")
  181. })))
  182. }
  183. // Error path 2: originURL is HTTP, but HTTPS connections work
  184. func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
  185. originURL := "http://127.0.0.1:1234/"
  186. hostname := "example.com"
  187. assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
  188. assert.Equal(t, req.Host, hostname)
  189. if req.URL.Scheme == "http" {
  190. return nil, assert.AnError
  191. }
  192. if req.URL.Scheme == "https" {
  193. return emptyResponse(200), nil
  194. }
  195. panic("Shouldn't reach here")
  196. })))
  197. assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
  198. assert.Equal(t, req.Host, hostname)
  199. if req.URL.Scheme == "http" {
  200. return nil, assert.AnError
  201. }
  202. if req.URL.Scheme == "https" {
  203. return emptyResponse(503), nil
  204. }
  205. panic("Shouldn't reach here")
  206. })))
  207. }
  208. // Ensure the client does not follow 302 responses
  209. func TestValidateHTTPService_NoFollowRedirects(t *testing.T) {
  210. hostname := "example.com"
  211. redirectServer, redirectClient, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  212. if r.URL.Path == "/followedRedirect" {
  213. t.Fatal("shouldn't have followed the 302")
  214. }
  215. if r.Method == "CONNECT" {
  216. assert.Equal(t, "127.0.0.1:443", r.Host)
  217. } else {
  218. assert.Equal(t, hostname, r.Host)
  219. }
  220. w.Header().Set("Location", "/followedRedirect")
  221. w.WriteHeader(302)
  222. }))
  223. assert.NoError(t, err)
  224. defer redirectServer.Close()
  225. assert.NoError(t, ValidateHTTPService(redirectServer.URL, hostname, redirectClient.Transport))
  226. }
  227. // Ensure validation times out when origin URL is nonresponsive
  228. func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) {
  229. originURL := "http://127.0.0.1/"
  230. hostname := "example.com"
  231. oldValidationTimeout := validationTimeout
  232. defer func() {
  233. validationTimeout = oldValidationTimeout
  234. }()
  235. validationTimeout = 500 * time.Millisecond
  236. // Use createMockServerAndClient, not createSecureMockServerAndClient.
  237. // The latter will bail with HTTP 400 immediately on an http:// request,
  238. // which defeats the purpose of a 'nonresponsive origin' test.
  239. server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  240. if r.Method == "CONNECT" {
  241. assert.Equal(t, "127.0.0.1:443", r.Host)
  242. } else {
  243. assert.Equal(t, hostname, r.Host)
  244. }
  245. time.Sleep(1 * time.Second)
  246. w.WriteHeader(200)
  247. }))
  248. if !assert.NoError(t, err) {
  249. t.FailNow()
  250. }
  251. defer server.Close()
  252. err = ValidateHTTPService(originURL, hostname, client.Transport)
  253. fmt.Println(err)
  254. if err, ok := err.(net.Error); assert.True(t, ok) {
  255. assert.True(t, err.Timeout())
  256. }
  257. }
  258. func TestNewAccessValidatorOk(t *testing.T) {
  259. ctx := context.Background()
  260. url := "test.cloudflareaccess.com"
  261. access, err := NewAccessValidator(ctx, url, url, "")
  262. assert.NoError(t, err)
  263. assert.NotNil(t, access)
  264. assert.Error(t, access.Validate(ctx, ""))
  265. assert.Error(t, access.Validate(ctx, "invalid"))
  266. req := httptest.NewRequest("GET", "https://test.cloudflareaccess.com", nil)
  267. req.Header.Set(accessJwtHeader, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c")
  268. assert.Error(t, access.ValidateRequest(ctx, req))
  269. }
  270. func TestNewAccessValidatorErr(t *testing.T) {
  271. ctx := context.Background()
  272. urls := []string{
  273. "",
  274. "ftp://test.cloudflareaccess.com",
  275. "wss://cloudflarenone.com",
  276. }
  277. for _, url := range urls {
  278. access, err := NewAccessValidator(ctx, url, url, "")
  279. assert.Error(t, err, url)
  280. assert.Nil(t, access)
  281. }
  282. }
  283. type testRoundTripper func(req *http.Request) (*http.Response, error)
  284. func (f testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  285. return f(req)
  286. }
  287. func emptyResponse(statusCode int) *http.Response {
  288. return &http.Response{
  289. StatusCode: statusCode,
  290. Body: ioutil.NopCloser(bytes.NewReader(nil)),
  291. Header: make(http.Header),
  292. }
  293. }
  294. func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
  295. client := http.DefaultClient
  296. server := httptest.NewServer(handler)
  297. client.Transport = &http.Transport{
  298. Proxy: func(req *http.Request) (*url.URL, error) {
  299. return url.Parse(server.URL)
  300. },
  301. }
  302. return server, client, nil
  303. }
  304. func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
  305. client := http.DefaultClient
  306. server := httptest.NewTLSServer(handler)
  307. cert, err := x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0])
  308. if err != nil {
  309. server.Close()
  310. return nil, nil, err
  311. }
  312. certpool := x509.NewCertPool()
  313. certpool.AddCert(cert)
  314. client.Transport = &http.Transport{
  315. DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
  316. return net.Dial("tcp", server.URL[strings.LastIndex(server.URL, "/")+1:])
  317. },
  318. TLSClientConfig: &tls.Config{
  319. RootCAs: certpool,
  320. },
  321. }
  322. return server, client, nil
  323. }