validation.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. package validation
  2. import (
  3. "context"
  4. "fmt"
  5. "net"
  6. "net/http"
  7. "net/url"
  8. "strings"
  9. "time"
  10. "github.com/coreos/go-oidc/v3/oidc"
  11. "github.com/pkg/errors"
  12. "golang.org/x/net/idna"
  13. )
  14. const (
  15. defaultScheme = "http"
  16. accessDomain = "cloudflareaccess.com"
  17. accessCertPath = "/cdn-cgi/access/certs"
  18. accessJwtHeader = "Cf-access-jwt-assertion"
  19. )
  20. var (
  21. supportedProtocols = []string{"http", "https", "rdp", "ssh", "smb", "tcp"}
  22. validationTimeout = time.Duration(30 * time.Second)
  23. )
  24. func ValidateHostname(hostname string) (string, error) {
  25. if hostname == "" {
  26. return "", nil
  27. }
  28. // users gives url(contains schema) not just hostname
  29. if strings.Contains(hostname, ":") || strings.Contains(hostname, "%3A") {
  30. unescapeHostname, err := url.PathUnescape(hostname)
  31. if err != nil {
  32. return "", fmt.Errorf("Hostname(actually a URL) %s has invalid escape characters %s", hostname, unescapeHostname)
  33. }
  34. hostnameToURL, err := url.Parse(unescapeHostname)
  35. if err != nil {
  36. return "", fmt.Errorf("Hostname(actually a URL) %s has invalid format %s", hostname, hostnameToURL)
  37. }
  38. asciiHostname, err := idna.ToASCII(hostnameToURL.Hostname())
  39. if err != nil {
  40. return "", fmt.Errorf("Hostname(actually a URL) %s has invalid ASCII encdoing %s", hostname, asciiHostname)
  41. }
  42. return asciiHostname, nil
  43. }
  44. asciiHostname, err := idna.ToASCII(hostname)
  45. if err != nil {
  46. return "", fmt.Errorf("Hostname %s has invalid ASCII encdoing %s", hostname, asciiHostname)
  47. }
  48. hostnameToURL, err := url.Parse(asciiHostname)
  49. if err != nil {
  50. return "", fmt.Errorf("Hostname %s is not valid", hostnameToURL)
  51. }
  52. return hostnameToURL.RequestURI(), nil
  53. }
  54. // ValidateUrl returns a validated version of `originUrl` with a scheme prepended (by default http://).
  55. // Note: when originUrl contains a scheme, the path is removed:
  56. //
  57. // ValidateUrl("https://localhost:8080/api/") => "https://localhost:8080"
  58. //
  59. // but when it does not, the path is preserved:
  60. //
  61. // ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/"
  62. //
  63. // This is arguably a bug, but changing it might break some cloudflared users.
  64. func ValidateUrl(originUrl string) (*url.URL, error) {
  65. urlStr, err := validateUrlString(originUrl)
  66. if err != nil {
  67. return nil, err
  68. }
  69. return url.Parse(urlStr)
  70. }
  71. func validateUrlString(originUrl string) (string, error) {
  72. if originUrl == "" {
  73. return "", fmt.Errorf("URL should not be empty")
  74. }
  75. if net.ParseIP(originUrl) != nil {
  76. return validateIP("", originUrl, "")
  77. } else if strings.HasPrefix(originUrl, "[") && strings.HasSuffix(originUrl, "]") {
  78. // ParseIP doesn't recoginze [::1]
  79. return validateIP("", originUrl[1:len(originUrl)-1], "")
  80. }
  81. host, port, err := net.SplitHostPort(originUrl)
  82. // user might pass in an ip address like 127.0.0.1
  83. if err == nil && net.ParseIP(host) != nil {
  84. return validateIP("", host, port)
  85. }
  86. unescapedUrl, err := url.PathUnescape(originUrl)
  87. if err != nil {
  88. return "", fmt.Errorf("URL %s has invalid escape characters %s", originUrl, unescapedUrl)
  89. }
  90. parsedUrl, err := url.Parse(unescapedUrl)
  91. if err != nil {
  92. return "", fmt.Errorf("URL %s has invalid format", originUrl)
  93. }
  94. // if the url is in the form of host:port, IsAbs() will think host is the schema
  95. var hostname string
  96. hasScheme := parsedUrl.IsAbs() && parsedUrl.Host != ""
  97. if hasScheme {
  98. err := validateScheme(parsedUrl.Scheme)
  99. if err != nil {
  100. return "", err
  101. }
  102. // The earlier check for ip address will miss the case http://[::1]
  103. // and http://[::1]:8080
  104. if net.ParseIP(parsedUrl.Hostname()) != nil {
  105. return validateIP(parsedUrl.Scheme, parsedUrl.Hostname(), parsedUrl.Port())
  106. }
  107. hostname, err = ValidateHostname(parsedUrl.Hostname())
  108. if err != nil {
  109. return "", fmt.Errorf("URL %s has invalid format", originUrl)
  110. }
  111. if parsedUrl.Port() != "" {
  112. return fmt.Sprintf("%s://%s", parsedUrl.Scheme, net.JoinHostPort(hostname, parsedUrl.Port())), nil
  113. }
  114. return fmt.Sprintf("%s://%s", parsedUrl.Scheme, hostname), nil
  115. } else {
  116. if host == "" {
  117. hostname, err = ValidateHostname(originUrl)
  118. if err != nil {
  119. return "", fmt.Errorf("URL no %s has invalid format", originUrl)
  120. }
  121. return fmt.Sprintf("%s://%s", defaultScheme, hostname), nil
  122. } else {
  123. hostname, err = ValidateHostname(host)
  124. if err != nil {
  125. return "", fmt.Errorf("URL %s has invalid format", originUrl)
  126. }
  127. // This is why the path is preserved when `originUrl` doesn't have a schema.
  128. // Using `parsedUrl.Port()` here, instead of `port`, would remove the path
  129. return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
  130. }
  131. }
  132. }
  133. func validateScheme(scheme string) error {
  134. for _, protocol := range supportedProtocols {
  135. if scheme == protocol {
  136. return nil
  137. }
  138. }
  139. return fmt.Errorf("Currently Cloudflare Tunnel does not support %s protocol.", scheme)
  140. }
  141. func validateIP(scheme, host, port string) (string, error) {
  142. if scheme == "" {
  143. scheme = defaultScheme
  144. }
  145. if port != "" {
  146. return fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, port)), nil
  147. } else if strings.Contains(host, ":") {
  148. // IPv6
  149. return fmt.Sprintf("%s://[%s]", scheme, host), nil
  150. }
  151. return fmt.Sprintf("%s://%s", scheme, host), nil
  152. }
  153. // Access checks if a JWT from Cloudflare Access is valid.
  154. type Access struct {
  155. verifier *oidc.IDTokenVerifier
  156. }
  157. func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) {
  158. domainURL, err := validateUrlString(domain)
  159. if err != nil {
  160. return nil, err
  161. }
  162. issuerURL, err := validateUrlString(issuer)
  163. if err != nil {
  164. return nil, err
  165. }
  166. // An issuerURL from Cloudflare Access will always use HTTPS.
  167. issuerURL = strings.Replace(issuerURL, "http:", "https:", 1)
  168. keySet := oidc.NewRemoteKeySet(ctx, domainURL+accessCertPath)
  169. return &Access{oidc.NewVerifier(issuerURL, keySet, &oidc.Config{ClientID: applicationAUD})}, nil
  170. }
  171. func (a *Access) Validate(ctx context.Context, jwt string) error {
  172. token, err := a.verifier.Verify(ctx, jwt)
  173. if err != nil {
  174. return errors.Wrapf(err, "token is invalid: %s", jwt)
  175. }
  176. // Perform extra sanity checks, just to be safe.
  177. if token == nil {
  178. return fmt.Errorf("token is nil: %s", jwt)
  179. }
  180. if !strings.HasSuffix(token.Issuer, accessDomain) {
  181. return fmt.Errorf("token has non-cloudflare issuer of %s: %s", token.Issuer, jwt)
  182. }
  183. return nil
  184. }
  185. func (a *Access) ValidateRequest(ctx context.Context, r *http.Request) error {
  186. return a.Validate(ctx, r.Header.Get(accessJwtHeader))
  187. }