websocket_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package carrier
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "fmt"
  7. "math/rand"
  8. "testing"
  9. "time"
  10. gws "github.com/gorilla/websocket"
  11. "github.com/rs/zerolog"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. "golang.org/x/net/websocket"
  15. "github.com/cloudflare/cloudflared/hello"
  16. "github.com/cloudflare/cloudflared/tlsconfig"
  17. cfwebsocket "github.com/cloudflare/cloudflared/websocket"
  18. )
  19. func websocketClientTLSConfig(t *testing.T) *tls.Config {
  20. certPool := x509.NewCertPool()
  21. helloCert, err := tlsconfig.GetHelloCertificateX509()
  22. assert.NoError(t, err)
  23. certPool.AddCert(helloCert)
  24. assert.NotNil(t, certPool)
  25. return &tls.Config{RootCAs: certPool}
  26. }
  27. func TestWebsocketHeaders(t *testing.T) {
  28. req := testRequest(t, "http://example.com", nil)
  29. wsHeaders := websocketHeaders(req)
  30. for _, header := range stripWebsocketHeaders {
  31. assert.Empty(t, wsHeaders[header])
  32. }
  33. assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
  34. }
  35. func TestServe(t *testing.T) {
  36. log := zerolog.Nop()
  37. shutdownC := make(chan struct{})
  38. errC := make(chan error)
  39. listener, err := hello.CreateTLSListener("localhost:1111")
  40. assert.NoError(t, err)
  41. defer listener.Close()
  42. go func() {
  43. errC <- hello.StartHelloWorldServer(&log, listener, shutdownC)
  44. }()
  45. req := testRequest(t, "https://localhost:1111/ws", nil)
  46. tlsConfig := websocketClientTLSConfig(t)
  47. assert.NotNil(t, tlsConfig)
  48. d := gws.Dialer{TLSClientConfig: tlsConfig}
  49. conn, resp, err := clientConnect(req, &d)
  50. assert.NoError(t, err)
  51. assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
  52. for i := 0; i < 1000; i++ {
  53. messageSize := rand.Int()%2048 + 1
  54. clientMessage := make([]byte, messageSize)
  55. // rand.Read always returns len(clientMessage) and a nil error
  56. rand.Read(clientMessage)
  57. err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
  58. assert.NoError(t, err)
  59. messageType, message, err := conn.ReadMessage()
  60. assert.NoError(t, err)
  61. assert.Equal(t, websocket.BinaryFrame, messageType)
  62. assert.Equal(t, clientMessage, message)
  63. }
  64. _ = conn.Close()
  65. close(shutdownC)
  66. <-errC
  67. }
  68. func TestWebsocketWrapper(t *testing.T) {
  69. listener, err := hello.CreateTLSListener("localhost:0")
  70. require.NoError(t, err)
  71. serverErrorChan := make(chan error)
  72. helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background())
  73. defer func() { <-serverErrorChan }()
  74. defer cancelHelloSvr()
  75. go func() {
  76. log := zerolog.Nop()
  77. serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done())
  78. }()
  79. tlsConfig := websocketClientTLSConfig(t)
  80. d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute}
  81. testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String())
  82. req := testRequest(t, testAddr, nil)
  83. conn, resp, err := clientConnect(req, &d)
  84. require.NoError(t, err)
  85. assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
  86. // Websocket now connected to test server so lets check our wrapper
  87. wrapper := cfwebsocket.GorillaConn{Conn: conn}
  88. buf := make([]byte, 100)
  89. wrapper.Write([]byte("abc"))
  90. n, err := wrapper.Read(buf)
  91. require.NoError(t, err)
  92. require.Equal(t, n, 3)
  93. require.Equal(t, "abc", string(buf[:n]))
  94. // Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
  95. wrapper.Write([]byte("abc"))
  96. buf = buf[:1]
  97. n, err = wrapper.Read(buf)
  98. require.NoError(t, err)
  99. require.Equal(t, n, 1)
  100. require.Equal(t, "a", string(buf[:n]))
  101. buf = buf[:cap(buf)]
  102. n, err = wrapper.Read(buf)
  103. require.NoError(t, err)
  104. require.Equal(t, n, 2)
  105. require.Equal(t, "bc", string(buf[:n]))
  106. }