jwtvalidator.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. package middleware
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "github.com/coreos/go-oidc/v3/oidc"
  7. )
  8. const (
  9. headerKeyAccessJWTAssertion = "Cf-Access-Jwt-Assertion"
  10. )
  11. var (
  12. cloudflareAccessCertsURL = "https://%s.cloudflareaccess.com"
  13. )
  14. // JWTValidator is an implementation of Verifier that validates access based JWT tokens.
  15. type JWTValidator struct {
  16. *oidc.IDTokenVerifier
  17. audTags []string
  18. }
  19. func NewJWTValidator(teamName string, certsURL string, audTags []string) *JWTValidator {
  20. if certsURL == "" {
  21. certsURL = fmt.Sprintf(cloudflareAccessCertsURL, teamName)
  22. }
  23. certsEndpoint := fmt.Sprintf("%s/cdn-cgi/access/certs", certsURL)
  24. config := &oidc.Config{
  25. SkipClientIDCheck: true,
  26. }
  27. ctx := context.Background()
  28. keySet := oidc.NewRemoteKeySet(ctx, certsEndpoint)
  29. verifier := oidc.NewVerifier(certsURL, keySet, config)
  30. return &JWTValidator{
  31. IDTokenVerifier: verifier,
  32. audTags: audTags,
  33. }
  34. }
  35. func (v *JWTValidator) Name() string {
  36. return "AccessJWTValidator"
  37. }
  38. func (v *JWTValidator) Handle(ctx context.Context, r *http.Request) (*HandleResult, error) {
  39. accessJWT := r.Header.Get(headerKeyAccessJWTAssertion)
  40. if accessJWT == "" {
  41. // log the exact error message here. the message is specific to the handler implementation logic, we don't gain anything
  42. // in passing it upstream. and each handler impl know what logging level to use for each.
  43. return &HandleResult{
  44. ShouldFilterRequest: true,
  45. StatusCode: http.StatusForbidden,
  46. Reason: "no access token in request",
  47. }, nil
  48. }
  49. token, err := v.IDTokenVerifier.Verify(ctx, accessJWT)
  50. if err != nil {
  51. return nil, err
  52. }
  53. // We want at least one audTag to match
  54. for _, jwtAudTag := range token.Audience {
  55. for _, acceptedAudTag := range v.audTags {
  56. if acceptedAudTag == jwtAudTag {
  57. return &HandleResult{ShouldFilterRequest: false}, nil
  58. }
  59. }
  60. }
  61. return &HandleResult{
  62. ShouldFilterRequest: true,
  63. StatusCode: http.StatusForbidden,
  64. Reason: fmt.Sprintf("Invalid token in jwt: %v", token.Audience),
  65. }, nil
  66. }