session_test.go 6.6 KB


  1. package datagramsession
  2. import (
  3. "bytes"
  4. "context"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "testing"
  10. "time"
  11. "github.com/google/uuid"
  12. "github.com/rs/zerolog"
  13. "github.com/stretchr/testify/require"
  14. "golang.org/x/sync/errgroup"
  15. "github.com/cloudflare/cloudflared/packet"
  16. )
  17. // TestCloseSession makes sure a session will stop after context is done
  18. func TestSessionCtxDone(t *testing.T) {
  19. testSessionReturns(t, closeByContext, time.Minute*2)
  20. }
  21. // TestCloseSession makes sure a session will stop after close method is called
  22. func TestCloseSession(t *testing.T) {
  23. testSessionReturns(t, closeByCallingClose, time.Minute*2)
  24. }
  25. // TestCloseIdle makess sure a session will stop after there is no read/write for a period defined by closeAfterIdle
  26. func TestCloseIdle(t *testing.T) {
  27. testSessionReturns(t, closeByTimeout, time.Millisecond*100)
  28. }
  29. func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.Duration) {
  30. var (
  31. localCloseReason = &errClosedSession{
  32. message: "connection closed by origin",
  33. byRemote: false,
  34. }
  35. )
  36. sessionID := uuid.New()
  37. cfdConn, originConn := net.Pipe()
  38. payload := testPayload(sessionID)
  39. log := zerolog.Nop()
  40. mg := NewManager(&log, nil, nil)
  41. session := mg.newSession(sessionID, cfdConn)
  42. ctx, cancel := context.WithCancel(context.Background())
  43. sessionDone := make(chan struct{})
  44. go func() {
  45. closedByRemote, err := session.Serve(ctx, closeAfterIdle)
  46. switch closeBy {
  47. case closeByContext:
  48. require.Equal(t, context.Canceled, err)
  49. require.False(t, closedByRemote)
  50. case closeByCallingClose:
  51. require.Equal(t, localCloseReason, err)
  52. require.Equal(t, localCloseReason.byRemote, closedByRemote)
  53. case closeByTimeout:
  54. require.Equal(t, SessionIdleErr(closeAfterIdle), err)
  55. require.False(t, closedByRemote)
  56. }
  57. close(sessionDone)
  58. }()
  59. go func() {
  60. n, err := session.transportToDst(payload)
  61. require.NoError(t, err)
  62. require.Equal(t, len(payload), n)
  63. }()
  64. readBuffer := make([]byte, len(payload)+1)
  65. n, err := originConn.Read(readBuffer)
  66. require.NoError(t, err)
  67. require.Equal(t, len(payload), n)
  68. lastRead := time.Now()
  69. switch closeBy {
  70. case closeByContext:
  71. cancel()
  72. case closeByCallingClose:
  73. session.close(localCloseReason)
  74. }
  75. <-sessionDone
  76. if closeBy == closeByTimeout {
  77. require.True(t, time.Now().After(lastRead.Add(closeAfterIdle)))
  78. }
  79. // call cancelled again otherwise the linter will warn about possible context leak
  80. cancel()
  81. }
  82. type closeMethod int
  83. const (
  84. closeByContext closeMethod = iota
  85. closeByCallingClose
  86. closeByTimeout
  87. )
  88. func TestWriteToDstSessionPreventClosed(t *testing.T) {
  89. testActiveSessionNotClosed(t, false, true)
  90. }
  91. func TestReadFromDstSessionPreventClosed(t *testing.T) {
  92. testActiveSessionNotClosed(t, true, false)
  93. }
  94. func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) {
  95. const closeAfterIdle = time.Millisecond * 100
  96. const activeTime = time.Millisecond * 500
  97. sessionID := uuid.New()
  98. cfdConn, originConn := net.Pipe()
  99. payload := testPayload(sessionID)
  100. respChan := make(chan *packet.Session)
  101. sender := newMockTransportSender(sessionID, payload)
  102. mg := NewManager(&nopLogger, sender.muxSession, respChan)
  103. session := mg.newSession(sessionID, cfdConn)
  104. startTime := time.Now()
  105. activeUntil := startTime.Add(activeTime)
  106. ctx, cancel := context.WithCancel(context.Background())
  107. errGroup, ctx := errgroup.WithContext(ctx)
  108. errGroup.Go(func() error {
  109. session.Serve(ctx, closeAfterIdle)
  110. if time.Now().Before(startTime.Add(activeTime)) {
  111. return fmt.Errorf("session closed while it's still active")
  112. }
  113. return nil
  114. })
  115. if readFromDst {
  116. errGroup.Go(func() error {
  117. for {
  118. if time.Now().After(activeUntil) {
  119. return nil
  120. }
  121. if _, err := originConn.Write(payload); err != nil {
  122. return err
  123. }
  124. time.Sleep(closeAfterIdle / 2)
  125. }
  126. })
  127. }
  128. if writeToDst {
  129. errGroup.Go(func() error {
  130. readBuffer := make([]byte, len(payload))
  131. for {
  132. n, err := originConn.Read(readBuffer)
  133. if err != nil {
  134. if err == io.EOF || err == io.ErrClosedPipe {
  135. return nil
  136. }
  137. return err
  138. }
  139. if !bytes.Equal(payload, readBuffer[:n]) {
  140. return fmt.Errorf("payload %v is not equal to %v", readBuffer[:n], payload)
  141. }
  142. }
  143. })
  144. errGroup.Go(func() error {
  145. for {
  146. if time.Now().After(activeUntil) {
  147. return nil
  148. }
  149. if _, err := session.transportToDst(payload); err != nil {
  150. return err
  151. }
  152. time.Sleep(closeAfterIdle / 2)
  153. }
  154. })
  155. }
  156. require.NoError(t, errGroup.Wait())
  157. cancel()
  158. }
  159. func TestMarkActiveNotBlocking(t *testing.T) {
  160. const concurrentCalls = 50
  161. mg := NewManager(&nopLogger, nil, nil)
  162. session := mg.newSession(uuid.New(), nil)
  163. var wg sync.WaitGroup
  164. wg.Add(concurrentCalls)
  165. for i := 0; i < concurrentCalls; i++ {
  166. go func() {
  167. session.markActive()
  168. wg.Done()
  169. }()
  170. }
  171. wg.Wait()
  172. }
  173. // Some UDP application might send 0-size payload.
  174. func TestZeroBytePayload(t *testing.T) {
  175. sessionID := uuid.New()
  176. cfdConn, originConn := net.Pipe()
  177. sender := sendOnceTransportSender{
  178. baseSender: newMockTransportSender(sessionID, make([]byte, 0)),
  179. sentChan: make(chan struct{}),
  180. }
  181. mg := NewManager(&nopLogger, sender.muxSession, nil)
  182. session := mg.newSession(sessionID, cfdConn)
  183. ctx, cancel := context.WithCancel(context.Background())
  184. errGroup, ctx := errgroup.WithContext(ctx)
  185. errGroup.Go(func() error {
  186. // Read from underlying conn and send to transport
  187. closedByRemote, err := session.Serve(ctx, time.Minute*2)
  188. require.Equal(t, context.Canceled, err)
  189. require.False(t, closedByRemote)
  190. return nil
  191. })
  192. errGroup.Go(func() error {
  193. // Write to underlying connection
  194. n, err := originConn.Write([]byte{})
  195. require.NoError(t, err)
  196. require.Equal(t, 0, n)
  197. return nil
  198. })
  199. <-sender.sentChan
  200. cancel()
  201. require.NoError(t, errGroup.Wait())
  202. }
  203. type mockTransportSender struct {
  204. expectedSessionID uuid.UUID
  205. expectedPayload []byte
  206. }
  207. func newMockTransportSender(expectedSessionID uuid.UUID, expectedPayload []byte) *mockTransportSender {
  208. return &mockTransportSender{
  209. expectedSessionID: expectedSessionID,
  210. expectedPayload: expectedPayload,
  211. }
  212. }
  213. func (mts *mockTransportSender) muxSession(session *packet.Session) error {
  214. if session.ID != mts.expectedSessionID {
  215. return fmt.Errorf("Expect session %s, got %s", mts.expectedSessionID, session.ID)
  216. }
  217. if !bytes.Equal(session.Payload, mts.expectedPayload) {
  218. return fmt.Errorf("Expect %v, read %v", mts.expectedPayload, session.Payload)
  219. }
  220. return nil
  221. }
  222. type sendOnceTransportSender struct {
  223. baseSender *mockTransportSender
  224. sentChan chan struct{}
  225. }
  226. func (sots *sendOnceTransportSender) muxSession(session *packet.Session) error {
  227. defer close(sots.sentChan)
  228. return sots.baseSender.muxSession(session)
  229. }