protocol.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. package connection
  2. import (
  3. "fmt"
  4. "hash/fnv"
  5. "sync"
  6. "time"
  7. "github.com/rs/zerolog"
  8. )
  9. const (
  10. AvailableProtocolFlagMessage = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux"
  11. // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
  12. edgeH2muxTLSServerName = "cftunnel.com"
  13. // edgeH2TLSServerName is the server name to establish http2 connection with edge
  14. edgeH2TLSServerName = "h2.cftunnel.com"
  15. // threshold to switch back to h2mux when the user intentionally pick --protocol http2
  16. explicitHTTP2FallbackThreshold = -1
  17. autoSelectFlag = "auto"
  18. )
  19. var (
  20. ProtocolList = []Protocol{H2mux, HTTP2}
  21. )
  22. type Protocol int64
  23. const (
  24. H2mux Protocol = iota
  25. HTTP2
  26. )
  27. func (p Protocol) ServerName() string {
  28. switch p {
  29. case H2mux:
  30. return edgeH2muxTLSServerName
  31. case HTTP2:
  32. return edgeH2TLSServerName
  33. default:
  34. return ""
  35. }
  36. }
  37. // Fallback returns the fallback protocol and whether the protocol has a fallback
  38. func (p Protocol) fallback() (Protocol, bool) {
  39. switch p {
  40. case H2mux:
  41. return 0, false
  42. case HTTP2:
  43. return H2mux, true
  44. default:
  45. return 0, false
  46. }
  47. }
  48. func (p Protocol) String() string {
  49. switch p {
  50. case H2mux:
  51. return "h2mux"
  52. case HTTP2:
  53. return "http2"
  54. default:
  55. return fmt.Sprintf("unknown protocol")
  56. }
  57. }
  58. type ProtocolSelector interface {
  59. Current() Protocol
  60. Fallback() (Protocol, bool)
  61. }
  62. type staticProtocolSelector struct {
  63. current Protocol
  64. }
  65. func (s *staticProtocolSelector) Current() Protocol {
  66. return s.current
  67. }
  68. func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
  69. return 0, false
  70. }
  71. type autoProtocolSelector struct {
  72. lock sync.RWMutex
  73. current Protocol
  74. switchThrehold int32
  75. fetchFunc PercentageFetcher
  76. refreshAfter time.Time
  77. ttl time.Duration
  78. log *zerolog.Logger
  79. }
  80. func newAutoProtocolSelector(
  81. current Protocol,
  82. switchThrehold int32,
  83. fetchFunc PercentageFetcher,
  84. ttl time.Duration,
  85. log *zerolog.Logger,
  86. ) *autoProtocolSelector {
  87. return &autoProtocolSelector{
  88. current: current,
  89. switchThrehold: switchThrehold,
  90. fetchFunc: fetchFunc,
  91. refreshAfter: time.Now().Add(ttl),
  92. ttl: ttl,
  93. log: log,
  94. }
  95. }
  96. func (s *autoProtocolSelector) Current() Protocol {
  97. s.lock.Lock()
  98. defer s.lock.Unlock()
  99. if time.Now().Before(s.refreshAfter) {
  100. return s.current
  101. }
  102. percentage, err := s.fetchFunc()
  103. if err != nil {
  104. s.log.Err(err).Msg("Failed to refresh protocol")
  105. return s.current
  106. }
  107. if s.switchThrehold < percentage {
  108. s.current = HTTP2
  109. } else {
  110. s.current = H2mux
  111. }
  112. s.refreshAfter = time.Now().Add(s.ttl)
  113. return s.current
  114. }
  115. func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
  116. s.lock.RLock()
  117. defer s.lock.RUnlock()
  118. return s.current.fallback()
  119. }
  120. type PercentageFetcher func() (int32, error)
  121. func NewProtocolSelector(
  122. protocolFlag string,
  123. namedTunnel *NamedTunnelConfig,
  124. fetchFunc PercentageFetcher,
  125. ttl time.Duration,
  126. log *zerolog.Logger,
  127. ) (ProtocolSelector, error) {
  128. if namedTunnel == nil {
  129. return &staticProtocolSelector{
  130. current: H2mux,
  131. }, nil
  132. }
  133. if protocolFlag == H2mux.String() {
  134. return &staticProtocolSelector{
  135. current: H2mux,
  136. }, nil
  137. }
  138. http2Percentage, err := fetchFunc()
  139. if err != nil {
  140. return nil, err
  141. }
  142. if protocolFlag == HTTP2.String() {
  143. if http2Percentage < 0 {
  144. return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
  145. }
  146. return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
  147. }
  148. if protocolFlag != autoSelectFlag {
  149. return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
  150. }
  151. threshold := switchThreshold(namedTunnel.Credentials.AccountTag)
  152. if threshold < http2Percentage {
  153. return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, log), nil
  154. }
  155. return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, log), nil
  156. }
  157. func switchThreshold(accountTag string) int32 {
  158. h := fnv.New32a()
  159. _, _ = h.Write([]byte(accountTag))
  160. return int32(h.Sum32() % 100)
  161. }