123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373 |
- package websocketconn
- import (
- "bytes"
- "fmt"
- "io"
- "net"
- "net/http"
- "net/url"
- "sync"
- "testing"
- "time"
- "github.com/gorilla/websocket"
- )
- // Returns a (server, client) pair of websocketconn.Conns.
- func connPair() (*Conn, *Conn, error) {
- // Will be assigned inside server.Handler.
- var serverConn *Conn
- // Start up a web server to receive the request.
- ln, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- return nil, nil, err
- }
- defer ln.Close()
- errCh := make(chan error)
- server := http.Server{
- Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
- upgrader := websocket.Upgrader{
- CheckOrigin: func(*http.Request) bool { return true },
- }
- ws, err := upgrader.Upgrade(rw, req, nil)
- if err != nil {
- errCh <- err
- return
- }
- serverConn = New(ws)
- close(errCh)
- }),
- }
- defer server.Close()
- go func() {
- err := server.Serve(ln)
- if err != nil && err != http.ErrServerClosed {
- errCh <- err
- }
- }()
- // Make a request to the web server.
- urlStr := (&url.URL{Scheme: "ws", Host: ln.Addr().String()}).String()
- ws, _, err := (&websocket.Dialer{}).Dial(urlStr, nil)
- if err != nil {
- return nil, nil, err
- }
- clientConn := New(ws)
- // The server is finished when errCh is written to or closed.
- err = <-errCh
- if err != nil {
- return nil, nil, err
- }
- return serverConn, clientConn, nil
- }
- // Test that you can write in chunks and read the result concatenated.
- func TestWrite(t *testing.T) {
- tests := [][][]byte{
- {},
- {[]byte("foo")},
- {[]byte("foo"), []byte("bar")},
- {{}, []byte("foo"), {}, {}, []byte("bar")},
- }
- for _, test := range tests {
- s, c, err := connPair()
- if err != nil {
- t.Fatal(err)
- }
- // This is a little awkward because we need to read to and write
- // from both ends of the Conn, and we need to do it in separate
- // goroutines because otherwise a Write may block waiting for
- // someone to Read it. Here we set up a loop in a separate
- // goroutine, reading from the Conn s and writing to the dataCh
- // and errCh channels, whose ultimate effect in the select loop
- // below is like
- // data, err := io.ReadAll(s)
- dataCh := make(chan []byte)
- errCh := make(chan error)
- go func() {
- for {
- var buf [1024]byte
- n, err := s.Read(buf[:])
- if err != nil {
- errCh <- err
- return
- }
- p := make([]byte, n)
- copy(p, buf[:])
- dataCh <- p
- }
- }()
- // Write the data to the client side of the Conn, one chunk at a
- // time.
- for i, chunk := range test {
- n, err := c.Write(chunk)
- if err != nil || n != len(chunk) {
- t.Fatalf("%+q Write chunk %d: got (%d, %v), expected (%d, %v)",
- test, i, n, err, len(chunk), nil)
- }
- }
- // We cannot immediately c.Close here, because that closes the
- // connection right away, without waiting for buffered data to
- // be sent.
- // Pull data and err from the server goroutine above.
- var data []byte
- err = nil
- loop:
- for {
- select {
- case p := <-dataCh:
- data = append(data, p...)
- case err = <-errCh:
- break loop
- case <-time.After(100 * time.Millisecond):
- break loop
- }
- }
- s.Close()
- c.Close()
- // Now data and err contain the result of reading everything
- // from s.
- expected := bytes.Join(test, []byte{})
- if err != nil || !bytes.Equal(data, expected) {
- t.Fatalf("%+q ReadAll: got (%+q, %v), expected (%+q, %v)",
- test, data, err, expected, nil)
- }
- }
- }
- // Test that multiple goroutines may call Read on a Conn simultaneously. Run
- // this with
- //
- // go test -race
- func TestConcurrentRead(t *testing.T) {
- s, c, err := connPair()
- if err != nil {
- t.Fatal(err)
- }
- defer s.Close()
- // Set up multiple threads reading from the same conn.
- errCh := make(chan error, 2)
- var wg sync.WaitGroup
- wg.Add(2)
- for i := 0; i < 2; i++ {
- go func() {
- defer wg.Done()
- _, err := io.Copy(io.Discard, s)
- if err != nil {
- errCh <- err
- }
- }()
- }
- // Write a bunch of data to the other end.
- for i := 0; i < 2000; i++ {
- _, err := fmt.Fprintf(c, "%d", i)
- if err != nil {
- c.Close()
- t.Fatalf("Write: %v", err)
- }
- }
- c.Close()
- wg.Wait()
- close(errCh)
- err = <-errCh
- if err != nil {
- t.Fatalf("Read: %v", err)
- }
- }
- // Test that multiple goroutines may call Write on a Conn simultaneously. Run
- // this with
- //
- // go test -race
- func TestConcurrentWrite(t *testing.T) {
- s, c, err := connPair()
- if err != nil {
- t.Fatal(err)
- }
- // Set up multiple threads writing to the same conn.
- errCh := make(chan error, 3)
- var wg sync.WaitGroup
- wg.Add(2)
- for i := 0; i < 2; i++ {
- go func() {
- defer wg.Done()
- for j := 0; j < 1000; j++ {
- _, err := fmt.Fprintf(s, "%d", j)
- if err != nil {
- errCh <- err
- break
- }
- }
- }()
- }
- go func() {
- wg.Wait()
- err := s.Close()
- if err != nil {
- errCh <- err
- }
- close(errCh)
- }()
- // Read from the other end.
- _, err = io.Copy(io.Discard, c)
- c.Close()
- if err != nil {
- t.Fatalf("Read: %v", err)
- }
- err = <-errCh
- if err != nil {
- t.Fatalf("Write: %v", err)
- }
- }
- // Test that Read and Write methods return errors after Close.
- func TestClose(t *testing.T) {
- s, c, err := connPair()
- if err != nil {
- t.Fatal(err)
- }
- defer c.Close()
- err = s.Close()
- if err != nil {
- t.Fatal(err)
- }
- var buf [10]byte
- n, err := s.Read(buf[:])
- if n != 0 || err == nil {
- t.Fatalf("Read after Close returned (%v, %v), expected (%v, non-nil)", n, err, 0)
- }
- _, err = s.Write([]byte{1, 2, 3})
- // Here we break the abstraction a little and look for a specific error,
- // io.ErrClosedPipe. This is because we know the Conn uses an io.Pipe
- // internally.
- if err != io.ErrClosedPipe {
- t.Fatalf("Write after Close returned %v, expected %v", err, io.ErrClosedPipe)
- }
- }
- // Benchmark creating a server websocket.Conn (without the websocketconn.Conn
- // wrapper) for different read/write buffer sizes.
- func BenchmarkUpgradeBufferSize(b *testing.B) {
- // Buffer size of 0 would mean the default of 4096:
- // https://github.com/gorilla/websocket/blob/v1.5.0/conn.go#L37
- // But a size of zero also has the effect of causing reuse of the HTTP
- // server's buffers. So we test 4096 separately from 0.
- // https://github.com/gorilla/websocket/blob/v1.5.0/server.go#L32
- for _, bufSize := range []int{0, 128, 1024, 2048, 4096, 8192} {
- upgrader := websocket.Upgrader{
- CheckOrigin: func(*http.Request) bool { return true },
- ReadBufferSize: bufSize,
- WriteBufferSize: bufSize,
- }
- b.Run(fmt.Sprintf("%d", bufSize), func(b *testing.B) {
- // Start up a web server to receive the request.
- ln, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- b.Fatal(err)
- }
- defer ln.Close()
- wsCh := make(chan *websocket.Conn)
- errCh := make(chan error)
- server := http.Server{
- Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
- ws, err := upgrader.Upgrade(rw, req, nil)
- if err != nil {
- errCh <- err
- return
- }
- wsCh <- ws
- }),
- }
- defer server.Close()
- go func() {
- err := server.Serve(ln)
- if err != nil && err != http.ErrServerClosed {
- errCh <- err
- }
- }()
- // Make a request to the web server.
- dialer := &websocket.Dialer{
- ReadBufferSize: bufSize,
- WriteBufferSize: bufSize,
- }
- urlStr := (&url.URL{Scheme: "ws", Host: ln.Addr().String()}).String()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- ws, _, err := dialer.Dial(urlStr, nil)
- if err != nil {
- b.Fatal(err)
- }
- ws.Close()
- select {
- case <-wsCh:
- case err := <-errCh:
- b.Fatal(err)
- }
- }
- b.StopTimer()
- })
- }
- }
- // Benchmark read/write in the client←server and server←client directions, with
- // messages of different sizes. Run with -benchmem to see memory allocations.
- func BenchmarkReadWrite(b *testing.B) {
- trial := func(b *testing.B, readConn, writeConn *Conn, msgSize int) {
- go func() {
- io.Copy(io.Discard, readConn)
- }()
- data := make([]byte, msgSize)
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- n, err := writeConn.Write(data[:])
- b.SetBytes(int64(n))
- if err != nil {
- b.Fatal(err)
- }
- }
- }
- for _, msgSize := range []int{150, 3000} {
- s, c, err := connPair()
- if err != nil {
- b.Fatal(err)
- }
- b.Run(fmt.Sprintf("c←s %d", msgSize), func(b *testing.B) {
- trial(b, c, s, msgSize)
- })
- b.Run(fmt.Sprintf("s←c %d", msgSize), func(b *testing.B) {
- trial(b, s, c, msgSize)
- })
- err = s.Close()
- if err != nil {
- b.Fatal(err)
- }
- err = c.Close()
- if err != nil {
- b.Fatal(err)
- }
- }
- }
|