h2mux.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. package h2mux
  2. import (
  3. "context"
  4. "io"
  5. "strings"
  6. "sync"
  7. "time"
  8. "github.com/prometheus/client_golang/prometheus"
  9. "github.com/rs/zerolog"
  10. "golang.org/x/net/http2"
  11. "golang.org/x/net/http2/hpack"
  12. "golang.org/x/sync/errgroup"
  13. )
  14. const (
  15. defaultFrameSize uint32 = 1 << 14 // Minimum frame size in http2 spec
  16. defaultWindowSize uint32 = (1 << 16) - 1 // Minimum window size in http2 spec
  17. maxWindowSize uint32 = (1 << 31) - 1 // 2^31-1 = 2147483647, max window size in http2 spec
  18. defaultTimeout time.Duration = 5 * time.Second
  19. defaultRetries uint64 = 5
  20. defaultWriteBufferMaxLen int = 1024 * 1024 // 1mb
  21. writeBufferInitialSize int = 16 * 1024 // 16KB
  22. SettingMuxerMagic http2.SettingID = 0x42db
  23. MuxerMagicOrigin uint32 = 0xa2e43c8b
  24. MuxerMagicEdge uint32 = 0x1088ebf9
  25. )
  26. type MuxedStreamHandler interface {
  27. ServeStream(*MuxedStream) error
  28. }
  29. type MuxedStreamFunc func(stream *MuxedStream) error
  30. func (f MuxedStreamFunc) ServeStream(stream *MuxedStream) error {
  31. return f(stream)
  32. }
  33. type MuxerConfig struct {
  34. Timeout time.Duration
  35. Handler MuxedStreamHandler
  36. IsClient bool
  37. // Name is used to identify this muxer instance when logging.
  38. Name string
  39. // The minimum time this connection can be idle before sending a heartbeat.
  40. HeartbeatInterval time.Duration
  41. // The minimum number of heartbeats to send before terminating the connection.
  42. MaxHeartbeats uint64
  43. // Logger to use
  44. Log *zerolog.Logger
  45. CompressionQuality CompressionSetting
  46. // Initial size for HTTP2 flow control windows
  47. DefaultWindowSize uint32
  48. // Largest allowable size for HTTP2 flow control windows
  49. MaxWindowSize uint32
  50. // Largest allowable capacity for the buffer of data to be sent
  51. StreamWriteBufferMaxLen int
  52. }
  53. type Muxer struct {
  54. // f is used to read and write HTTP2 frames on the wire.
  55. f *http2.Framer
  56. // config is the MuxerConfig given in Handshake.
  57. config MuxerConfig
  58. // w, r are references to the underlying connection used.
  59. w io.WriteCloser
  60. r io.ReadCloser
  61. // muxReader is the read process.
  62. muxReader *MuxReader
  63. // muxWriter is the write process.
  64. muxWriter *MuxWriter
  65. // muxMetricsUpdater is the process to update metrics
  66. muxMetricsUpdater muxMetricsUpdater
  67. // newStreamChan is used to create new streams on the writer thread.
  68. // The writer will assign the next available stream ID.
  69. newStreamChan chan MuxedStreamRequest
  70. // abortChan is used to abort the writer event loop.
  71. abortChan chan struct{}
  72. // abortOnce is used to ensure abortChan is closed once only.
  73. abortOnce sync.Once
  74. // readyList is used to signal writable streams.
  75. readyList *ReadyList
  76. // streams tracks currently-open streams.
  77. streams *activeStreamMap
  78. // explicitShutdown records whether the Muxer is closing because Shutdown was called, or due to another
  79. // error.
  80. explicitShutdown *BooleanFuse
  81. compressionQuality CompressionPreset
  82. }
  83. func RPCHeaders() []Header {
  84. return []Header{
  85. {Name: ":method", Value: "RPC"},
  86. {Name: ":scheme", Value: "capnp"},
  87. {Name: ":path", Value: "*"},
  88. }
  89. }
  90. // Handshake establishes a muxed connection with the peer.
  91. // After the handshake completes, it is possible to open and accept streams.
  92. func Handshake(
  93. w io.WriteCloser,
  94. r io.ReadCloser,
  95. config MuxerConfig,
  96. activeStreamsMetrics prometheus.Gauge,
  97. ) (*Muxer, error) {
  98. // Set default config values
  99. if config.Timeout == 0 {
  100. config.Timeout = defaultTimeout
  101. }
  102. if config.DefaultWindowSize == 0 {
  103. config.DefaultWindowSize = defaultWindowSize
  104. }
  105. if config.MaxWindowSize == 0 {
  106. config.MaxWindowSize = maxWindowSize
  107. }
  108. if config.StreamWriteBufferMaxLen == 0 {
  109. config.StreamWriteBufferMaxLen = defaultWriteBufferMaxLen
  110. }
  111. // Initialise connection state fields
  112. m := &Muxer{
  113. f: http2.NewFramer(w, r), // A framer that writes to w and reads from r
  114. config: config,
  115. w: w,
  116. r: r,
  117. newStreamChan: make(chan MuxedStreamRequest),
  118. abortChan: make(chan struct{}),
  119. readyList: NewReadyList(),
  120. streams: newActiveStreamMap(config.IsClient, activeStreamsMetrics),
  121. }
  122. m.f.ReadMetaHeaders = hpack.NewDecoder(4096, func(hpack.HeaderField) {})
  123. // Initialise the settings to identify this connection and confirm the other end is sane.
  124. handshakeSetting := http2.Setting{ID: SettingMuxerMagic, Val: MuxerMagicEdge}
  125. compressionSetting := http2.Setting{ID: SettingCompression, Val: config.CompressionQuality.toH2Setting()}
  126. if CompressionIsSupported() {
  127. config.Log.Debug().Msg("muxer: Compression is supported")
  128. m.compressionQuality = config.CompressionQuality.getPreset()
  129. } else {
  130. config.Log.Debug().Msg("muxer: Compression is not supported")
  131. compressionSetting = http2.Setting{ID: SettingCompression, Val: 0}
  132. }
  133. expectedMagic := MuxerMagicOrigin
  134. if config.IsClient {
  135. handshakeSetting.Val = MuxerMagicOrigin
  136. expectedMagic = MuxerMagicEdge
  137. }
  138. errChan := make(chan error, 2)
  139. // Simultaneously send our settings and verify the peer's settings.
  140. go func() { errChan <- m.f.WriteSettings(handshakeSetting, compressionSetting) }()
  141. go func() { errChan <- m.readPeerSettings(expectedMagic) }()
  142. err := joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
  143. if err != nil {
  144. return nil, err
  145. }
  146. // Confirm sanity by ACKing the frame and expecting an ACK for our frame.
  147. // Not strictly necessary, but let's pretend to be H2-like.
  148. go func() { errChan <- m.f.WriteSettingsAck() }()
  149. go func() { errChan <- m.readPeerSettingsAck() }()
  150. err = joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
  151. if err != nil {
  152. return nil, err
  153. }
  154. // set up reader/writer pair ready for serve
  155. streamErrors := NewStreamErrorMap()
  156. goAwayChan := make(chan http2.ErrCode, 1)
  157. inBoundCounter := NewAtomicCounter(0)
  158. outBoundCounter := NewAtomicCounter(0)
  159. pingTimestamp := NewPingTimestamp()
  160. connActive := NewSignal()
  161. idleDuration := config.HeartbeatInterval
  162. // Sanity check to enusre idelDuration is sane
  163. if idleDuration == 0 || idleDuration < defaultTimeout {
  164. idleDuration = defaultTimeout
  165. config.Log.Info().Msgf("muxer: Minimum idle time has been adjusted to %d", defaultTimeout)
  166. }
  167. maxRetries := config.MaxHeartbeats
  168. if maxRetries == 0 {
  169. maxRetries = defaultRetries
  170. config.Log.Info().Msgf("muxer: Minimum number of unacked heartbeats to send before closing the connection has been adjusted to %d", maxRetries)
  171. }
  172. compBytesBefore, compBytesAfter := NewAtomicCounter(0), NewAtomicCounter(0)
  173. m.muxMetricsUpdater = newMuxMetricsUpdater(
  174. m.abortChan,
  175. compBytesBefore,
  176. compBytesAfter,
  177. )
  178. m.explicitShutdown = NewBooleanFuse()
  179. m.muxReader = &MuxReader{
  180. f: m.f,
  181. handler: m.config.Handler,
  182. streams: m.streams,
  183. readyList: m.readyList,
  184. streamErrors: streamErrors,
  185. goAwayChan: goAwayChan,
  186. abortChan: m.abortChan,
  187. pingTimestamp: pingTimestamp,
  188. connActive: connActive,
  189. initialStreamWindow: m.config.DefaultWindowSize,
  190. streamWindowMax: m.config.MaxWindowSize,
  191. streamWriteBufferMaxLen: m.config.StreamWriteBufferMaxLen,
  192. r: m.r,
  193. metricsUpdater: m.muxMetricsUpdater,
  194. bytesRead: inBoundCounter,
  195. }
  196. m.muxWriter = &MuxWriter{
  197. f: m.f,
  198. streams: m.streams,
  199. streamErrors: streamErrors,
  200. readyStreamChan: m.readyList.ReadyChannel(),
  201. newStreamChan: m.newStreamChan,
  202. goAwayChan: goAwayChan,
  203. abortChan: m.abortChan,
  204. pingTimestamp: pingTimestamp,
  205. idleTimer: NewIdleTimer(idleDuration, maxRetries),
  206. connActiveChan: connActive.WaitChannel(),
  207. maxFrameSize: defaultFrameSize,
  208. metricsUpdater: m.muxMetricsUpdater,
  209. bytesWrote: outBoundCounter,
  210. }
  211. m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
  212. if m.compressionQuality.dictSize > 0 && m.compressionQuality.nDicts > 0 {
  213. nd, sz := m.compressionQuality.nDicts, m.compressionQuality.dictSize
  214. writeDicts, dictChan := newH2WriteDictionaries(
  215. nd,
  216. sz,
  217. m.compressionQuality.quality,
  218. compBytesBefore,
  219. compBytesAfter,
  220. )
  221. readDicts := newH2ReadDictionaries(nd, sz)
  222. m.muxReader.dictionaries = h2Dictionaries{read: &readDicts, write: writeDicts}
  223. m.muxWriter.useDictChan = dictChan
  224. }
  225. return m, nil
  226. }
  227. func (m *Muxer) readPeerSettings(magic uint32) error {
  228. frame, err := m.f.ReadFrame()
  229. if err != nil {
  230. return err
  231. }
  232. settingsFrame, ok := frame.(*http2.SettingsFrame)
  233. if !ok {
  234. return ErrBadHandshakeNotSettings
  235. }
  236. if settingsFrame.Header().Flags != 0 {
  237. return ErrBadHandshakeUnexpectedAck
  238. }
  239. peerMagic, ok := settingsFrame.Value(SettingMuxerMagic)
  240. if !ok {
  241. return ErrBadHandshakeNoMagic
  242. }
  243. if magic != peerMagic {
  244. return ErrBadHandshakeWrongMagic
  245. }
  246. peerCompression, ok := settingsFrame.Value(SettingCompression)
  247. if !ok {
  248. m.compressionQuality = compressionPresets[CompressionNone]
  249. return nil
  250. }
  251. ver, fmt, sz, nd := parseCompressionSettingVal(peerCompression)
  252. if ver != compressionVersion || fmt != compressionFormat || sz == 0 || nd == 0 {
  253. m.compressionQuality = compressionPresets[CompressionNone]
  254. return nil
  255. }
  256. // Values used for compression are the mimimum between the two peers
  257. if sz < m.compressionQuality.dictSize {
  258. m.compressionQuality.dictSize = sz
  259. }
  260. if nd < m.compressionQuality.nDicts {
  261. m.compressionQuality.nDicts = nd
  262. }
  263. return nil
  264. }
  265. func (m *Muxer) readPeerSettingsAck() error {
  266. frame, err := m.f.ReadFrame()
  267. if err != nil {
  268. return err
  269. }
  270. settingsFrame, ok := frame.(*http2.SettingsFrame)
  271. if !ok {
  272. return ErrBadHandshakeNotSettingsAck
  273. }
  274. if settingsFrame.Header().Flags != http2.FlagSettingsAck {
  275. return ErrBadHandshakeUnexpectedSettings
  276. }
  277. return nil
  278. }
  279. func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time.Duration, timeoutError error) error {
  280. for i := 0; i < receiveCount; i++ {
  281. select {
  282. case err := <-errChan:
  283. if err != nil {
  284. return err
  285. }
  286. case <-time.After(timeout):
  287. return timeoutError
  288. }
  289. }
  290. return nil
  291. }
  292. // Serve runs the event loops that comprise h2mux:
  293. // - MuxReader.run()
  294. // - MuxWriter.run()
  295. // - muxMetricsUpdater.run()
  296. // In the normal case, Shutdown() is called concurrently with Serve() to stop
  297. // these loops.
  298. func (m *Muxer) Serve(ctx context.Context) error {
  299. errGroup, _ := errgroup.WithContext(ctx)
  300. errGroup.Go(func() error {
  301. ch := make(chan error)
  302. go func() {
  303. err := m.muxReader.run(m.config.Log)
  304. m.explicitShutdown.Fuse(false)
  305. m.r.Close()
  306. m.abort()
  307. // don't block if parent goroutine quit early
  308. select {
  309. case ch <- err:
  310. default:
  311. }
  312. }()
  313. select {
  314. case err := <-ch:
  315. return err
  316. case <-ctx.Done():
  317. return ctx.Err()
  318. }
  319. })
  320. errGroup.Go(func() error {
  321. ch := make(chan error)
  322. go func() {
  323. err := m.muxWriter.run(m.config.Log)
  324. m.explicitShutdown.Fuse(false)
  325. m.w.Close()
  326. m.abort()
  327. // don't block if parent goroutine quit early
  328. select {
  329. case ch <- err:
  330. default:
  331. }
  332. }()
  333. select {
  334. case err := <-ch:
  335. return err
  336. case <-ctx.Done():
  337. return ctx.Err()
  338. }
  339. })
  340. errGroup.Go(func() error {
  341. ch := make(chan error)
  342. go func() {
  343. err := m.muxMetricsUpdater.run(m.config.Log)
  344. // don't block if parent goroutine quit early
  345. select {
  346. case ch <- err:
  347. default:
  348. }
  349. }()
  350. select {
  351. case err := <-ch:
  352. return err
  353. case <-ctx.Done():
  354. return ctx.Err()
  355. }
  356. })
  357. err := errGroup.Wait()
  358. if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) {
  359. return err
  360. }
  361. return nil
  362. }
  363. // Shutdown is called to initiate the "happy path" of muxer termination.
  364. // It blocks new streams from being created.
  365. // It returns a channel that is closed when the last stream has been closed.
  366. func (m *Muxer) Shutdown() <-chan struct{} {
  367. m.explicitShutdown.Fuse(true)
  368. return m.muxReader.Shutdown()
  369. }
  370. // IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
  371. // The set of expected errors change depending on whether we initiated shutdown or not.
  372. func isUnexpectedTunnelError(err error, expectedShutdown bool) bool {
  373. if err == nil {
  374. return false
  375. }
  376. if !expectedShutdown {
  377. return true
  378. }
  379. return !isConnectionClosedError(err)
  380. }
  381. func isConnectionClosedError(err error) bool {
  382. if err == io.EOF {
  383. return true
  384. }
  385. if err == io.ErrClosedPipe {
  386. return true
  387. }
  388. if err.Error() == "tls: use of closed connection" {
  389. return true
  390. }
  391. if strings.HasSuffix(err.Error(), "use of closed network connection") {
  392. return true
  393. }
  394. return false
  395. }
  396. // OpenStream opens a new data stream with the given headers.
  397. // Called by proxy server and tunnel
  398. func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) {
  399. stream := m.NewStream(headers)
  400. if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, body)); err != nil {
  401. return nil, err
  402. }
  403. if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
  404. return nil, err
  405. }
  406. return stream, nil
  407. }
  408. func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) {
  409. stream := m.NewStream(RPCHeaders())
  410. if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, nil)); err != nil {
  411. stream.Close()
  412. return nil, err
  413. }
  414. if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
  415. stream.Close()
  416. return nil, err
  417. }
  418. if !IsRPCStreamResponse(stream) {
  419. stream.Close()
  420. return nil, ErrNotRPCStream
  421. }
  422. return stream, nil
  423. }
  424. func (m *Muxer) NewStream(headers []Header) *MuxedStream {
  425. return NewStream(m.config, headers, m.readyList, m.muxReader.dictionaries)
  426. }
  427. func (m *Muxer) MakeMuxedStreamRequest(ctx context.Context, request MuxedStreamRequest) error {
  428. select {
  429. case <-ctx.Done():
  430. return ErrStreamRequestTimeout
  431. case <-m.abortChan:
  432. return ErrStreamRequestConnectionClosed
  433. // Will be received by mux writer
  434. case m.newStreamChan <- request:
  435. return nil
  436. }
  437. }
  438. func (m *Muxer) CloseStreamRead(stream *MuxedStream) {
  439. stream.CloseRead()
  440. if stream.WriteClosed() {
  441. m.streams.Delete(stream.streamID)
  442. }
  443. }
  444. func (m *Muxer) AwaitResponseHeaders(ctx context.Context, stream *MuxedStream) error {
  445. select {
  446. case <-ctx.Done():
  447. return ErrResponseHeadersTimeout
  448. case <-m.abortChan:
  449. return ErrResponseHeadersConnectionClosed
  450. case <-stream.responseHeadersReceived:
  451. return nil
  452. }
  453. }
  454. func (m *Muxer) Metrics() *MuxerMetrics {
  455. return m.muxMetricsUpdater.metrics()
  456. }
  457. func (m *Muxer) abort() {
  458. m.abortOnce.Do(func() {
  459. close(m.abortChan)
  460. m.readyList.Close()
  461. m.streams.Abort()
  462. })
  463. }
  464. // Return how many retries/ticks since the connection was last marked active
  465. func (m *Muxer) TimerRetries() uint64 {
  466. return m.muxWriter.idleTimer.RetryCount()
  467. }
  468. func IsRPCStreamResponse(stream *MuxedStream) bool {
  469. headers := stream.Headers
  470. return len(headers) == 1 &&
  471. headers[0].Name == ":status" &&
  472. headers[0].Value == "200"
  473. }