muxwriter.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. package h2mux
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "io"
  6. "time"
  7. "github.com/cloudflare/cloudflared/logger"
  8. "golang.org/x/net/http2"
  9. "golang.org/x/net/http2/hpack"
  10. )
  11. type MuxWriter struct {
  12. // f is used to write HTTP2 frames.
  13. f *http2.Framer
  14. // streams tracks currently-open streams.
  15. streams *activeStreamMap
  16. // streamErrors receives stream errors raised by the MuxReader.
  17. streamErrors *StreamErrorMap
  18. // readyStreamChan is used to multiplex writable streams onto the single connection.
  19. // When a stream becomes writable its ID is sent on this channel.
  20. readyStreamChan <-chan uint32
  21. // newStreamChan is used to create new streams with a given set of headers.
  22. newStreamChan <-chan MuxedStreamRequest
  23. // goAwayChan is used to send a single GOAWAY message to the peer. The element received
  24. // is the HTTP/2 error code to send.
  25. goAwayChan <-chan http2.ErrCode
  26. // abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
  27. abortChan <-chan struct{}
  28. // pingTimestamp is an atomic value containing the latest received ping timestamp.
  29. pingTimestamp *PingTimestamp
  30. // A timer used to measure idle connection time. Reset after sending data.
  31. idleTimer *IdleTimer
  32. // connActiveChan receives a signal that the connection received some (read) activity.
  33. connActiveChan <-chan struct{}
  34. // Maximum size of all frames that can be sent on this connection.
  35. maxFrameSize uint32
  36. // headerEncoder is the stateful header encoder for this connection
  37. headerEncoder *hpack.Encoder
  38. // headerBuffer is the temporary buffer used by headerEncoder.
  39. headerBuffer bytes.Buffer
  40. // metricsUpdater is used to report metrics
  41. metricsUpdater muxMetricsUpdater
  42. // bytesWrote is the amount of bytes written to data frames since the last time we called metricsUpdater.updateOutBoundBytes()
  43. bytesWrote *AtomicCounter
  44. useDictChan <-chan useDictRequest
  45. }
  46. type MuxedStreamRequest struct {
  47. stream *MuxedStream
  48. body io.Reader
  49. }
  50. func NewMuxedStreamRequest(stream *MuxedStream, body io.Reader) MuxedStreamRequest {
  51. return MuxedStreamRequest{
  52. stream: stream,
  53. body: body,
  54. }
  55. }
  56. func (r *MuxedStreamRequest) flushBody() {
  57. io.Copy(r.stream, r.body)
  58. r.stream.CloseWrite()
  59. }
  60. func tsToPingData(ts int64) [8]byte {
  61. pingData := [8]byte{}
  62. binary.LittleEndian.PutUint64(pingData[:], uint64(ts))
  63. return pingData
  64. }
  65. func (w *MuxWriter) run(logger logger.Service) error {
  66. defer logger.Debug("mux - write: event loop finished")
  67. // routine to periodically communicate bytesWrote
  68. go func() {
  69. tickC := time.Tick(updateFreq)
  70. for {
  71. select {
  72. case <-w.abortChan:
  73. return
  74. case <-tickC:
  75. w.metricsUpdater.updateOutBoundBytes(w.bytesWrote.Count())
  76. }
  77. }
  78. }()
  79. for {
  80. select {
  81. case <-w.abortChan:
  82. logger.Debug("mux - write: aborting writer thread")
  83. return nil
  84. case errCode := <-w.goAwayChan:
  85. logger.Debugf("mux - write: sending GOAWAY code %v", errCode)
  86. err := w.f.WriteGoAway(w.streams.LastPeerStreamID(), errCode, []byte{})
  87. if err != nil {
  88. return err
  89. }
  90. w.idleTimer.MarkActive()
  91. case <-w.pingTimestamp.GetUpdateChan():
  92. logger.Debug("mux - write: sending PING ACK")
  93. err := w.f.WritePing(true, tsToPingData(w.pingTimestamp.Get()))
  94. if err != nil {
  95. return err
  96. }
  97. w.idleTimer.MarkActive()
  98. case <-w.idleTimer.C:
  99. if !w.idleTimer.Retry() {
  100. return ErrConnectionDropped
  101. }
  102. logger.Debug("mux - write: sending PING")
  103. err := w.f.WritePing(false, tsToPingData(time.Now().UnixNano()))
  104. if err != nil {
  105. return err
  106. }
  107. w.idleTimer.ResetTimer()
  108. case <-w.connActiveChan:
  109. w.idleTimer.MarkActive()
  110. case <-w.streamErrors.GetSignalChan():
  111. for streamID, errCode := range w.streamErrors.GetErrors() {
  112. logger.Debugf("mux - write: resetting stream with code: %v streamID: %d", errCode, streamID)
  113. err := w.f.WriteRSTStream(streamID, errCode)
  114. if err != nil {
  115. return err
  116. }
  117. }
  118. w.idleTimer.MarkActive()
  119. case streamRequest := <-w.newStreamChan:
  120. streamID := w.streams.AcquireLocalID()
  121. streamRequest.stream.streamID = streamID
  122. if !w.streams.Set(streamRequest.stream) {
  123. // Race between OpenStream and Shutdown, and Shutdown won. Let Shutdown (and the eventual abort) take
  124. // care of this stream. Ideally we'd pass the error directly to the stream object somehow so the
  125. // caller can be unblocked sooner, but the value of that optimisation is minimal for most of the
  126. // reasons why you'd call Shutdown anyway.
  127. continue
  128. }
  129. if streamRequest.body != nil {
  130. go streamRequest.flushBody()
  131. }
  132. err := w.writeStreamData(streamRequest.stream, logger)
  133. if err != nil {
  134. return err
  135. }
  136. w.idleTimer.MarkActive()
  137. case streamID := <-w.readyStreamChan:
  138. stream, ok := w.streams.Get(streamID)
  139. if !ok {
  140. continue
  141. }
  142. err := w.writeStreamData(stream, logger)
  143. if err != nil {
  144. return err
  145. }
  146. w.idleTimer.MarkActive()
  147. case useDict := <-w.useDictChan:
  148. err := w.writeUseDictionary(useDict)
  149. if err != nil {
  150. logger.Errorf("mux - write: error writing use dictionary: %s", err)
  151. return err
  152. }
  153. w.idleTimer.MarkActive()
  154. }
  155. }
  156. }
  157. func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger logger.Service) error {
  158. logger.Debugf("mux - write: writable: streamID: %d", stream.streamID)
  159. chunk := stream.getChunk()
  160. w.metricsUpdater.updateReceiveWindow(stream.getReceiveWindow())
  161. w.metricsUpdater.updateSendWindow(stream.getSendWindow())
  162. if chunk.sendHeadersFrame() {
  163. err := w.writeHeaders(chunk.streamID, chunk.headers)
  164. if err != nil {
  165. logger.Errorf("mux - write: error writing headers: %s: streamID: %d", err, stream.streamID)
  166. return err
  167. }
  168. logger.Debugf("mux - write: output headers: streamID: %d", stream.streamID)
  169. }
  170. if chunk.sendWindowUpdateFrame() {
  171. // Send a WINDOW_UPDATE frame to update our receive window.
  172. // If the Stream ID is zero, the window update applies to the connection as a whole
  173. // RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST
  174. // always account for its contribution against the connection flow-control
  175. // window, unless the receiver treats this as a connection error"
  176. err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
  177. if err != nil {
  178. logger.Errorf("mux - write: error writing window update: %s: streamID: %d", err, stream.streamID)
  179. return err
  180. }
  181. logger.Debugf("mux - write: increment receive window by %d streamID: %d", chunk.windowUpdate, stream.streamID)
  182. }
  183. for chunk.sendDataFrame() {
  184. payload, sentEOF := chunk.nextDataFrame(int(w.maxFrameSize))
  185. err := w.f.WriteData(chunk.streamID, sentEOF, payload)
  186. if err != nil {
  187. logger.Errorf("mux - write: error writing data: %s: streamID: %d", err, stream.streamID)
  188. return err
  189. }
  190. // update the amount of data wrote
  191. w.bytesWrote.IncrementBy(uint64(len(payload)))
  192. logger.Debugf("mux - write: output data: %d: streamID: %d", len(payload), stream.streamID)
  193. if sentEOF {
  194. if stream.readBuffer.Closed() {
  195. // transition into closed state
  196. if !stream.gotReceiveEOF() {
  197. // the peer may send data that we no longer want to receive. Force them into the
  198. // closed state.
  199. logger.Debugf("mux - write: resetting stream: streamID: %d", stream.streamID)
  200. w.f.WriteRSTStream(chunk.streamID, http2.ErrCodeNo)
  201. } else {
  202. // Half-open stream transitioned into closed
  203. logger.Debugf("mux - write: closing stream: streamID: %d", stream.streamID)
  204. }
  205. w.streams.Delete(chunk.streamID)
  206. } else {
  207. logger.Debugf("mux - write: closing stream write side: streamID: %d", stream.streamID)
  208. }
  209. }
  210. }
  211. return nil
  212. }
  213. func (w *MuxWriter) encodeHeaders(headers []Header) ([]byte, error) {
  214. w.headerBuffer.Reset()
  215. for _, header := range headers {
  216. err := w.headerEncoder.WriteField(hpack.HeaderField{
  217. Name: header.Name,
  218. Value: header.Value,
  219. })
  220. if err != nil {
  221. return nil, err
  222. }
  223. }
  224. return w.headerBuffer.Bytes(), nil
  225. }
  226. // writeHeaders writes a block of encoded headers, splitting it into multiple frames if necessary.
  227. func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
  228. encodedHeaders, err := w.encodeHeaders(headers)
  229. if err != nil || len(encodedHeaders) == 0 {
  230. return err
  231. }
  232. blockSize := int(w.maxFrameSize)
  233. // CONTINUATION is unnecessary; the headers fit within the blockSize
  234. if len(encodedHeaders) < blockSize {
  235. return w.f.WriteHeaders(http2.HeadersFrameParam{
  236. StreamID: streamID,
  237. EndHeaders: true,
  238. BlockFragment: encodedHeaders,
  239. })
  240. }
  241. choppedHeaders := chopEncodedHeaders(encodedHeaders, blockSize)
  242. // len(choppedHeaders) is at least 2
  243. if err := w.f.WriteHeaders(http2.HeadersFrameParam{StreamID: streamID, EndHeaders: false, BlockFragment: choppedHeaders[0]}); err != nil {
  244. return err
  245. }
  246. for i := 1; i < len(choppedHeaders)-1; i++ {
  247. if err := w.f.WriteContinuation(streamID, false, choppedHeaders[i]); err != nil {
  248. return err
  249. }
  250. }
  251. if err := w.f.WriteContinuation(streamID, true, choppedHeaders[len(choppedHeaders)-1]); err != nil {
  252. return err
  253. }
  254. return nil
  255. }
  256. // Partition a slice of bytes into `len(slice) / blockSize` slices of length `blockSize`
  257. func chopEncodedHeaders(headers []byte, chunkSize int) [][]byte {
  258. var divided [][]byte
  259. for i := 0; i < len(headers); i += chunkSize {
  260. end := i + chunkSize
  261. if end > len(headers) {
  262. end = len(headers)
  263. }
  264. divided = append(divided, headers[i:end])
  265. }
  266. return divided
  267. }
  268. func (w *MuxWriter) writeUseDictionary(dictRequest useDictRequest) error {
  269. err := w.f.WriteRawFrame(FrameUseDictionary, 0, dictRequest.streamID, []byte{byte(dictRequest.dictID)})
  270. if err != nil {
  271. return err
  272. }
  273. payload := make([]byte, 0, 64)
  274. for _, set := range dictRequest.setDict {
  275. payload = append(payload, byte(set.dictID))
  276. payload = appendVarInt(payload, 7, uint64(set.dictSZ))
  277. payload = append(payload, 0x80) // E = 1, D = 0, Truncate = 0
  278. }
  279. err = w.f.WriteRawFrame(FrameSetDictionary, FlagSetDictionaryAppend, dictRequest.streamID, payload)
  280. return err
  281. }