control.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. package connection
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "time"
  7. "github.com/pkg/errors"
  8. "github.com/cloudflare/cloudflared/management"
  9. "github.com/cloudflare/cloudflared/tunnelrpc"
  10. tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  11. )
  12. // registerClient derives a named tunnel rpc client that can then be used to register and unregister connections.
  13. type registerClientFunc func(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient
  14. type controlStream struct {
  15. observer *Observer
  16. connectedFuse ConnectedFuse
  17. tunnelProperties *TunnelProperties
  18. connIndex uint8
  19. edgeAddress net.IP
  20. protocol Protocol
  21. registerClientFunc registerClientFunc
  22. registerTimeout time.Duration
  23. gracefulShutdownC <-chan struct{}
  24. gracePeriod time.Duration
  25. stoppedGracefully bool
  26. }
  27. // ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown.
  28. type ControlStreamHandler interface {
  29. // ServeControlStream handles the control plane of the transport in the current goroutine calling this
  30. ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, tunnelConfigGetter TunnelConfigJSONGetter) error
  31. // IsStopped tells whether the method above has finished
  32. IsStopped() bool
  33. }
  34. type TunnelConfigJSONGetter interface {
  35. GetConfigJSON() ([]byte, error)
  36. }
  37. // NewControlStream returns a new instance of ControlStreamHandler
  38. func NewControlStream(
  39. observer *Observer,
  40. connectedFuse ConnectedFuse,
  41. tunnelProperties *TunnelProperties,
  42. connIndex uint8,
  43. edgeAddress net.IP,
  44. registerClientFunc registerClientFunc,
  45. registerTimeout time.Duration,
  46. gracefulShutdownC <-chan struct{},
  47. gracePeriod time.Duration,
  48. protocol Protocol,
  49. ) ControlStreamHandler {
  50. if registerClientFunc == nil {
  51. registerClientFunc = tunnelrpc.NewRegistrationClient
  52. }
  53. return &controlStream{
  54. observer: observer,
  55. connectedFuse: connectedFuse,
  56. tunnelProperties: tunnelProperties,
  57. registerClientFunc: registerClientFunc,
  58. registerTimeout: registerTimeout,
  59. connIndex: connIndex,
  60. edgeAddress: edgeAddress,
  61. gracefulShutdownC: gracefulShutdownC,
  62. gracePeriod: gracePeriod,
  63. protocol: protocol,
  64. }
  65. }
  66. func (c *controlStream) ServeControlStream(
  67. ctx context.Context,
  68. rw io.ReadWriteCloser,
  69. connOptions *tunnelpogs.ConnectionOptions,
  70. tunnelConfigGetter TunnelConfigJSONGetter,
  71. ) error {
  72. registrationClient := c.registerClientFunc(ctx, rw, c.registerTimeout)
  73. registrationDetails, err := registrationClient.RegisterConnection(
  74. ctx,
  75. c.tunnelProperties.Credentials.Auth(),
  76. c.tunnelProperties.Credentials.TunnelID,
  77. connOptions,
  78. c.connIndex,
  79. c.edgeAddress)
  80. if err != nil {
  81. defer registrationClient.Close()
  82. if err.Error() == DuplicateConnectionError {
  83. c.observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc()
  84. return errDuplicationConnection
  85. }
  86. c.observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc()
  87. return serverRegistrationErrorFromRPC(err)
  88. }
  89. c.observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
  90. c.observer.logConnected(registrationDetails.UUID, c.connIndex, registrationDetails.Location, c.edgeAddress, c.protocol)
  91. c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location, c.edgeAddress)
  92. c.connectedFuse.Connected()
  93. // if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration
  94. if c.connIndex == 0 && !registrationDetails.TunnelIsRemotelyManaged {
  95. if tunnelConfig, err := tunnelConfigGetter.GetConfigJSON(); err == nil {
  96. if err := registrationClient.SendLocalConfiguration(ctx, tunnelConfig); err != nil {
  97. c.observer.metrics.localConfigMetrics.pushesErrors.Inc()
  98. c.observer.log.Err(err).Msg("unable to send local configuration")
  99. }
  100. c.observer.metrics.localConfigMetrics.pushes.Inc()
  101. } else {
  102. c.observer.log.Err(err).Msg("failed to obtain current configuration")
  103. }
  104. }
  105. return c.waitForUnregister(ctx, registrationClient)
  106. }
  107. func (c *controlStream) waitForUnregister(ctx context.Context, registrationClient tunnelrpc.RegistrationClient) error {
  108. // wait for connection termination or start of graceful shutdown
  109. defer registrationClient.Close()
  110. var shutdownError error
  111. select {
  112. case <-ctx.Done():
  113. shutdownError = ctx.Err()
  114. break
  115. case <-c.gracefulShutdownC:
  116. c.stoppedGracefully = true
  117. }
  118. c.observer.sendUnregisteringEvent(c.connIndex)
  119. err := registrationClient.GracefulShutdown(ctx, c.gracePeriod)
  120. if err != nil {
  121. return errors.Wrap(err, "Error shutting down control stream")
  122. }
  123. c.observer.log.Info().
  124. Int(management.EventTypeKey, int(management.Cloudflared)).
  125. Uint8(LogFieldConnIndex, c.connIndex).
  126. IPAddr(LogFieldIPAddress, c.edgeAddress).
  127. Msg("Unregistered tunnel connection")
  128. return shutdownError
  129. }
  130. func (c *controlStream) IsStopped() bool {
  131. return c.stoppedGracefully
  132. }