123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332 |
- package supervisor
- import (
- "context"
- "errors"
- "net"
- "strings"
- "time"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/quic-go/quic-go"
- "github.com/rs/zerolog"
- "github.com/cloudflare/cloudflared/connection"
- "github.com/cloudflare/cloudflared/edgediscovery"
- "github.com/cloudflare/cloudflared/ingress"
- "github.com/cloudflare/cloudflared/orchestration"
- v3 "github.com/cloudflare/cloudflared/quic/v3"
- "github.com/cloudflare/cloudflared/retry"
- "github.com/cloudflare/cloudflared/signal"
- "github.com/cloudflare/cloudflared/tunnelstate"
- )
- const (
- // Waiting time before retrying a failed tunnel connection
- tunnelRetryDuration = time.Second * 10
- // Interval between registering new tunnels
- registrationInterval = time.Second
- )
- // Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
- // reconnects them if they disconnect.
- type Supervisor struct {
- config *TunnelConfig
- orchestrator *orchestration.Orchestrator
- edgeIPs *edgediscovery.Edge
- edgeTunnelServer TunnelServer
- tunnelErrors chan tunnelError
- tunnelsConnecting map[int]chan struct{}
- tunnelsProtocolFallback map[int]*protocolFallback
- // nextConnectedIndex and nextConnectedSignal are used to wait for all
- // currently-connecting tunnels to finish connecting so we can reset backoff timer
- nextConnectedIndex int
- nextConnectedSignal chan struct{}
- log *ConnAwareLogger
- logTransport *zerolog.Logger
- reconnectCh chan ReconnectSignal
- gracefulShutdownC <-chan struct{}
- }
- var errEarlyShutdown = errors.New("shutdown started")
- type tunnelError struct {
- index int
- err error
- }
- func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
- isStaticEdge := len(config.EdgeAddrs) > 0
- var err error
- var edgeIPs *edgediscovery.Edge
- if isStaticEdge { // static edge addresses
- edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
- } else {
- edgeIPs, err = edgediscovery.ResolveEdge(config.Log, config.Region, config.EdgeIPVersion)
- }
- if err != nil {
- return nil, err
- }
- tracker := tunnelstate.NewConnTracker(config.Log)
- log := NewConnAwareLogger(config.Log, tracker, config.Observer)
- edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries)
- edgeBindAddr := config.EdgeBindAddr
- datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
- sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetFlowLimiter())
- edgeTunnelServer := EdgeTunnelServer{
- config: config,
- orchestrator: orchestrator,
- sessionManager: sessionManager,
- datagramMetrics: datagramMetrics,
- edgeAddrs: edgeIPs,
- edgeAddrHandler: edgeAddrHandler,
- edgeBindAddr: edgeBindAddr,
- tracker: tracker,
- reconnectCh: reconnectCh,
- gracefulShutdownC: gracefulShutdownC,
- connAwareLogger: log,
- }
- return &Supervisor{
- config: config,
- orchestrator: orchestrator,
- edgeIPs: edgeIPs,
- edgeTunnelServer: &edgeTunnelServer,
- tunnelErrors: make(chan tunnelError),
- tunnelsConnecting: map[int]chan struct{}{},
- tunnelsProtocolFallback: map[int]*protocolFallback{},
- log: log,
- logTransport: config.LogTransport,
- reconnectCh: reconnectCh,
- gracefulShutdownC: gracefulShutdownC,
- }, nil
- }
- func (s *Supervisor) Run(
- ctx context.Context,
- connectedSignal *signal.Signal,
- ) error {
- if s.config.ICMPRouterServer != nil {
- go func() {
- if err := s.config.ICMPRouterServer.Serve(ctx); err != nil {
- if errors.Is(err, net.ErrClosed) {
- s.log.Logger().Info().Err(err).Msg("icmp router terminated")
- } else {
- s.log.Logger().Err(err).Msg("icmp router terminated")
- }
- }
- }()
- }
- if err := s.initialize(ctx, connectedSignal); err != nil {
- if err == errEarlyShutdown {
- return nil
- }
- return err
- }
- var tunnelsWaiting []int
- tunnelsActive := s.config.HAConnections
- backoff := retry.NewBackoff(s.config.Retries, tunnelRetryDuration, true)
- var backoffTimer <-chan time.Time
- shuttingDown := false
- for {
- select {
- // Context cancelled
- case <-ctx.Done():
- for tunnelsActive > 0 {
- <-s.tunnelErrors
- tunnelsActive--
- }
- return nil
- // startTunnel completed with a response
- // (note that this may also be caused by context cancellation)
- case tunnelError := <-s.tunnelErrors:
- tunnelsActive--
- if tunnelError.err != nil && !shuttingDown {
- switch tunnelError.err.(type) {
- case ReconnectSignal:
- // For tunnels that closed with reconnect signal, we reconnect immediately
- go s.startTunnel(ctx, tunnelError.index, s.newConnectedTunnelSignal(tunnelError.index))
- tunnelsActive++
- continue
- }
- // Make sure we don't continue if there is no more fallback allowed
- if _, retry := s.tunnelsProtocolFallback[tunnelError.index].GetMaxBackoffDuration(ctx); !retry {
- continue
- }
- s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
- tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
- s.waitForNextTunnel(tunnelError.index)
- if backoffTimer == nil {
- backoffTimer = backoff.BackoffTimer()
- }
- } else if tunnelsActive == 0 {
- s.log.ConnAwareLogger().Msg("no more connections active and exiting")
- // All connected tunnels exited gracefully, no more work to do
- return nil
- }
- // Backoff was set and its timer expired
- case <-backoffTimer:
- backoffTimer = nil
- for _, index := range tunnelsWaiting {
- go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
- }
- tunnelsActive += len(tunnelsWaiting)
- tunnelsWaiting = nil
- // Tunnel successfully connected
- case <-s.nextConnectedSignal:
- if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
- // No more tunnels outstanding, clear backoff timer
- backoff.SetGracePeriod()
- }
- case <-s.gracefulShutdownC:
- shuttingDown = true
- }
- }
- }
- // Returns nil if initialization succeeded, else the initialization error.
- // Attempts here will be made to connect one tunnel, if successful, it will
- // connect the available tunnels up to config.HAConnections.
- func (s *Supervisor) initialize(
- ctx context.Context,
- connectedSignal *signal.Signal,
- ) error {
- availableAddrs := s.edgeIPs.AvailableAddrs()
- if s.config.HAConnections > availableAddrs {
- s.log.Logger().Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
- s.config.HAConnections = availableAddrs
- }
- s.tunnelsProtocolFallback[0] = &protocolFallback{
- retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
- s.config.ProtocolSelector.Current(),
- false,
- }
- go s.startFirstTunnel(ctx, connectedSignal)
- // Wait for response from first tunnel before proceeding to attempt other HA edge tunnels
- select {
- case <-ctx.Done():
- <-s.tunnelErrors
- return ctx.Err()
- case tunnelError := <-s.tunnelErrors:
- return tunnelError.err
- case <-s.gracefulShutdownC:
- return errEarlyShutdown
- case <-connectedSignal.Wait():
- }
- // At least one successful connection, so start the rest
- for i := 1; i < s.config.HAConnections; i++ {
- s.tunnelsProtocolFallback[i] = &protocolFallback{
- retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
- // Set the protocol we know the first tunnel connected with.
- s.tunnelsProtocolFallback[0].protocol,
- false,
- }
- go s.startTunnel(ctx, i, s.newConnectedTunnelSignal(i))
- time.Sleep(registrationInterval)
- }
- return nil
- }
- // startTunnel starts the first tunnel connection. The resulting error will be sent on
- // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
- func (s *Supervisor) startFirstTunnel(
- ctx context.Context,
- connectedSignal *signal.Signal,
- ) {
- var (
- err error
- )
- const firstConnIndex = 0
- isStaticEdge := len(s.config.EdgeAddrs) > 0
- defer func() {
- s.tunnelErrors <- tunnelError{index: firstConnIndex, err: err}
- }()
- // If the first tunnel disconnects, keep restarting it.
- for {
- err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, s.tunnelsProtocolFallback[firstConnIndex], connectedSignal)
- if ctx.Err() != nil {
- return
- }
- if err == nil {
- return
- }
- // Make sure we don't continue if there is no more fallback allowed
- if _, retry := s.tunnelsProtocolFallback[firstConnIndex].GetMaxBackoffDuration(ctx); !retry {
- return
- }
- // Try again for Unauthorized errors because we hope them to be
- // transient due to edge propagation lag on new Tunnels.
- if strings.Contains(err.Error(), "Unauthorized") {
- continue
- }
- switch err.(type) {
- case edgediscovery.ErrNoAddressesLeft:
- // If your provided addresses are not available, we will keep trying regardless.
- if !isStaticEdge {
- return
- }
- case connection.DupConnRegisterTunnelError,
- *quic.IdleTimeoutError,
- *quic.ApplicationError,
- edgediscovery.DialError,
- *connection.EdgeQuicDialError:
- // Try again for these types of errors
- default:
- // Uncaught errors should bail startup
- return
- }
- }
- }
- // startTunnel starts a new tunnel connection. The resulting error will be sent on
- // s.tunnelError as this is expected to run in a goroutine.
- func (s *Supervisor) startTunnel(
- ctx context.Context,
- index int,
- connectedSignal *signal.Signal,
- ) {
- var (
- err error
- )
- defer func() {
- s.tunnelErrors <- tunnelError{index: index, err: err}
- }()
- // nolint: gosec
- err = s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
- }
- func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
- sig := make(chan struct{})
- s.tunnelsConnecting[index] = sig
- s.nextConnectedSignal = sig
- s.nextConnectedIndex = index
- return signal.New(sig)
- }
- func (s *Supervisor) waitForNextTunnel(index int) bool {
- delete(s.tunnelsConnecting, index)
- s.nextConnectedSignal = nil
- for k, v := range s.tunnelsConnecting {
- s.nextConnectedIndex = k
- s.nextConnectedSignal = v
- return true
- }
- return false
- }
|