reconnect_test.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package origin
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  10. )
  11. func TestRefreshAuthBackoff(t *testing.T) {
  12. rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
  13. var wait time.Duration
  14. timeAfter = func(d time.Duration) <-chan time.Time {
  15. wait = d
  16. return time.After(d)
  17. }
  18. backoff := &BackoffHandler{MaxRetries: 3}
  19. auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
  20. return nil, fmt.Errorf("authentication failure")
  21. }
  22. // authentication failures should consume the backoff
  23. for i := uint(0); i < backoff.MaxRetries; i++ {
  24. retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
  25. assert.NoError(t, err)
  26. assert.NotNil(t, retryChan)
  27. assert.Equal(t, (1<<i)*time.Second, wait)
  28. }
  29. retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
  30. assert.Error(t, err)
  31. assert.Nil(t, retryChan)
  32. // now we actually make contact with the remote server
  33. _, _ = rcm.RefreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
  34. return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
  35. })
  36. // The backoff timer should have been reset. To confirm this, make timeNow
  37. // return a value after the backoff timer's grace period
  38. timeNow = func() time.Time {
  39. expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
  40. return time.Now().Add(expectedGracePeriod * 2)
  41. }
  42. _, ok := backoff.GetBackoffDuration(context.Background())
  43. assert.True(t, ok)
  44. }
  45. func TestRefreshAuthSuccess(t *testing.T) {
  46. rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
  47. var wait time.Duration
  48. timeAfter = func(d time.Duration) <-chan time.Time {
  49. wait = d
  50. return time.After(d)
  51. }
  52. backoff := &BackoffHandler{MaxRetries: 3}
  53. auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
  54. return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
  55. }
  56. retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
  57. assert.NoError(t, err)
  58. assert.NotNil(t, retryChan)
  59. assert.Equal(t, 19*time.Hour, wait)
  60. token, err := rcm.ReconnectToken()
  61. assert.NoError(t, err)
  62. assert.Equal(t, []byte("jwt"), token)
  63. }
  64. func TestRefreshAuthUnknown(t *testing.T) {
  65. rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
  66. var wait time.Duration
  67. timeAfter = func(d time.Duration) <-chan time.Time {
  68. wait = d
  69. return time.After(d)
  70. }
  71. backoff := &BackoffHandler{MaxRetries: 3}
  72. auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
  73. return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
  74. }
  75. retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
  76. assert.NoError(t, err)
  77. assert.NotNil(t, retryChan)
  78. assert.Equal(t, 19*time.Hour, wait)
  79. token, err := rcm.ReconnectToken()
  80. assert.Equal(t, errJWTUnset, err)
  81. assert.Nil(t, token)
  82. }
  83. func TestRefreshAuthFail(t *testing.T) {
  84. rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
  85. backoff := &BackoffHandler{MaxRetries: 3}
  86. auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
  87. return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
  88. }
  89. retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
  90. assert.Error(t, err)
  91. assert.Nil(t, retryChan)
  92. token, err := rcm.ReconnectToken()
  93. assert.Equal(t, errJWTUnset, err)
  94. assert.Nil(t, token)
  95. }