123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- package packet
- import (
- "net"
- "net/netip"
- "testing"
- "github.com/google/gopacket"
- "github.com/google/gopacket/layers"
- "github.com/stretchr/testify/require"
- "golang.org/x/net/icmp"
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
- )
- func TestDecodeIP(t *testing.T) {
- ipDecoder := NewIPDecoder()
- icmpDecoder := NewICMPDecoder()
- udps := []UDP{
- {
- IP: IP{
- Src: netip.MustParseAddr("172.16.0.1"),
- Dst: netip.MustParseAddr("10.0.0.1"),
- Protocol: layers.IPProtocolUDP,
- },
- SrcPort: 31678,
- DstPort: 53,
- },
- {
- IP: IP{
- Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
- Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"),
- Protocol: layers.IPProtocolUDP,
- },
- SrcPort: 52139,
- DstPort: 1053,
- },
- }
- encoder := NewEncoder()
- for _, udp := range udps {
- p, err := encoder.Encode(&udp)
- require.NoError(t, err)
- ipPacket, err := ipDecoder.Decode(p)
- require.NoError(t, err)
- assertIPLayer(t, &udp.IP, ipPacket)
- icmpPacket, err := icmpDecoder.Decode(p)
- require.Error(t, err)
- require.Nil(t, icmpPacket)
- }
- }
- func TestDecodeICMP(t *testing.T) {
- ipDecoder := NewIPDecoder()
- icmpDecoder := NewICMPDecoder()
- var (
- ipv4Packet = IP{
- Src: netip.MustParseAddr("172.16.0.1"),
- Dst: netip.MustParseAddr("10.0.0.1"),
- Protocol: layers.IPProtocolICMPv4,
- TTL: DefaultTTL,
- }
- ipv6Packet = IP{
- Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
- Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"),
- Protocol: layers.IPProtocolICMPv6,
- TTL: DefaultTTL,
- }
- icmpID = 100
- icmpSeq = 52819
- )
- tests := []struct {
- testCase string
- packet *ICMP
- }{
- {
- testCase: "icmpv4 time exceed",
- packet: &ICMP{
- IP: &ipv4Packet,
- Message: &icmp.Message{
- Type: ipv4.ICMPTypeTimeExceeded,
- Code: 0,
- Body: &icmp.TimeExceeded{
- Data: []byte("original packet"),
- },
- },
- },
- },
- {
- testCase: "icmpv4 echo",
- packet: &ICMP{
- IP: &ipv4Packet,
- Message: &icmp.Message{
- Type: ipv4.ICMPTypeEcho,
- Code: 0,
- Body: &icmp.Echo{
- ID: icmpID,
- Seq: icmpSeq,
- Data: []byte("icmpv4 echo"),
- },
- },
- },
- },
- {
- testCase: "icmpv6 destination unreachable",
- packet: &ICMP{
- IP: &ipv6Packet,
- Message: &icmp.Message{
- Type: ipv6.ICMPTypeDestinationUnreachable,
- Code: 4,
- Body: &icmp.DstUnreach{
- Data: []byte("original packet"),
- },
- },
- },
- },
- {
- testCase: "icmpv6 echo",
- packet: &ICMP{
- IP: &ipv6Packet,
- Message: &icmp.Message{
- Type: ipv6.ICMPTypeEchoRequest,
- Code: 0,
- Body: &icmp.Echo{
- ID: icmpID,
- Seq: icmpSeq,
- Data: []byte("icmpv6 echo"),
- },
- },
- },
- },
- }
- encoder := NewEncoder()
- for _, test := range tests {
- p, err := encoder.Encode(test.packet)
- require.NoError(t, err)
- ipPacket, err := ipDecoder.Decode(p)
- require.NoError(t, err)
- if ipPacket.Src.Is4() {
- assertIPLayer(t, &ipv4Packet, ipPacket)
- } else {
- assertIPLayer(t, &ipv6Packet, ipPacket)
- }
- icmpPacket, err := icmpDecoder.Decode(p)
- require.NoError(t, err)
- require.Equal(t, ipPacket, icmpPacket.IP)
- require.Equal(t, test.packet.Type, icmpPacket.Type)
- require.Equal(t, test.packet.Code, icmpPacket.Code)
- assertICMPChecksum(t, icmpPacket)
- require.Equal(t, test.packet.Body, icmpPacket.Body)
- expectedBody, err := test.packet.Body.Marshal(test.packet.Type.Protocol())
- require.NoError(t, err)
- decodedBody, err := icmpPacket.Body.Marshal(test.packet.Type.Protocol())
- require.NoError(t, err)
- require.Equal(t, expectedBody, decodedBody)
- }
- }
- // TestDecodeBadPackets makes sure decoders don't decode invalid packets
- func TestDecodeBadPackets(t *testing.T) {
- var (
- srcIPv4 = net.ParseIP("172.16.0.1")
- dstIPv4 = net.ParseIP("10.0.0.1")
- )
- ipLayer := layers.IPv4{
- Version: 10,
- SrcIP: srcIPv4,
- DstIP: dstIPv4,
- Protocol: layers.IPProtocolICMPv4,
- TTL: DefaultTTL,
- }
- icmpLayer := layers.ICMPv4{
- TypeCode: layers.CreateICMPv4TypeCode(uint8(ipv4.ICMPTypeEcho), 0),
- Id: 100,
- Seq: 52819,
- }
- wrongIPVersion, err := createPacket(&ipLayer, &icmpLayer, nil, nil)
- require.NoError(t, err)
- tests := []struct {
- testCase string
- packet []byte
- }{
- {
- testCase: "unknown IP version",
- packet: wrongIPVersion,
- },
- {
- testCase: "invalid packet",
- packet: []byte("not a packet"),
- },
- {
- testCase: "zero length packet",
- packet: []byte{},
- },
- }
- ipDecoder := NewIPDecoder()
- icmpDecoder := NewICMPDecoder()
- for _, test := range tests {
- ipPacket, err := ipDecoder.Decode(RawPacket{Data: test.packet})
- require.Error(t, err)
- require.Nil(t, ipPacket)
- icmpPacket, err := icmpDecoder.Decode(RawPacket{Data: test.packet})
- require.Error(t, err)
- require.Nil(t, icmpPacket)
- }
- }
- func createPacket(ipLayer, secondLayer, thirdLayer gopacket.SerializableLayer, body []byte) ([]byte, error) {
- payload := gopacket.Payload(body)
- packet := gopacket.NewSerializeBuffer()
- var err error
- if thirdLayer != nil {
- err = gopacket.SerializeLayers(packet, serializeOpts, ipLayer, secondLayer, thirdLayer, payload)
- } else {
- err = gopacket.SerializeLayers(packet, serializeOpts, ipLayer, secondLayer, payload)
- }
- if err != nil {
- return nil, err
- }
- return packet.Bytes(), nil
- }
- func assertIPLayer(t *testing.T, expected, actual *IP) {
- require.Equal(t, expected.Src, actual.Src)
- require.Equal(t, expected.Dst, actual.Dst)
- require.Equal(t, expected.Protocol, actual.Protocol)
- require.Equal(t, expected.TTL, actual.TTL)
- }
- type UDP struct {
- IP
- SrcPort, DstPort layers.UDPPort
- }
- func (u *UDP) EncodeLayers() ([]gopacket.SerializableLayer, error) {
- ipLayers, err := u.IP.EncodeLayers()
- if err != nil {
- return nil, err
- }
- udpLayer := layers.UDP{
- SrcPort: u.SrcPort,
- DstPort: u.DstPort,
- }
- udpLayer.SetNetworkLayerForChecksum(ipLayers[0].(gopacket.NetworkLayer))
- return append(ipLayers, &udpLayer), nil
- }
- func FuzzIPDecoder(f *testing.F) {
- f.Fuzz(func(t *testing.T, data []byte) {
- ipDecoder := NewIPDecoder()
- ipDecoder.Decode(RawPacket{Data: data})
- })
- }
- func FuzzICMPDecoder(f *testing.F) {
- f.Fuzz(func(t *testing.T, data []byte) {
- icmpDecoder := NewICMPDecoder()
- icmpDecoder.Decode(RawPacket{Data: data})
- })
- }
|