123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- package turbotunnel
- import (
- "context"
- "errors"
- "net"
- "sync"
- "sync/atomic"
- "time"
- )
- // RedialPacketConn implements a long-lived net.PacketConn atop a sequence of
- // other, transient net.PacketConns. RedialPacketConn creates a new
- // net.PacketConn by calling a provided dialContext function. Whenever the
- // net.PacketConn experiences a ReadFrom or WriteTo error, RedialPacketConn
- // calls the dialContext function again and starts sending and receiving packets
- // on the new net.PacketConn. RedialPacketConn's own ReadFrom and WriteTo
- // methods return an error only when the dialContext function returns an error.
- //
- // RedialPacketConn uses static local and remote addresses that are independent
- // of those of any dialed net.PacketConn.
- type RedialPacketConn struct {
- localAddr net.Addr
- remoteAddr net.Addr
- dialContext func(context.Context) (net.PacketConn, error)
- recvQueue chan []byte
- sendQueue chan []byte
- closed chan struct{}
- closeOnce sync.Once
- // The first dial error, which causes the clientPacketConn to be
- // closed and is returned from future read/write operations. Compare to
- // the rerr and werr in io.Pipe.
- err atomic.Value
- }
- // NewQueuePacketConn makes a new RedialPacketConn, with the given static local
- // and remote addresses, and dialContext function.
- func NewRedialPacketConn(
- localAddr, remoteAddr net.Addr,
- dialContext func(context.Context) (net.PacketConn, error),
- ) *RedialPacketConn {
- c := &RedialPacketConn{
- localAddr: localAddr,
- remoteAddr: remoteAddr,
- dialContext: dialContext,
- recvQueue: make(chan []byte, queueSize),
- sendQueue: make(chan []byte, queueSize),
- closed: make(chan struct{}),
- err: atomic.Value{},
- }
- go c.dialLoop()
- return c
- }
- // dialLoop repeatedly calls c.dialContext and passes the resulting
- // net.PacketConn to c.exchange. It returns only when c is closed or dialContext
- // returns an error.
- func (c *RedialPacketConn) dialLoop() {
- ctx, cancel := context.WithCancel(context.Background())
- for {
- select {
- case <-c.closed:
- cancel()
- return
- default:
- }
- conn, err := c.dialContext(ctx)
- if err != nil {
- c.closeWithError(err)
- cancel()
- return
- }
- c.exchange(conn)
- conn.Close()
- }
- }
- // exchange calls ReadFrom on the given net.PacketConn and places the resulting
- // packets in the receive queue, and takes packets from the send queue and calls
- // WriteTo on them, making the current net.PacketConn active.
- func (c *RedialPacketConn) exchange(conn net.PacketConn) {
- readErrCh := make(chan error)
- writeErrCh := make(chan error)
- go func() {
- defer close(readErrCh)
- for {
- select {
- case <-c.closed:
- return
- case <-writeErrCh:
- return
- default:
- }
- var buf [1500]byte
- n, _, err := conn.ReadFrom(buf[:])
- if err != nil {
- readErrCh <- err
- return
- }
- p := make([]byte, n)
- copy(p, buf[:])
- select {
- case c.recvQueue <- p:
- default: // OK to drop packets.
- }
- }
- }()
- go func() {
- defer close(writeErrCh)
- for {
- select {
- case <-c.closed:
- return
- case <-readErrCh:
- return
- case p := <-c.sendQueue:
- _, err := conn.WriteTo(p, c.remoteAddr)
- if err != nil {
- writeErrCh <- err
- return
- }
- }
- }
- }()
- select {
- case <-readErrCh:
- case <-writeErrCh:
- }
- }
- // ReadFrom reads a packet from the currently active net.PacketConn. The
- // packet's original remote address is replaced with the RedialPacketConn's own
- // remote address.
- func (c *RedialPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
- select {
- case <-c.closed:
- return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
- default:
- }
- select {
- case <-c.closed:
- return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
- case buf := <-c.recvQueue:
- return copy(p, buf), c.remoteAddr, nil
- }
- }
- // WriteTo writes a packet to the currently active net.PacketConn. The addr
- // argument is ignored and instead replaced with the RedialPacketConn's own
- // remote address.
- func (c *RedialPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
- // addr is ignored.
- select {
- case <-c.closed:
- return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
- default:
- }
- buf := make([]byte, len(p))
- copy(buf, p)
- select {
- case c.sendQueue <- buf:
- return len(buf), nil
- default:
- // Drop the outgoing packet if the send queue is full.
- return len(buf), nil
- }
- }
- // closeWithError unblocks pending operations and makes future operations fail
- // with the given error. If err is nil, it becomes errClosedPacketConn.
- func (c *RedialPacketConn) closeWithError(err error) error {
- var once bool
- c.closeOnce.Do(func() {
- // Store the error to be returned by future read/write
- // operations.
- if err == nil {
- err = errors.New("operation on closed connection")
- }
- c.err.Store(err)
- close(c.closed)
- once = true
- })
- if !once {
- return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
- }
- return nil
- }
- // Close unblocks pending operations and makes future operations fail with a
- // "closed connection" error.
- func (c *RedialPacketConn) Close() error {
- return c.closeWithError(nil)
- }
- // LocalAddr returns the localAddr value that was passed to NewRedialPacketConn.
- func (c *RedialPacketConn) LocalAddr() net.Addr { return c.localAddr }
- func (c *RedialPacketConn) SetDeadline(t time.Time) error { return errNotImplemented }
- func (c *RedialPacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented }
- func (c *RedialPacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }
|