funnel_test.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package packet
  2. import (
  3. "fmt"
  4. "net/netip"
  5. "testing"
  6. "time"
  7. "github.com/stretchr/testify/require"
  8. )
  9. type mockFunnelUniPipe struct {
  10. uniPipe chan RawPacket
  11. }
  12. func (mfui *mockFunnelUniPipe) SendPacket(dst netip.Addr, pk RawPacket) error {
  13. mfui.uniPipe <- pk
  14. return nil
  15. }
  16. func (mfui *mockFunnelUniPipe) Close() error {
  17. return nil
  18. }
  19. func TestFunnelRegistration(t *testing.T) {
  20. id := testFunnelID{"id1"}
  21. funnelErr := fmt.Errorf("expected error")
  22. newFunnelFuncErr := func() (Funnel, error) { return nil, funnelErr }
  23. newFunnelFuncUncalled := func() (Funnel, error) {
  24. require.FailNow(t, "a new funnel should not be created")
  25. panic("unreached")
  26. }
  27. funnel1, newFunnelFunc1 := newFunnelAndFunc("funnel1")
  28. funnel2, newFunnelFunc2 := newFunnelAndFunc("funnel2")
  29. ft := NewFunnelTracker()
  30. // Register funnel1
  31. funnel, new, err := ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFunc1)
  32. require.NoError(t, err)
  33. require.True(t, new)
  34. require.Equal(t, funnel1, funnel)
  35. // Register funnel, no replace
  36. funnel, new, err = ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFuncUncalled)
  37. require.NoError(t, err)
  38. require.False(t, new)
  39. require.Equal(t, funnel1, funnel)
  40. // Register funnel2, replace
  41. funnel, new, err = ft.GetOrRegister(id, shouldReplaceTrue, newFunnelFunc2)
  42. require.NoError(t, err)
  43. require.True(t, new)
  44. require.Equal(t, funnel2, funnel)
  45. require.True(t, funnel1.closed)
  46. // Register funnel error, replace
  47. funnel, new, err = ft.GetOrRegister(id, shouldReplaceTrue, newFunnelFuncErr)
  48. require.ErrorIs(t, err, funnelErr)
  49. require.False(t, new)
  50. require.Nil(t, funnel)
  51. require.True(t, funnel2.closed)
  52. }
  53. func TestFunnelUnregister(t *testing.T) {
  54. id := testFunnelID{"id1"}
  55. funnel1, newFunnelFunc1 := newFunnelAndFunc("funnel1")
  56. funnel2, newFunnelFunc2 := newFunnelAndFunc("funnel2")
  57. funnel3, newFunnelFunc3 := newFunnelAndFunc("funnel3")
  58. ft := NewFunnelTracker()
  59. // Register & unregister
  60. _, _, err := ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFunc1)
  61. require.NoError(t, err)
  62. require.True(t, ft.Unregister(id, funnel1))
  63. require.True(t, funnel1.closed)
  64. require.True(t, ft.Unregister(id, funnel1))
  65. // Register, replace, and unregister
  66. _, _, err = ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFunc2)
  67. require.NoError(t, err)
  68. _, _, err = ft.GetOrRegister(id, shouldReplaceTrue, newFunnelFunc3)
  69. require.NoError(t, err)
  70. require.True(t, funnel2.closed)
  71. require.False(t, ft.Unregister(id, funnel2))
  72. require.True(t, ft.Unregister(id, funnel3))
  73. require.True(t, funnel3.closed)
  74. }
  75. func shouldReplaceFalse(_ Funnel) bool {
  76. return false
  77. }
  78. func shouldReplaceTrue(_ Funnel) bool {
  79. return true
  80. }
  81. func newFunnelAndFunc(id string) (*testFunnel, func() (Funnel, error)) {
  82. funnel := newTestFunnel(id)
  83. funnelFunc := func() (Funnel, error) {
  84. return funnel, nil
  85. }
  86. return funnel, funnelFunc
  87. }
  88. type testFunnelID struct {
  89. id string
  90. }
  91. func (t testFunnelID) Type() string {
  92. return "testFunnelID"
  93. }
  94. func (t testFunnelID) String() string {
  95. return t.id
  96. }
  97. type testFunnel struct {
  98. id string
  99. closed bool
  100. }
  101. func newTestFunnel(id string) *testFunnel {
  102. return &testFunnel{
  103. id,
  104. false,
  105. }
  106. }
  107. func (tf *testFunnel) Close() error {
  108. tf.closed = true
  109. return nil
  110. }
  111. func (tf *testFunnel) Equal(other Funnel) bool {
  112. return tf.id == other.(*testFunnel).id
  113. }
  114. func (tf *testFunnel) LastActive() time.Time {
  115. return time.Now()
  116. }
  117. func (tf *testFunnel) UpdateLastActive() {}