header.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package connection
  2. import (
  3. "encoding/base64"
  4. "fmt"
  5. "net/http"
  6. "strings"
  7. "github.com/pkg/errors"
  8. )
  9. var (
  10. // internal special headers
  11. RequestUserHeaders = "cf-cloudflared-request-headers"
  12. ResponseUserHeaders = "cf-cloudflared-response-headers"
  13. ResponseMetaHeader = "cf-cloudflared-response-meta"
  14. // internal special headers
  15. CanonicalResponseUserHeaders = http.CanonicalHeaderKey(ResponseUserHeaders)
  16. CanonicalResponseMetaHeader = http.CanonicalHeaderKey(ResponseMetaHeader)
  17. )
  18. var (
  19. // pre-generate possible values for res
  20. responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared", false)
  21. responseMetaHeaderCfdFlowRateLimited = mustInitRespMetaHeader("cloudflared", true)
  22. responseMetaHeaderOrigin = mustInitRespMetaHeader("origin", false)
  23. )
  24. // HTTPHeader is a custom header struct that expects only ever one value for the header.
  25. // This structure is used to serialize the headers and attach them to the HTTP2 request when proxying.
  26. type HTTPHeader struct {
  27. Name string
  28. Value string
  29. }
  30. type responseMetaHeader struct {
  31. Source string `json:"src"`
  32. FlowRateLimited bool `json:"flow_rate_limited,omitempty"`
  33. }
  34. func mustInitRespMetaHeader(src string, flowRateLimited bool) string {
  35. header, err := json.Marshal(responseMetaHeader{Source: src, FlowRateLimited: flowRateLimited})
  36. if err != nil {
  37. panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
  38. }
  39. return string(header)
  40. }
  41. var headerEncoding = base64.RawStdEncoding
  42. // IsControlResponseHeader is called in the direction of eyeball <- origin.
  43. func IsControlResponseHeader(headerName string) bool {
  44. return strings.HasPrefix(headerName, ":") ||
  45. strings.HasPrefix(headerName, "cf-int-") ||
  46. strings.HasPrefix(headerName, "cf-cloudflared-")
  47. }
  48. // isWebsocketClientHeader returns true if the header name is required by the client to upgrade properly
  49. func IsWebsocketClientHeader(headerName string) bool {
  50. return headerName == "sec-websocket-accept" ||
  51. headerName == "connection" ||
  52. headerName == "upgrade"
  53. }
  54. // Serialize HTTP1.x headers by base64-encoding each header name and value,
  55. // and then joining them in the format of [key:value;]
  56. func SerializeHeaders(h1Headers http.Header) string {
  57. // compute size of the fully serialized value and largest temp buffer we will need
  58. serializedLen := 0
  59. maxTempLen := 0
  60. for headerName, headerValues := range h1Headers {
  61. for _, headerValue := range headerValues {
  62. nameLen := headerEncoding.EncodedLen(len(headerName))
  63. valueLen := headerEncoding.EncodedLen(len(headerValue))
  64. const delims = 2
  65. serializedLen += delims + nameLen + valueLen
  66. if nameLen > maxTempLen {
  67. maxTempLen = nameLen
  68. }
  69. if valueLen > maxTempLen {
  70. maxTempLen = valueLen
  71. }
  72. }
  73. }
  74. var buf strings.Builder
  75. buf.Grow(serializedLen)
  76. temp := make([]byte, maxTempLen)
  77. writeB64 := func(s string) {
  78. n := headerEncoding.EncodedLen(len(s))
  79. if n > len(temp) {
  80. temp = make([]byte, n)
  81. }
  82. headerEncoding.Encode(temp[:n], []byte(s))
  83. buf.Write(temp[:n])
  84. }
  85. for headerName, headerValues := range h1Headers {
  86. for _, headerValue := range headerValues {
  87. if buf.Len() > 0 {
  88. buf.WriteByte(';')
  89. }
  90. writeB64(headerName)
  91. buf.WriteByte(':')
  92. writeB64(headerValue)
  93. }
  94. }
  95. return buf.String()
  96. }
  97. // Deserialize headers serialized by `SerializeHeader`
  98. func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) {
  99. const unableToDeserializeErr = "Unable to deserialize headers"
  100. deserialized := make([]HTTPHeader, 0)
  101. for _, serializedPair := range strings.Split(serializedHeaders, ";") {
  102. if len(serializedPair) == 0 {
  103. continue
  104. }
  105. serializedHeaderParts := strings.Split(serializedPair, ":")
  106. if len(serializedHeaderParts) != 2 {
  107. return nil, errors.New(unableToDeserializeErr)
  108. }
  109. serializedName := serializedHeaderParts[0]
  110. serializedValue := serializedHeaderParts[1]
  111. deserializedName := make([]byte, headerEncoding.DecodedLen(len(serializedName)))
  112. deserializedValue := make([]byte, headerEncoding.DecodedLen(len(serializedValue)))
  113. if _, err := headerEncoding.Decode(deserializedName, []byte(serializedName)); err != nil {
  114. return nil, errors.Wrap(err, unableToDeserializeErr)
  115. }
  116. if _, err := headerEncoding.Decode(deserializedValue, []byte(serializedValue)); err != nil {
  117. return nil, errors.Wrap(err, unableToDeserializeErr)
  118. }
  119. deserialized = append(deserialized, HTTPHeader{
  120. Name: string(deserializedName),
  121. Value: string(deserializedValue),
  122. })
  123. }
  124. return deserialized, nil
  125. }