h2mux_test.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. package connection
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net"
  7. "net/http"
  8. "strconv"
  9. "sync"
  10. "testing"
  11. "time"
  12. "github.com/gobwas/ws/wsutil"
  13. "github.com/rs/zerolog"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. "github.com/cloudflare/cloudflared/h2mux"
  17. )
  18. var (
  19. testMuxerConfig = &MuxerConfig{
  20. HeartbeatInterval: time.Second * 5,
  21. MaxHeartbeats: 5,
  22. CompressionSetting: 0,
  23. MetricsUpdateFreq: time.Second * 5,
  24. }
  25. )
  26. func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
  27. edgeConn, originConn := net.Pipe()
  28. edgeMuxChan := make(chan *h2mux.Muxer)
  29. go func() {
  30. edgeMuxConfig := h2mux.MuxerConfig{
  31. Log: &log,
  32. Handler: h2mux.MuxedStreamFunc(func(stream *h2mux.MuxedStream) error {
  33. // we only expect RPC traffic in client->edge direction, provide minimal support for mocking
  34. require.True(t, stream.IsRPCStream())
  35. return stream.WriteHeaders([]h2mux.Header{
  36. {Name: ":status", Value: "200"},
  37. })
  38. }),
  39. }
  40. edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams)
  41. require.NoError(t, err)
  42. edgeMuxChan <- edgeMux
  43. }()
  44. var connIndex = uint8(0)
  45. testObserver := NewObserver(&log, &log, false)
  46. h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver, nil)
  47. require.NoError(t, err)
  48. return h2muxConn, <-edgeMuxChan
  49. }
  50. func TestServeStreamHTTP(t *testing.T) {
  51. tests := []testRequest{
  52. {
  53. name: "ok",
  54. endpoint: "/ok",
  55. expectedStatus: http.StatusOK,
  56. expectedBody: []byte(http.StatusText(http.StatusOK)),
  57. },
  58. {
  59. name: "large_file",
  60. endpoint: "/large_file",
  61. expectedStatus: http.StatusOK,
  62. expectedBody: testLargeResp,
  63. },
  64. {
  65. name: "Bad request",
  66. endpoint: "/400",
  67. expectedStatus: http.StatusBadRequest,
  68. expectedBody: []byte(http.StatusText(http.StatusBadRequest)),
  69. },
  70. {
  71. name: "Internal server error",
  72. endpoint: "/500",
  73. expectedStatus: http.StatusInternalServerError,
  74. expectedBody: []byte(http.StatusText(http.StatusInternalServerError)),
  75. },
  76. {
  77. name: "Proxy error",
  78. endpoint: "/error",
  79. expectedStatus: http.StatusBadGateway,
  80. expectedBody: nil,
  81. isProxyError: true,
  82. },
  83. }
  84. ctx, cancel := context.WithCancel(context.Background())
  85. h2muxConn, edgeMux := newH2MuxConnection(t)
  86. var wg sync.WaitGroup
  87. wg.Add(2)
  88. go func() {
  89. defer wg.Done()
  90. _ = edgeMux.Serve(ctx)
  91. }()
  92. go func() {
  93. defer wg.Done()
  94. err := h2muxConn.serveMuxer(ctx)
  95. require.Error(t, err)
  96. }()
  97. for _, test := range tests {
  98. headers := []h2mux.Header{
  99. {
  100. Name: ":path",
  101. Value: test.endpoint,
  102. },
  103. }
  104. stream, err := edgeMux.OpenStream(ctx, headers, nil)
  105. require.NoError(t, err)
  106. require.True(t, hasHeader(stream, ":status", strconv.Itoa(test.expectedStatus)))
  107. if test.isProxyError {
  108. assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderCfd))
  109. } else {
  110. assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
  111. body := make([]byte, len(test.expectedBody))
  112. _, err = stream.Read(body)
  113. require.NoError(t, err)
  114. require.Equal(t, test.expectedBody, body)
  115. }
  116. }
  117. cancel()
  118. wg.Wait()
  119. }
  120. func TestServeStreamWS(t *testing.T) {
  121. ctx, cancel := context.WithCancel(context.Background())
  122. h2muxConn, edgeMux := newH2MuxConnection(t)
  123. var wg sync.WaitGroup
  124. wg.Add(2)
  125. go func() {
  126. defer wg.Done()
  127. edgeMux.Serve(ctx)
  128. }()
  129. go func() {
  130. defer wg.Done()
  131. err := h2muxConn.serveMuxer(ctx)
  132. require.Error(t, err)
  133. }()
  134. headers := []h2mux.Header{
  135. {
  136. Name: ":path",
  137. Value: "/ws",
  138. },
  139. {
  140. Name: "connection",
  141. Value: "upgrade",
  142. },
  143. {
  144. Name: "upgrade",
  145. Value: "websocket",
  146. },
  147. }
  148. readPipe, writePipe := io.Pipe()
  149. stream, err := edgeMux.OpenStream(ctx, headers, readPipe)
  150. require.NoError(t, err)
  151. require.True(t, hasHeader(stream, ":status", strconv.Itoa(http.StatusSwitchingProtocols)))
  152. assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
  153. data := []byte("test websocket")
  154. err = wsutil.WriteClientText(writePipe, data)
  155. require.NoError(t, err)
  156. respBody, err := wsutil.ReadServerText(stream)
  157. require.NoError(t, err)
  158. require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
  159. cancel()
  160. wg.Wait()
  161. }
  162. func TestGracefulShutdownH2Mux(t *testing.T) {
  163. ctx, cancel := context.WithCancel(context.Background())
  164. defer cancel()
  165. h2muxConn, edgeMux := newH2MuxConnection(t)
  166. shutdownC := make(chan struct{})
  167. unregisteredC := make(chan struct{})
  168. h2muxConn.gracefulShutdownC = shutdownC
  169. h2muxConn.newRPCClientFunc = func(_ context.Context, _ io.ReadWriteCloser, _ *zerolog.Logger) NamedTunnelRPCClient {
  170. return &mockNamedTunnelRPCClient{
  171. registered: nil,
  172. unregistered: unregisteredC,
  173. }
  174. }
  175. var wg sync.WaitGroup
  176. wg.Add(3)
  177. go func() {
  178. defer wg.Done()
  179. _ = edgeMux.Serve(ctx)
  180. }()
  181. go func() {
  182. defer wg.Done()
  183. _ = h2muxConn.serveMuxer(ctx)
  184. }()
  185. go func() {
  186. defer wg.Done()
  187. h2muxConn.controlLoop(ctx, &mockConnectedFuse{}, true)
  188. }()
  189. time.Sleep(100 * time.Millisecond)
  190. close(shutdownC)
  191. select {
  192. case <-unregisteredC:
  193. break // ok
  194. case <-time.Tick(time.Second):
  195. assert.Fail(t, "timed out waiting for control loop to unregister")
  196. }
  197. cancel()
  198. wg.Wait()
  199. assert.True(t, h2muxConn.stoppedGracefully)
  200. assert.Nil(t, h2muxConn.gracefulShutdownC)
  201. }
  202. func hasHeader(stream *h2mux.MuxedStream, name, val string) bool {
  203. for _, header := range stream.Headers {
  204. if header.Name == name && header.Value == val {
  205. return true
  206. }
  207. }
  208. return false
  209. }
  210. func benchmarkServeStreamHTTPSimple(b *testing.B, test testRequest) {
  211. ctx, cancel := context.WithCancel(context.Background())
  212. h2muxConn, edgeMux := newH2MuxConnection(b)
  213. var wg sync.WaitGroup
  214. wg.Add(2)
  215. go func() {
  216. defer wg.Done()
  217. edgeMux.Serve(ctx)
  218. }()
  219. go func() {
  220. defer wg.Done()
  221. err := h2muxConn.serveMuxer(ctx)
  222. require.Error(b, err)
  223. }()
  224. headers := []h2mux.Header{
  225. {
  226. Name: ":path",
  227. Value: test.endpoint,
  228. },
  229. }
  230. body := make([]byte, len(test.expectedBody))
  231. b.ResetTimer()
  232. for i := 0; i < b.N; i++ {
  233. b.StartTimer()
  234. stream, openstreamErr := edgeMux.OpenStream(ctx, headers, nil)
  235. _, readBodyErr := stream.Read(body)
  236. b.StopTimer()
  237. require.NoError(b, openstreamErr)
  238. assert.True(b, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
  239. require.True(b, hasHeader(stream, ":status", strconv.Itoa(http.StatusOK)))
  240. require.NoError(b, readBodyErr)
  241. require.Equal(b, test.expectedBody, body)
  242. }
  243. cancel()
  244. wg.Wait()
  245. }
  246. func BenchmarkServeStreamHTTPSimple(b *testing.B) {
  247. test := testRequest{
  248. name: "ok",
  249. endpoint: "/ok",
  250. expectedStatus: http.StatusOK,
  251. expectedBody: []byte(http.StatusText(http.StatusOK)),
  252. }
  253. benchmarkServeStreamHTTPSimple(b, test)
  254. }
  255. func BenchmarkServeStreamHTTPLargeFile(b *testing.B) {
  256. test := testRequest{
  257. name: "large_file",
  258. endpoint: "/large_file",
  259. expectedStatus: http.StatusOK,
  260. expectedBody: testLargeResp,
  261. }
  262. benchmarkServeStreamHTTPSimple(b, test)
  263. }