supervisor.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. package supervisor
  2. import (
  3. "context"
  4. "errors"
  5. "net"
  6. "strings"
  7. "time"
  8. "github.com/prometheus/client_golang/prometheus"
  9. "github.com/quic-go/quic-go"
  10. "github.com/rs/zerolog"
  11. "github.com/cloudflare/cloudflared/connection"
  12. "github.com/cloudflare/cloudflared/edgediscovery"
  13. "github.com/cloudflare/cloudflared/ingress"
  14. "github.com/cloudflare/cloudflared/orchestration"
  15. v3 "github.com/cloudflare/cloudflared/quic/v3"
  16. "github.com/cloudflare/cloudflared/retry"
  17. "github.com/cloudflare/cloudflared/signal"
  18. "github.com/cloudflare/cloudflared/tunnelstate"
  19. )
  20. const (
  21. // Waiting time before retrying a failed tunnel connection
  22. tunnelRetryDuration = time.Second * 10
  23. // Interval between registering new tunnels
  24. registrationInterval = time.Second
  25. )
  26. // Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
  27. // reconnects them if they disconnect.
  28. type Supervisor struct {
  29. config *TunnelConfig
  30. orchestrator *orchestration.Orchestrator
  31. edgeIPs *edgediscovery.Edge
  32. edgeTunnelServer TunnelServer
  33. tunnelErrors chan tunnelError
  34. tunnelsConnecting map[int]chan struct{}
  35. tunnelsProtocolFallback map[int]*protocolFallback
  36. // nextConnectedIndex and nextConnectedSignal are used to wait for all
  37. // currently-connecting tunnels to finish connecting so we can reset backoff timer
  38. nextConnectedIndex int
  39. nextConnectedSignal chan struct{}
  40. log *ConnAwareLogger
  41. logTransport *zerolog.Logger
  42. reconnectCh chan ReconnectSignal
  43. gracefulShutdownC <-chan struct{}
  44. }
  45. var errEarlyShutdown = errors.New("shutdown started")
  46. type tunnelError struct {
  47. index int
  48. err error
  49. }
  50. func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
  51. isStaticEdge := len(config.EdgeAddrs) > 0
  52. var err error
  53. var edgeIPs *edgediscovery.Edge
  54. if isStaticEdge { // static edge addresses
  55. edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
  56. } else {
  57. edgeIPs, err = edgediscovery.ResolveEdge(config.Log, config.Region, config.EdgeIPVersion)
  58. }
  59. if err != nil {
  60. return nil, err
  61. }
  62. tracker := tunnelstate.NewConnTracker(config.Log)
  63. log := NewConnAwareLogger(config.Log, tracker, config.Observer)
  64. edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries)
  65. edgeBindAddr := config.EdgeBindAddr
  66. datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
  67. sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetFlowLimiter())
  68. edgeTunnelServer := EdgeTunnelServer{
  69. config: config,
  70. orchestrator: orchestrator,
  71. sessionManager: sessionManager,
  72. datagramMetrics: datagramMetrics,
  73. edgeAddrs: edgeIPs,
  74. edgeAddrHandler: edgeAddrHandler,
  75. edgeBindAddr: edgeBindAddr,
  76. tracker: tracker,
  77. reconnectCh: reconnectCh,
  78. gracefulShutdownC: gracefulShutdownC,
  79. connAwareLogger: log,
  80. }
  81. return &Supervisor{
  82. config: config,
  83. orchestrator: orchestrator,
  84. edgeIPs: edgeIPs,
  85. edgeTunnelServer: &edgeTunnelServer,
  86. tunnelErrors: make(chan tunnelError),
  87. tunnelsConnecting: map[int]chan struct{}{},
  88. tunnelsProtocolFallback: map[int]*protocolFallback{},
  89. log: log,
  90. logTransport: config.LogTransport,
  91. reconnectCh: reconnectCh,
  92. gracefulShutdownC: gracefulShutdownC,
  93. }, nil
  94. }
  95. func (s *Supervisor) Run(
  96. ctx context.Context,
  97. connectedSignal *signal.Signal,
  98. ) error {
  99. if s.config.ICMPRouterServer != nil {
  100. go func() {
  101. if err := s.config.ICMPRouterServer.Serve(ctx); err != nil {
  102. if errors.Is(err, net.ErrClosed) {
  103. s.log.Logger().Info().Err(err).Msg("icmp router terminated")
  104. } else {
  105. s.log.Logger().Err(err).Msg("icmp router terminated")
  106. }
  107. }
  108. }()
  109. }
  110. if err := s.initialize(ctx, connectedSignal); err != nil {
  111. if err == errEarlyShutdown {
  112. return nil
  113. }
  114. return err
  115. }
  116. var tunnelsWaiting []int
  117. tunnelsActive := s.config.HAConnections
  118. backoff := retry.NewBackoff(s.config.Retries, tunnelRetryDuration, true)
  119. var backoffTimer <-chan time.Time
  120. shuttingDown := false
  121. for {
  122. select {
  123. // Context cancelled
  124. case <-ctx.Done():
  125. for tunnelsActive > 0 {
  126. <-s.tunnelErrors
  127. tunnelsActive--
  128. }
  129. return nil
  130. // startTunnel completed with a response
  131. // (note that this may also be caused by context cancellation)
  132. case tunnelError := <-s.tunnelErrors:
  133. tunnelsActive--
  134. if tunnelError.err != nil && !shuttingDown {
  135. switch tunnelError.err.(type) {
  136. case ReconnectSignal:
  137. // For tunnels that closed with reconnect signal, we reconnect immediately
  138. go s.startTunnel(ctx, tunnelError.index, s.newConnectedTunnelSignal(tunnelError.index))
  139. tunnelsActive++
  140. continue
  141. }
  142. // Make sure we don't continue if there is no more fallback allowed
  143. if _, retry := s.tunnelsProtocolFallback[tunnelError.index].GetMaxBackoffDuration(ctx); !retry {
  144. continue
  145. }
  146. s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
  147. tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
  148. s.waitForNextTunnel(tunnelError.index)
  149. if backoffTimer == nil {
  150. backoffTimer = backoff.BackoffTimer()
  151. }
  152. } else if tunnelsActive == 0 {
  153. s.log.ConnAwareLogger().Msg("no more connections active and exiting")
  154. // All connected tunnels exited gracefully, no more work to do
  155. return nil
  156. }
  157. // Backoff was set and its timer expired
  158. case <-backoffTimer:
  159. backoffTimer = nil
  160. for _, index := range tunnelsWaiting {
  161. go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
  162. }
  163. tunnelsActive += len(tunnelsWaiting)
  164. tunnelsWaiting = nil
  165. // Tunnel successfully connected
  166. case <-s.nextConnectedSignal:
  167. if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
  168. // No more tunnels outstanding, clear backoff timer
  169. backoff.SetGracePeriod()
  170. }
  171. case <-s.gracefulShutdownC:
  172. shuttingDown = true
  173. }
  174. }
  175. }
  176. // Returns nil if initialization succeeded, else the initialization error.
  177. // Attempts here will be made to connect one tunnel, if successful, it will
  178. // connect the available tunnels up to config.HAConnections.
  179. func (s *Supervisor) initialize(
  180. ctx context.Context,
  181. connectedSignal *signal.Signal,
  182. ) error {
  183. availableAddrs := s.edgeIPs.AvailableAddrs()
  184. if s.config.HAConnections > availableAddrs {
  185. s.log.Logger().Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
  186. s.config.HAConnections = availableAddrs
  187. }
  188. s.tunnelsProtocolFallback[0] = &protocolFallback{
  189. retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
  190. s.config.ProtocolSelector.Current(),
  191. false,
  192. }
  193. go s.startFirstTunnel(ctx, connectedSignal)
  194. // Wait for response from first tunnel before proceeding to attempt other HA edge tunnels
  195. select {
  196. case <-ctx.Done():
  197. <-s.tunnelErrors
  198. return ctx.Err()
  199. case tunnelError := <-s.tunnelErrors:
  200. return tunnelError.err
  201. case <-s.gracefulShutdownC:
  202. return errEarlyShutdown
  203. case <-connectedSignal.Wait():
  204. }
  205. // At least one successful connection, so start the rest
  206. for i := 1; i < s.config.HAConnections; i++ {
  207. s.tunnelsProtocolFallback[i] = &protocolFallback{
  208. retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
  209. // Set the protocol we know the first tunnel connected with.
  210. s.tunnelsProtocolFallback[0].protocol,
  211. false,
  212. }
  213. go s.startTunnel(ctx, i, s.newConnectedTunnelSignal(i))
  214. time.Sleep(registrationInterval)
  215. }
  216. return nil
  217. }
  218. // startTunnel starts the first tunnel connection. The resulting error will be sent on
  219. // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
  220. func (s *Supervisor) startFirstTunnel(
  221. ctx context.Context,
  222. connectedSignal *signal.Signal,
  223. ) {
  224. var (
  225. err error
  226. )
  227. const firstConnIndex = 0
  228. isStaticEdge := len(s.config.EdgeAddrs) > 0
  229. defer func() {
  230. s.tunnelErrors <- tunnelError{index: firstConnIndex, err: err}
  231. }()
  232. // If the first tunnel disconnects, keep restarting it.
  233. for {
  234. err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, s.tunnelsProtocolFallback[firstConnIndex], connectedSignal)
  235. if ctx.Err() != nil {
  236. return
  237. }
  238. if err == nil {
  239. return
  240. }
  241. // Make sure we don't continue if there is no more fallback allowed
  242. if _, retry := s.tunnelsProtocolFallback[firstConnIndex].GetMaxBackoffDuration(ctx); !retry {
  243. return
  244. }
  245. // Try again for Unauthorized errors because we hope them to be
  246. // transient due to edge propagation lag on new Tunnels.
  247. if strings.Contains(err.Error(), "Unauthorized") {
  248. continue
  249. }
  250. switch err.(type) {
  251. case edgediscovery.ErrNoAddressesLeft:
  252. // If your provided addresses are not available, we will keep trying regardless.
  253. if !isStaticEdge {
  254. return
  255. }
  256. case connection.DupConnRegisterTunnelError,
  257. *quic.IdleTimeoutError,
  258. *quic.ApplicationError,
  259. edgediscovery.DialError,
  260. *connection.EdgeQuicDialError:
  261. // Try again for these types of errors
  262. default:
  263. // Uncaught errors should bail startup
  264. return
  265. }
  266. }
  267. }
  268. // startTunnel starts a new tunnel connection. The resulting error will be sent on
  269. // s.tunnelError as this is expected to run in a goroutine.
  270. func (s *Supervisor) startTunnel(
  271. ctx context.Context,
  272. index int,
  273. connectedSignal *signal.Signal,
  274. ) {
  275. var (
  276. err error
  277. )
  278. defer func() {
  279. s.tunnelErrors <- tunnelError{index: index, err: err}
  280. }()
  281. // nolint: gosec
  282. err = s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
  283. }
  284. func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
  285. sig := make(chan struct{})
  286. s.tunnelsConnecting[index] = sig
  287. s.nextConnectedSignal = sig
  288. s.nextConnectedIndex = index
  289. return signal.New(sig)
  290. }
  291. func (s *Supervisor) waitForNextTunnel(index int) bool {
  292. delete(s.tunnelsConnecting, index)
  293. s.nextConnectedSignal = nil
  294. for k, v := range s.tunnelsConnecting {
  295. s.nextConnectedIndex = k
  296. s.nextConnectedSignal = v
  297. return true
  298. }
  299. return false
  300. }