123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- package carrier
- import (
- "context"
- "crypto/tls"
- "crypto/x509"
- "fmt"
- "math/rand"
- "testing"
- "time"
- gws "github.com/gorilla/websocket"
- "github.com/rs/zerolog"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "golang.org/x/net/websocket"
- "github.com/cloudflare/cloudflared/hello"
- "github.com/cloudflare/cloudflared/tlsconfig"
- cfwebsocket "github.com/cloudflare/cloudflared/websocket"
- )
- func websocketClientTLSConfig(t *testing.T) *tls.Config {
- certPool := x509.NewCertPool()
- helloCert, err := tlsconfig.GetHelloCertificateX509()
- assert.NoError(t, err)
- certPool.AddCert(helloCert)
- assert.NotNil(t, certPool)
- return &tls.Config{RootCAs: certPool}
- }
- func TestWebsocketHeaders(t *testing.T) {
- req := testRequest(t, "http://example.com", nil)
- wsHeaders := websocketHeaders(req)
- for _, header := range stripWebsocketHeaders {
- assert.Empty(t, wsHeaders[header])
- }
- assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
- }
- func TestServe(t *testing.T) {
- log := zerolog.Nop()
- shutdownC := make(chan struct{})
- errC := make(chan error)
- listener, err := hello.CreateTLSListener("localhost:1111")
- assert.NoError(t, err)
- defer listener.Close()
- go func() {
- errC <- hello.StartHelloWorldServer(&log, listener, shutdownC)
- }()
- req := testRequest(t, "https://localhost:1111/ws", nil)
- tlsConfig := websocketClientTLSConfig(t)
- assert.NotNil(t, tlsConfig)
- d := gws.Dialer{TLSClientConfig: tlsConfig}
- conn, resp, err := clientConnect(req, &d)
- assert.NoError(t, err)
- assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
- for i := 0; i < 1000; i++ {
- messageSize := rand.Int()%2048 + 1
- clientMessage := make([]byte, messageSize)
- // rand.Read always returns len(clientMessage) and a nil error
- rand.Read(clientMessage)
- err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
- assert.NoError(t, err)
- messageType, message, err := conn.ReadMessage()
- assert.NoError(t, err)
- assert.Equal(t, websocket.BinaryFrame, messageType)
- assert.Equal(t, clientMessage, message)
- }
- _ = conn.Close()
- close(shutdownC)
- <-errC
- }
- func TestWebsocketWrapper(t *testing.T) {
- listener, err := hello.CreateTLSListener("localhost:0")
- require.NoError(t, err)
- serverErrorChan := make(chan error)
- helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background())
- defer func() { <-serverErrorChan }()
- defer cancelHelloSvr()
- go func() {
- log := zerolog.Nop()
- serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done())
- }()
- tlsConfig := websocketClientTLSConfig(t)
- d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute}
- testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String())
- req := testRequest(t, testAddr, nil)
- conn, resp, err := clientConnect(req, &d)
- require.NoError(t, err)
- assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
- // Websocket now connected to test server so lets check our wrapper
- wrapper := cfwebsocket.GorillaConn{Conn: conn}
- buf := make([]byte, 100)
- wrapper.Write([]byte("abc"))
- n, err := wrapper.Read(buf)
- require.NoError(t, err)
- require.Equal(t, n, 3)
- require.Equal(t, "abc", string(buf[:n]))
- // Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
- wrapper.Write([]byte("abc"))
- buf = buf[:1]
- n, err = wrapper.Read(buf)
- require.NoError(t, err)
- require.Equal(t, n, 1)
- require.Equal(t, "a", string(buf[:n]))
- buf = buf[:cap(buf)]
- n, err = wrapper.Read(buf)
- require.NoError(t, err)
- require.Equal(t, n, 2)
- require.Equal(t, "bc", string(buf[:n]))
- }
|