123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261 |
- package connection
- import (
- "context"
- "io"
- "net"
- "net/http"
- "time"
- "github.com/cloudflare/cloudflared/h2mux"
- tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
- "github.com/cloudflare/cloudflared/websocket"
- "github.com/pkg/errors"
- "github.com/rs/zerolog"
- "golang.org/x/sync/errgroup"
- )
- const (
- muxerTimeout = 5 * time.Second
- openStreamTimeout = 30 * time.Second
- )
- type h2muxConnection struct {
- config *Config
- muxerConfig *MuxerConfig
- muxer *h2mux.Muxer
- // connectionID is only used by metrics, and prometheus requires labels to be string
- connIndexStr string
- connIndex uint8
- observer *Observer
- gracefulShutdownC <-chan struct{}
- stoppedGracefully bool
- // newRPCClientFunc allows us to mock RPCs during testing
- newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
- }
- type MuxerConfig struct {
- HeartbeatInterval time.Duration
- MaxHeartbeats uint64
- CompressionSetting h2mux.CompressionSetting
- MetricsUpdateFreq time.Duration
- }
- func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Logger) *h2mux.MuxerConfig {
- return &h2mux.MuxerConfig{
- Timeout: muxerTimeout,
- Handler: h,
- IsClient: true,
- HeartbeatInterval: mc.HeartbeatInterval,
- MaxHeartbeats: mc.MaxHeartbeats,
- Log: log,
- CompressionQuality: mc.CompressionSetting,
- }
- }
- // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
- func NewH2muxConnection(
- config *Config,
- muxerConfig *MuxerConfig,
- edgeConn net.Conn,
- connIndex uint8,
- observer *Observer,
- gracefulShutdownC <-chan struct{},
- ) (*h2muxConnection, error, bool) {
- h := &h2muxConnection{
- config: config,
- muxerConfig: muxerConfig,
- connIndexStr: uint8ToString(connIndex),
- connIndex: connIndex,
- observer: observer,
- gracefulShutdownC: gracefulShutdownC,
- newRPCClientFunc: newRegistrationRPCClient,
- }
- // Establish a muxed connection with the edge
- // Client mux handshake with agent server
- muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig.H2MuxerConfig(h, observer.logTransport), h2mux.ActiveStreams)
- if err != nil {
- recoverable := isHandshakeErrRecoverable(err, connIndex, observer)
- return nil, err, recoverable
- }
- h.muxer = muxer
- return h, nil, false
- }
- func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
- errGroup, serveCtx := errgroup.WithContext(ctx)
- errGroup.Go(func() error {
- return h.serveMuxer(serveCtx)
- })
- errGroup.Go(func() error {
- if err := h.registerNamedTunnel(serveCtx, namedTunnel, connOptions); err != nil {
- return err
- }
- connectedFuse.Connected()
- return nil
- })
- errGroup.Go(func() error {
- h.controlLoop(serveCtx, connectedFuse, true)
- return nil
- })
- err := errGroup.Wait()
- if err == errMuxerStopped {
- if h.stoppedGracefully {
- return nil
- }
- h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown")
- }
- return err
- }
- func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
- errGroup, serveCtx := errgroup.WithContext(ctx)
- errGroup.Go(func() error {
- return h.serveMuxer(serveCtx)
- })
- errGroup.Go(func() (err error) {
- defer func() {
- if err == nil {
- connectedFuse.Connected()
- }
- }()
- if classicTunnel.UseReconnectToken && connectedFuse.IsConnected() {
- err := h.reconnectTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
- if err == nil {
- return nil
- }
- // log errors and proceed to RegisterTunnel
- h.observer.log.Err(err).
- Uint8(LogFieldConnIndex, h.connIndex).
- Msg("Couldn't reconnect connection. Re-registering it instead.")
- }
- return h.registerTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
- })
- errGroup.Go(func() error {
- h.controlLoop(serveCtx, connectedFuse, false)
- return nil
- })
- err := errGroup.Wait()
- if err == errMuxerStopped {
- if h.stoppedGracefully {
- return nil
- }
- h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown")
- }
- return err
- }
- func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
- // All routines should stop when muxer finish serving. When muxer is shutdown
- // gracefully, it doesn't return an error, so we need to return errMuxerShutdown
- // here to notify other routines to stop
- err := h.muxer.Serve(ctx)
- if err == nil {
- return errMuxerStopped
- }
- return err
- }
- func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
- updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
- var shutdownCompleted <-chan struct{}
- for {
- select {
- case <-h.gracefulShutdownC:
- if connectedFuse.IsConnected() {
- h.unregister(isNamedTunnel)
- }
- h.stoppedGracefully = true
- h.gracefulShutdownC = nil
- shutdownCompleted = h.muxer.Shutdown()
- case <-shutdownCompleted:
- return
- case <-ctx.Done():
- // UnregisterTunnel blocks until the RPC call returns
- if !h.stoppedGracefully && connectedFuse.IsConnected() {
- h.unregister(isNamedTunnel)
- }
- h.muxer.Shutdown()
- // don't wait for shutdown to finish when context is closed, this is the hard termination path
- return
- case <-updateMetricsTickC:
- h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
- }
- }
- }
- func (h *h2muxConnection) newRPCStream(ctx context.Context, rpcName rpcName) (*h2mux.MuxedStream, error) {
- openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
- defer openStreamCancel()
- stream, err := h.muxer.OpenRPCStream(openStreamCtx)
- if err != nil {
- return nil, err
- }
- return stream, nil
- }
- func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
- respWriter := &h2muxRespWriter{stream}
- req, reqErr := h.newRequest(stream)
- if reqErr != nil {
- respWriter.WriteErrorResponse()
- return reqErr
- }
- var sourceConnectionType = TypeHTTP
- if websocket.IsWebSocketUpgrade(req) {
- sourceConnectionType = TypeWebsocket
- }
- err := h.config.OriginProxy.Proxy(respWriter, req, sourceConnectionType)
- if err != nil {
- respWriter.WriteErrorResponse()
- return err
- }
- return nil
- }
- func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
- req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream})
- if err != nil {
- return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
- }
- err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
- if err != nil {
- return nil, errors.Wrap(err, "invalid request received")
- }
- return req, nil
- }
- type h2muxRespWriter struct {
- *h2mux.MuxedStream
- }
- func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
- headers := h2mux.H1ResponseToH2ResponseHeaders(status, header)
- headers = append(headers, h2mux.Header{Name: ResponseMetaHeaderField, Value: responseMetaHeaderOrigin})
- return rp.WriteHeaders(headers)
- }
- func (rp *h2muxRespWriter) WriteErrorResponse() {
- _ = rp.WriteHeaders([]h2mux.Header{
- {Name: ":status", Value: "502"},
- {Name: ResponseMetaHeaderField, Value: responseMetaHeaderCfd},
- })
- _, _ = rp.Write([]byte("502 Bad Gateway"))
- }
|