muxreader_test.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. package h2mux
  2. import (
  3. "context"
  4. "testing"
  5. "time"
  6. "github.com/stretchr/testify/assert"
  7. )
  8. var (
  9. methodHeader = Header{
  10. Name: ":method",
  11. Value: "GET",
  12. }
  13. schemeHeader = Header{
  14. Name: ":scheme",
  15. Value: "https",
  16. }
  17. pathHeader = Header{
  18. Name: ":path",
  19. Value: "/api/tunnels",
  20. }
  21. tunnelHostnameHeader = Header{
  22. Name: CloudflaredProxyTunnelHostnameHeader,
  23. Value: "tunnel.example.com",
  24. }
  25. respStatusHeader = Header{
  26. Name: ":status",
  27. Value: "200",
  28. }
  29. )
  30. type mockOriginStreamHandler struct {
  31. stream *MuxedStream
  32. }
  33. func (mosh *mockOriginStreamHandler) ServeStream(stream *MuxedStream) error {
  34. mosh.stream = stream
  35. // Echo tunnel hostname in header
  36. stream.WriteHeaders([]Header{respStatusHeader})
  37. return nil
  38. }
  39. func getCloudflaredProxyTunnelHostnameHeader(stream *MuxedStream) string {
  40. for _, header := range stream.Headers {
  41. if header.Name == CloudflaredProxyTunnelHostnameHeader {
  42. return header.Value
  43. }
  44. }
  45. return ""
  46. }
  47. func assertOpenStreamSucceed(t *testing.T, stream *MuxedStream, err error) {
  48. assert.NoError(t, err)
  49. assert.Len(t, stream.Headers, 1)
  50. assert.Equal(t, respStatusHeader, stream.Headers[0])
  51. }
  52. func TestMissingHeaders(t *testing.T) {
  53. originHandler := &mockOriginStreamHandler{}
  54. muxPair := NewDefaultMuxerPair(t, t.Name(), originHandler.ServeStream)
  55. muxPair.Serve(t)
  56. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  57. defer cancel()
  58. reqHeaders := []Header{
  59. {
  60. Name: "content-type",
  61. Value: "application/json",
  62. },
  63. }
  64. // Request doesn't contain CloudflaredProxyTunnelHostnameHeader
  65. stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil)
  66. assertOpenStreamSucceed(t, stream, err)
  67. assert.Empty(t, originHandler.stream.method)
  68. assert.Empty(t, originHandler.stream.path)
  69. assert.False(t, originHandler.stream.TunnelHostname().IsSet())
  70. }
  71. func TestReceiveHeaderData(t *testing.T) {
  72. originHandler := &mockOriginStreamHandler{}
  73. muxPair := NewDefaultMuxerPair(t, t.Name(), originHandler.ServeStream)
  74. muxPair.Serve(t)
  75. reqHeaders := []Header{
  76. methodHeader,
  77. schemeHeader,
  78. pathHeader,
  79. tunnelHostnameHeader,
  80. }
  81. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  82. defer cancel()
  83. reqHeaders = append(reqHeaders, tunnelHostnameHeader)
  84. stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil)
  85. assertOpenStreamSucceed(t, stream, err)
  86. assert.Equal(t, methodHeader.Value, originHandler.stream.method)
  87. assert.Equal(t, pathHeader.Value, originHandler.stream.path)
  88. assert.True(t, originHandler.stream.TunnelHostname().IsSet())
  89. assert.Equal(t, tunnelHostnameHeader.Value, originHandler.stream.TunnelHostname().String())
  90. }