header.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package connection
  2. import (
  3. "encoding/base64"
  4. "fmt"
  5. "net/http"
  6. "strings"
  7. "github.com/pkg/errors"
  8. )
  9. var (
  10. // internal special headers
  11. RequestUserHeaders = "cf-cloudflared-request-headers"
  12. ResponseUserHeaders = "cf-cloudflared-response-headers"
  13. ResponseMetaHeader = "cf-cloudflared-response-meta"
  14. // internal special headers
  15. CanonicalResponseUserHeaders = http.CanonicalHeaderKey(ResponseUserHeaders)
  16. CanonicalResponseMetaHeader = http.CanonicalHeaderKey(ResponseMetaHeader)
  17. )
  18. var (
  19. // pre-generate possible values for res
  20. responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
  21. responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
  22. )
  23. // HTTPHeader is a custom header struct that expects only ever one value for the header.
  24. // This structure is used to serialize the headers and attach them to the HTTP2 request when proxying.
  25. type HTTPHeader struct {
  26. Name string
  27. Value string
  28. }
  29. type responseMetaHeader struct {
  30. Source string `json:"src"`
  31. }
  32. func mustInitRespMetaHeader(src string) string {
  33. header, err := json.Marshal(responseMetaHeader{Source: src})
  34. if err != nil {
  35. panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
  36. }
  37. return string(header)
  38. }
  39. var headerEncoding = base64.RawStdEncoding
  40. // IsControlResponseHeader is called in the direction of eyeball <- origin.
  41. func IsControlResponseHeader(headerName string) bool {
  42. return strings.HasPrefix(headerName, ":") ||
  43. strings.HasPrefix(headerName, "cf-int-") ||
  44. strings.HasPrefix(headerName, "cf-cloudflared-")
  45. }
  46. // isWebsocketClientHeader returns true if the header name is required by the client to upgrade properly
  47. func IsWebsocketClientHeader(headerName string) bool {
  48. return headerName == "sec-websocket-accept" ||
  49. headerName == "connection" ||
  50. headerName == "upgrade"
  51. }
  52. // Serialize HTTP1.x headers by base64-encoding each header name and value,
  53. // and then joining them in the format of [key:value;]
  54. func SerializeHeaders(h1Headers http.Header) string {
  55. // compute size of the fully serialized value and largest temp buffer we will need
  56. serializedLen := 0
  57. maxTempLen := 0
  58. for headerName, headerValues := range h1Headers {
  59. for _, headerValue := range headerValues {
  60. nameLen := headerEncoding.EncodedLen(len(headerName))
  61. valueLen := headerEncoding.EncodedLen(len(headerValue))
  62. const delims = 2
  63. serializedLen += delims + nameLen + valueLen
  64. if nameLen > maxTempLen {
  65. maxTempLen = nameLen
  66. }
  67. if valueLen > maxTempLen {
  68. maxTempLen = valueLen
  69. }
  70. }
  71. }
  72. var buf strings.Builder
  73. buf.Grow(serializedLen)
  74. temp := make([]byte, maxTempLen)
  75. writeB64 := func(s string) {
  76. n := headerEncoding.EncodedLen(len(s))
  77. if n > len(temp) {
  78. temp = make([]byte, n)
  79. }
  80. headerEncoding.Encode(temp[:n], []byte(s))
  81. buf.Write(temp[:n])
  82. }
  83. for headerName, headerValues := range h1Headers {
  84. for _, headerValue := range headerValues {
  85. if buf.Len() > 0 {
  86. buf.WriteByte(';')
  87. }
  88. writeB64(headerName)
  89. buf.WriteByte(':')
  90. writeB64(headerValue)
  91. }
  92. }
  93. return buf.String()
  94. }
  95. // Deserialize headers serialized by `SerializeHeader`
  96. func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) {
  97. const unableToDeserializeErr = "Unable to deserialize headers"
  98. var deserialized []HTTPHeader
  99. for _, serializedPair := range strings.Split(serializedHeaders, ";") {
  100. if len(serializedPair) == 0 {
  101. continue
  102. }
  103. serializedHeaderParts := strings.Split(serializedPair, ":")
  104. if len(serializedHeaderParts) != 2 {
  105. return nil, errors.New(unableToDeserializeErr)
  106. }
  107. serializedName := serializedHeaderParts[0]
  108. serializedValue := serializedHeaderParts[1]
  109. deserializedName := make([]byte, headerEncoding.DecodedLen(len(serializedName)))
  110. deserializedValue := make([]byte, headerEncoding.DecodedLen(len(serializedValue)))
  111. if _, err := headerEncoding.Decode(deserializedName, []byte(serializedName)); err != nil {
  112. return nil, errors.Wrap(err, unableToDeserializeErr)
  113. }
  114. if _, err := headerEncoding.Decode(deserializedValue, []byte(serializedValue)); err != nil {
  115. return nil, errors.Wrap(err, unableToDeserializeErr)
  116. }
  117. deserialized = append(deserialized, HTTPHeader{
  118. Name: string(deserializedName),
  119. Value: string(deserializedValue),
  120. })
  121. }
  122. return deserialized, nil
  123. }