decoder_test.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. package packet
  2. import (
  3. "net"
  4. "net/netip"
  5. "testing"
  6. "github.com/google/gopacket"
  7. "github.com/google/gopacket/layers"
  8. "github.com/stretchr/testify/require"
  9. "golang.org/x/net/icmp"
  10. "golang.org/x/net/ipv4"
  11. "golang.org/x/net/ipv6"
  12. )
  13. func TestDecodeIP(t *testing.T) {
  14. ipDecoder := NewIPDecoder()
  15. icmpDecoder := NewICMPDecoder()
  16. udps := []UDP{
  17. {
  18. IP: IP{
  19. Src: netip.MustParseAddr("172.16.0.1"),
  20. Dst: netip.MustParseAddr("10.0.0.1"),
  21. Protocol: layers.IPProtocolUDP,
  22. },
  23. SrcPort: 31678,
  24. DstPort: 53,
  25. },
  26. {
  27. IP: IP{
  28. Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
  29. Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"),
  30. Protocol: layers.IPProtocolUDP,
  31. },
  32. SrcPort: 52139,
  33. DstPort: 1053,
  34. },
  35. }
  36. encoder := NewEncoder()
  37. for _, udp := range udps {
  38. p, err := encoder.Encode(&udp)
  39. require.NoError(t, err)
  40. ipPacket, err := ipDecoder.Decode(p)
  41. require.NoError(t, err)
  42. assertIPLayer(t, &udp.IP, ipPacket)
  43. icmpPacket, err := icmpDecoder.Decode(p)
  44. require.Error(t, err)
  45. require.Nil(t, icmpPacket)
  46. }
  47. }
  48. func TestDecodeICMP(t *testing.T) {
  49. ipDecoder := NewIPDecoder()
  50. icmpDecoder := NewICMPDecoder()
  51. var (
  52. ipv4Packet = IP{
  53. Src: netip.MustParseAddr("172.16.0.1"),
  54. Dst: netip.MustParseAddr("10.0.0.1"),
  55. Protocol: layers.IPProtocolICMPv4,
  56. TTL: DefaultTTL,
  57. }
  58. ipv6Packet = IP{
  59. Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
  60. Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"),
  61. Protocol: layers.IPProtocolICMPv6,
  62. TTL: DefaultTTL,
  63. }
  64. icmpID = 100
  65. icmpSeq = 52819
  66. )
  67. tests := []struct {
  68. testCase string
  69. packet *ICMP
  70. }{
  71. {
  72. testCase: "icmpv4 time exceed",
  73. packet: &ICMP{
  74. IP: &ipv4Packet,
  75. Message: &icmp.Message{
  76. Type: ipv4.ICMPTypeTimeExceeded,
  77. Code: 0,
  78. Body: &icmp.TimeExceeded{
  79. Data: []byte("original packet"),
  80. },
  81. },
  82. },
  83. },
  84. {
  85. testCase: "icmpv4 echo",
  86. packet: &ICMP{
  87. IP: &ipv4Packet,
  88. Message: &icmp.Message{
  89. Type: ipv4.ICMPTypeEcho,
  90. Code: 0,
  91. Body: &icmp.Echo{
  92. ID: icmpID,
  93. Seq: icmpSeq,
  94. Data: []byte("icmpv4 echo"),
  95. },
  96. },
  97. },
  98. },
  99. {
  100. testCase: "icmpv6 destination unreachable",
  101. packet: &ICMP{
  102. IP: &ipv6Packet,
  103. Message: &icmp.Message{
  104. Type: ipv6.ICMPTypeDestinationUnreachable,
  105. Code: 4,
  106. Body: &icmp.DstUnreach{
  107. Data: []byte("original packet"),
  108. },
  109. },
  110. },
  111. },
  112. {
  113. testCase: "icmpv6 echo",
  114. packet: &ICMP{
  115. IP: &ipv6Packet,
  116. Message: &icmp.Message{
  117. Type: ipv6.ICMPTypeEchoRequest,
  118. Code: 0,
  119. Body: &icmp.Echo{
  120. ID: icmpID,
  121. Seq: icmpSeq,
  122. Data: []byte("icmpv6 echo"),
  123. },
  124. },
  125. },
  126. },
  127. }
  128. encoder := NewEncoder()
  129. for _, test := range tests {
  130. p, err := encoder.Encode(test.packet)
  131. require.NoError(t, err)
  132. ipPacket, err := ipDecoder.Decode(p)
  133. require.NoError(t, err)
  134. if ipPacket.Src.Is4() {
  135. assertIPLayer(t, &ipv4Packet, ipPacket)
  136. } else {
  137. assertIPLayer(t, &ipv6Packet, ipPacket)
  138. }
  139. icmpPacket, err := icmpDecoder.Decode(p)
  140. require.NoError(t, err)
  141. require.Equal(t, ipPacket, icmpPacket.IP)
  142. require.Equal(t, test.packet.Type, icmpPacket.Type)
  143. require.Equal(t, test.packet.Code, icmpPacket.Code)
  144. assertICMPChecksum(t, icmpPacket)
  145. require.Equal(t, test.packet.Body, icmpPacket.Body)
  146. expectedBody, err := test.packet.Body.Marshal(test.packet.Type.Protocol())
  147. require.NoError(t, err)
  148. decodedBody, err := icmpPacket.Body.Marshal(test.packet.Type.Protocol())
  149. require.NoError(t, err)
  150. require.Equal(t, expectedBody, decodedBody)
  151. }
  152. }
  153. // TestDecodeBadPackets makes sure decoders don't decode invalid packets
  154. func TestDecodeBadPackets(t *testing.T) {
  155. var (
  156. srcIPv4 = net.ParseIP("172.16.0.1")
  157. dstIPv4 = net.ParseIP("10.0.0.1")
  158. )
  159. ipLayer := layers.IPv4{
  160. Version: 10,
  161. SrcIP: srcIPv4,
  162. DstIP: dstIPv4,
  163. Protocol: layers.IPProtocolICMPv4,
  164. TTL: DefaultTTL,
  165. }
  166. icmpLayer := layers.ICMPv4{
  167. TypeCode: layers.CreateICMPv4TypeCode(uint8(ipv4.ICMPTypeEcho), 0),
  168. Id: 100,
  169. Seq: 52819,
  170. }
  171. wrongIPVersion, err := createPacket(&ipLayer, &icmpLayer, nil, nil)
  172. require.NoError(t, err)
  173. tests := []struct {
  174. testCase string
  175. packet []byte
  176. }{
  177. {
  178. testCase: "unknown IP version",
  179. packet: wrongIPVersion,
  180. },
  181. {
  182. testCase: "invalid packet",
  183. packet: []byte("not a packet"),
  184. },
  185. {
  186. testCase: "zero length packet",
  187. packet: []byte{},
  188. },
  189. }
  190. ipDecoder := NewIPDecoder()
  191. icmpDecoder := NewICMPDecoder()
  192. for _, test := range tests {
  193. ipPacket, err := ipDecoder.Decode(RawPacket{Data: test.packet})
  194. require.Error(t, err)
  195. require.Nil(t, ipPacket)
  196. icmpPacket, err := icmpDecoder.Decode(RawPacket{Data: test.packet})
  197. require.Error(t, err)
  198. require.Nil(t, icmpPacket)
  199. }
  200. }
  201. func createPacket(ipLayer, secondLayer, thirdLayer gopacket.SerializableLayer, body []byte) ([]byte, error) {
  202. payload := gopacket.Payload(body)
  203. packet := gopacket.NewSerializeBuffer()
  204. var err error
  205. if thirdLayer != nil {
  206. err = gopacket.SerializeLayers(packet, serializeOpts, ipLayer, secondLayer, thirdLayer, payload)
  207. } else {
  208. err = gopacket.SerializeLayers(packet, serializeOpts, ipLayer, secondLayer, payload)
  209. }
  210. if err != nil {
  211. return nil, err
  212. }
  213. return packet.Bytes(), nil
  214. }
  215. func assertIPLayer(t *testing.T, expected, actual *IP) {
  216. require.Equal(t, expected.Src, actual.Src)
  217. require.Equal(t, expected.Dst, actual.Dst)
  218. require.Equal(t, expected.Protocol, actual.Protocol)
  219. require.Equal(t, expected.TTL, actual.TTL)
  220. }
  221. type UDP struct {
  222. IP
  223. SrcPort, DstPort layers.UDPPort
  224. }
  225. func (u *UDP) EncodeLayers() ([]gopacket.SerializableLayer, error) {
  226. ipLayers, err := u.IP.EncodeLayers()
  227. if err != nil {
  228. return nil, err
  229. }
  230. udpLayer := layers.UDP{
  231. SrcPort: u.SrcPort,
  232. DstPort: u.DstPort,
  233. }
  234. udpLayer.SetNetworkLayerForChecksum(ipLayers[0].(gopacket.NetworkLayer))
  235. return append(ipLayers, &udpLayer), nil
  236. }
  237. func FuzzIPDecoder(f *testing.F) {
  238. f.Fuzz(func(t *testing.T, data []byte) {
  239. ipDecoder := NewIPDecoder()
  240. ipDecoder.Decode(RawPacket{Data: data})
  241. })
  242. }
  243. func FuzzICMPDecoder(f *testing.F) {
  244. f.Fuzz(func(t *testing.T, data []byte) {
  245. icmpDecoder := NewICMPDecoder()
  246. icmpDecoder.Decode(RawPacket{Data: data})
  247. })
  248. }