stream.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package stream
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "io"
  6. "runtime/debug"
  7. "sync/atomic"
  8. "time"
  9. "github.com/getsentry/sentry-go"
  10. "github.com/pkg/errors"
  11. "github.com/rs/zerolog"
  12. "github.com/cloudflare/cloudflared/cfio"
  13. )
  14. type Stream interface {
  15. Reader
  16. WriterCloser
  17. }
  18. type Reader interface {
  19. io.Reader
  20. }
  21. type WriterCloser interface {
  22. io.Writer
  23. WriteCloser
  24. }
  25. type WriteCloser interface {
  26. CloseWrite() error
  27. }
  28. type nopCloseWriterAdapter struct {
  29. io.ReadWriter
  30. }
  31. func NopCloseWriterAdapter(stream io.ReadWriter) *nopCloseWriterAdapter {
  32. return &nopCloseWriterAdapter{stream}
  33. }
  34. func (n *nopCloseWriterAdapter) CloseWrite() error {
  35. return nil
  36. }
  37. type bidirectionalStreamStatus struct {
  38. doneChan chan struct{}
  39. anyDone uint32
  40. }
  41. func newBiStreamStatus() *bidirectionalStreamStatus {
  42. return &bidirectionalStreamStatus{
  43. doneChan: make(chan struct{}, 2),
  44. anyDone: 0,
  45. }
  46. }
  47. func (s *bidirectionalStreamStatus) markUniStreamDone() {
  48. atomic.StoreUint32(&s.anyDone, 1)
  49. s.doneChan <- struct{}{}
  50. }
  51. func (s *bidirectionalStreamStatus) wait(maxWaitForSecondStream time.Duration) error {
  52. <-s.doneChan
  53. // Only wait for second stream to finish if maxWait is greater than zero
  54. if maxWaitForSecondStream > 0 {
  55. timer := time.NewTimer(maxWaitForSecondStream)
  56. defer timer.Stop()
  57. select {
  58. case <-timer.C:
  59. return fmt.Errorf("timeout waiting for second stream to finish")
  60. case <-s.doneChan:
  61. return nil
  62. }
  63. }
  64. return nil
  65. }
  66. func (s *bidirectionalStreamStatus) isAnyDone() bool {
  67. return atomic.LoadUint32(&s.anyDone) > 0
  68. }
  69. // Pipe copies copy data to & from provided io.ReadWriters.
  70. func Pipe(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) {
  71. PipeBidirectional(NopCloseWriterAdapter(tunnelConn), NopCloseWriterAdapter(originConn), 0, log)
  72. }
  73. // PipeBidirectional copies data two BidirectionStreams. It is a special case of Pipe where it receives a concept that allows for Read and Write side to be closed independently.
  74. // The main difference is that when piping data from a reader to a writer, if EOF is read, then this implementation propagates the EOF signal to the destination/writer by closing the write side of the
  75. // Bidirectional Stream.
  76. // Finally, depending on once EOF is ready from one of the provided streams, the other direction of streaming data will have a configured time period to also finish, otherwise,
  77. // the method will return immediately with a timeout error. It is however, the responsability of the caller to close the associated streams in both ends in order to free all the resources/go-routines.
  78. func PipeBidirectional(downstream, upstream Stream, maxWaitForSecondStream time.Duration, log *zerolog.Logger) error {
  79. status := newBiStreamStatus()
  80. go unidirectionalStream(downstream, upstream, "upstream->downstream", status, log)
  81. go unidirectionalStream(upstream, downstream, "downstream->upstream", status, log)
  82. if err := status.wait(maxWaitForSecondStream); err != nil {
  83. return errors.Wrap(err, "unable to wait for both streams while proxying")
  84. }
  85. return nil
  86. }
  87. func unidirectionalStream(dst WriterCloser, src Reader, dir string, status *bidirectionalStreamStatus, log *zerolog.Logger) {
  88. defer func() {
  89. // The bidirectional streaming spawns 2 goroutines to stream each direction.
  90. // If any ends, the callstack returns, meaning the Tunnel request/stream (depending on http2 vs quic) will
  91. // close. In such case, if the other direction did not stop (due to application level stopping, e.g., if a
  92. // server/origin listens forever until closure), it may read/write from the underlying ReadWriter (backed by
  93. // the Edge<->cloudflared transport) in an unexpected state.
  94. // Because of this, we set this recover() logic.
  95. if err := recover(); err != nil {
  96. if status.isAnyDone() {
  97. // We handle such unexpected errors only when we detect that one side of the streaming is done.
  98. log.Debug().Msgf("recovered from panic in stream.Pipe for %s, error %s, %s", dir, err, debug.Stack())
  99. } else {
  100. // Otherwise, this is unexpected, but we prevent the program from crashing anyway.
  101. log.Warn().Msgf("recovered from panic in stream.Pipe for %s, error %s, %s", dir, err, debug.Stack())
  102. sentry.CurrentHub().Recover(err)
  103. sentry.Flush(time.Second * 5)
  104. }
  105. }
  106. }()
  107. defer dst.CloseWrite()
  108. _, err := copyData(dst, src, dir)
  109. if err != nil {
  110. log.Debug().Msgf("%s copy: %v", dir, err)
  111. }
  112. status.markUniStreamDone()
  113. }
  114. // when set to true, enables logging of content copied to/from origin and tunnel
  115. const debugCopy = false
  116. func copyData(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
  117. if debugCopy {
  118. // copyBuffer is based on stdio Copy implementation but shows copied data
  119. copyBuffer := func(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
  120. var buf []byte
  121. size := 32 * 1024
  122. buf = make([]byte, size)
  123. for {
  124. t := time.Now()
  125. nr, er := src.Read(buf)
  126. if nr > 0 {
  127. fmt.Println(dir, t.UnixNano(), "\n"+hex.Dump(buf[0:nr]))
  128. nw, ew := dst.Write(buf[0:nr])
  129. if nw < 0 || nr < nw {
  130. nw = 0
  131. if ew == nil {
  132. ew = errors.New("invalid write")
  133. }
  134. }
  135. written += int64(nw)
  136. if ew != nil {
  137. err = ew
  138. break
  139. }
  140. if nr != nw {
  141. err = io.ErrShortWrite
  142. break
  143. }
  144. }
  145. if er != nil {
  146. if er != io.EOF {
  147. err = er
  148. }
  149. break
  150. }
  151. }
  152. return written, err
  153. }
  154. return copyBuffer(dst, src, dir)
  155. } else {
  156. return cfio.Copy(dst, src)
  157. }
  158. }