123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346 |
- package origin
- import (
- "context"
- "errors"
- "net"
- "time"
- "github.com/google/uuid"
- "github.com/cloudflare/cloudflared/buffer"
- "github.com/cloudflare/cloudflared/connection"
- "github.com/cloudflare/cloudflared/edgediscovery"
- "github.com/cloudflare/cloudflared/h2mux"
- "github.com/cloudflare/cloudflared/logger"
- "github.com/cloudflare/cloudflared/signal"
- tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
- )
- const (
- // Waiting time before retrying a failed tunnel connection
- tunnelRetryDuration = time.Second * 10
- // SRV record resolution TTL
- resolveTTL = time.Hour
- // Interval between registering new tunnels
- registrationInterval = time.Second
- subsystemRefreshAuth = "refresh_auth"
- // Maximum exponent for 'Authenticate' exponential backoff
- refreshAuthMaxBackoff = 10
- // Waiting time before retrying a failed 'Authenticate' connection
- refreshAuthRetryDuration = time.Second * 10
- // Maximum time to make an Authenticate RPC
- authTokenTimeout = time.Second * 30
- )
- var (
- errEventDigestUnset = errors.New("event digest unset")
- )
- // Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
- // reconnects them if they disconnect.
- type Supervisor struct {
- cloudflaredUUID uuid.UUID
- config *TunnelConfig
- edgeIPs *edgediscovery.Edge
- lastResolve time.Time
- resolverC chan resolveResult
- tunnelErrors chan tunnelError
- tunnelsConnecting map[int]chan struct{}
- // nextConnectedIndex and nextConnectedSignal are used to wait for all
- // currently-connecting tunnels to finish connecting so we can reset backoff timer
- nextConnectedIndex int
- nextConnectedSignal chan struct{}
- logger logger.Service
- reconnectCredentialManager *reconnectCredentialManager
- bufferPool *buffer.Pool
- }
- type resolveResult struct {
- err error
- }
- type tunnelError struct {
- index int
- addr *net.TCPAddr
- err error
- }
- func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor, error) {
- var (
- edgeIPs *edgediscovery.Edge
- err error
- )
- if len(config.EdgeAddrs) > 0 {
- edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs)
- } else {
- edgeIPs, err = edgediscovery.ResolveEdge(config.Logger)
- }
- if err != nil {
- return nil, err
- }
- return &Supervisor{
- cloudflaredUUID: cloudflaredUUID,
- config: config,
- edgeIPs: edgeIPs,
- tunnelErrors: make(chan tunnelError),
- tunnelsConnecting: map[int]chan struct{}{},
- logger: config.Logger,
- reconnectCredentialManager: newReconnectCredentialManager(metricsNamespace, tunnelSubsystem, config.HAConnections),
- bufferPool: buffer.NewPool(512 * 1024),
- }, nil
- }
- func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
- logger := s.config.Logger
- if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
- return err
- }
- var tunnelsWaiting []int
- tunnelsActive := s.config.HAConnections
- backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
- var backoffTimer <-chan time.Time
- refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
- var refreshAuthBackoffTimer <-chan time.Time
- if s.config.UseReconnectToken {
- if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
- refreshAuthBackoffTimer = timer
- } else {
- logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
- refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
- }
- }
- for {
- select {
- // Context cancelled
- case <-ctx.Done():
- for tunnelsActive > 0 {
- <-s.tunnelErrors
- tunnelsActive--
- }
- return nil
- // startTunnel returned with error
- // (note that this may also be caused by context cancellation)
- case tunnelError := <-s.tunnelErrors:
- tunnelsActive--
- if tunnelError.err != nil {
- logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err)
- tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
- s.waitForNextTunnel(tunnelError.index)
- if backoffTimer == nil {
- backoffTimer = backoff.BackoffTimer()
- }
- // Previously we'd mark the edge address as bad here, but now we'll just silently use
- // another.
- }
- // Backoff was set and its timer expired
- case <-backoffTimer:
- backoffTimer = nil
- for _, index := range tunnelsWaiting {
- go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh)
- }
- tunnelsActive += len(tunnelsWaiting)
- tunnelsWaiting = nil
- // Time to call Authenticate
- case <-refreshAuthBackoffTimer:
- newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
- if err != nil {
- logger.Errorf("supervisor: Authentication failed: %s", err)
- // Permanent failure. Leave the `select` without setting the
- // channel to be non-null, so we'll never hit this case of the `select` again.
- continue
- }
- refreshAuthBackoffTimer = newTimer
- // Tunnel successfully connected
- case <-s.nextConnectedSignal:
- if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
- // No more tunnels outstanding, clear backoff timer
- backoff.SetGracePeriod()
- }
- // DNS resolution returned
- case result := <-s.resolverC:
- s.lastResolve = time.Now()
- s.resolverC = nil
- if result.err == nil {
- logger.Debug("supervisor: Service discovery refresh complete")
- } else {
- logger.Errorf("supervisor: Service discovery error: %s", result.err)
- }
- }
- }
- }
- // Returns nil if initialization succeeded, else the initialization error.
- func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
- logger := s.logger
- s.lastResolve = time.Now()
- availableAddrs := int(s.edgeIPs.AvailableAddrs())
- if s.config.HAConnections > availableAddrs {
- logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
- s.config.HAConnections = availableAddrs
- }
- go s.startFirstTunnel(ctx, connectedSignal, reconnectCh)
- select {
- case <-ctx.Done():
- <-s.tunnelErrors
- return ctx.Err()
- case tunnelError := <-s.tunnelErrors:
- return tunnelError.err
- case <-connectedSignal.Wait():
- }
- // At least one successful connection, so start the rest
- for i := 1; i < s.config.HAConnections; i++ {
- ch := signal.New(make(chan struct{}))
- go s.startTunnel(ctx, i, ch, reconnectCh)
- time.Sleep(registrationInterval)
- }
- return nil
- }
- // startTunnel starts the first tunnel connection. The resulting error will be sent on
- // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
- func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
- var (
- addr *net.TCPAddr
- err error
- )
- const firstConnIndex = 0
- defer func() {
- s.tunnelErrors <- tunnelError{index: firstConnIndex, addr: addr, err: err}
- }()
- addr, err = s.edgeIPs.GetAddr(firstConnIndex)
- if err != nil {
- return
- }
- err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
- // If the first tunnel disconnects, keep restarting it.
- edgeErrors := 0
- for s.unusedIPs() {
- if ctx.Err() != nil {
- return
- }
- switch err.(type) {
- case nil:
- return
- // try the next address if it was a dialError(network problem) or
- // dupConnRegisterTunnelError
- case connection.DialError, dupConnRegisterTunnelError:
- edgeErrors++
- default:
- return
- }
- if edgeErrors >= 2 {
- addr, err = s.edgeIPs.GetDifferentAddr(firstConnIndex)
- if err != nil {
- return
- }
- }
- err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
- }
- }
- // startTunnel starts a new tunnel connection. The resulting error will be sent on
- // s.tunnelErrors.
- func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
- var (
- addr *net.TCPAddr
- err error
- )
- defer func() {
- s.tunnelErrors <- tunnelError{index: index, addr: addr, err: err}
- }()
- addr, err = s.edgeIPs.GetDifferentAddr(index)
- if err != nil {
- return
- }
- err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
- }
- func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
- sig := make(chan struct{})
- s.tunnelsConnecting[index] = sig
- s.nextConnectedSignal = sig
- s.nextConnectedIndex = index
- return signal.New(sig)
- }
- func (s *Supervisor) waitForNextTunnel(index int) bool {
- delete(s.tunnelsConnecting, index)
- s.nextConnectedSignal = nil
- for k, v := range s.tunnelsConnecting {
- s.nextConnectedIndex = k
- s.nextConnectedSignal = v
- return true
- }
- return false
- }
- func (s *Supervisor) unusedIPs() bool {
- return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
- }
- func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
- arbitraryEdgeIP, err := s.edgeIPs.GetAddrForRPC()
- if err != nil {
- return nil, err
- }
- edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP)
- if err != nil {
- return nil, err
- }
- defer edgeConn.Close()
- handler := h2mux.MuxedStreamFunc(func(*h2mux.MuxedStream) error {
- // This callback is invoked by h2mux when the edge initiates a stream.
- return nil // noop
- })
- muxerConfig := s.config.muxerConfig(handler)
- muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams)
- if err != nil {
- return nil, err
- }
- go muxer.Serve(ctx)
- defer func() {
- // If we don't wait for the muxer shutdown here, edgeConn.Close() runs before the muxer connections are done,
- // and the user sees log noise: "error writing data", "connection closed unexpectedly"
- <-muxer.Shutdown()
- }()
- tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger, openStreamTimeout)
- if err != nil {
- return nil, err
- }
- defer tunnelServer.Close()
- const arbitraryConnectionID = uint8(0)
- registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
- registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
- authResponse, err := tunnelServer.Authenticate(
- ctx,
- s.config.OriginCert,
- s.config.Hostname,
- registrationOptions,
- )
- if err != nil {
- return nil, err
- }
- return authResponse.Outcome(), nil
- }
|