connectionrpc.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. package pogs
  2. import (
  3. "context"
  4. "errors"
  5. "net"
  6. "time"
  7. "github.com/google/uuid"
  8. "zombiezen.com/go/capnproto2/pogs"
  9. "zombiezen.com/go/capnproto2/server"
  10. "github.com/cloudflare/cloudflared/tunnelrpc"
  11. )
  12. type RegistrationServer interface {
  13. RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error)
  14. UnregisterConnection(ctx context.Context)
  15. }
  16. type ClientInfo struct {
  17. ClientID []byte `capnp:"clientId"` // must be a slice for capnp compatibility
  18. Features []string
  19. Version string
  20. Arch string
  21. }
  22. type ConnectionOptions struct {
  23. Client ClientInfo
  24. OriginLocalIP net.IP `capnp:"originLocalIp"`
  25. ReplaceExisting bool
  26. CompressionQuality uint8
  27. NumPreviousAttempts uint8
  28. }
  29. type TunnelAuth struct {
  30. AccountTag string
  31. TunnelSecret []byte
  32. }
  33. func (p *ConnectionOptions) MarshalCapnproto(s tunnelrpc.ConnectionOptions) error {
  34. return pogs.Insert(tunnelrpc.ConnectionOptions_TypeID, s.Struct, p)
  35. }
  36. func (p *ConnectionOptions) UnmarshalCapnproto(s tunnelrpc.ConnectionOptions) error {
  37. return pogs.Extract(p, tunnelrpc.ConnectionOptions_TypeID, s.Struct)
  38. }
  39. func (a *TunnelAuth) MarshalCapnproto(s tunnelrpc.TunnelAuth) error {
  40. return pogs.Insert(tunnelrpc.TunnelAuth_TypeID, s.Struct, a)
  41. }
  42. func (a *TunnelAuth) UnmarshalCapnproto(s tunnelrpc.TunnelAuth) error {
  43. return pogs.Extract(a, tunnelrpc.TunnelAuth_TypeID, s.Struct)
  44. }
  45. type ConnectionDetails struct {
  46. UUID uuid.UUID
  47. Location string
  48. }
  49. func (details *ConnectionDetails) MarshalCapnproto(s tunnelrpc.ConnectionDetails) error {
  50. if err := s.SetUuid(details.UUID[:]); err != nil {
  51. return err
  52. }
  53. if err := s.SetLocationName(details.Location); err != nil {
  54. return err
  55. }
  56. return nil
  57. }
  58. func (details *ConnectionDetails) UnmarshalCapnproto(s tunnelrpc.ConnectionDetails) error {
  59. uuidBytes, err := s.Uuid()
  60. if err != nil {
  61. return err
  62. }
  63. details.UUID, err = uuid.FromBytes(uuidBytes)
  64. if err != nil {
  65. return err
  66. }
  67. details.Location, err = s.LocationName()
  68. if err != nil {
  69. return err
  70. }
  71. return err
  72. }
  73. func MarshalError(s tunnelrpc.ConnectionError, err error) error {
  74. if err := s.SetCause(err.Error()); err != nil {
  75. return err
  76. }
  77. if retryableErr, ok := err.(*RetryableError); ok {
  78. s.SetShouldRetry(true)
  79. s.SetRetryAfter(int64(retryableErr.Delay))
  80. }
  81. return nil
  82. }
  83. func (i TunnelServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer_registerConnection) error {
  84. server.Ack(p.Options)
  85. auth, err := p.Params.Auth()
  86. if err != nil {
  87. return err
  88. }
  89. var pogsAuth TunnelAuth
  90. err = pogsAuth.UnmarshalCapnproto(auth)
  91. if err != nil {
  92. return err
  93. }
  94. uuidBytes, err := p.Params.TunnelId()
  95. if err != nil {
  96. return err
  97. }
  98. tunnelID, err := uuid.FromBytes(uuidBytes)
  99. if err != nil {
  100. return err
  101. }
  102. connIndex := p.Params.ConnIndex()
  103. options, err := p.Params.Options()
  104. if err != nil {
  105. return err
  106. }
  107. var pogsOptions ConnectionOptions
  108. err = pogsOptions.UnmarshalCapnproto(options)
  109. if err != nil {
  110. return err
  111. }
  112. connDetails, callError := i.impl.RegisterConnection(p.Ctx, pogsAuth, tunnelID, connIndex, &pogsOptions)
  113. resp, err := p.Results.NewResult()
  114. if err != nil {
  115. return err
  116. }
  117. if callError != nil {
  118. if connError, err := resp.Result().NewError(); err != nil {
  119. return err
  120. } else {
  121. return MarshalError(connError, callError)
  122. }
  123. }
  124. if details, err := resp.Result().NewConnectionDetails(); err != nil {
  125. return err
  126. } else {
  127. return connDetails.MarshalCapnproto(details)
  128. }
  129. }
  130. func (i TunnelServer_PogsImpl) UnregisterConnection(p tunnelrpc.RegistrationServer_unregisterConnection) error {
  131. server.Ack(p.Options)
  132. i.impl.UnregisterConnection(p.Ctx)
  133. return nil
  134. }
  135. func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
  136. client := tunnelrpc.TunnelServer{Client: c.Client}
  137. promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
  138. tunnelAuth, err := p.NewAuth()
  139. if err != nil {
  140. return err
  141. }
  142. if err = auth.MarshalCapnproto(tunnelAuth); err != nil {
  143. return err
  144. }
  145. err = p.SetAuth(tunnelAuth)
  146. if err != nil {
  147. return err
  148. }
  149. err = p.SetTunnelId(tunnelID[:])
  150. if err != nil {
  151. return err
  152. }
  153. p.SetConnIndex(connIndex)
  154. connectionOptions, err := p.NewOptions()
  155. if err != nil {
  156. return err
  157. }
  158. err = options.MarshalCapnproto(connectionOptions)
  159. if err != nil {
  160. return err
  161. }
  162. return nil
  163. })
  164. response, err := promise.Result().Struct()
  165. if err != nil {
  166. return nil, wrapRPCError(err)
  167. }
  168. result := response.Result()
  169. switch result.Which() {
  170. case tunnelrpc.ConnectionResponse_result_Which_error:
  171. resultError, err := result.Error()
  172. if err != nil {
  173. return nil, wrapRPCError(err)
  174. }
  175. cause, err := resultError.Cause()
  176. if err != nil {
  177. return nil, wrapRPCError(err)
  178. }
  179. err = errors.New(cause)
  180. if resultError.ShouldRetry() {
  181. err = RetryErrorAfter(err, time.Duration(resultError.RetryAfter()))
  182. }
  183. return nil, err
  184. case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
  185. connDetails, err := result.ConnectionDetails()
  186. if err != nil {
  187. return nil, wrapRPCError(err)
  188. }
  189. details := new(ConnectionDetails)
  190. if err = details.UnmarshalCapnproto(connDetails); err != nil {
  191. return nil, wrapRPCError(err)
  192. }
  193. return details, nil
  194. }
  195. return nil, newRPCError("unknown result which %d", result.Which())
  196. }
  197. func (c TunnelServer_PogsClient) UnregisterConnection(ctx context.Context) error {
  198. client := tunnelrpc.TunnelServer{Client: c.Client}
  199. promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
  200. return nil
  201. })
  202. _, err := promise.Struct()
  203. if err != nil {
  204. return wrapRPCError(err)
  205. }
  206. return nil
  207. }