protocol.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. package connection
  2. import (
  3. "fmt"
  4. "hash/fnv"
  5. "sync"
  6. "time"
  7. "github.com/rs/zerolog"
  8. "github.com/cloudflare/cloudflared/edgediscovery"
  9. )
  10. const (
  11. AvailableProtocolFlagMessage = "Available protocols: 'auto' - automatically chooses the best protocol over time (the default; and also the recommended one); 'quic' - based on QUIC, relying on UDP egress to Cloudflare edge; 'http2' - using Go's HTTP2 library, relying on TCP egress to Cloudflare edge"
  12. // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge (unused, but kept for legacy reference).
  13. _ = "cftunnel.com"
  14. // edgeH2TLSServerName is the server name to establish http2 connection with edge
  15. edgeH2TLSServerName = "h2.cftunnel.com"
  16. // edgeQUICServerName is the server name to establish quic connection with edge.
  17. edgeQUICServerName = "quic.cftunnel.com"
  18. AutoSelectFlag = "auto"
  19. // SRV and TXT record resolution TTL
  20. ResolveTTL = time.Hour
  21. )
  22. // ProtocolList represents a list of supported protocols for communication with the edge
  23. // in order of precedence for remote percentage fetcher.
  24. var ProtocolList = []Protocol{QUIC, HTTP2}
  25. type Protocol int64
  26. const (
  27. // HTTP2 using golang HTTP2 library for edge connections.
  28. HTTP2 Protocol = iota
  29. // QUIC using quic-go for edge connections.
  30. QUIC
  31. )
  32. // Fallback returns the fallback protocol and whether the protocol has a fallback
  33. func (p Protocol) fallback() (Protocol, bool) {
  34. switch p {
  35. case HTTP2:
  36. return 0, false
  37. case QUIC:
  38. return HTTP2, true
  39. default:
  40. return 0, false
  41. }
  42. }
  43. func (p Protocol) String() string {
  44. switch p {
  45. case HTTP2:
  46. return "http2"
  47. case QUIC:
  48. return "quic"
  49. default:
  50. return "unknown protocol"
  51. }
  52. }
  53. func (p Protocol) TLSSettings() *TLSSettings {
  54. switch p {
  55. case HTTP2:
  56. return &TLSSettings{
  57. ServerName: edgeH2TLSServerName,
  58. }
  59. case QUIC:
  60. return &TLSSettings{
  61. ServerName: edgeQUICServerName,
  62. NextProtos: []string{"argotunnel"},
  63. }
  64. default:
  65. return nil
  66. }
  67. }
  68. type TLSSettings struct {
  69. ServerName string
  70. NextProtos []string
  71. }
  72. type ProtocolSelector interface {
  73. Current() Protocol
  74. Fallback() (Protocol, bool)
  75. }
  76. // staticProtocolSelector will not provide a different protocol for Fallback
  77. type staticProtocolSelector struct {
  78. current Protocol
  79. }
  80. func (s *staticProtocolSelector) Current() Protocol {
  81. return s.current
  82. }
  83. func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
  84. return s.current, false
  85. }
  86. // remoteProtocolSelector will fetch a list of remote protocols to provide for edge discovery
  87. type remoteProtocolSelector struct {
  88. lock sync.RWMutex
  89. current Protocol
  90. // protocolPool is desired protocols in the order of priority they should be picked in.
  91. protocolPool []Protocol
  92. switchThreshold int32
  93. fetchFunc edgediscovery.PercentageFetcher
  94. refreshAfter time.Time
  95. ttl time.Duration
  96. log *zerolog.Logger
  97. }
  98. func newRemoteProtocolSelector(
  99. current Protocol,
  100. protocolPool []Protocol,
  101. switchThreshold int32,
  102. fetchFunc edgediscovery.PercentageFetcher,
  103. ttl time.Duration,
  104. log *zerolog.Logger,
  105. ) *remoteProtocolSelector {
  106. return &remoteProtocolSelector{
  107. current: current,
  108. protocolPool: protocolPool,
  109. switchThreshold: switchThreshold,
  110. fetchFunc: fetchFunc,
  111. refreshAfter: time.Now().Add(ttl),
  112. ttl: ttl,
  113. log: log,
  114. }
  115. }
  116. func (s *remoteProtocolSelector) Current() Protocol {
  117. s.lock.Lock()
  118. defer s.lock.Unlock()
  119. if time.Now().Before(s.refreshAfter) {
  120. return s.current
  121. }
  122. protocol, err := getProtocol(s.protocolPool, s.fetchFunc, s.switchThreshold)
  123. if err != nil {
  124. s.log.Err(err).Msg("Failed to refresh protocol")
  125. return s.current
  126. }
  127. s.current = protocol
  128. s.refreshAfter = time.Now().Add(s.ttl)
  129. return s.current
  130. }
  131. func (s *remoteProtocolSelector) Fallback() (Protocol, bool) {
  132. s.lock.RLock()
  133. defer s.lock.RUnlock()
  134. return s.current.fallback()
  135. }
  136. func getProtocol(protocolPool []Protocol, fetchFunc edgediscovery.PercentageFetcher, switchThreshold int32) (Protocol, error) {
  137. protocolPercentages, err := fetchFunc()
  138. if err != nil {
  139. return 0, err
  140. }
  141. for _, protocol := range protocolPool {
  142. protocolPercentage := protocolPercentages.GetPercentage(protocol.String())
  143. if protocolPercentage > switchThreshold {
  144. return protocol, nil
  145. }
  146. }
  147. // Default to first index in protocolPool list
  148. return protocolPool[0], nil
  149. }
  150. // defaultProtocolSelector will allow for a protocol to have a fallback
  151. type defaultProtocolSelector struct {
  152. lock sync.RWMutex
  153. current Protocol
  154. }
  155. func newDefaultProtocolSelector(
  156. current Protocol,
  157. ) *defaultProtocolSelector {
  158. return &defaultProtocolSelector{
  159. current: current,
  160. }
  161. }
  162. func (s *defaultProtocolSelector) Current() Protocol {
  163. s.lock.Lock()
  164. defer s.lock.Unlock()
  165. return s.current
  166. }
  167. func (s *defaultProtocolSelector) Fallback() (Protocol, bool) {
  168. s.lock.RLock()
  169. defer s.lock.RUnlock()
  170. return s.current.fallback()
  171. }
  172. func NewProtocolSelector(
  173. protocolFlag string,
  174. accountTag string,
  175. tunnelTokenProvided bool,
  176. needPQ bool,
  177. protocolFetcher edgediscovery.PercentageFetcher,
  178. resolveTTL time.Duration,
  179. log *zerolog.Logger,
  180. ) (ProtocolSelector, error) {
  181. // With --post-quantum, we force quic
  182. if needPQ {
  183. return &staticProtocolSelector{
  184. current: QUIC,
  185. }, nil
  186. }
  187. threshold := switchThreshold(accountTag)
  188. fetchedProtocol, err := getProtocol(ProtocolList, protocolFetcher, threshold)
  189. log.Debug().Msgf("Fetched protocol: %s", fetchedProtocol)
  190. if err != nil {
  191. log.Warn().Msg("Unable to lookup protocol percentage.")
  192. // Falling through here since 'auto' is handled in the switch and failing
  193. // to do the protocol lookup isn't a failure since it can be triggered again
  194. // after the TTL.
  195. }
  196. // If the user picks a protocol, then we stick to it no matter what.
  197. switch protocolFlag {
  198. case "h2mux":
  199. // Any users still requesting h2mux will be upgraded to http2 instead
  200. log.Warn().Msg("h2mux is no longer a supported protocol: upgrading edge connection to http2. Please remove '--protocol h2mux' from runtime arguments to remove this warning.")
  201. return &staticProtocolSelector{current: HTTP2}, nil
  202. case QUIC.String():
  203. return &staticProtocolSelector{current: QUIC}, nil
  204. case HTTP2.String():
  205. return &staticProtocolSelector{current: HTTP2}, nil
  206. case AutoSelectFlag:
  207. // When a --token is provided, we want to start with QUIC but have fallback to HTTP2
  208. if tunnelTokenProvided {
  209. return newDefaultProtocolSelector(QUIC), nil
  210. }
  211. return newRemoteProtocolSelector(fetchedProtocol, ProtocolList, threshold, protocolFetcher, resolveTTL, log), nil
  212. }
  213. return nil, fmt.Errorf("unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
  214. }
  215. func switchThreshold(accountTag string) int32 {
  216. h := fnv.New32a()
  217. _, _ = h.Write([]byte(accountTag))
  218. return int32(h.Sum32() % 100) // nolint: gosec
  219. }