redialpacketconn.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. package turbotunnel
  2. import (
  3. "context"
  4. "errors"
  5. "net"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. )
  10. // RedialPacketConn implements a long-lived net.PacketConn atop a sequence of
  11. // other, transient net.PacketConns. RedialPacketConn creates a new
  12. // net.PacketConn by calling a provided dialContext function. Whenever the
  13. // net.PacketConn experiences a ReadFrom or WriteTo error, RedialPacketConn
  14. // calls the dialContext function again and starts sending and receiving packets
  15. // on the new net.PacketConn. RedialPacketConn's own ReadFrom and WriteTo
  16. // methods return an error only when the dialContext function returns an error.
  17. //
  18. // RedialPacketConn uses static local and remote addresses that are independent
  19. // of those of any dialed net.PacketConn.
  20. type RedialPacketConn struct {
  21. localAddr net.Addr
  22. remoteAddr net.Addr
  23. dialContext func(context.Context) (net.PacketConn, error)
  24. recvQueue chan []byte
  25. sendQueue chan []byte
  26. closed chan struct{}
  27. closeOnce sync.Once
  28. // The first dial error, which causes the clientPacketConn to be
  29. // closed and is returned from future read/write operations. Compare to
  30. // the rerr and werr in io.Pipe.
  31. err atomic.Value
  32. }
  33. // NewQueuePacketConn makes a new RedialPacketConn, with the given static local
  34. // and remote addresses, and dialContext function.
  35. func NewRedialPacketConn(
  36. localAddr, remoteAddr net.Addr,
  37. dialContext func(context.Context) (net.PacketConn, error),
  38. ) *RedialPacketConn {
  39. c := &RedialPacketConn{
  40. localAddr: localAddr,
  41. remoteAddr: remoteAddr,
  42. dialContext: dialContext,
  43. recvQueue: make(chan []byte, queueSize),
  44. sendQueue: make(chan []byte, queueSize),
  45. closed: make(chan struct{}),
  46. err: atomic.Value{},
  47. }
  48. go c.dialLoop()
  49. return c
  50. }
  51. // dialLoop repeatedly calls c.dialContext and passes the resulting
  52. // net.PacketConn to c.exchange. It returns only when c is closed or dialContext
  53. // returns an error.
  54. func (c *RedialPacketConn) dialLoop() {
  55. ctx, cancel := context.WithCancel(context.Background())
  56. for {
  57. select {
  58. case <-c.closed:
  59. cancel()
  60. return
  61. default:
  62. }
  63. conn, err := c.dialContext(ctx)
  64. if err != nil {
  65. c.closeWithError(err)
  66. cancel()
  67. return
  68. }
  69. c.exchange(conn)
  70. conn.Close()
  71. }
  72. }
  73. // exchange calls ReadFrom on the given net.PacketConn and places the resulting
  74. // packets in the receive queue, and takes packets from the send queue and calls
  75. // WriteTo on them, making the current net.PacketConn active.
  76. func (c *RedialPacketConn) exchange(conn net.PacketConn) {
  77. readErrCh := make(chan error)
  78. writeErrCh := make(chan error)
  79. go func() {
  80. defer close(readErrCh)
  81. for {
  82. select {
  83. case <-c.closed:
  84. return
  85. case <-writeErrCh:
  86. return
  87. default:
  88. }
  89. var buf [1500]byte
  90. n, _, err := conn.ReadFrom(buf[:])
  91. if err != nil {
  92. readErrCh <- err
  93. return
  94. }
  95. p := make([]byte, n)
  96. copy(p, buf[:])
  97. select {
  98. case c.recvQueue <- p:
  99. default: // OK to drop packets.
  100. }
  101. }
  102. }()
  103. go func() {
  104. defer close(writeErrCh)
  105. for {
  106. select {
  107. case <-c.closed:
  108. return
  109. case <-readErrCh:
  110. return
  111. case p := <-c.sendQueue:
  112. _, err := conn.WriteTo(p, c.remoteAddr)
  113. if err != nil {
  114. writeErrCh <- err
  115. return
  116. }
  117. }
  118. }
  119. }()
  120. select {
  121. case <-readErrCh:
  122. case <-writeErrCh:
  123. }
  124. }
  125. // ReadFrom reads a packet from the currently active net.PacketConn. The
  126. // packet's original remote address is replaced with the RedialPacketConn's own
  127. // remote address.
  128. func (c *RedialPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
  129. select {
  130. case <-c.closed:
  131. return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
  132. default:
  133. }
  134. select {
  135. case <-c.closed:
  136. return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
  137. case buf := <-c.recvQueue:
  138. return copy(p, buf), c.remoteAddr, nil
  139. }
  140. }
  141. // WriteTo writes a packet to the currently active net.PacketConn. The addr
  142. // argument is ignored and instead replaced with the RedialPacketConn's own
  143. // remote address.
  144. func (c *RedialPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
  145. // addr is ignored.
  146. select {
  147. case <-c.closed:
  148. return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
  149. default:
  150. }
  151. buf := make([]byte, len(p))
  152. copy(buf, p)
  153. select {
  154. case c.sendQueue <- buf:
  155. return len(buf), nil
  156. default:
  157. // Drop the outgoing packet if the send queue is full.
  158. return len(buf), nil
  159. }
  160. }
  161. // closeWithError unblocks pending operations and makes future operations fail
  162. // with the given error. If err is nil, it becomes errClosedPacketConn.
  163. func (c *RedialPacketConn) closeWithError(err error) error {
  164. var once bool
  165. c.closeOnce.Do(func() {
  166. // Store the error to be returned by future read/write
  167. // operations.
  168. if err == nil {
  169. err = errors.New("operation on closed connection")
  170. }
  171. c.err.Store(err)
  172. close(c.closed)
  173. once = true
  174. })
  175. if !once {
  176. return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
  177. }
  178. return nil
  179. }
  180. // Close unblocks pending operations and makes future operations fail with a
  181. // "closed connection" error.
  182. func (c *RedialPacketConn) Close() error {
  183. return c.closeWithError(nil)
  184. }
  185. // LocalAddr returns the localAddr value that was passed to NewRedialPacketConn.
  186. func (c *RedialPacketConn) LocalAddr() net.Addr { return c.localAddr }
  187. func (c *RedialPacketConn) SetDeadline(t time.Time) error { return errNotImplemented }
  188. func (c *RedialPacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented }
  189. func (c *RedialPacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }