packet_router_test.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package ingress
  2. import (
  3. "bytes"
  4. "context"
  5. "fmt"
  6. "net/netip"
  7. "sync/atomic"
  8. "testing"
  9. "github.com/google/gopacket/layers"
  10. "github.com/stretchr/testify/require"
  11. "golang.org/x/net/icmp"
  12. "golang.org/x/net/ipv4"
  13. "golang.org/x/net/ipv6"
  14. "github.com/cloudflare/cloudflared/packet"
  15. quicpogs "github.com/cloudflare/cloudflared/quic"
  16. )
  17. var (
  18. defaultRouter = &icmpRouter{
  19. ipv4Proxy: nil,
  20. ipv4Src: netip.MustParseAddr("172.16.0.1"),
  21. ipv6Proxy: nil,
  22. ipv6Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
  23. }
  24. )
  25. func TestRouterReturnTTLExceed(t *testing.T) {
  26. muxer := newMockMuxer(0)
  27. router := NewPacketRouter(defaultRouter, muxer, 0, &noopLogger)
  28. ctx, cancel := context.WithCancel(context.Background())
  29. routerStopped := make(chan struct{})
  30. go func() {
  31. router.Serve(ctx)
  32. close(routerStopped)
  33. }()
  34. pk := packet.ICMP{
  35. IP: &packet.IP{
  36. Src: netip.MustParseAddr("192.168.1.1"),
  37. Dst: netip.MustParseAddr("10.0.0.1"),
  38. Protocol: layers.IPProtocolICMPv4,
  39. TTL: 1,
  40. },
  41. Message: &icmp.Message{
  42. Type: ipv4.ICMPTypeEcho,
  43. Code: 0,
  44. Body: &icmp.Echo{
  45. ID: 12481,
  46. Seq: 8036,
  47. Data: []byte("TTL exceed"),
  48. },
  49. },
  50. }
  51. assertTTLExceed(t, &pk, defaultRouter.ipv4Src, muxer)
  52. pk = packet.ICMP{
  53. IP: &packet.IP{
  54. Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
  55. Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"),
  56. Protocol: layers.IPProtocolICMPv6,
  57. TTL: 1,
  58. },
  59. Message: &icmp.Message{
  60. Type: ipv6.ICMPTypeEchoRequest,
  61. Code: 0,
  62. Body: &icmp.Echo{
  63. ID: 42583,
  64. Seq: 7039,
  65. Data: []byte("TTL exceed"),
  66. },
  67. },
  68. }
  69. assertTTLExceed(t, &pk, defaultRouter.ipv6Src, muxer)
  70. cancel()
  71. <-routerStopped
  72. }
  73. func assertTTLExceed(t *testing.T, originalPacket *packet.ICMP, expectedSrc netip.Addr, muxer *mockMuxer) {
  74. encoder := packet.NewEncoder()
  75. rawPacket, err := encoder.Encode(originalPacket)
  76. require.NoError(t, err)
  77. muxer.edgeToCfd <- quicpogs.RawPacket(rawPacket)
  78. resp := <-muxer.cfdToEdge
  79. decoder := packet.NewICMPDecoder()
  80. decoded, err := decoder.Decode(packet.RawPacket(resp.(quicpogs.RawPacket)))
  81. require.NoError(t, err)
  82. require.Equal(t, expectedSrc, decoded.Src)
  83. require.Equal(t, originalPacket.Src, decoded.Dst)
  84. require.Equal(t, originalPacket.Protocol, decoded.Protocol)
  85. require.Equal(t, packet.DefaultTTL, decoded.TTL)
  86. if originalPacket.Dst.Is4() {
  87. require.Equal(t, ipv4.ICMPTypeTimeExceeded, decoded.Type)
  88. } else {
  89. require.Equal(t, ipv6.ICMPTypeTimeExceeded, decoded.Type)
  90. }
  91. require.Equal(t, 0, decoded.Code)
  92. timeExceed, ok := decoded.Body.(*icmp.TimeExceeded)
  93. require.True(t, ok)
  94. require.True(t, bytes.Equal(rawPacket.Data, timeExceed.Data))
  95. }
  96. type mockMuxer struct {
  97. cfdToEdge chan quicpogs.Packet
  98. edgeToCfd chan quicpogs.Packet
  99. }
  100. func newMockMuxer(capacity int) *mockMuxer {
  101. return &mockMuxer{
  102. cfdToEdge: make(chan quicpogs.Packet, capacity),
  103. edgeToCfd: make(chan quicpogs.Packet, capacity),
  104. }
  105. }
  106. // Copy packet, because icmpProxy expects the encoder buffer to be reusable after the packet is sent
  107. func (mm *mockMuxer) SendPacket(pk quicpogs.Packet) error {
  108. payload := pk.Payload()
  109. copiedPayload := make([]byte, len(payload))
  110. copy(copiedPayload, payload)
  111. metadata := pk.Metadata()
  112. copiedMetadata := make([]byte, len(metadata))
  113. copy(copiedMetadata, metadata)
  114. var copiedPacket quicpogs.Packet
  115. switch pk.Type() {
  116. case quicpogs.DatagramTypeIP:
  117. copiedPacket = quicpogs.RawPacket(packet.RawPacket{
  118. Data: copiedPayload,
  119. })
  120. case quicpogs.DatagramTypeIPWithTrace:
  121. copiedPacket = &quicpogs.TracedPacket{
  122. Packet: packet.RawPacket{
  123. Data: copiedPayload,
  124. },
  125. TracingIdentity: copiedMetadata,
  126. }
  127. case quicpogs.DatagramTypeTracingSpan:
  128. copiedPacket = &quicpogs.TracingSpanPacket{
  129. Spans: copiedPayload,
  130. TracingIdentity: copiedMetadata,
  131. }
  132. default:
  133. return fmt.Errorf("unexpected metadata type %d", pk.Type())
  134. }
  135. mm.cfdToEdge <- copiedPacket
  136. return nil
  137. }
  138. func (mm *mockMuxer) ReceivePacket(ctx context.Context) (quicpogs.Packet, error) {
  139. select {
  140. case <-ctx.Done():
  141. return nil, ctx.Err()
  142. case pk := <-mm.edgeToCfd:
  143. return pk, nil
  144. }
  145. }
  146. type routerEnabledChecker struct {
  147. enabled uint32
  148. }
  149. func (rec *routerEnabledChecker) isEnabled() bool {
  150. if atomic.LoadUint32(&rec.enabled) == 0 {
  151. return false
  152. }
  153. return true
  154. }
  155. func (rec *routerEnabledChecker) set(enabled bool) {
  156. if enabled {
  157. atomic.StoreUint32(&rec.enabled, 1)
  158. } else {
  159. atomic.StoreUint32(&rec.enabled, 0)
  160. }
  161. }