h2_dictionaries.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. package h2mux
  2. import (
  3. "bytes"
  4. "io"
  5. "strings"
  6. "sync"
  7. "golang.org/x/net/http2"
  8. )
  9. /* This is an implementation of https://github.com/vkrasnov/h2-compression-dictionaries
  10. but modified for tunnels in a few key ways:
  11. Since tunnels is a server-to-server service, some aspects of the spec would cause
  12. unnessasary head-of-line blocking on the CPU and on the network, hence this implementation
  13. allows for parallel compression on the "client", and buffering on the "server" to solve
  14. this problem. */
  15. // Assign temporary values
  16. const SettingCompression http2.SettingID = 0xff20
  17. const (
  18. FrameSetCompressionContext http2.FrameType = 0xf0
  19. FrameUseDictionary http2.FrameType = 0xf1
  20. FrameSetDictionary http2.FrameType = 0xf2
  21. )
  22. const (
  23. FlagSetDictionaryAppend http2.Flags = 0x1
  24. FlagSetDictionaryOffset http2.Flags = 0x2
  25. )
  26. const compressionVersion = uint8(1)
  27. const compressionFormat = uint8(2)
  28. type CompressionSetting uint
  29. const (
  30. CompressionNone CompressionSetting = iota
  31. CompressionLow
  32. CompressionMedium
  33. CompressionMax
  34. )
  35. type CompressionPreset struct {
  36. nDicts, dictSize, quality uint8
  37. }
  38. type compressor interface {
  39. Write([]byte) (int, error)
  40. Flush() error
  41. SetDictionary([]byte)
  42. Close() error
  43. }
  44. type decompressor interface {
  45. Read([]byte) (int, error)
  46. SetDictionary([]byte)
  47. Close() error
  48. }
  49. var compressionPresets = map[CompressionSetting]CompressionPreset{
  50. CompressionNone: {0, 0, 0},
  51. CompressionLow: {32, 17, 5},
  52. CompressionMedium: {64, 18, 6},
  53. CompressionMax: {255, 19, 9},
  54. }
  55. func compressionSettingVal(version, fmt, sz, nd uint8) uint32 {
  56. // Currently the compression settings are inlcude:
  57. // * version: only 1 is supported
  58. // * fmt: only 2 for brotli is supported
  59. // * sz: log2 of the maximal allowed dictionary size
  60. // * nd: max allowed number of dictionaries
  61. return uint32(version)<<24 + uint32(fmt)<<16 + uint32(sz)<<8 + uint32(nd)
  62. }
  63. func parseCompressionSettingVal(setting uint32) (version, fmt, sz, nd uint8) {
  64. version = uint8(setting >> 24)
  65. fmt = uint8(setting >> 16)
  66. sz = uint8(setting >> 8)
  67. nd = uint8(setting)
  68. return
  69. }
  70. func (c CompressionSetting) toH2Setting() uint32 {
  71. p, ok := compressionPresets[c]
  72. if !ok {
  73. return 0
  74. }
  75. return compressionSettingVal(compressionVersion, compressionFormat, p.dictSize, p.nDicts)
  76. }
  77. func (c CompressionSetting) getPreset() CompressionPreset {
  78. return compressionPresets[c]
  79. }
  80. type dictUpdate struct {
  81. reader *h2DictionaryReader
  82. dictionary *h2ReadDictionary
  83. buff []byte
  84. isReady bool
  85. isUse bool
  86. s setDictRequest
  87. }
  88. type h2ReadDictionary struct {
  89. dictionary []byte
  90. queue []*dictUpdate
  91. maxSize int
  92. }
  93. type h2ReadDictionaries struct {
  94. d []h2ReadDictionary
  95. maxSize int
  96. }
  97. type h2DictionaryReader struct {
  98. *SharedBuffer // Propagate the decompressed output into the original buffer
  99. decompBuffer *bytes.Buffer // Intermediate buffer for the brotli compressor
  100. dictionary []byte // The content of the dictionary being used by this reader
  101. internalBuffer []byte
  102. s, e int // Start and end of the buffer
  103. decomp decompressor // The brotli compressor
  104. isClosed bool // Indicates that Close was called for this reader
  105. queue []*dictUpdate // List of dictionaries to update, when the data is available
  106. }
  107. type h2WriteDictionary []byte
  108. type setDictRequest struct {
  109. streamID uint32
  110. dictID uint8
  111. dictSZ uint64
  112. truncate, offset uint64
  113. P, E, D bool
  114. }
  115. type useDictRequest struct {
  116. dictID uint8
  117. streamID uint32
  118. setDict []setDictRequest
  119. }
  120. type h2WriteDictionaries struct {
  121. dictLock sync.Mutex
  122. dictChan chan useDictRequest
  123. dictionaries []h2WriteDictionary
  124. nextAvail int // next unused dictionary slot
  125. maxAvail int // max ID, defined by SETTINGS
  126. maxSize int // max size, defined by SETTINGS
  127. typeToDict map[string]uint8 // map from content type to dictionary that encodes it
  128. pathToDict map[string]uint8 // map from path to dictionary that encodes it
  129. quality int
  130. window int
  131. compIn, compOut *AtomicCounter
  132. }
  133. type h2DictWriter struct {
  134. *bytes.Buffer
  135. comp compressor
  136. dicts *h2WriteDictionaries
  137. writerLock sync.Mutex
  138. streamID uint32
  139. path string
  140. contentType string
  141. }
  142. type h2Dictionaries struct {
  143. write *h2WriteDictionaries
  144. read *h2ReadDictionaries
  145. }
  146. func (o *dictUpdate) update(buff []byte) {
  147. o.buff = make([]byte, len(buff))
  148. copy(o.buff, buff)
  149. o.isReady = true
  150. }
  151. func (d *h2ReadDictionary) update() {
  152. for len(d.queue) > 0 {
  153. o := d.queue[0]
  154. if !o.isReady {
  155. break
  156. }
  157. if o.isUse {
  158. reader := o.reader
  159. reader.dictionary = make([]byte, len(d.dictionary))
  160. copy(reader.dictionary, d.dictionary)
  161. reader.decomp = newDecompressor(reader.decompBuffer)
  162. if len(reader.dictionary) > 0 {
  163. reader.decomp.SetDictionary(reader.dictionary)
  164. }
  165. reader.Write([]byte{})
  166. } else {
  167. d.dictionary = adjustDictionary(d.dictionary, o.buff, o.s, d.maxSize)
  168. }
  169. d.queue = d.queue[1:]
  170. }
  171. }
  172. func newH2ReadDictionaries(nd, sz uint8) h2ReadDictionaries {
  173. d := make([]h2ReadDictionary, int(nd))
  174. for i := range d {
  175. d[i].maxSize = 1 << uint(sz)
  176. }
  177. return h2ReadDictionaries{d: d, maxSize: 1 << uint(sz)}
  178. }
  179. func (dicts *h2ReadDictionaries) getDictByID(dictID uint8) (*h2ReadDictionary, error) {
  180. if int(dictID) > len(dicts.d) {
  181. return nil, MuxerStreamError{"dictID too big", http2.ErrCodeProtocol}
  182. }
  183. return &dicts.d[dictID], nil
  184. }
  185. func (dicts *h2ReadDictionaries) newReader(b *SharedBuffer, dictID uint8) *h2DictionaryReader {
  186. if int(dictID) > len(dicts.d) {
  187. return nil
  188. }
  189. dictionary := &dicts.d[dictID]
  190. reader := &h2DictionaryReader{SharedBuffer: b, decompBuffer: &bytes.Buffer{}, internalBuffer: make([]byte, dicts.maxSize)}
  191. if len(dictionary.queue) == 0 {
  192. reader.dictionary = make([]byte, len(dictionary.dictionary))
  193. copy(reader.dictionary, dictionary.dictionary)
  194. reader.decomp = newDecompressor(reader.decompBuffer)
  195. if len(reader.dictionary) > 0 {
  196. reader.decomp.SetDictionary(reader.dictionary)
  197. }
  198. } else {
  199. dictionary.queue = append(dictionary.queue, &dictUpdate{isUse: true, isReady: true, reader: reader})
  200. }
  201. return reader
  202. }
  203. func (r *h2DictionaryReader) updateWaitingDictionaries() {
  204. // Update all the waiting dictionaries
  205. for _, o := range r.queue {
  206. if o.isReady {
  207. continue
  208. }
  209. if r.isClosed || uint64(r.e) >= o.s.dictSZ {
  210. o.update(r.internalBuffer[:r.e])
  211. if o == o.dictionary.queue[0] {
  212. defer o.dictionary.update()
  213. }
  214. }
  215. }
  216. }
  217. // Write actually happens when reading from network, this is therefore the stage where we decompress the buffer
  218. func (r *h2DictionaryReader) Write(p []byte) (n int, err error) {
  219. // Every write goes into brotli buffer first
  220. n, err = r.decompBuffer.Write(p)
  221. if err != nil {
  222. return
  223. }
  224. if r.decomp == nil {
  225. return
  226. }
  227. for {
  228. m, err := r.decomp.Read(r.internalBuffer[r.e:])
  229. if err != nil && err != io.EOF {
  230. r.SharedBuffer.Close()
  231. r.decomp.Close()
  232. return n, err
  233. }
  234. r.SharedBuffer.Write(r.internalBuffer[r.e : r.e+m])
  235. r.e += m
  236. if m == 0 {
  237. break
  238. }
  239. if r.e == len(r.internalBuffer) {
  240. r.updateWaitingDictionaries()
  241. r.e = 0
  242. }
  243. }
  244. r.updateWaitingDictionaries()
  245. if r.isClosed {
  246. r.SharedBuffer.Close()
  247. r.decomp.Close()
  248. }
  249. return
  250. }
  251. func (r *h2DictionaryReader) Close() error {
  252. if r.isClosed {
  253. return nil
  254. }
  255. r.isClosed = true
  256. r.Write([]byte{})
  257. return nil
  258. }
  259. var compressibleTypes = map[string]bool{
  260. "application/atom+xml": true,
  261. "application/javascript": true,
  262. "application/json": true,
  263. "application/ld+json": true,
  264. "application/manifest+json": true,
  265. "application/rss+xml": true,
  266. "application/vnd.geo+json": true,
  267. "application/vnd.ms-fontobject": true,
  268. "application/x-font-ttf": true,
  269. "application/x-yaml": true,
  270. "application/x-web-app-manifest+json": true,
  271. "application/xhtml+xml": true,
  272. "application/xml": true,
  273. "font/opentype": true,
  274. "image/bmp": true,
  275. "image/svg+xml": true,
  276. "image/x-icon": true,
  277. "text/cache-manifest": true,
  278. "text/css": true,
  279. "text/html": true,
  280. "text/plain": true,
  281. "text/vcard": true,
  282. "text/vnd.rim.location.xloc": true,
  283. "text/vtt": true,
  284. "text/x-component": true,
  285. "text/x-cross-domain-policy": true,
  286. "text/x-yaml": true,
  287. }
  288. func getContentType(headers []Header) string {
  289. for _, h := range headers {
  290. if strings.ToLower(h.Name) == "content-type" {
  291. val := strings.ToLower(h.Value)
  292. sep := strings.IndexRune(val, ';')
  293. if sep != -1 {
  294. return val[:sep]
  295. }
  296. return val
  297. }
  298. }
  299. return ""
  300. }
  301. func newH2WriteDictionaries(nd, sz, quality uint8, compIn, compOut *AtomicCounter) (*h2WriteDictionaries, chan useDictRequest) {
  302. useDictChan := make(chan useDictRequest)
  303. return &h2WriteDictionaries{
  304. dictionaries: make([]h2WriteDictionary, nd),
  305. nextAvail: 0,
  306. maxAvail: int(nd),
  307. maxSize: 1 << uint(sz),
  308. dictChan: useDictChan,
  309. typeToDict: make(map[string]uint8),
  310. pathToDict: make(map[string]uint8),
  311. quality: int(quality),
  312. window: 1 << uint(sz+1),
  313. compIn: compIn,
  314. compOut: compOut,
  315. }, useDictChan
  316. }
  317. func adjustDictionary(currentDictionary, newData []byte, set setDictRequest, maxSize int) []byte {
  318. currentDictionary = append(currentDictionary, newData[:set.dictSZ]...)
  319. if len(currentDictionary) > maxSize {
  320. currentDictionary = currentDictionary[len(currentDictionary)-maxSize:]
  321. }
  322. return currentDictionary
  323. }
  324. func (h2d *h2WriteDictionaries) getNextDictID() (dictID uint8, ok bool) {
  325. if h2d.nextAvail < h2d.maxAvail {
  326. dictID, ok = uint8(h2d.nextAvail), true
  327. h2d.nextAvail++
  328. return
  329. }
  330. return 0, false
  331. }
  332. func (h2d *h2WriteDictionaries) getGenericDictID() (dictID uint8, ok bool) {
  333. if h2d.maxAvail == 0 {
  334. return 0, false
  335. }
  336. return uint8(h2d.maxAvail - 1), true
  337. }
  338. func (h2d *h2WriteDictionaries) getDictWriter(s *MuxedStream, headers []Header) *h2DictWriter {
  339. w := s.writeBuffer
  340. if w == nil {
  341. return nil
  342. }
  343. if s.method != "GET" && s.method != "POST" {
  344. return nil
  345. }
  346. s.contentType = getContentType(headers)
  347. if _, ok := compressibleTypes[s.contentType]; !ok && !strings.HasPrefix(s.contentType, "text") {
  348. return nil
  349. }
  350. return &h2DictWriter{
  351. Buffer: w.(*bytes.Buffer),
  352. path: s.path,
  353. contentType: s.contentType,
  354. streamID: s.streamID,
  355. dicts: h2d,
  356. }
  357. }
  358. func assignDictToStream(s *MuxedStream, p []byte) bool {
  359. // On first write to stream:
  360. // * assign the right dictionary
  361. // * update relevant dictionaries
  362. // * send the required USE_DICT and SET_DICT frames
  363. h2d := s.dictionaries.write
  364. if h2d == nil {
  365. return false
  366. }
  367. w, ok := s.writeBuffer.(*h2DictWriter)
  368. if !ok || w.comp != nil {
  369. return false
  370. }
  371. h2d.dictLock.Lock()
  372. if w.comp != nil {
  373. // Check again with lock, in therory the inteface allows for unordered writes
  374. h2d.dictLock.Unlock()
  375. return false
  376. }
  377. // The logic of dictionary generation is below
  378. // Is there a dictionary for the exact path or content-type?
  379. var useID uint8
  380. pathID, pathFound := h2d.pathToDict[w.path]
  381. typeID, typeFound := h2d.typeToDict[w.contentType]
  382. if pathFound {
  383. // Use dictionary for path as top priority
  384. useID = pathID
  385. if !typeFound { // Shouldn't really happen, unless type changes between requests
  386. typeID, typeFound = h2d.getNextDictID()
  387. if typeFound {
  388. h2d.typeToDict[w.contentType] = typeID
  389. }
  390. }
  391. } else if typeFound {
  392. // Use dictionary for same content type as second priority
  393. useID = typeID
  394. pathID, pathFound = h2d.getNextDictID()
  395. if pathFound { // If a slot is available, generate new dictionary for path
  396. h2d.pathToDict[w.path] = pathID
  397. }
  398. } else {
  399. // Use the overflow dictionary as last resort
  400. // If slots are availabe generate new dictioanries for path and content-type
  401. useID, _ = h2d.getGenericDictID()
  402. pathID, pathFound = h2d.getNextDictID()
  403. if pathFound {
  404. h2d.pathToDict[w.path] = pathID
  405. }
  406. typeID, typeFound = h2d.getNextDictID()
  407. if typeFound {
  408. h2d.typeToDict[w.contentType] = typeID
  409. }
  410. }
  411. useLen := h2d.maxSize
  412. if len(p) < useLen {
  413. useLen = len(p)
  414. }
  415. // Update all the dictionaries using the new data
  416. setDicts := make([]setDictRequest, 0, 3)
  417. setDict := setDictRequest{
  418. streamID: w.streamID,
  419. dictID: useID,
  420. dictSZ: uint64(useLen),
  421. }
  422. setDicts = append(setDicts, setDict)
  423. if pathID != useID {
  424. setDict.dictID = pathID
  425. setDicts = append(setDicts, setDict)
  426. }
  427. if typeID != useID {
  428. setDict.dictID = typeID
  429. setDicts = append(setDicts, setDict)
  430. }
  431. h2d.dictChan <- useDictRequest{streamID: w.streamID, dictID: uint8(useID), setDict: setDicts}
  432. dict := h2d.dictionaries[useID]
  433. // Brolti requires the dictionary to be immutable
  434. copyDict := make([]byte, len(dict))
  435. copy(copyDict, dict)
  436. for _, set := range setDicts {
  437. h2d.dictionaries[set.dictID] = adjustDictionary(h2d.dictionaries[set.dictID], p, set, h2d.maxSize)
  438. }
  439. w.comp = newCompressor(w.Buffer, h2d.quality, h2d.window)
  440. s.writeLock.Lock()
  441. h2d.dictLock.Unlock()
  442. if len(copyDict) > 0 {
  443. w.comp.SetDictionary(copyDict)
  444. }
  445. return true
  446. }
  447. func (w *h2DictWriter) Write(p []byte) (n int, err error) {
  448. bufLen := w.Buffer.Len()
  449. if w.comp != nil {
  450. n, err = w.comp.Write(p)
  451. if err != nil {
  452. return
  453. }
  454. err = w.comp.Flush()
  455. w.dicts.compIn.IncrementBy(uint64(n))
  456. w.dicts.compOut.IncrementBy(uint64(w.Buffer.Len() - bufLen))
  457. return
  458. }
  459. return w.Buffer.Write(p)
  460. }
  461. func (w *h2DictWriter) Close() error {
  462. if w.comp != nil {
  463. return w.comp.Close()
  464. }
  465. return nil
  466. }
  467. // From http2/hpack
  468. func http2ReadVarInt(n byte, p []byte) (remain []byte, v uint64, err error) {
  469. if n < 1 || n > 8 {
  470. panic("bad n")
  471. }
  472. if len(p) == 0 {
  473. return nil, 0, MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
  474. }
  475. v = uint64(p[0])
  476. if n < 8 {
  477. v &= (1 << uint64(n)) - 1
  478. }
  479. if v < (1<<uint64(n))-1 {
  480. return p[1:], v, nil
  481. }
  482. origP := p
  483. p = p[1:]
  484. var m uint64
  485. for len(p) > 0 {
  486. b := p[0]
  487. p = p[1:]
  488. v += uint64(b&127) << m
  489. if b&128 == 0 {
  490. return p, v, nil
  491. }
  492. m += 7
  493. if m >= 63 {
  494. return origP, 0, MuxerStreamError{"invalid integer", http2.ErrCodeProtocol}
  495. }
  496. }
  497. return nil, 0, MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
  498. }
  499. func appendVarInt(dst []byte, n byte, i uint64) []byte {
  500. k := uint64((1 << n) - 1)
  501. if i < k {
  502. return append(dst, byte(i))
  503. }
  504. dst = append(dst, byte(k))
  505. i -= k
  506. for ; i >= 128; i >>= 7 {
  507. dst = append(dst, byte(0x80|(i&0x7f)))
  508. }
  509. return append(dst, byte(i))
  510. }