rpc.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. package connection
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "time"
  7. "github.com/cloudflare/cloudflared/tunnelrpc"
  8. tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  9. "github.com/rs/zerolog"
  10. "zombiezen.com/go/capnproto2/rpc"
  11. )
  12. type tunnelServerClient struct {
  13. client tunnelpogs.TunnelServer_PogsClient
  14. transport rpc.Transport
  15. }
  16. // NewTunnelRPCClient creates and returns a new RPC client, which will communicate using a stream on the given muxer.
  17. // This method is exported for supervisor to call Authenticate RPC
  18. func NewTunnelServerClient(
  19. ctx context.Context,
  20. stream io.ReadWriteCloser,
  21. log *zerolog.Logger,
  22. ) *tunnelServerClient {
  23. transport := tunnelrpc.NewTransportLogger(log, rpc.StreamTransport(stream))
  24. conn := rpc.NewConn(
  25. transport,
  26. tunnelrpc.ConnLog(log),
  27. )
  28. registrationClient := tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
  29. return &tunnelServerClient{
  30. client: tunnelpogs.TunnelServer_PogsClient{RegistrationServer_PogsClient: registrationClient, Client: conn.Bootstrap(ctx), Conn: conn},
  31. transport: transport,
  32. }
  33. }
  34. func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) {
  35. authResp, err := tsc.client.Authenticate(ctx, classicTunnel.OriginCert, classicTunnel.Hostname, registrationOptions)
  36. if err != nil {
  37. return nil, err
  38. }
  39. return authResp.Outcome(), nil
  40. }
  41. func (tsc *tunnelServerClient) Close() {
  42. // Closing the client will also close the connection
  43. _ = tsc.client.Close()
  44. _ = tsc.transport.Close()
  45. }
  46. type NamedTunnelRPCClient interface {
  47. RegisterConnection(
  48. c context.Context,
  49. config *NamedTunnelConfig,
  50. options *tunnelpogs.ConnectionOptions,
  51. connIndex uint8,
  52. observer *Observer,
  53. ) error
  54. GracefulShutdown(ctx context.Context, gracePeriod time.Duration)
  55. Close()
  56. }
  57. type registrationServerClient struct {
  58. client tunnelpogs.RegistrationServer_PogsClient
  59. transport rpc.Transport
  60. }
  61. func newRegistrationRPCClient(
  62. ctx context.Context,
  63. stream io.ReadWriteCloser,
  64. log *zerolog.Logger,
  65. ) NamedTunnelRPCClient {
  66. transport := tunnelrpc.NewTransportLogger(log, rpc.StreamTransport(stream))
  67. conn := rpc.NewConn(
  68. transport,
  69. tunnelrpc.ConnLog(log),
  70. )
  71. return &registrationServerClient{
  72. client: tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
  73. transport: transport,
  74. }
  75. }
  76. func (rsc *registrationServerClient) RegisterConnection(
  77. ctx context.Context,
  78. config *NamedTunnelConfig,
  79. options *tunnelpogs.ConnectionOptions,
  80. connIndex uint8,
  81. observer *Observer,
  82. ) error {
  83. conn, err := rsc.client.RegisterConnection(
  84. ctx,
  85. config.Credentials.Auth(),
  86. config.Credentials.TunnelID,
  87. connIndex,
  88. options,
  89. )
  90. if err != nil {
  91. if err.Error() == DuplicateConnectionError {
  92. observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc()
  93. return errDuplicationConnection
  94. }
  95. observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc()
  96. return serverRegistrationErrorFromRPC(err)
  97. }
  98. observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
  99. observer.logServerInfo(connIndex, conn.Location, fmt.Sprintf("Connection %s registered", conn.UUID))
  100. observer.sendConnectedEvent(connIndex, conn.Location)
  101. return nil
  102. }
  103. func (rsc *registrationServerClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
  104. ctx, cancel := context.WithTimeout(ctx, gracePeriod)
  105. defer cancel()
  106. _ = rsc.client.UnregisterConnection(ctx)
  107. }
  108. func (rsc *registrationServerClient) Close() {
  109. // Closing the client will also close the connection
  110. _ = rsc.client.Close()
  111. // Closing the transport also closes the stream
  112. _ = rsc.transport.Close()
  113. }
  114. type rpcName string
  115. const (
  116. register rpcName = "register"
  117. reconnect rpcName = "reconnect"
  118. unregister rpcName = "unregister"
  119. authenticate rpcName = " authenticate"
  120. )
  121. func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
  122. h.observer.sendRegisteringEvent(registrationOptions.ConnectionID)
  123. stream, err := h.newRPCStream(ctx, register)
  124. if err != nil {
  125. return err
  126. }
  127. rpcClient := NewTunnelServerClient(ctx, stream, h.observer.log)
  128. defer rpcClient.Close()
  129. _ = h.logServerInfo(ctx, rpcClient)
  130. registration := rpcClient.client.RegisterTunnel(
  131. ctx,
  132. classicTunnel.OriginCert,
  133. classicTunnel.Hostname,
  134. registrationOptions,
  135. )
  136. if registrationErr := registration.DeserializeError(); registrationErr != nil {
  137. // RegisterTunnel RPC failure
  138. return h.processRegisterTunnelError(registrationErr, register)
  139. }
  140. // Send free tunnel URL to UI
  141. h.observer.sendURL(registration.Url)
  142. credentialSetter.SetEventDigest(h.connIndex, registration.EventDigest)
  143. return h.processRegistrationSuccess(registration, register, credentialSetter, classicTunnel)
  144. }
  145. type CredentialManager interface {
  146. ReconnectToken() ([]byte, error)
  147. EventDigest(connID uint8) ([]byte, error)
  148. SetEventDigest(connID uint8, digest []byte)
  149. ConnDigest(connID uint8) ([]byte, error)
  150. SetConnDigest(connID uint8, digest []byte)
  151. }
  152. func (h *h2muxConnection) processRegistrationSuccess(
  153. registration *tunnelpogs.TunnelRegistration,
  154. name rpcName,
  155. credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig,
  156. ) error {
  157. for _, logLine := range registration.LogLines {
  158. h.observer.log.Info().Msg(logLine)
  159. }
  160. if registration.TunnelID != "" {
  161. h.observer.metrics.tunnelsHA.AddTunnelID(h.connIndex, registration.TunnelID)
  162. h.observer.log.Info().Msgf("Each HA connection's tunnel IDs: %v", h.observer.metrics.tunnelsHA.String())
  163. }
  164. // Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set)
  165. if classicTunnel.IsTrialZone() {
  166. err := h.observer.logTrialHostname(registration)
  167. if err != nil {
  168. return err
  169. }
  170. }
  171. credentialManager.SetConnDigest(h.connIndex, registration.ConnDigest)
  172. h.observer.metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
  173. h.observer.log.Info().Msgf("Route propagating, it may take up to 1 minute for your new route to become functional")
  174. h.observer.metrics.regSuccess.WithLabelValues(string(name)).Inc()
  175. return nil
  176. }
  177. func (h *h2muxConnection) processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, name rpcName) error {
  178. if err.Error() == DuplicateConnectionError {
  179. h.observer.metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
  180. return errDuplicationConnection
  181. }
  182. h.observer.metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
  183. return ServerRegisterTunnelError{
  184. Cause: err,
  185. Permanent: err.IsPermanent(),
  186. }
  187. }
  188. func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
  189. token, err := credentialManager.ReconnectToken()
  190. if err != nil {
  191. return err
  192. }
  193. eventDigest, err := credentialManager.EventDigest(h.connIndex)
  194. if err != nil {
  195. return err
  196. }
  197. connDigest, err := credentialManager.ConnDigest(h.connIndex)
  198. if err != nil {
  199. return err
  200. }
  201. h.observer.log.Debug().Msg("initiating RPC stream to reconnect")
  202. stream, err := h.newRPCStream(ctx, register)
  203. if err != nil {
  204. return err
  205. }
  206. rpcClient := NewTunnelServerClient(ctx, stream, h.observer.log)
  207. defer rpcClient.Close()
  208. _ = h.logServerInfo(ctx, rpcClient)
  209. registration := rpcClient.client.ReconnectTunnel(
  210. ctx,
  211. token,
  212. eventDigest,
  213. connDigest,
  214. classicTunnel.Hostname,
  215. registrationOptions,
  216. )
  217. if registrationErr := registration.DeserializeError(); registrationErr != nil {
  218. // ReconnectTunnel RPC failure
  219. return h.processRegisterTunnelError(registrationErr, reconnect)
  220. }
  221. return h.processRegistrationSuccess(registration, reconnect, credentialManager, classicTunnel)
  222. }
  223. func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelServerClient) error {
  224. // Request server info without blocking tunnel registration; must use capnp library directly.
  225. serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.client.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
  226. return nil
  227. })
  228. serverInfoMessage, err := serverInfoPromise.Result().Struct()
  229. if err != nil {
  230. h.observer.log.Err(err).Msg("Failed to retrieve server information")
  231. return err
  232. }
  233. serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
  234. if err != nil {
  235. h.observer.log.Err(err).Msg("Failed to retrieve server information")
  236. return err
  237. }
  238. h.observer.logServerInfo(h.connIndex, serverInfo.LocationName, "Connection established")
  239. return nil
  240. }
  241. func (h *h2muxConnection) registerNamedTunnel(
  242. ctx context.Context,
  243. namedTunnel *NamedTunnelConfig,
  244. connOptions *tunnelpogs.ConnectionOptions,
  245. ) error {
  246. stream, err := h.newRPCStream(ctx, register)
  247. if err != nil {
  248. return err
  249. }
  250. rpcClient := h.newRPCClientFunc(ctx, stream, h.observer.log)
  251. defer rpcClient.Close()
  252. if err = rpcClient.RegisterConnection(ctx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
  253. return err
  254. }
  255. return nil
  256. }
  257. func (h *h2muxConnection) unregister(isNamedTunnel bool) {
  258. h.observer.sendUnregisteringEvent(h.connIndex)
  259. unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod)
  260. defer cancel()
  261. stream, err := h.newRPCStream(unregisterCtx, unregister)
  262. if err != nil {
  263. return
  264. }
  265. defer stream.Close()
  266. if isNamedTunnel {
  267. rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log)
  268. defer rpcClient.Close()
  269. rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod)
  270. } else {
  271. rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer.log)
  272. defer rpcClient.Close()
  273. // gracePeriod is encoded in int64 using capnproto
  274. _ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds())
  275. }
  276. h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection")
  277. }