protocol.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. package connection
  2. import (
  3. "fmt"
  4. "hash/fnv"
  5. "sync"
  6. "time"
  7. "github.com/rs/zerolog"
  8. )
  9. const (
  10. AvailableProtocolFlagMessage = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux"
  11. // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
  12. edgeH2muxTLSServerName = "cftunnel.com"
  13. // edgeH2TLSServerName is the server name to establish http2 connection with edge
  14. edgeH2TLSServerName = "h2.cftunnel.com"
  15. // threshold to switch back to h2mux when the user intentionally pick --protocol http2
  16. explicitHTTP2FallbackThreshold = -1
  17. autoSelectFlag = "auto"
  18. )
  19. var (
  20. ProtocolList = []Protocol{H2mux, HTTP2}
  21. )
  22. type Protocol int64
  23. const (
  24. H2mux Protocol = iota
  25. HTTP2
  26. )
  27. func (p Protocol) ServerName() string {
  28. switch p {
  29. case H2mux:
  30. return edgeH2muxTLSServerName
  31. case HTTP2:
  32. return edgeH2TLSServerName
  33. default:
  34. return ""
  35. }
  36. }
  37. // Fallback returns the fallback protocol and whether the protocol has a fallback
  38. func (p Protocol) fallback() (Protocol, bool) {
  39. switch p {
  40. case H2mux:
  41. return 0, false
  42. case HTTP2:
  43. return H2mux, true
  44. default:
  45. return 0, false
  46. }
  47. }
  48. func (p Protocol) String() string {
  49. switch p {
  50. case H2mux:
  51. return "h2mux"
  52. case HTTP2:
  53. return "http2"
  54. default:
  55. return fmt.Sprintf("unknown protocol")
  56. }
  57. }
  58. type ProtocolSelector interface {
  59. Current() Protocol
  60. Fallback() (Protocol, bool)
  61. }
  62. type staticProtocolSelector struct {
  63. current Protocol
  64. }
  65. func (s *staticProtocolSelector) Current() Protocol {
  66. return s.current
  67. }
  68. func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
  69. return 0, false
  70. }
  71. type autoProtocolSelector struct {
  72. lock sync.RWMutex
  73. current Protocol
  74. switchThrehold int32
  75. fetchFunc PercentageFetcher
  76. refreshAfter time.Time
  77. ttl time.Duration
  78. log *zerolog.Logger
  79. }
  80. func newAutoProtocolSelector(
  81. current Protocol,
  82. switchThrehold int32,
  83. fetchFunc PercentageFetcher,
  84. ttl time.Duration,
  85. log *zerolog.Logger,
  86. ) *autoProtocolSelector {
  87. return &autoProtocolSelector{
  88. current: current,
  89. switchThrehold: switchThrehold,
  90. fetchFunc: fetchFunc,
  91. refreshAfter: time.Now().Add(ttl),
  92. ttl: ttl,
  93. log: log,
  94. }
  95. }
  96. func (s *autoProtocolSelector) Current() Protocol {
  97. s.lock.Lock()
  98. defer s.lock.Unlock()
  99. if time.Now().Before(s.refreshAfter) {
  100. return s.current
  101. }
  102. percentage, err := s.fetchFunc()
  103. if err != nil {
  104. s.log.Err(err).Msg("Failed to refresh protocol")
  105. return s.current
  106. }
  107. if s.switchThrehold < percentage {
  108. s.current = HTTP2
  109. } else {
  110. s.current = H2mux
  111. }
  112. s.refreshAfter = time.Now().Add(s.ttl)
  113. return s.current
  114. }
  115. func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
  116. s.lock.RLock()
  117. defer s.lock.RUnlock()
  118. return s.current.fallback()
  119. }
  120. type PercentageFetcher func() (int32, error)
  121. func NewProtocolSelector(
  122. protocolFlag string,
  123. warpRoutingEnabled bool,
  124. namedTunnel *NamedTunnelConfig,
  125. fetchFunc PercentageFetcher,
  126. ttl time.Duration,
  127. log *zerolog.Logger,
  128. ) (ProtocolSelector, error) {
  129. // Classic tunnel is only supported with h2mux
  130. if namedTunnel == nil {
  131. return &staticProtocolSelector{
  132. current: H2mux,
  133. }, nil
  134. }
  135. // warp routing can only be served over http2 connections
  136. if warpRoutingEnabled {
  137. if protocolFlag == H2mux.String() {
  138. log.Warn().Msg("Warp routing is only supported by http2 protocol. Upgrading protocol to http2")
  139. }
  140. return &staticProtocolSelector{
  141. current: HTTP2,
  142. }, nil
  143. }
  144. if protocolFlag == H2mux.String() {
  145. return &staticProtocolSelector{
  146. current: H2mux,
  147. }, nil
  148. }
  149. http2Percentage, err := fetchFunc()
  150. if err != nil {
  151. return nil, err
  152. }
  153. if protocolFlag == HTTP2.String() {
  154. if http2Percentage < 0 {
  155. return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
  156. }
  157. return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
  158. }
  159. if protocolFlag != autoSelectFlag {
  160. return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
  161. }
  162. threshold := switchThreshold(namedTunnel.Credentials.AccountTag)
  163. if threshold < http2Percentage {
  164. return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, log), nil
  165. }
  166. return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, log), nil
  167. }
  168. func switchThreshold(accountTag string) int32 {
  169. h := fnv.New32a()
  170. _, _ = h.Write([]byte(accountTag))
  171. return int32(h.Sum32() % 100)
  172. }