123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- package pogs
- import (
- "context"
- "errors"
- "net"
- "time"
- "github.com/google/uuid"
- "zombiezen.com/go/capnproto2/pogs"
- "zombiezen.com/go/capnproto2/server"
- "github.com/cloudflare/cloudflared/tunnelrpc"
- )
- type RegistrationServer interface {
- RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error)
- UnregisterConnection(ctx context.Context)
- }
- type ClientInfo struct {
- ClientID []byte `capnp:"clientId"` // must be a slice for capnp compatibility
- Features []string
- Version string
- Arch string
- }
- type ConnectionOptions struct {
- Client ClientInfo
- OriginLocalIP net.IP `capnp:"originLocalIp"`
- ReplaceExisting bool
- CompressionQuality uint8
- NumPreviousAttempts uint8
- }
- type TunnelAuth struct {
- AccountTag string
- TunnelSecret []byte
- }
- func (p *ConnectionOptions) MarshalCapnproto(s tunnelrpc.ConnectionOptions) error {
- return pogs.Insert(tunnelrpc.ConnectionOptions_TypeID, s.Struct, p)
- }
- func (p *ConnectionOptions) UnmarshalCapnproto(s tunnelrpc.ConnectionOptions) error {
- return pogs.Extract(p, tunnelrpc.ConnectionOptions_TypeID, s.Struct)
- }
- func (a *TunnelAuth) MarshalCapnproto(s tunnelrpc.TunnelAuth) error {
- return pogs.Insert(tunnelrpc.TunnelAuth_TypeID, s.Struct, a)
- }
- func (a *TunnelAuth) UnmarshalCapnproto(s tunnelrpc.TunnelAuth) error {
- return pogs.Extract(a, tunnelrpc.TunnelAuth_TypeID, s.Struct)
- }
- type ConnectionDetails struct {
- UUID uuid.UUID
- Location string
- }
- func (details *ConnectionDetails) MarshalCapnproto(s tunnelrpc.ConnectionDetails) error {
- if err := s.SetUuid(details.UUID[:]); err != nil {
- return err
- }
- if err := s.SetLocationName(details.Location); err != nil {
- return err
- }
- return nil
- }
- func (details *ConnectionDetails) UnmarshalCapnproto(s tunnelrpc.ConnectionDetails) error {
- uuidBytes, err := s.Uuid()
- if err != nil {
- return err
- }
- details.UUID, err = uuid.FromBytes(uuidBytes)
- if err != nil {
- return err
- }
- details.Location, err = s.LocationName()
- if err != nil {
- return err
- }
- return err
- }
- func MarshalError(s tunnelrpc.ConnectionError, err error) error {
- if err := s.SetCause(err.Error()); err != nil {
- return err
- }
- if retryableErr, ok := err.(*RetryableError); ok {
- s.SetShouldRetry(true)
- s.SetRetryAfter(int64(retryableErr.Delay))
- }
- return nil
- }
- func (i TunnelServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer_registerConnection) error {
- server.Ack(p.Options)
- auth, err := p.Params.Auth()
- if err != nil {
- return err
- }
- var pogsAuth TunnelAuth
- err = pogsAuth.UnmarshalCapnproto(auth)
- if err != nil {
- return err
- }
- uuidBytes, err := p.Params.TunnelId()
- if err != nil {
- return err
- }
- tunnelID, err := uuid.FromBytes(uuidBytes)
- if err != nil {
- return err
- }
- connIndex := p.Params.ConnIndex()
- options, err := p.Params.Options()
- if err != nil {
- return err
- }
- var pogsOptions ConnectionOptions
- err = pogsOptions.UnmarshalCapnproto(options)
- if err != nil {
- return err
- }
- connDetails, callError := i.impl.RegisterConnection(p.Ctx, pogsAuth, tunnelID, connIndex, &pogsOptions)
- resp, err := p.Results.NewResult()
- if err != nil {
- return err
- }
- if callError != nil {
- if connError, err := resp.Result().NewError(); err != nil {
- return err
- } else {
- return MarshalError(connError, callError)
- }
- }
- if details, err := resp.Result().NewConnectionDetails(); err != nil {
- return err
- } else {
- return connDetails.MarshalCapnproto(details)
- }
- }
- func (i TunnelServer_PogsImpl) UnregisterConnection(p tunnelrpc.RegistrationServer_unregisterConnection) error {
- server.Ack(p.Options)
- i.impl.UnregisterConnection(p.Ctx)
- return nil
- }
- func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
- client := tunnelrpc.TunnelServer{Client: c.Client}
- promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
- tunnelAuth, err := p.NewAuth()
- if err != nil {
- return err
- }
- if err = auth.MarshalCapnproto(tunnelAuth); err != nil {
- return err
- }
- err = p.SetAuth(tunnelAuth)
- if err != nil {
- return err
- }
- err = p.SetTunnelId(tunnelID[:])
- if err != nil {
- return err
- }
- p.SetConnIndex(connIndex)
- connectionOptions, err := p.NewOptions()
- if err != nil {
- return err
- }
- err = options.MarshalCapnproto(connectionOptions)
- if err != nil {
- return err
- }
- return nil
- })
- response, err := promise.Result().Struct()
- if err != nil {
- return nil, wrapRPCError(err)
- }
- result := response.Result()
- switch result.Which() {
- case tunnelrpc.ConnectionResponse_result_Which_error:
- resultError, err := result.Error()
- if err != nil {
- return nil, wrapRPCError(err)
- }
- cause, err := resultError.Cause()
- if err != nil {
- return nil, wrapRPCError(err)
- }
- err = errors.New(cause)
- if resultError.ShouldRetry() {
- err = RetryErrorAfter(err, time.Duration(resultError.RetryAfter()))
- }
- return nil, err
- case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
- connDetails, err := result.ConnectionDetails()
- if err != nil {
- return nil, wrapRPCError(err)
- }
- details := new(ConnectionDetails)
- if err = details.UnmarshalCapnproto(connDetails); err != nil {
- return nil, wrapRPCError(err)
- }
- return details, nil
- }
- return nil, newRPCError("unknown result which %d", result.Which())
- }
- func (c TunnelServer_PogsClient) UnregisterConnection(ctx context.Context) error {
- client := tunnelrpc.TunnelServer{Client: c.Client}
- promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
- return nil
- })
- _, err := promise.Struct()
- if err != nil {
- return wrapRPCError(err)
- }
- return nil
- }
|