123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514 |
- package h2mux
- import (
- "context"
- "io"
- "strings"
- "sync"
- "time"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/rs/zerolog"
- "golang.org/x/net/http2"
- "golang.org/x/net/http2/hpack"
- "golang.org/x/sync/errgroup"
- )
- const (
- defaultFrameSize uint32 = 1 << 14 // Minimum frame size in http2 spec
- defaultWindowSize uint32 = (1 << 16) - 1 // Minimum window size in http2 spec
- maxWindowSize uint32 = (1 << 31) - 1 // 2^31-1 = 2147483647, max window size in http2 spec
- defaultTimeout time.Duration = 5 * time.Second
- defaultRetries uint64 = 5
- defaultWriteBufferMaxLen int = 1024 * 1024 // 1mb
- writeBufferInitialSize int = 16 * 1024 // 16KB
- SettingMuxerMagic http2.SettingID = 0x42db
- MuxerMagicOrigin uint32 = 0xa2e43c8b
- MuxerMagicEdge uint32 = 0x1088ebf9
- )
- type MuxedStreamHandler interface {
- ServeStream(*MuxedStream) error
- }
- type MuxedStreamFunc func(stream *MuxedStream) error
- func (f MuxedStreamFunc) ServeStream(stream *MuxedStream) error {
- return f(stream)
- }
- type MuxerConfig struct {
- Timeout time.Duration
- Handler MuxedStreamHandler
- IsClient bool
- // Name is used to identify this muxer instance when logging.
- Name string
- // The minimum time this connection can be idle before sending a heartbeat.
- HeartbeatInterval time.Duration
- // The minimum number of heartbeats to send before terminating the connection.
- MaxHeartbeats uint64
- // Logger to use
- Log *zerolog.Logger
- CompressionQuality CompressionSetting
- // Initial size for HTTP2 flow control windows
- DefaultWindowSize uint32
- // Largest allowable size for HTTP2 flow control windows
- MaxWindowSize uint32
- // Largest allowable capacity for the buffer of data to be sent
- StreamWriteBufferMaxLen int
- }
- type Muxer struct {
- // f is used to read and write HTTP2 frames on the wire.
- f *http2.Framer
- // config is the MuxerConfig given in Handshake.
- config MuxerConfig
- // w, r are references to the underlying connection used.
- w io.WriteCloser
- r io.ReadCloser
- // muxReader is the read process.
- muxReader *MuxReader
- // muxWriter is the write process.
- muxWriter *MuxWriter
- // muxMetricsUpdater is the process to update metrics
- muxMetricsUpdater muxMetricsUpdater
- // newStreamChan is used to create new streams on the writer thread.
- // The writer will assign the next available stream ID.
- newStreamChan chan MuxedStreamRequest
- // abortChan is used to abort the writer event loop.
- abortChan chan struct{}
- // abortOnce is used to ensure abortChan is closed once only.
- abortOnce sync.Once
- // readyList is used to signal writable streams.
- readyList *ReadyList
- // streams tracks currently-open streams.
- streams *activeStreamMap
- // explicitShutdown records whether the Muxer is closing because Shutdown was called, or due to another
- // error.
- explicitShutdown *BooleanFuse
- compressionQuality CompressionPreset
- }
- func RPCHeaders() []Header {
- return []Header{
- {Name: ":method", Value: "RPC"},
- {Name: ":scheme", Value: "capnp"},
- {Name: ":path", Value: "*"},
- }
- }
- // Handshake establishes a muxed connection with the peer.
- // After the handshake completes, it is possible to open and accept streams.
- func Handshake(
- w io.WriteCloser,
- r io.ReadCloser,
- config MuxerConfig,
- activeStreamsMetrics prometheus.Gauge,
- ) (*Muxer, error) {
- // Set default config values
- if config.Timeout == 0 {
- config.Timeout = defaultTimeout
- }
- if config.DefaultWindowSize == 0 {
- config.DefaultWindowSize = defaultWindowSize
- }
- if config.MaxWindowSize == 0 {
- config.MaxWindowSize = maxWindowSize
- }
- if config.StreamWriteBufferMaxLen == 0 {
- config.StreamWriteBufferMaxLen = defaultWriteBufferMaxLen
- }
- // Initialise connection state fields
- m := &Muxer{
- f: http2.NewFramer(w, r), // A framer that writes to w and reads from r
- config: config,
- w: w,
- r: r,
- newStreamChan: make(chan MuxedStreamRequest),
- abortChan: make(chan struct{}),
- readyList: NewReadyList(),
- streams: newActiveStreamMap(config.IsClient, activeStreamsMetrics),
- }
- m.f.ReadMetaHeaders = hpack.NewDecoder(4096, func(hpack.HeaderField) {})
- // Initialise the settings to identify this connection and confirm the other end is sane.
- handshakeSetting := http2.Setting{ID: SettingMuxerMagic, Val: MuxerMagicEdge}
- compressionSetting := http2.Setting{ID: SettingCompression, Val: config.CompressionQuality.toH2Setting()}
- if CompressionIsSupported() {
- config.Log.Debug().Msg("muxer: Compression is supported")
- m.compressionQuality = config.CompressionQuality.getPreset()
- } else {
- config.Log.Debug().Msg("muxer: Compression is not supported")
- compressionSetting = http2.Setting{ID: SettingCompression, Val: 0}
- }
- expectedMagic := MuxerMagicOrigin
- if config.IsClient {
- handshakeSetting.Val = MuxerMagicOrigin
- expectedMagic = MuxerMagicEdge
- }
- errChan := make(chan error, 2)
- // Simultaneously send our settings and verify the peer's settings.
- go func() { errChan <- m.f.WriteSettings(handshakeSetting, compressionSetting) }()
- go func() { errChan <- m.readPeerSettings(expectedMagic) }()
- err := joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
- if err != nil {
- return nil, err
- }
- // Confirm sanity by ACKing the frame and expecting an ACK for our frame.
- // Not strictly necessary, but let's pretend to be H2-like.
- go func() { errChan <- m.f.WriteSettingsAck() }()
- go func() { errChan <- m.readPeerSettingsAck() }()
- err = joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
- if err != nil {
- return nil, err
- }
- // set up reader/writer pair ready for serve
- streamErrors := NewStreamErrorMap()
- goAwayChan := make(chan http2.ErrCode, 1)
- inBoundCounter := NewAtomicCounter(0)
- outBoundCounter := NewAtomicCounter(0)
- pingTimestamp := NewPingTimestamp()
- connActive := NewSignal()
- idleDuration := config.HeartbeatInterval
- // Sanity check to enusre idelDuration is sane
- if idleDuration == 0 || idleDuration < defaultTimeout {
- idleDuration = defaultTimeout
- config.Log.Info().Msgf("muxer: Minimum idle time has been adjusted to %d", defaultTimeout)
- }
- maxRetries := config.MaxHeartbeats
- if maxRetries == 0 {
- maxRetries = defaultRetries
- config.Log.Info().Msgf("muxer: Minimum number of unacked heartbeats to send before closing the connection has been adjusted to %d", maxRetries)
- }
- compBytesBefore, compBytesAfter := NewAtomicCounter(0), NewAtomicCounter(0)
- m.muxMetricsUpdater = newMuxMetricsUpdater(
- m.abortChan,
- compBytesBefore,
- compBytesAfter,
- )
- m.explicitShutdown = NewBooleanFuse()
- m.muxReader = &MuxReader{
- f: m.f,
- handler: m.config.Handler,
- streams: m.streams,
- readyList: m.readyList,
- streamErrors: streamErrors,
- goAwayChan: goAwayChan,
- abortChan: m.abortChan,
- pingTimestamp: pingTimestamp,
- connActive: connActive,
- initialStreamWindow: m.config.DefaultWindowSize,
- streamWindowMax: m.config.MaxWindowSize,
- streamWriteBufferMaxLen: m.config.StreamWriteBufferMaxLen,
- r: m.r,
- metricsUpdater: m.muxMetricsUpdater,
- bytesRead: inBoundCounter,
- }
- m.muxWriter = &MuxWriter{
- f: m.f,
- streams: m.streams,
- streamErrors: streamErrors,
- readyStreamChan: m.readyList.ReadyChannel(),
- newStreamChan: m.newStreamChan,
- goAwayChan: goAwayChan,
- abortChan: m.abortChan,
- pingTimestamp: pingTimestamp,
- idleTimer: NewIdleTimer(idleDuration, maxRetries),
- connActiveChan: connActive.WaitChannel(),
- maxFrameSize: defaultFrameSize,
- metricsUpdater: m.muxMetricsUpdater,
- bytesWrote: outBoundCounter,
- }
- m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
- if m.compressionQuality.dictSize > 0 && m.compressionQuality.nDicts > 0 {
- nd, sz := m.compressionQuality.nDicts, m.compressionQuality.dictSize
- writeDicts, dictChan := newH2WriteDictionaries(
- nd,
- sz,
- m.compressionQuality.quality,
- compBytesBefore,
- compBytesAfter,
- )
- readDicts := newH2ReadDictionaries(nd, sz)
- m.muxReader.dictionaries = h2Dictionaries{read: &readDicts, write: writeDicts}
- m.muxWriter.useDictChan = dictChan
- }
- return m, nil
- }
- func (m *Muxer) readPeerSettings(magic uint32) error {
- frame, err := m.f.ReadFrame()
- if err != nil {
- return err
- }
- settingsFrame, ok := frame.(*http2.SettingsFrame)
- if !ok {
- return ErrBadHandshakeNotSettings
- }
- if settingsFrame.Header().Flags != 0 {
- return ErrBadHandshakeUnexpectedAck
- }
- peerMagic, ok := settingsFrame.Value(SettingMuxerMagic)
- if !ok {
- return ErrBadHandshakeNoMagic
- }
- if magic != peerMagic {
- return ErrBadHandshakeWrongMagic
- }
- peerCompression, ok := settingsFrame.Value(SettingCompression)
- if !ok {
- m.compressionQuality = compressionPresets[CompressionNone]
- return nil
- }
- ver, fmt, sz, nd := parseCompressionSettingVal(peerCompression)
- if ver != compressionVersion || fmt != compressionFormat || sz == 0 || nd == 0 {
- m.compressionQuality = compressionPresets[CompressionNone]
- return nil
- }
- // Values used for compression are the mimimum between the two peers
- if sz < m.compressionQuality.dictSize {
- m.compressionQuality.dictSize = sz
- }
- if nd < m.compressionQuality.nDicts {
- m.compressionQuality.nDicts = nd
- }
- return nil
- }
- func (m *Muxer) readPeerSettingsAck() error {
- frame, err := m.f.ReadFrame()
- if err != nil {
- return err
- }
- settingsFrame, ok := frame.(*http2.SettingsFrame)
- if !ok {
- return ErrBadHandshakeNotSettingsAck
- }
- if settingsFrame.Header().Flags != http2.FlagSettingsAck {
- return ErrBadHandshakeUnexpectedSettings
- }
- return nil
- }
- func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time.Duration, timeoutError error) error {
- for i := 0; i < receiveCount; i++ {
- select {
- case err := <-errChan:
- if err != nil {
- return err
- }
- case <-time.After(timeout):
- return timeoutError
- }
- }
- return nil
- }
- // Serve runs the event loops that comprise h2mux:
- // - MuxReader.run()
- // - MuxWriter.run()
- // - muxMetricsUpdater.run()
- // In the normal case, Shutdown() is called concurrently with Serve() to stop
- // these loops.
- func (m *Muxer) Serve(ctx context.Context) error {
- errGroup, _ := errgroup.WithContext(ctx)
- errGroup.Go(func() error {
- ch := make(chan error)
- go func() {
- err := m.muxReader.run(m.config.Log)
- m.explicitShutdown.Fuse(false)
- m.r.Close()
- m.abort()
- // don't block if parent goroutine quit early
- select {
- case ch <- err:
- default:
- }
- }()
- select {
- case err := <-ch:
- return err
- case <-ctx.Done():
- return ctx.Err()
- }
- })
- errGroup.Go(func() error {
- ch := make(chan error)
- go func() {
- err := m.muxWriter.run(m.config.Log)
- m.explicitShutdown.Fuse(false)
- m.w.Close()
- m.abort()
- // don't block if parent goroutine quit early
- select {
- case ch <- err:
- default:
- }
- }()
- select {
- case err := <-ch:
- return err
- case <-ctx.Done():
- return ctx.Err()
- }
- })
- errGroup.Go(func() error {
- ch := make(chan error)
- go func() {
- err := m.muxMetricsUpdater.run(m.config.Log)
- // don't block if parent goroutine quit early
- select {
- case ch <- err:
- default:
- }
- }()
- select {
- case err := <-ch:
- return err
- case <-ctx.Done():
- return ctx.Err()
- }
- })
- err := errGroup.Wait()
- if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) {
- return err
- }
- return nil
- }
- // Shutdown is called to initiate the "happy path" of muxer termination.
- // It blocks new streams from being created.
- // It returns a channel that is closed when the last stream has been closed.
- func (m *Muxer) Shutdown() <-chan struct{} {
- m.explicitShutdown.Fuse(true)
- return m.muxReader.Shutdown()
- }
- // IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
- // The set of expected errors change depending on whether we initiated shutdown or not.
- func isUnexpectedTunnelError(err error, expectedShutdown bool) bool {
- if err == nil {
- return false
- }
- if !expectedShutdown {
- return true
- }
- return !isConnectionClosedError(err)
- }
- func isConnectionClosedError(err error) bool {
- if err == io.EOF {
- return true
- }
- if err == io.ErrClosedPipe {
- return true
- }
- if err.Error() == "tls: use of closed connection" {
- return true
- }
- if strings.HasSuffix(err.Error(), "use of closed network connection") {
- return true
- }
- return false
- }
- // OpenStream opens a new data stream with the given headers.
- // Called by proxy server and tunnel
- func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) {
- stream := m.NewStream(headers)
- if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, body)); err != nil {
- return nil, err
- }
- if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
- return nil, err
- }
- return stream, nil
- }
- func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) {
- stream := m.NewStream(RPCHeaders())
- if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, nil)); err != nil {
- stream.Close()
- return nil, err
- }
- if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
- stream.Close()
- return nil, err
- }
- if !IsRPCStreamResponse(stream) {
- stream.Close()
- return nil, ErrNotRPCStream
- }
- return stream, nil
- }
- func (m *Muxer) NewStream(headers []Header) *MuxedStream {
- return NewStream(m.config, headers, m.readyList, m.muxReader.dictionaries)
- }
- func (m *Muxer) MakeMuxedStreamRequest(ctx context.Context, request MuxedStreamRequest) error {
- select {
- case <-ctx.Done():
- return ErrStreamRequestTimeout
- case <-m.abortChan:
- return ErrStreamRequestConnectionClosed
- // Will be received by mux writer
- case m.newStreamChan <- request:
- return nil
- }
- }
- func (m *Muxer) CloseStreamRead(stream *MuxedStream) {
- stream.CloseRead()
- if stream.WriteClosed() {
- m.streams.Delete(stream.streamID)
- }
- }
- func (m *Muxer) AwaitResponseHeaders(ctx context.Context, stream *MuxedStream) error {
- select {
- case <-ctx.Done():
- return ErrResponseHeadersTimeout
- case <-m.abortChan:
- return ErrResponseHeadersConnectionClosed
- case <-stream.responseHeadersReceived:
- return nil
- }
- }
- func (m *Muxer) Metrics() *MuxerMetrics {
- return m.muxMetricsUpdater.metrics()
- }
- func (m *Muxer) abort() {
- m.abortOnce.Do(func() {
- close(m.abortChan)
- m.readyList.Close()
- m.streams.Abort()
- })
- }
- // Return how many retries/ticks since the connection was last marked active
- func (m *Muxer) TimerRetries() uint64 {
- return m.muxWriter.idleTimer.RetryCount()
- }
- func IsRPCStreamResponse(stream *MuxedStream) bool {
- headers := stream.Headers
- return len(headers) == 1 &&
- headers[0].Name == ":status" &&
- headers[0].Value == "200"
- }
|