123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- package validation
- import (
- "context"
- "fmt"
- "net"
- "net/url"
- "strings"
- "time"
- "net/http"
- "github.com/pkg/errors"
- "golang.org/x/net/idna"
- "gopkg.in/coreos/go-oidc.v2"
- )
- const (
- defaultScheme = "http"
- accessDomain = "cloudflareaccess.com"
- accessCertPath = "/cdn-cgi/access/certs"
- accessJwtHeader = "Cf-access-jwt-assertion"
- )
- var (
- supportedProtocols = []string{"http", "https", "rdp", "ssh", "smb", "tcp"}
- validationTimeout = time.Duration(30 * time.Second)
- )
- func ValidateHostname(hostname string) (string, error) {
- if hostname == "" {
- return "", nil
- }
- // users gives url(contains schema) not just hostname
- if strings.Contains(hostname, ":") || strings.Contains(hostname, "%3A") {
- unescapeHostname, err := url.PathUnescape(hostname)
- if err != nil {
- return "", fmt.Errorf("Hostname(actually a URL) %s has invalid escape characters %s", hostname, unescapeHostname)
- }
- hostnameToURL, err := url.Parse(unescapeHostname)
- if err != nil {
- return "", fmt.Errorf("Hostname(actually a URL) %s has invalid format %s", hostname, hostnameToURL)
- }
- asciiHostname, err := idna.ToASCII(hostnameToURL.Hostname())
- if err != nil {
- return "", fmt.Errorf("Hostname(actually a URL) %s has invalid ASCII encdoing %s", hostname, asciiHostname)
- }
- return asciiHostname, nil
- }
- asciiHostname, err := idna.ToASCII(hostname)
- if err != nil {
- return "", fmt.Errorf("Hostname %s has invalid ASCII encdoing %s", hostname, asciiHostname)
- }
- hostnameToURL, err := url.Parse(asciiHostname)
- if err != nil {
- return "", fmt.Errorf("Hostname %s is not valid", hostnameToURL)
- }
- return hostnameToURL.RequestURI(), nil
- }
- // ValidateUrl returns a validated version of `originUrl` with a scheme prepended (by default http://).
- // Note: when originUrl contains a scheme, the path is removed:
- // ValidateUrl("https://localhost:8080/api/") => "https://localhost:8080"
- // but when it does not, the path is preserved:
- // ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/"
- // This is arguably a bug, but changing it might break some cloudflared users.
- func ValidateUrl(originUrl string) (string, error) {
- if originUrl == "" {
- return "", fmt.Errorf("URL should not be empty")
- }
- if net.ParseIP(originUrl) != nil {
- return validateIP("", originUrl, "")
- } else if strings.HasPrefix(originUrl, "[") && strings.HasSuffix(originUrl, "]") {
- // ParseIP doesn't recoginze [::1]
- return validateIP("", originUrl[1:len(originUrl)-1], "")
- }
- host, port, err := net.SplitHostPort(originUrl)
- // user might pass in an ip address like 127.0.0.1
- if err == nil && net.ParseIP(host) != nil {
- return validateIP("", host, port)
- }
- unescapedUrl, err := url.PathUnescape(originUrl)
- if err != nil {
- return "", fmt.Errorf("URL %s has invalid escape characters %s", originUrl, unescapedUrl)
- }
- parsedUrl, err := url.Parse(unescapedUrl)
- if err != nil {
- return "", fmt.Errorf("URL %s has invalid format", originUrl)
- }
- // if the url is in the form of host:port, IsAbs() will think host is the schema
- var hostname string
- hasScheme := parsedUrl.IsAbs() && parsedUrl.Host != ""
- if hasScheme {
- err := validateScheme(parsedUrl.Scheme)
- if err != nil {
- return "", err
- }
- // The earlier check for ip address will miss the case http://[::1]
- // and http://[::1]:8080
- if net.ParseIP(parsedUrl.Hostname()) != nil {
- return validateIP(parsedUrl.Scheme, parsedUrl.Hostname(), parsedUrl.Port())
- }
- hostname, err = ValidateHostname(parsedUrl.Hostname())
- if err != nil {
- return "", fmt.Errorf("URL %s has invalid format", originUrl)
- }
- if parsedUrl.Port() != "" {
- return fmt.Sprintf("%s://%s", parsedUrl.Scheme, net.JoinHostPort(hostname, parsedUrl.Port())), nil
- }
- return fmt.Sprintf("%s://%s", parsedUrl.Scheme, hostname), nil
- } else {
- if host == "" {
- hostname, err = ValidateHostname(originUrl)
- if err != nil {
- return "", fmt.Errorf("URL no %s has invalid format", originUrl)
- }
- return fmt.Sprintf("%s://%s", defaultScheme, hostname), nil
- } else {
- hostname, err = ValidateHostname(host)
- if err != nil {
- return "", fmt.Errorf("URL %s has invalid format", originUrl)
- }
- // This is why the path is preserved when `originUrl` doesn't have a schema.
- // Using `parsedUrl.Port()` here, instead of `port`, would remove the path
- return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
- }
- }
- }
- func validateScheme(scheme string) error {
- for _, protocol := range supportedProtocols {
- if scheme == protocol {
- return nil
- }
- }
- return fmt.Errorf("Currently Argo Tunnel does not support %s protocol.", scheme)
- }
- func validateIP(scheme, host, port string) (string, error) {
- if scheme == "" {
- scheme = defaultScheme
- }
- if port != "" {
- return fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, port)), nil
- } else if strings.Contains(host, ":") {
- // IPv6
- return fmt.Sprintf("%s://[%s]", scheme, host), nil
- }
- return fmt.Sprintf("%s://%s", scheme, host), nil
- }
- func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error {
- parsedURL, err := url.Parse(originURL)
- if err != nil {
- return err
- }
- client := &http.Client{
- Transport: transport,
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return http.ErrUseLastResponse
- },
- Timeout: validationTimeout,
- }
- initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
- if err != nil {
- return err
- }
- initialRequest.Host = hostname
- _, initialErr := client.Do(initialRequest)
- if initialErr != nil {
- // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
- oldScheme := parsedURL.Scheme
- parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
- secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
- if err != nil {
- return err
- }
- secondRequest.Host = hostname
- _, secondErr := client.Do(secondRequest)
- if secondErr == nil { // Worked this time--advise the user to switch protocols
- return errors.Errorf(
- "%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s",
- parsedURL.Host,
- oldScheme,
- parsedURL.Scheme,
- initialErr,
- parsedURL,
- )
- }
- }
- return initialErr
- }
- func toggleProtocol(httpProtocol string) string {
- switch httpProtocol {
- case "http":
- return "https"
- case "https":
- return "http"
- default:
- return httpProtocol
- }
- }
- // Access checks if a JWT from Cloudflare Access is valid.
- type Access struct {
- verifier *oidc.IDTokenVerifier
- }
- func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) {
- domainURL, err := ValidateUrl(domain)
- if err != nil {
- return nil, err
- }
- issuerURL, err := ValidateUrl(issuer)
- if err != nil {
- return nil, err
- }
- // An issuerURL from Cloudflare Access will always use HTTPS.
- issuerURL = strings.Replace(issuerURL, "http:", "https:", 1)
- keySet := oidc.NewRemoteKeySet(ctx, domainURL+accessCertPath)
- return &Access{oidc.NewVerifier(issuerURL, keySet, &oidc.Config{ClientID: applicationAUD})}, nil
- }
- func (a *Access) Validate(ctx context.Context, jwt string) error {
- token, err := a.verifier.Verify(ctx, jwt)
- if err != nil {
- return errors.Wrapf(err, "token is invalid: %s", jwt)
- }
- // Perform extra sanity checks, just to be safe.
- if token == nil {
- return fmt.Errorf("token is nil: %s", jwt)
- }
- if !strings.HasSuffix(token.Issuer, accessDomain) {
- return fmt.Errorf("token has non-cloudflare issuer of %s: %s", token.Issuer, jwt)
- }
- return nil
- }
- func (a *Access) ValidateRequest(ctx context.Context, r *http.Request) error {
- return a.Validate(ctx, r.Header.Get(accessJwtHeader))
- }
|