|
- package datagramsession
- import (
- "bytes"
- "context"
- "fmt"
- "io"
- "net"
- "sync"
- "testing"
- "time"
- "github.com/google/uuid"
- "github.com/rs/zerolog"
- "github.com/stretchr/testify/require"
- "golang.org/x/sync/errgroup"
- "github.com/cloudflare/cloudflared/packet"
- )
- var (
- nopLogger = zerolog.Nop()
- )
- func TestManagerServe(t *testing.T) {
- const (
- sessions = 2
- msgs = 5
- remoteUnregisterMsg = "eyeball closed connection"
- )
- requestChan := make(chan *packet.Session)
- transport := mockQUICTransport{
- sessions: make(map[uuid.UUID]chan []byte),
- }
- for i := 0; i < sessions; i++ {
- transport.sessions[uuid.New()] = make(chan []byte)
- }
- mg := NewManager(&nopLogger, transport.MuxSession, requestChan)
- ctx, cancel := context.WithCancel(context.Background())
- serveDone := make(chan struct{})
- go func(ctx context.Context) {
- mg.Serve(ctx)
- close(serveDone)
- }(ctx)
- errGroup, ctx := errgroup.WithContext(ctx)
- for sessionID, eyeballRespChan := range transport.sessions {
- // Assign loop variables to local variables
- sID := sessionID
- payload := testPayload(sID)
- expectResp := testResponse(payload)
- cfdConn, originConn := net.Pipe()
- origin := mockOrigin{
- expectMsgCount: msgs,
- expectedMsg: payload,
- expectedResp: expectResp,
- conn: originConn,
- }
- eyeball := mockEyeballSession{
- id: sID,
- expectedMsgCount: msgs,
- expectedMsg: payload,
- expectedResponse: expectResp,
- respReceiver: eyeballRespChan,
- }
- // Assign loop variables to local variables
- errGroup.Go(func() error {
- session, err := mg.RegisterSession(ctx, sID, cfdConn)
- require.NoError(t, err)
- reqErrGroup, reqCtx := errgroup.WithContext(ctx)
- reqErrGroup.Go(func() error {
- return origin.serve()
- })
- reqErrGroup.Go(func() error {
- return eyeball.serve(reqCtx, requestChan)
- })
- sessionDone := make(chan struct{})
- go func() {
- closedByRemote, err := session.Serve(ctx, time.Minute*2)
- closeSession := &errClosedSession{
- message: remoteUnregisterMsg,
- byRemote: true,
- }
- require.Equal(t, closeSession, err)
- require.True(t, closedByRemote)
- close(sessionDone)
- }()
- // Make sure eyeball and origin have received all messages before unregistering the session
- require.NoError(t, reqErrGroup.Wait())
- require.NoError(t, mg.UnregisterSession(ctx, sID, remoteUnregisterMsg, true))
- <-sessionDone
- return nil
- })
- }
- require.NoError(t, errGroup.Wait())
- cancel()
- <-serveDone
- }
- func TestTimeout(t *testing.T) {
- const (
- testTimeout = time.Millisecond * 50
- )
- mg := NewManager(&nopLogger, nil, nil)
- mg.timeout = testTimeout
- ctx := context.Background()
- sessionID := uuid.New()
- // session manager is not running, so event loop is not running and therefore calling the APIs should timeout
- session, err := mg.RegisterSession(ctx, sessionID, nil)
- require.ErrorIs(t, err, context.DeadlineExceeded)
- require.Nil(t, session)
- err = mg.UnregisterSession(ctx, sessionID, "session gone", true)
- require.ErrorIs(t, err, context.DeadlineExceeded)
- }
- func TestUnregisterSessionCloseSession(t *testing.T) {
- sessionID := uuid.New()
- payload := []byte(t.Name())
- sender := newMockTransportSender(sessionID, payload)
- mg := NewManager(&nopLogger, sender.muxSession, nil)
- ctx, cancel := context.WithCancel(context.Background())
- managerDone := make(chan struct{})
- go func() {
- err := mg.Serve(ctx)
- require.Error(t, err)
- close(managerDone)
- }()
- cfdConn, originConn := net.Pipe()
- session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
- require.NoError(t, err)
- require.NotNil(t, session)
- unregisteredChan := make(chan struct{})
- go func() {
- _, err := originConn.Write(payload)
- require.NoError(t, err)
- err = mg.UnregisterSession(ctx, sessionID, "eyeball closed session", true)
- require.NoError(t, err)
- close(unregisteredChan)
- }()
- closedByRemote, err := session.Serve(ctx, time.Minute)
- require.True(t, closedByRemote)
- require.Error(t, err)
- <-unregisteredChan
- cancel()
- <-managerDone
- }
- func TestManagerCtxDoneCloseSessions(t *testing.T) {
- sessionID := uuid.New()
- payload := []byte(t.Name())
- sender := newMockTransportSender(sessionID, payload)
- mg := NewManager(&nopLogger, sender.muxSession, nil)
- ctx, cancel := context.WithCancel(context.Background())
- var wg sync.WaitGroup
- wg.Add(1)
- go func() {
- defer wg.Done()
- err := mg.Serve(ctx)
- require.Error(t, err)
- }()
- cfdConn, originConn := net.Pipe()
- session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
- require.NoError(t, err)
- require.NotNil(t, session)
- wg.Add(1)
- go func() {
- defer wg.Done()
- _, err := originConn.Write(payload)
- require.NoError(t, err)
- cancel()
- }()
- closedByRemote, err := session.Serve(ctx, time.Minute)
- require.False(t, closedByRemote)
- require.Error(t, err)
- wg.Wait()
- }
- type mockOrigin struct {
- expectMsgCount int
- expectedMsg []byte
- expectedResp []byte
- conn io.ReadWriteCloser
- }
- func (mo *mockOrigin) serve() error {
- expectedMsgLen := len(mo.expectedMsg)
- readBuffer := make([]byte, expectedMsgLen+1)
- for i := 0; i < mo.expectMsgCount; i++ {
- n, err := mo.conn.Read(readBuffer)
- if err != nil {
- return err
- }
- if n != expectedMsgLen {
- return fmt.Errorf("Expect to read %d bytes, read %d", expectedMsgLen, n)
- }
- if !bytes.Equal(readBuffer[:n], mo.expectedMsg) {
- return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n])
- }
- _, err = mo.conn.Write(mo.expectedResp)
- if err != nil {
- return err
- }
- }
- return nil
- }
- func testPayload(sessionID uuid.UUID) []byte {
- return []byte(fmt.Sprintf("Message from %s", sessionID))
- }
- func testResponse(msg []byte) []byte {
- return []byte(fmt.Sprintf("Response to %v", msg))
- }
- type mockQUICTransport struct {
- sessions map[uuid.UUID]chan []byte
- }
- func (me *mockQUICTransport) MuxSession(session *packet.Session) error {
- s := me.sessions[session.ID]
- s <- session.Payload
- return nil
- }
- type mockEyeballSession struct {
- id uuid.UUID
- expectedMsgCount int
- expectedMsg []byte
- expectedResponse []byte
- respReceiver <-chan []byte
- }
- func (me *mockEyeballSession) serve(ctx context.Context, requestChan chan *packet.Session) error {
- for i := 0; i < me.expectedMsgCount; i++ {
- requestChan <- &packet.Session{
- ID: me.id,
- Payload: me.expectedMsg,
- }
- resp := <-me.respReceiver
- if !bytes.Equal(resp, me.expectedResponse) {
- return fmt.Errorf("Expect %v, read %v", me.expectedResponse, resp)
- }
- fmt.Println("Resp", resp)
- }
- return nil
- }
|