protocol_test.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package connection
  2. import (
  3. "fmt"
  4. "testing"
  5. "time"
  6. "github.com/stretchr/testify/assert"
  7. )
  8. const (
  9. testNoTTL = 0
  10. )
  11. var (
  12. testNamedTunnelConfig = &NamedTunnelConfig{
  13. Credentials: Credentials{
  14. AccountTag: "testAccountTag",
  15. },
  16. }
  17. )
  18. func mockFetcher(percentage int32) PercentageFetcher {
  19. return func() (int32, error) {
  20. return percentage, nil
  21. }
  22. }
  23. func mockFetcherWithError() PercentageFetcher {
  24. return func() (int32, error) {
  25. return 0, fmt.Errorf("failed to fetch precentage")
  26. }
  27. }
  28. type dynamicMockFetcher struct {
  29. percentage int32
  30. err error
  31. }
  32. func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
  33. return func() (int32, error) {
  34. if dmf.err != nil {
  35. return 0, dmf.err
  36. }
  37. return dmf.percentage, nil
  38. }
  39. }
  40. func TestNewProtocolSelector(t *testing.T) {
  41. tests := []struct {
  42. name string
  43. protocol string
  44. expectedProtocol Protocol
  45. hasFallback bool
  46. expectedFallback Protocol
  47. namedTunnelConfig *NamedTunnelConfig
  48. fetchFunc PercentageFetcher
  49. wantErr bool
  50. }{
  51. {
  52. name: "classic tunnel",
  53. protocol: "h2mux",
  54. expectedProtocol: H2mux,
  55. namedTunnelConfig: nil,
  56. },
  57. {
  58. name: "named tunnel over h2mux",
  59. protocol: "h2mux",
  60. expectedProtocol: H2mux,
  61. namedTunnelConfig: testNamedTunnelConfig,
  62. },
  63. {
  64. name: "named tunnel over http2",
  65. protocol: "http2",
  66. expectedProtocol: HTTP2,
  67. hasFallback: true,
  68. expectedFallback: H2mux,
  69. fetchFunc: mockFetcher(0),
  70. namedTunnelConfig: testNamedTunnelConfig,
  71. },
  72. {
  73. name: "named tunnel http2 disabled",
  74. protocol: "http2",
  75. expectedProtocol: H2mux,
  76. fetchFunc: mockFetcher(-1),
  77. namedTunnelConfig: testNamedTunnelConfig,
  78. },
  79. {
  80. name: "named tunnel auto all http2 disabled",
  81. protocol: "auto",
  82. expectedProtocol: H2mux,
  83. fetchFunc: mockFetcher(-1),
  84. namedTunnelConfig: testNamedTunnelConfig,
  85. },
  86. {
  87. name: "named tunnel auto to h2mux",
  88. protocol: "auto",
  89. expectedProtocol: H2mux,
  90. fetchFunc: mockFetcher(0),
  91. namedTunnelConfig: testNamedTunnelConfig,
  92. },
  93. {
  94. name: "named tunnel auto to http2",
  95. protocol: "auto",
  96. expectedProtocol: HTTP2,
  97. hasFallback: true,
  98. expectedFallback: H2mux,
  99. fetchFunc: mockFetcher(100),
  100. namedTunnelConfig: testNamedTunnelConfig,
  101. },
  102. {
  103. // None named tunnel can only use h2mux, so specifying an unknown protocol is not an error
  104. name: "classic tunnel unknown protocol",
  105. protocol: "unknown",
  106. expectedProtocol: H2mux,
  107. },
  108. {
  109. name: "named tunnel unknown protocol",
  110. protocol: "unknown",
  111. fetchFunc: mockFetcher(100),
  112. namedTunnelConfig: testNamedTunnelConfig,
  113. wantErr: true,
  114. },
  115. {
  116. name: "named tunnel fetch error",
  117. protocol: "unknown",
  118. fetchFunc: mockFetcherWithError(),
  119. namedTunnelConfig: testNamedTunnelConfig,
  120. wantErr: true,
  121. },
  122. }
  123. for _, test := range tests {
  124. selector, err := NewProtocolSelector(test.protocol, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
  125. if test.wantErr {
  126. assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
  127. } else {
  128. assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
  129. assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
  130. fallback, ok := selector.Fallback()
  131. assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
  132. if test.hasFallback {
  133. assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
  134. }
  135. }
  136. }
  137. }
  138. func TestAutoProtocolSelectorRefresh(t *testing.T) {
  139. fetcher := dynamicMockFetcher{}
  140. selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
  141. assert.NoError(t, err)
  142. assert.Equal(t, H2mux, selector.Current())
  143. fetcher.percentage = 100
  144. assert.Equal(t, HTTP2, selector.Current())
  145. fetcher.percentage = 0
  146. assert.Equal(t, H2mux, selector.Current())
  147. fetcher.percentage = 100
  148. assert.Equal(t, HTTP2, selector.Current())
  149. fetcher.err = fmt.Errorf("failed to fetch")
  150. assert.Equal(t, HTTP2, selector.Current())
  151. fetcher.percentage = -1
  152. fetcher.err = nil
  153. assert.Equal(t, H2mux, selector.Current())
  154. fetcher.percentage = 0
  155. assert.Equal(t, H2mux, selector.Current())
  156. fetcher.percentage = 100
  157. assert.Equal(t, HTTP2, selector.Current())
  158. }
  159. func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
  160. fetcher := dynamicMockFetcher{}
  161. selector, err := NewProtocolSelector("http2", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
  162. assert.NoError(t, err)
  163. assert.Equal(t, HTTP2, selector.Current())
  164. fetcher.percentage = 100
  165. assert.Equal(t, HTTP2, selector.Current())
  166. fetcher.percentage = 0
  167. assert.Equal(t, HTTP2, selector.Current())
  168. fetcher.err = fmt.Errorf("failed to fetch")
  169. assert.Equal(t, HTTP2, selector.Current())
  170. fetcher.percentage = -1
  171. fetcher.err = nil
  172. assert.Equal(t, H2mux, selector.Current())
  173. fetcher.percentage = 0
  174. assert.Equal(t, HTTP2, selector.Current())
  175. fetcher.percentage = 100
  176. assert.Equal(t, HTTP2, selector.Current())
  177. fetcher.percentage = -1
  178. assert.Equal(t, H2mux, selector.Current())
  179. }
  180. func TestProtocolSelectorRefreshTTL(t *testing.T) {
  181. fetcher := dynamicMockFetcher{percentage: 100}
  182. selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log)
  183. assert.NoError(t, err)
  184. assert.Equal(t, HTTP2, selector.Current())
  185. fetcher.percentage = 0
  186. assert.Equal(t, HTTP2, selector.Current())
  187. }