header_test.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. package connection
  2. import (
  3. "fmt"
  4. "net/http"
  5. "reflect"
  6. "sort"
  7. "testing"
  8. "github.com/stretchr/testify/assert"
  9. )
  10. func TestSerializeHeaders(t *testing.T) {
  11. request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
  12. assert.NoError(t, err)
  13. mockHeaders := http.Header{
  14. "Mock-Header-One": {"Mock header one value", "three"},
  15. "Mock-Header-Two-Long": {"Mock header two value\nlong"},
  16. ":;": {":;", ";:"},
  17. ":": {":"},
  18. ";": {";"},
  19. ";;": {";;"},
  20. "Empty values": {"", ""},
  21. "": {"Empty key"},
  22. "control\tcharacter\b\n": {"value\n\b\t"},
  23. ";\v:": {":\v;"},
  24. }
  25. for header, values := range mockHeaders {
  26. for _, value := range values {
  27. // Note that Golang's http library is opinionated;
  28. // at this point every header name will be title-cased in order to comply with the HTTP RFC
  29. // This means our proxy is not completely transparent when it comes to proxying headers
  30. request.Header.Add(header, value)
  31. }
  32. }
  33. serializedHeaders := SerializeHeaders(request.Header)
  34. // Sanity check: the headers serialized to something that's not an empty string
  35. assert.NotEqual(t, "", serializedHeaders)
  36. // Deserialize back, and ensure we get the same set of headers
  37. deserializedHeaders, err := DeserializeHeaders(serializedHeaders)
  38. assert.NoError(t, err)
  39. assert.Equal(t, 13, len(deserializedHeaders))
  40. expectedHeaders := headerToReqHeader(mockHeaders)
  41. sort.Sort(ByName(deserializedHeaders))
  42. sort.Sort(ByName(expectedHeaders))
  43. assert.True(
  44. t,
  45. reflect.DeepEqual(expectedHeaders, deserializedHeaders),
  46. fmt.Sprintf("got = %#v, want = %#v\n", deserializedHeaders, expectedHeaders),
  47. )
  48. }
  49. type ByName []HTTPHeader
  50. func (a ByName) Len() int { return len(a) }
  51. func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
  52. func (a ByName) Less(i, j int) bool {
  53. if a[i].Name == a[j].Name {
  54. return a[i].Value < a[j].Value
  55. }
  56. return a[i].Name < a[j].Name
  57. }
  58. func headerToReqHeader(headers http.Header) (reqHeaders []HTTPHeader) {
  59. for name, values := range headers {
  60. for _, value := range values {
  61. reqHeaders = append(reqHeaders, HTTPHeader{Name: name, Value: value})
  62. }
  63. }
  64. return reqHeaders
  65. }
  66. func TestSerializeNoHeaders(t *testing.T) {
  67. request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
  68. assert.NoError(t, err)
  69. serializedHeaders := SerializeHeaders(request.Header)
  70. deserializedHeaders, err := DeserializeHeaders(serializedHeaders)
  71. assert.NoError(t, err)
  72. assert.Equal(t, 0, len(deserializedHeaders))
  73. }
  74. func TestDeserializeMalformed(t *testing.T) {
  75. var err error
  76. malformedData := []string{
  77. "malformed data",
  78. "bW9jawo=", // "mock"
  79. "bW9jawo=:ZGF0YQo=:bW9jawo=", // "mock:data:mock"
  80. "::",
  81. }
  82. for _, malformedValue := range malformedData {
  83. _, err = DeserializeHeaders(malformedValue)
  84. assert.Error(t, err)
  85. }
  86. }
  87. func TestIsControlResponseHeader(t *testing.T) {
  88. controlResponseHeaders := []string{
  89. // Anything that begins with cf-int- or cf-cloudflared-
  90. "cf-int-sample-header",
  91. "cf-cloudflared-sample-header",
  92. // Any http2 pseudoheader
  93. ":sample-pseudo-header",
  94. }
  95. for _, header := range controlResponseHeaders {
  96. assert.True(t, IsControlResponseHeader(header))
  97. }
  98. }
  99. func TestIsNotControlResponseHeader(t *testing.T) {
  100. notControlResponseHeaders := []string{
  101. "mock-header",
  102. "another-sample-header",
  103. "upgrade",
  104. "connection",
  105. "cf-whatever", // On the response path, we only want to filter cf-int- and cf-cloudflared-
  106. }
  107. for _, header := range notControlResponseHeaders {
  108. assert.False(t, IsControlResponseHeader(header))
  109. }
  110. }