queuepacketconn.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. package turbotunnel
  2. import (
  3. "net"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. )
  8. // taggedPacket is a combination of a []byte and a net.Addr, encapsulating the
  9. // return type of PacketConn.ReadFrom.
  10. type taggedPacket struct {
  11. P []byte
  12. Addr net.Addr
  13. }
  14. // QueuePacketConn implements net.PacketConn by storing queues of packets. There
  15. // is one incoming queue (where packets are additionally tagged by the source
  16. // address of the client that sent them). There are many outgoing queues, one
  17. // for each client address that has been recently seen. The QueueIncoming method
  18. // inserts a packet into the incoming queue, to eventually be returned by
  19. // ReadFrom. WriteTo inserts a packet into an address-specific outgoing queue,
  20. // which can later by accessed through the OutgoingQueue method.
  21. type QueuePacketConn struct {
  22. clients *ClientMap
  23. localAddr net.Addr
  24. recvQueue chan taggedPacket
  25. closeOnce sync.Once
  26. closed chan struct{}
  27. mtu int
  28. // Pool of reusable mtu-sized buffers.
  29. bufPool sync.Pool
  30. // What error to return when the QueuePacketConn is closed.
  31. err atomic.Value
  32. }
  33. // NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients
  34. // for at least a duration of timeout. The maximum packet size is mtu.
  35. func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration, mtu int) *QueuePacketConn {
  36. return &QueuePacketConn{
  37. clients: NewClientMap(timeout),
  38. localAddr: localAddr,
  39. recvQueue: make(chan taggedPacket, queueSize),
  40. closed: make(chan struct{}),
  41. mtu: mtu,
  42. bufPool: sync.Pool{New: func() interface{} { return make([]byte, mtu) }},
  43. }
  44. }
  45. // QueueIncoming queues an incoming packet and its source address, to be
  46. // returned in a future call to ReadFrom. If p is longer than the MTU, only its
  47. // first MTU bytes will be used.
  48. func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
  49. select {
  50. case <-c.closed:
  51. // If we're closed, silently drop it.
  52. return
  53. default:
  54. }
  55. // Copy the slice so that the caller may reuse it.
  56. buf := c.bufPool.Get().([]byte)
  57. if len(p) < cap(buf) {
  58. buf = buf[:len(p)]
  59. } else {
  60. buf = buf[:cap(buf)]
  61. }
  62. copy(buf, p)
  63. select {
  64. case c.recvQueue <- taggedPacket{buf, addr}:
  65. default:
  66. // Drop the incoming packet if the receive queue is full.
  67. c.Restore(buf)
  68. }
  69. }
  70. // OutgoingQueue returns the queue of outgoing packets corresponding to addr,
  71. // creating it if necessary. The contents of the queue will be packets that are
  72. // written to the address in question using WriteTo.
  73. func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte {
  74. return c.clients.SendQueue(addr)
  75. }
  76. // Restore adds a slice to the internal pool of packet buffers. Typically you
  77. // will call this with a slice from the OutgoingQueue channel once you are done
  78. // using it. (It is not an error to fail to do so, it will just result in more
  79. // allocations.)
  80. func (c *QueuePacketConn) Restore(p []byte) {
  81. if cap(p) >= c.mtu {
  82. c.bufPool.Put(p)
  83. }
  84. }
  85. // ReadFrom returns a packet and address previously stored by QueueIncoming.
  86. func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
  87. select {
  88. case <-c.closed:
  89. return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
  90. default:
  91. }
  92. select {
  93. case <-c.closed:
  94. return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
  95. case packet := <-c.recvQueue:
  96. n := copy(p, packet.P)
  97. c.Restore(packet.P)
  98. return n, packet.Addr, nil
  99. }
  100. }
  101. // WriteTo queues an outgoing packet for the given address. The queue can later
  102. // be retrieved using the OutgoingQueue method. If p is longer than the MTU,
  103. // only its first MTU bytes will be used.
  104. func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
  105. select {
  106. case <-c.closed:
  107. return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
  108. default:
  109. }
  110. // Copy the slice so that the caller may reuse it.
  111. buf := c.bufPool.Get().([]byte)
  112. if len(p) < cap(buf) {
  113. buf = buf[:len(p)]
  114. } else {
  115. buf = buf[:cap(buf)]
  116. }
  117. copy(buf, p)
  118. select {
  119. case c.clients.SendQueue(addr) <- buf:
  120. return len(buf), nil
  121. default:
  122. // Drop the outgoing packet if the send queue is full.
  123. c.Restore(buf)
  124. return len(p), nil
  125. }
  126. }
  127. // closeWithError unblocks pending operations and makes future operations fail
  128. // with the given error. If err is nil, it becomes errClosedPacketConn.
  129. func (c *QueuePacketConn) closeWithError(err error) error {
  130. var newlyClosed bool
  131. c.closeOnce.Do(func() {
  132. newlyClosed = true
  133. // Store the error to be returned by future PacketConn
  134. // operations.
  135. if err == nil {
  136. err = errClosedPacketConn
  137. }
  138. c.err.Store(err)
  139. close(c.closed)
  140. })
  141. if !newlyClosed {
  142. return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
  143. }
  144. return nil
  145. }
  146. // Close unblocks pending operations and makes future operations fail with a
  147. // "closed connection" error.
  148. func (c *QueuePacketConn) Close() error {
  149. return c.closeWithError(nil)
  150. }
  151. // LocalAddr returns the localAddr value that was passed to NewQueuePacketConn.
  152. func (c *QueuePacketConn) LocalAddr() net.Addr { return c.localAddr }
  153. func (c *QueuePacketConn) SetDeadline(t time.Time) error { return errNotImplemented }
  154. func (c *QueuePacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented }
  155. func (c *QueuePacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }