protocol_test.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. package connection
  2. import (
  3. "fmt"
  4. "testing"
  5. "time"
  6. "github.com/stretchr/testify/assert"
  7. )
  8. const (
  9. testNoTTL = 0
  10. noWarpRoutingEnabled = false
  11. )
  12. var (
  13. testNamedTunnelConfig = &NamedTunnelConfig{
  14. Credentials: Credentials{
  15. AccountTag: "testAccountTag",
  16. },
  17. }
  18. )
  19. func mockFetcher(percentage int32) PercentageFetcher {
  20. return func() (int32, error) {
  21. return percentage, nil
  22. }
  23. }
  24. func mockFetcherWithError() PercentageFetcher {
  25. return func() (int32, error) {
  26. return 0, fmt.Errorf("failed to fetch precentage")
  27. }
  28. }
  29. type dynamicMockFetcher struct {
  30. percentage int32
  31. err error
  32. }
  33. func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
  34. return func() (int32, error) {
  35. if dmf.err != nil {
  36. return 0, dmf.err
  37. }
  38. return dmf.percentage, nil
  39. }
  40. }
  41. func TestNewProtocolSelector(t *testing.T) {
  42. tests := []struct {
  43. name string
  44. protocol string
  45. expectedProtocol Protocol
  46. hasFallback bool
  47. expectedFallback Protocol
  48. warpRoutingEnabled bool
  49. namedTunnelConfig *NamedTunnelConfig
  50. fetchFunc PercentageFetcher
  51. wantErr bool
  52. }{
  53. {
  54. name: "classic tunnel",
  55. protocol: "h2mux",
  56. expectedProtocol: H2mux,
  57. namedTunnelConfig: nil,
  58. },
  59. {
  60. name: "named tunnel over h2mux",
  61. protocol: "h2mux",
  62. expectedProtocol: H2mux,
  63. namedTunnelConfig: testNamedTunnelConfig,
  64. },
  65. {
  66. name: "named tunnel over http2",
  67. protocol: "http2",
  68. expectedProtocol: HTTP2,
  69. hasFallback: true,
  70. expectedFallback: H2mux,
  71. fetchFunc: mockFetcher(0),
  72. namedTunnelConfig: testNamedTunnelConfig,
  73. },
  74. {
  75. name: "named tunnel http2 disabled",
  76. protocol: "http2",
  77. expectedProtocol: H2mux,
  78. fetchFunc: mockFetcher(-1),
  79. namedTunnelConfig: testNamedTunnelConfig,
  80. },
  81. {
  82. name: "named tunnel auto all http2 disabled",
  83. protocol: "auto",
  84. expectedProtocol: H2mux,
  85. fetchFunc: mockFetcher(-1),
  86. namedTunnelConfig: testNamedTunnelConfig,
  87. },
  88. {
  89. name: "named tunnel auto to h2mux",
  90. protocol: "auto",
  91. expectedProtocol: H2mux,
  92. fetchFunc: mockFetcher(0),
  93. namedTunnelConfig: testNamedTunnelConfig,
  94. },
  95. {
  96. name: "named tunnel auto to http2",
  97. protocol: "auto",
  98. expectedProtocol: HTTP2,
  99. hasFallback: true,
  100. expectedFallback: H2mux,
  101. fetchFunc: mockFetcher(100),
  102. namedTunnelConfig: testNamedTunnelConfig,
  103. },
  104. {
  105. name: "warp routing requesting h2mux",
  106. protocol: "h2mux",
  107. expectedProtocol: HTTP2,
  108. hasFallback: false,
  109. expectedFallback: H2mux,
  110. fetchFunc: mockFetcher(100),
  111. warpRoutingEnabled: true,
  112. namedTunnelConfig: testNamedTunnelConfig,
  113. },
  114. {
  115. name: "warp routing http2",
  116. protocol: "http2",
  117. expectedProtocol: HTTP2,
  118. hasFallback: false,
  119. expectedFallback: H2mux,
  120. fetchFunc: mockFetcher(100),
  121. warpRoutingEnabled: true,
  122. namedTunnelConfig: testNamedTunnelConfig,
  123. },
  124. {
  125. name: "warp routing auto",
  126. protocol: "auto",
  127. expectedProtocol: HTTP2,
  128. hasFallback: false,
  129. expectedFallback: H2mux,
  130. fetchFunc: mockFetcher(100),
  131. warpRoutingEnabled: true,
  132. namedTunnelConfig: testNamedTunnelConfig,
  133. },
  134. {
  135. // None named tunnel can only use h2mux, so specifying an unknown protocol is not an error
  136. name: "classic tunnel unknown protocol",
  137. protocol: "unknown",
  138. expectedProtocol: H2mux,
  139. },
  140. {
  141. name: "named tunnel unknown protocol",
  142. protocol: "unknown",
  143. fetchFunc: mockFetcher(100),
  144. namedTunnelConfig: testNamedTunnelConfig,
  145. wantErr: true,
  146. },
  147. {
  148. name: "named tunnel fetch error",
  149. protocol: "unknown",
  150. fetchFunc: mockFetcherWithError(),
  151. namedTunnelConfig: testNamedTunnelConfig,
  152. wantErr: true,
  153. },
  154. }
  155. for _, test := range tests {
  156. selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
  157. if test.wantErr {
  158. assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
  159. } else {
  160. assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
  161. assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
  162. fallback, ok := selector.Fallback()
  163. assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
  164. if test.hasFallback {
  165. assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
  166. }
  167. }
  168. }
  169. }
  170. func TestAutoProtocolSelectorRefresh(t *testing.T) {
  171. fetcher := dynamicMockFetcher{}
  172. selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
  173. assert.NoError(t, err)
  174. assert.Equal(t, H2mux, selector.Current())
  175. fetcher.percentage = 100
  176. assert.Equal(t, HTTP2, selector.Current())
  177. fetcher.percentage = 0
  178. assert.Equal(t, H2mux, selector.Current())
  179. fetcher.percentage = 100
  180. assert.Equal(t, HTTP2, selector.Current())
  181. fetcher.err = fmt.Errorf("failed to fetch")
  182. assert.Equal(t, HTTP2, selector.Current())
  183. fetcher.percentage = -1
  184. fetcher.err = nil
  185. assert.Equal(t, H2mux, selector.Current())
  186. fetcher.percentage = 0
  187. assert.Equal(t, H2mux, selector.Current())
  188. fetcher.percentage = 100
  189. assert.Equal(t, HTTP2, selector.Current())
  190. }
  191. func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
  192. fetcher := dynamicMockFetcher{}
  193. selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
  194. assert.NoError(t, err)
  195. assert.Equal(t, HTTP2, selector.Current())
  196. fetcher.percentage = 100
  197. assert.Equal(t, HTTP2, selector.Current())
  198. fetcher.percentage = 0
  199. assert.Equal(t, HTTP2, selector.Current())
  200. fetcher.err = fmt.Errorf("failed to fetch")
  201. assert.Equal(t, HTTP2, selector.Current())
  202. fetcher.percentage = -1
  203. fetcher.err = nil
  204. assert.Equal(t, H2mux, selector.Current())
  205. fetcher.percentage = 0
  206. assert.Equal(t, HTTP2, selector.Current())
  207. fetcher.percentage = 100
  208. assert.Equal(t, HTTP2, selector.Current())
  209. fetcher.percentage = -1
  210. assert.Equal(t, H2mux, selector.Current())
  211. }
  212. func TestProtocolSelectorRefreshTTL(t *testing.T) {
  213. fetcher := dynamicMockFetcher{percentage: 100}
  214. selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log)
  215. assert.NoError(t, err)
  216. assert.Equal(t, HTTP2, selector.Current())
  217. fetcher.percentage = 0
  218. assert.Equal(t, HTTP2, selector.Current())
  219. }