jwtvalidator_test.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. package middleware
  2. import (
  3. "context"
  4. "crypto"
  5. "crypto/ecdsa"
  6. "crypto/elliptic"
  7. "crypto/rand"
  8. "encoding/json"
  9. "fmt"
  10. "net/http/httptest"
  11. "testing"
  12. "time"
  13. "github.com/coreos/go-oidc/v3/oidc"
  14. "github.com/go-jose/go-jose/v4"
  15. "github.com/go-jose/go-jose/v4/jwt"
  16. "github.com/stretchr/testify/assert"
  17. "github.com/stretchr/testify/require"
  18. )
  19. var (
  20. issuer = fmt.Sprintf(cloudflareAccessCertsURL, "testteam")
  21. )
  22. type accessTokenClaims struct {
  23. Email string `json:"email"`
  24. Type string `json:"type"`
  25. jwt.Claims
  26. }
  27. func TestJWTValidator(t *testing.T) {
  28. req := httptest.NewRequest("GET", "http://example.com", nil)
  29. key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  30. require.NoError(t, err)
  31. issued := time.Now()
  32. claims := accessTokenClaims{
  33. Email: "test@example.com",
  34. Type: "app",
  35. Claims: jwt.Claims{
  36. Issuer: issuer,
  37. Subject: "ee239b7a-e3e6-4173-972a-8fbe9d99c04f",
  38. Audience: []string{""},
  39. Expiry: jwt.NewNumericDate(issued.Add(time.Hour)),
  40. IssuedAt: jwt.NewNumericDate(issued),
  41. },
  42. }
  43. token := signToken(t, claims, key)
  44. req.Header.Add(headerKeyAccessJWTAssertion, token)
  45. keySet := oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.Public()}}
  46. config := &oidc.Config{
  47. SkipClientIDCheck: true,
  48. SupportedSigningAlgs: []string{string(jose.ES256)},
  49. }
  50. verifier := oidc.NewVerifier(issuer, &keySet, config)
  51. tests := []struct {
  52. name string
  53. audTags []string
  54. aud jwt.Audience
  55. error bool
  56. }{
  57. {
  58. name: "valid",
  59. audTags: []string{
  60. "0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
  61. "d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
  62. },
  63. aud: jwt.Audience{"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba"},
  64. error: false,
  65. },
  66. {
  67. name: "invalid no match",
  68. audTags: []string{
  69. "0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
  70. "d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
  71. },
  72. aud: jwt.Audience{"09dc377143841843ecca28b196bdb1ec1675af38c8b7b60c7def5876c8877157"},
  73. error: true,
  74. },
  75. {
  76. name: "invalid empty check",
  77. audTags: []string{},
  78. aud: jwt.Audience{"09dc377143841843ecca28b196bdb1ec1675af38c8b7b60c7def5876c8877157"},
  79. error: true,
  80. },
  81. {
  82. name: "invalid absent aud",
  83. audTags: []string{
  84. "0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
  85. "d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
  86. },
  87. aud: jwt.Audience{""},
  88. error: true,
  89. },
  90. }
  91. for _, test := range tests {
  92. t.Run(test.name, func(t *testing.T) {
  93. validator := JWTValidator{
  94. IDTokenVerifier: verifier,
  95. audTags: test.audTags,
  96. }
  97. claims.Audience = test.aud
  98. token := signToken(t, claims, key)
  99. req.Header.Set(headerKeyAccessJWTAssertion, token)
  100. result, err := validator.Handle(context.Background(), req)
  101. assert.NoError(t, err)
  102. assert.Equal(t, test.error, result.ShouldFilterRequest)
  103. })
  104. }
  105. }
  106. func signToken(t *testing.T, token accessTokenClaims, key *ecdsa.PrivateKey) string {
  107. signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, &jose.SignerOptions{})
  108. require.NoError(t, err)
  109. payload, err := json.Marshal(token)
  110. require.NoError(t, err)
  111. jws, err := signer.Sign(payload)
  112. require.NoError(t, err)
  113. jwt, err := jws.CompactSerialize()
  114. require.NoError(t, err)
  115. return jwt
  116. }