safe_stream.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package quic
  2. import (
  3. "errors"
  4. "net"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. "github.com/quic-go/quic-go"
  9. "github.com/rs/zerolog"
  10. "github.com/rs/zerolog/log"
  11. )
  12. // The error that is throw by the writer when there is `no network activity`.
  13. var idleTimeoutError = quic.IdleTimeoutError{}
  14. type SafeStreamCloser struct {
  15. lock sync.Mutex
  16. stream quic.Stream
  17. writeTimeout time.Duration
  18. log *zerolog.Logger
  19. closing atomic.Bool
  20. }
  21. func NewSafeStreamCloser(stream quic.Stream, writeTimeout time.Duration, log *zerolog.Logger) *SafeStreamCloser {
  22. return &SafeStreamCloser{
  23. stream: stream,
  24. writeTimeout: writeTimeout,
  25. log: log,
  26. }
  27. }
  28. func (s *SafeStreamCloser) Read(p []byte) (n int, err error) {
  29. return s.stream.Read(p)
  30. }
  31. func (s *SafeStreamCloser) Write(p []byte) (n int, err error) {
  32. s.lock.Lock()
  33. defer s.lock.Unlock()
  34. if s.writeTimeout > 0 {
  35. err = s.stream.SetWriteDeadline(time.Now().Add(s.writeTimeout))
  36. if err != nil {
  37. log.Err(err).Msg("Error setting write deadline for QUIC stream")
  38. }
  39. }
  40. nBytes, err := s.stream.Write(p)
  41. if err != nil {
  42. s.handleWriteError(err)
  43. }
  44. return nBytes, err
  45. }
  46. // Handles the timeout error in case it happened, by canceling the stream write.
  47. func (s *SafeStreamCloser) handleWriteError(err error) {
  48. // If we are closing the stream we just ignore any write error.
  49. if s.closing.Load() {
  50. return
  51. }
  52. var netErr net.Error
  53. if errors.As(err, &netErr) {
  54. if netErr.Timeout() {
  55. // We don't need to log if what cause the timeout was no network activity.
  56. if !errors.Is(netErr, &idleTimeoutError) {
  57. s.log.Error().Err(netErr).Msg("Closing quic stream due to timeout while writing")
  58. }
  59. // We need to explicitly cancel the write so that it frees all buffers.
  60. s.stream.CancelWrite(0)
  61. }
  62. }
  63. }
  64. func (s *SafeStreamCloser) Close() error {
  65. // Set this stream to a closing state.
  66. s.closing.Store(true)
  67. // Make sure a possible writer does not block the lock forever. We need it, so we can close the writer
  68. // side of the stream safely.
  69. _ = s.stream.SetWriteDeadline(time.Now())
  70. // This lock is eventually acquired despite Write also acquiring it, because we set a deadline to writes.
  71. s.lock.Lock()
  72. defer s.lock.Unlock()
  73. // We have to clean up the receiving stream ourselves since the Close in the bottom does not handle that.
  74. s.stream.CancelRead(0)
  75. return s.stream.Close()
  76. }
  77. func (s *SafeStreamCloser) CloseWrite() error {
  78. s.lock.Lock()
  79. defer s.lock.Unlock()
  80. // As documented by the quic-go library, this doesn't actually close the entire stream.
  81. // It prevents further writes, which in turn will result in an EOF signal being sent the other side of stream when
  82. // reading.
  83. // We can still read from this stream.
  84. return s.stream.Close()
  85. }
  86. func (s *SafeStreamCloser) SetDeadline(deadline time.Time) error {
  87. return s.stream.SetDeadline(deadline)
  88. }