middleware_test.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. package management
  2. import (
  3. "io"
  4. "net/http"
  5. "net/http/httptest"
  6. "testing"
  7. "github.com/go-chi/chi/v5"
  8. "github.com/stretchr/testify/assert"
  9. )
  10. func TestValidateAccessTokenQueryMiddleware(t *testing.T) {
  11. r := chi.NewRouter()
  12. r.Use(ValidateAccessTokenQueryMiddleware)
  13. r.Get("/valid", func(w http.ResponseWriter, r *http.Request) {
  14. claims, ok := r.Context().Value(accessClaimsCtxKey).(*managementTokenClaims)
  15. assert.True(t, ok)
  16. assert.True(t, claims.verify())
  17. w.WriteHeader(http.StatusOK)
  18. })
  19. r.Get("/invalid", func(w http.ResponseWriter, r *http.Request) {
  20. _, ok := r.Context().Value(accessClaimsCtxKey).(*managementTokenClaims)
  21. assert.False(t, ok)
  22. w.WriteHeader(http.StatusOK)
  23. })
  24. ts := httptest.NewServer(r)
  25. defer ts.Close()
  26. // valid: with access_token query param
  27. path := "/valid?access_token=" + validToken
  28. resp, _ := testRequest(t, ts, "GET", path, nil)
  29. assert.Equal(t, http.StatusOK, resp.StatusCode)
  30. // invalid: unset token
  31. path = "/invalid"
  32. resp, err := testRequest(t, ts, "GET", path, nil)
  33. assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
  34. assert.NotNil(t, err)
  35. assert.Equal(t, errMissingAccessToken, err.Errors[0])
  36. // invalid: invalid token
  37. path = "/invalid?access_token=eyJ"
  38. resp, err = testRequest(t, ts, "GET", path, nil)
  39. assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
  40. assert.NotNil(t, err)
  41. assert.Equal(t, errMissingAccessToken, err.Errors[0])
  42. }
  43. func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, *managementErrorResponse) {
  44. req, err := http.NewRequest(method, ts.URL+path, body)
  45. if err != nil {
  46. t.Fatal(err)
  47. }
  48. resp, err := ts.Client().Do(req)
  49. if err != nil {
  50. t.Fatal(err)
  51. }
  52. var claims managementErrorResponse
  53. err = json.NewDecoder(resp.Body).Decode(&claims)
  54. if err != nil {
  55. return resp, nil
  56. }
  57. defer resp.Body.Close()
  58. return resp, &claims
  59. }