protocol_test.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package connection
  2. import (
  3. "fmt"
  4. "testing"
  5. "github.com/stretchr/testify/assert"
  6. "github.com/cloudflare/cloudflared/edgediscovery"
  7. )
  8. const (
  9. testNoTTL = 0
  10. testAccountTag = "testAccountTag"
  11. )
  12. func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) edgediscovery.PercentageFetcher {
  13. return func() (edgediscovery.ProtocolPercents, error) {
  14. if getError {
  15. return nil, fmt.Errorf("failed to fetch percentage")
  16. }
  17. return protocolPercent, nil
  18. }
  19. }
  20. type dynamicMockFetcher struct {
  21. protocolPercents edgediscovery.ProtocolPercents
  22. err error
  23. }
  24. func (dmf *dynamicMockFetcher) fetch() edgediscovery.PercentageFetcher {
  25. return func() (edgediscovery.ProtocolPercents, error) {
  26. return dmf.protocolPercents, dmf.err
  27. }
  28. }
  29. func TestNewProtocolSelector(t *testing.T) {
  30. tests := []struct {
  31. name string
  32. protocol string
  33. tunnelTokenProvided bool
  34. needPQ bool
  35. expectedProtocol Protocol
  36. hasFallback bool
  37. expectedFallback Protocol
  38. wantErr bool
  39. }{
  40. {
  41. name: "named tunnel with unknown protocol",
  42. protocol: "unknown",
  43. wantErr: true,
  44. },
  45. {
  46. name: "named tunnel with h2mux: force to http2",
  47. protocol: "h2mux",
  48. expectedProtocol: HTTP2,
  49. },
  50. {
  51. name: "named tunnel with http2: no fallback",
  52. protocol: "http2",
  53. expectedProtocol: HTTP2,
  54. },
  55. {
  56. name: "named tunnel with auto: quic",
  57. protocol: AutoSelectFlag,
  58. expectedProtocol: QUIC,
  59. hasFallback: true,
  60. expectedFallback: HTTP2,
  61. },
  62. {
  63. name: "named tunnel (post quantum)",
  64. protocol: AutoSelectFlag,
  65. needPQ: true,
  66. expectedProtocol: QUIC,
  67. },
  68. {
  69. name: "named tunnel (post quantum) w/http2",
  70. protocol: "http2",
  71. needPQ: true,
  72. expectedProtocol: QUIC,
  73. },
  74. }
  75. fetcher := dynamicMockFetcher{
  76. protocolPercents: edgediscovery.ProtocolPercents{},
  77. }
  78. for _, test := range tests {
  79. t.Run(test.name, func(t *testing.T) {
  80. selector, err := NewProtocolSelector(test.protocol, testAccountTag, test.tunnelTokenProvided, test.needPQ, fetcher.fetch(), ResolveTTL, &log)
  81. if test.wantErr {
  82. assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
  83. } else {
  84. assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
  85. assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
  86. fallback, ok := selector.Fallback()
  87. assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
  88. if test.hasFallback {
  89. assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
  90. }
  91. }
  92. })
  93. }
  94. }
  95. func TestAutoProtocolSelectorRefresh(t *testing.T) {
  96. fetcher := dynamicMockFetcher{}
  97. selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log)
  98. assert.NoError(t, err)
  99. assert.Equal(t, QUIC, selector.Current())
  100. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
  101. assert.Equal(t, HTTP2, selector.Current())
  102. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
  103. assert.Equal(t, QUIC, selector.Current())
  104. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
  105. assert.Equal(t, HTTP2, selector.Current())
  106. fetcher.err = fmt.Errorf("failed to fetch")
  107. assert.Equal(t, HTTP2, selector.Current())
  108. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
  109. fetcher.err = nil
  110. assert.Equal(t, QUIC, selector.Current())
  111. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
  112. assert.Equal(t, QUIC, selector.Current())
  113. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
  114. assert.Equal(t, QUIC, selector.Current())
  115. }
  116. func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
  117. fetcher := dynamicMockFetcher{}
  118. // Since the user chooses http2 on purpose, we always stick to it.
  119. selector, err := NewProtocolSelector(HTTP2.String(), testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log)
  120. assert.NoError(t, err)
  121. assert.Equal(t, HTTP2, selector.Current())
  122. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
  123. assert.Equal(t, HTTP2, selector.Current())
  124. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
  125. assert.Equal(t, HTTP2, selector.Current())
  126. fetcher.err = fmt.Errorf("failed to fetch")
  127. assert.Equal(t, HTTP2, selector.Current())
  128. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
  129. fetcher.err = nil
  130. assert.Equal(t, HTTP2, selector.Current())
  131. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
  132. assert.Equal(t, HTTP2, selector.Current())
  133. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
  134. assert.Equal(t, HTTP2, selector.Current())
  135. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
  136. assert.Equal(t, HTTP2, selector.Current())
  137. }
  138. func TestAutoProtocolSelectorNoRefreshWithToken(t *testing.T) {
  139. fetcher := dynamicMockFetcher{}
  140. selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, true, false, fetcher.fetch(), testNoTTL, &log)
  141. assert.NoError(t, err)
  142. assert.Equal(t, QUIC, selector.Current())
  143. fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
  144. assert.Equal(t, QUIC, selector.Current())
  145. }