connection.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. package websocket
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "sync"
  9. "time"
  10. gobwas "github.com/gobwas/ws"
  11. "github.com/gobwas/ws/wsutil"
  12. "github.com/gorilla/websocket"
  13. "github.com/rs/zerolog"
  14. )
  15. const (
  16. // Time allowed to read the next pong message from the peer.
  17. defaultPongWait = 60 * time.Second
  18. // Send pings to peer with this period. Must be less than pongWait.
  19. defaultPingPeriod = (defaultPongWait * 9) / 10
  20. PingPeriodContextKey = PingPeriodContext("pingPeriod")
  21. )
  22. type PingPeriodContext string
  23. // GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
  24. // This is still used by access carrier
  25. type GorillaConn struct {
  26. *websocket.Conn
  27. log *zerolog.Logger
  28. readBuf bytes.Buffer
  29. }
  30. // Read will read messages from the websocket connection
  31. func (c *GorillaConn) Read(p []byte) (int, error) {
  32. // Intermediate buffer may contain unread bytes from the last read, start there before blocking on a new frame
  33. if c.readBuf.Len() > 0 {
  34. return c.readBuf.Read(p)
  35. }
  36. _, message, err := c.Conn.ReadMessage()
  37. if err != nil {
  38. return 0, err
  39. }
  40. copied := copy(p, message)
  41. // Write unread bytes to readBuf; if everything was read this is a no-op
  42. // Write returns a nil error always and grows the buffer; everything is always written or panic
  43. c.readBuf.Write(message[copied:])
  44. return copied, nil
  45. }
  46. // Write will write messages to the websocket connection
  47. func (c *GorillaConn) Write(p []byte) (int, error) {
  48. if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
  49. return 0, err
  50. }
  51. return len(p), nil
  52. }
  53. // SetDeadline sets both read and write deadlines, as per net.Conn interface docs:
  54. // "It is equivalent to calling both SetReadDeadline and SetWriteDeadline."
  55. // Note there is no synchronization here, but the gorilla implementation isn't thread safe anyway
  56. func (c *GorillaConn) SetDeadline(t time.Time) error {
  57. if err := c.Conn.SetReadDeadline(t); err != nil {
  58. return fmt.Errorf("error setting read deadline: %w", err)
  59. }
  60. if err := c.Conn.SetWriteDeadline(t); err != nil {
  61. return fmt.Errorf("error setting write deadline: %w", err)
  62. }
  63. return nil
  64. }
  65. type Conn struct {
  66. rw io.ReadWriter
  67. log *zerolog.Logger
  68. // writeLock makes sure
  69. // 1. Only one write at a time. The pinger and Stream function can both call write.
  70. // 2. Close only returns after in progress Write is finished, and no more Write will succeed after calling Close.
  71. writeLock sync.Mutex
  72. done bool
  73. }
  74. func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
  75. c := &Conn{
  76. rw: rw,
  77. log: log,
  78. }
  79. go c.pinger(ctx)
  80. return c
  81. }
  82. // Read will read messages from the websocket connection
  83. func (c *Conn) Read(reader []byte) (int, error) {
  84. data, err := wsutil.ReadClientBinary(c.rw)
  85. if err != nil {
  86. return 0, err
  87. }
  88. return copy(reader, data), nil
  89. }
  90. // Write will write messages to the websocket connection.
  91. // It will not write to the connection after Close is called to fix TUN-5184
  92. func (c *Conn) Write(p []byte) (int, error) {
  93. c.writeLock.Lock()
  94. defer c.writeLock.Unlock()
  95. if c.done {
  96. return 0, errors.New("write to closed websocket connection")
  97. }
  98. if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
  99. return 0, err
  100. }
  101. return len(p), nil
  102. }
  103. func (c *Conn) pinger(ctx context.Context) {
  104. pongMessge := wsutil.Message{
  105. OpCode: gobwas.OpPong,
  106. Payload: []byte{},
  107. }
  108. ticker := time.NewTicker(c.pingPeriod(ctx))
  109. defer ticker.Stop()
  110. for {
  111. select {
  112. case <-ticker.C:
  113. done, err := c.ping()
  114. if done {
  115. return
  116. }
  117. if err != nil {
  118. c.log.Debug().Err(err).Msgf("failed to write ping message")
  119. }
  120. if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
  121. c.log.Debug().Err(err).Msgf("failed to write pong message")
  122. }
  123. case <-ctx.Done():
  124. return
  125. }
  126. }
  127. }
  128. func (c *Conn) ping() (bool, error) {
  129. c.writeLock.Lock()
  130. defer c.writeLock.Unlock()
  131. if c.done {
  132. return true, nil
  133. }
  134. return false, wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{})
  135. }
  136. func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
  137. if val := ctx.Value(PingPeriodContextKey); val != nil {
  138. if period, ok := val.(time.Duration); ok {
  139. return period
  140. }
  141. }
  142. return defaultPingPeriod
  143. }
  144. // Close waits for the current write to finish. Further writes will return error
  145. func (c *Conn) Close() {
  146. c.writeLock.Lock()
  147. defer c.writeLock.Unlock()
  148. c.done = true
  149. }