mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
753 lines
21 KiB
Go
753 lines
21 KiB
Go
|
/* SPDX-License-Identifier: MIT
|
||
|
*
|
||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||
|
*/
|
||
|
|
||
|
package tun
|
||
|
|
||
|
import (
|
||
|
"net/netip"
|
||
|
"testing"
|
||
|
|
||
|
"golang.org/x/sys/unix"
|
||
|
"golang.zx2c4.com/wireguard/conn"
|
||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
offset = virtioNetHdrLen
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
|
||
|
ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
|
||
|
ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
|
||
|
ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
|
||
|
ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
|
||
|
ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
|
||
|
)
|
||
|
|
||
|
func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
|
||
|
totalLen := 28 + payloadLen
|
||
|
b := make([]byte, offset+int(totalLen), 65535)
|
||
|
ipv4H := header.IPv4(b[offset:])
|
||
|
srcAs4 := srcIPPort.Addr().As4()
|
||
|
dstAs4 := dstIPPort.Addr().As4()
|
||
|
ipFields := &header.IPv4Fields{
|
||
|
SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
|
||
|
DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
|
||
|
Protocol: unix.IPPROTO_UDP,
|
||
|
TTL: 64,
|
||
|
TotalLength: uint16(totalLen),
|
||
|
}
|
||
|
if ipFn != nil {
|
||
|
ipFn(ipFields)
|
||
|
}
|
||
|
ipv4H.Encode(ipFields)
|
||
|
udpH := header.UDP(b[offset+20:])
|
||
|
udpH.Encode(&header.UDPFields{
|
||
|
SrcPort: srcIPPort.Port(),
|
||
|
DstPort: dstIPPort.Port(),
|
||
|
Length: uint16(payloadLen + udphLen),
|
||
|
})
|
||
|
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
|
||
|
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
|
||
|
udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
|
||
|
return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
|
||
|
}
|
||
|
|
||
|
func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
|
||
|
totalLen := 48 + payloadLen
|
||
|
b := make([]byte, offset+int(totalLen), 65535)
|
||
|
ipv6H := header.IPv6(b[offset:])
|
||
|
srcAs16 := srcIPPort.Addr().As16()
|
||
|
dstAs16 := dstIPPort.Addr().As16()
|
||
|
ipFields := &header.IPv6Fields{
|
||
|
SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
|
||
|
DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
|
||
|
TransportProtocol: unix.IPPROTO_UDP,
|
||
|
HopLimit: 64,
|
||
|
PayloadLength: uint16(payloadLen + udphLen),
|
||
|
}
|
||
|
if ipFn != nil {
|
||
|
ipFn(ipFields)
|
||
|
}
|
||
|
ipv6H.Encode(ipFields)
|
||
|
udpH := header.UDP(b[offset+40:])
|
||
|
udpH.Encode(&header.UDPFields{
|
||
|
SrcPort: srcIPPort.Port(),
|
||
|
DstPort: dstIPPort.Port(),
|
||
|
Length: uint16(payloadLen + udphLen),
|
||
|
})
|
||
|
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
|
||
|
udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
|
||
|
return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
|
||
|
}
|
||
|
|
||
|
func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
|
||
|
totalLen := 40 + segmentSize
|
||
|
b := make([]byte, offset+int(totalLen), 65535)
|
||
|
ipv4H := header.IPv4(b[offset:])
|
||
|
srcAs4 := srcIPPort.Addr().As4()
|
||
|
dstAs4 := dstIPPort.Addr().As4()
|
||
|
ipFields := &header.IPv4Fields{
|
||
|
SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
|
||
|
DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
|
||
|
Protocol: unix.IPPROTO_TCP,
|
||
|
TTL: 64,
|
||
|
TotalLength: uint16(totalLen),
|
||
|
}
|
||
|
if ipFn != nil {
|
||
|
ipFn(ipFields)
|
||
|
}
|
||
|
ipv4H.Encode(ipFields)
|
||
|
tcpH := header.TCP(b[offset+20:])
|
||
|
tcpH.Encode(&header.TCPFields{
|
||
|
SrcPort: srcIPPort.Port(),
|
||
|
DstPort: dstIPPort.Port(),
|
||
|
SeqNum: seq,
|
||
|
AckNum: 1,
|
||
|
DataOffset: 20,
|
||
|
Flags: flags,
|
||
|
WindowSize: 3000,
|
||
|
})
|
||
|
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
|
||
|
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
|
||
|
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
|
||
|
return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
|
||
|
}
|
||
|
|
||
|
func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
|
||
|
totalLen := 60 + segmentSize
|
||
|
b := make([]byte, offset+int(totalLen), 65535)
|
||
|
ipv6H := header.IPv6(b[offset:])
|
||
|
srcAs16 := srcIPPort.Addr().As16()
|
||
|
dstAs16 := dstIPPort.Addr().As16()
|
||
|
ipFields := &header.IPv6Fields{
|
||
|
SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
|
||
|
DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
|
||
|
TransportProtocol: unix.IPPROTO_TCP,
|
||
|
HopLimit: 64,
|
||
|
PayloadLength: uint16(segmentSize + 20),
|
||
|
}
|
||
|
if ipFn != nil {
|
||
|
ipFn(ipFields)
|
||
|
}
|
||
|
ipv6H.Encode(ipFields)
|
||
|
tcpH := header.TCP(b[offset+40:])
|
||
|
tcpH.Encode(&header.TCPFields{
|
||
|
SrcPort: srcIPPort.Port(),
|
||
|
DstPort: dstIPPort.Port(),
|
||
|
SeqNum: seq,
|
||
|
AckNum: 1,
|
||
|
DataOffset: 20,
|
||
|
Flags: flags,
|
||
|
WindowSize: 3000,
|
||
|
})
|
||
|
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
|
||
|
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
|
||
|
return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
|
||
|
}
|
||
|
|
||
|
func Test_handleVirtioRead(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
hdr virtioNetHdr
|
||
|
pktIn []byte
|
||
|
wantLens []int
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{
|
||
|
"tcp4",
|
||
|
virtioNetHdr{
|
||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||
|
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
||
|
gsoSize: 100,
|
||
|
hdrLen: 40,
|
||
|
csumStart: 20,
|
||
|
csumOffset: 16,
|
||
|
},
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
|
||
|
[]int{140, 140},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"tcp6",
|
||
|
virtioNetHdr{
|
||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||
|
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
|
||
|
gsoSize: 100,
|
||
|
hdrLen: 60,
|
||
|
csumStart: 40,
|
||
|
csumOffset: 16,
|
||
|
},
|
||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
|
||
|
[]int{160, 160},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"udp4",
|
||
|
virtioNetHdr{
|
||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||
|
gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
|
||
|
gsoSize: 100,
|
||
|
hdrLen: 28,
|
||
|
csumStart: 20,
|
||
|
csumOffset: 6,
|
||
|
},
|
||
|
udp4Packet(ip4PortA, ip4PortB, 200),
|
||
|
[]int{128, 128},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"udp6",
|
||
|
virtioNetHdr{
|
||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||
|
gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
|
||
|
gsoSize: 100,
|
||
|
hdrLen: 48,
|
||
|
csumStart: 40,
|
||
|
csumOffset: 6,
|
||
|
},
|
||
|
udp6Packet(ip6PortA, ip6PortB, 200),
|
||
|
[]int{148, 148},
|
||
|
false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
out := make([][]byte, conn.IdealBatchSize)
|
||
|
sizes := make([]int, conn.IdealBatchSize)
|
||
|
for i := range out {
|
||
|
out[i] = make([]byte, 65535)
|
||
|
}
|
||
|
tt.hdr.encode(tt.pktIn)
|
||
|
n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
|
||
|
if err != nil {
|
||
|
if tt.wantErr {
|
||
|
return
|
||
|
}
|
||
|
t.Fatalf("got err: %v", err)
|
||
|
}
|
||
|
if n != len(tt.wantLens) {
|
||
|
t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
|
||
|
}
|
||
|
for i := range tt.wantLens {
|
||
|
if tt.wantLens[i] != sizes[i] {
|
||
|
t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func flipTCP4Checksum(b []byte) []byte {
|
||
|
at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
|
||
|
b[at] ^= 0xFF
|
||
|
b[at+1] ^= 0xFF
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
func flipUDP4Checksum(b []byte) []byte {
|
||
|
at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
|
||
|
b[at] ^= 0xFF
|
||
|
b[at+1] ^= 0xFF
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
func Fuzz_handleGRO(f *testing.F) {
|
||
|
pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
|
||
|
pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
|
||
|
pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
|
||
|
pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
|
||
|
pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
|
||
|
pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
|
||
|
pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
|
||
|
pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
|
||
|
pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
|
||
|
pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
|
||
|
pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
|
||
|
pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
|
||
|
f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
|
||
|
f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
|
||
|
pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
|
||
|
toWrite := make([]int, 0, len(pkts))
|
||
|
handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
|
||
|
if len(toWrite) > len(pkts) {
|
||
|
t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
|
||
|
}
|
||
|
seenWriteI := make(map[int]bool)
|
||
|
for _, writeI := range toWrite {
|
||
|
if writeI < 0 || writeI > len(pkts)-1 {
|
||
|
t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
|
||
|
}
|
||
|
if seenWriteI[writeI] {
|
||
|
t.Errorf("duplicate toWrite value: %d", writeI)
|
||
|
}
|
||
|
seenWriteI[writeI] = true
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func Test_handleGRO(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
pktsIn [][]byte
|
||
|
canUDPGRO bool
|
||
|
wantToWrite []int
|
||
|
wantLens []int
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{
|
||
|
"multiple protocols and flows",
|
||
|
[][]byte{
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
|
||
|
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||
|
udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
|
||
|
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
|
||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
|
||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
|
||
|
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
|
||
|
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||
|
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||
|
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||
|
},
|
||
|
true,
|
||
|
[]int{0, 1, 2, 4, 5, 7, 9},
|
||
|
[]int{240, 228, 128, 140, 260, 160, 248},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"multiple protocols and flows no UDP GRO",
|
||
|
[][]byte{
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
|
||
|
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||
|
udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
|
||
|
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
|
||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
|
||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
|
||
|
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
|
||
|
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||
|
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||
|
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||
|
},
|
||
|
false,
|
||
|
[]int{0, 1, 2, 4, 5, 7, 8, 9, 10},
|
||
|
[]int{240, 128, 128, 140, 260, 160, 128, 148, 148},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"PSH interleaved",
|
||
|
[][]byte{
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
|
||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
|
||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
|
||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
|
||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
|
||
|
},
|
||
|
true,
|
||
|
[]int{0, 2, 4, 6},
|
||
|
[]int{240, 240, 260, 260},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"coalesceItemInvalidCSum",
|
||
|
[][]byte{
|
||
|
flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
|
||
|
flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
|
||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||
|
},
|
||
|
true,
|
||
|
[]int{0, 1, 3, 4},
|
||
|
[]int{140, 240, 128, 228},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"out of order",
|
||
|
[][]byte{
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
|
||
|
},
|
||
|
true,
|
||
|
[]int{0},
|
||
|
[]int{340},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"unequal TTL",
|
||
|
[][]byte{
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||
|
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||
|
fields.TTL++
|
||
|
}),
|
||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||
|
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||
|
fields.TTL++
|
||
|
}),
|
||
|
},
|
||
|
true,
|
||
|
[]int{0, 1, 2, 3},
|
||
|
[]int{140, 140, 128, 128},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"unequal ToS",
|
||
|
[][]byte{
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||
|
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||
|
fields.TOS++
|
||
|
}),
|
||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||
|
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||
|
fields.TOS++
|
||
|
}),
|
||
|
},
|
||
|
true,
|
||
|
[]int{0, 1, 2, 3},
|
||
|
[]int{140, 140, 128, 128},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"unequal flags more fragments set",
|
||
|
[][]byte{
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||
|
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||
|
fields.Flags = 1
|
||
|
}),
|
||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||
|
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||
|
fields.Flags = 1
|
||
|
}),
|
||
|
},
|
||
|
true,
|
||
|
[]int{0, 1, 2, 3},
|
||
|
[]int{140, 140, 128, 128},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"unequal flags DF set",
|
||
|
[][]byte{
|
||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||
|
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||
|
fields.Flags = 2
|
||
|
}),
|
||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||
|
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||
|
fields.Flags = 2
|
||
|
}),
|
||
|
},
|
||
|
true,
|
||
|
[]int{0, 1, 2, 3},
|
||
|
[]int{140, 140, 128, 128},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"ipv6 unequal hop limit",
|
||
|
[][]byte{
|
||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
|
||
|
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
|
||
|
fields.HopLimit++
|
||
|
}),
|
||
|
udp6Packet(ip6PortA, ip6PortB, 100),
|
||
|
udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
|
||
|
fields.HopLimit++
|
||
|
}),
|
||
|
},
|
||
|
true,
|
||
|
[]int{0, 1, 2, 3},
|
||
|
[]int{160, 160, 148, 148},
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"ipv6 unequal traffic class",
|
||
|
[][]byte{
|
||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
|
||
|
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
|
||
|
fields.TrafficClass++
|
||
|
}),
|
||
|
udp6Packet(ip6PortA, ip6PortB, 100),
|
||
|
udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
|
||
|
fields.TrafficClass++
|
||
|
}),
|
||
|
},
|
||
|
true,
|
||
|
[]int{0, 1, 2, 3},
|
||
|
[]int{160, 160, 148, 148},
|
||
|
false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
toWrite := make([]int, 0, len(tt.pktsIn))
|
||
|
err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
|
||
|
if err != nil {
|
||
|
if tt.wantErr {
|
||
|
return
|
||
|
}
|
||
|
t.Fatalf("got err: %v", err)
|
||
|
}
|
||
|
if len(toWrite) != len(tt.wantToWrite) {
|
||
|
t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
|
||
|
}
|
||
|
for i, pktI := range tt.wantToWrite {
|
||
|
if tt.wantToWrite[i] != toWrite[i] {
|
||
|
t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
|
||
|
}
|
||
|
if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
|
||
|
t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Test_packetIsGROCandidate(t *testing.T) {
|
||
|
tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
|
||
|
tcp4TooShort := tcp4[:39]
|
||
|
ip4InvalidHeaderLen := make([]byte, len(tcp4))
|
||
|
copy(ip4InvalidHeaderLen, tcp4)
|
||
|
ip4InvalidHeaderLen[0] = 0x46
|
||
|
ip4InvalidProtocol := make([]byte, len(tcp4))
|
||
|
copy(ip4InvalidProtocol, tcp4)
|
||
|
ip4InvalidProtocol[9] = unix.IPPROTO_GRE
|
||
|
|
||
|
tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
|
||
|
tcp6TooShort := tcp6[:59]
|
||
|
ip6InvalidProtocol := make([]byte, len(tcp6))
|
||
|
copy(ip6InvalidProtocol, tcp6)
|
||
|
ip6InvalidProtocol[6] = unix.IPPROTO_GRE
|
||
|
|
||
|
udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
|
||
|
udp4TooShort := udp4[:27]
|
||
|
|
||
|
udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
|
||
|
udp6TooShort := udp6[:47]
|
||
|
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
b []byte
|
||
|
canUDPGRO bool
|
||
|
want groCandidateType
|
||
|
}{
|
||
|
{
|
||
|
"tcp4",
|
||
|
tcp4,
|
||
|
true,
|
||
|
tcp4GROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"tcp6",
|
||
|
tcp6,
|
||
|
true,
|
||
|
tcp6GROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"udp4",
|
||
|
udp4,
|
||
|
true,
|
||
|
udp4GROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"udp4 no support",
|
||
|
udp4,
|
||
|
false,
|
||
|
notGROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"udp6",
|
||
|
udp6,
|
||
|
true,
|
||
|
udp6GROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"udp6 no support",
|
||
|
udp6,
|
||
|
false,
|
||
|
notGROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"udp4 too short",
|
||
|
udp4TooShort,
|
||
|
true,
|
||
|
notGROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"udp6 too short",
|
||
|
udp6TooShort,
|
||
|
true,
|
||
|
notGROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"tcp4 too short",
|
||
|
tcp4TooShort,
|
||
|
true,
|
||
|
notGROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"tcp6 too short",
|
||
|
tcp6TooShort,
|
||
|
true,
|
||
|
notGROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"invalid IP version",
|
||
|
[]byte{0x00},
|
||
|
true,
|
||
|
notGROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"invalid IP header len",
|
||
|
ip4InvalidHeaderLen,
|
||
|
true,
|
||
|
notGROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"ip4 invalid protocol",
|
||
|
ip4InvalidProtocol,
|
||
|
true,
|
||
|
notGROCandidate,
|
||
|
},
|
||
|
{
|
||
|
"ip6 invalid protocol",
|
||
|
ip6InvalidProtocol,
|
||
|
true,
|
||
|
notGROCandidate,
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
|
||
|
t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Test_udpPacketsCanCoalesce(t *testing.T) {
|
||
|
udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
|
||
|
udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
|
||
|
udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
|
||
|
|
||
|
type args struct {
|
||
|
pkt []byte
|
||
|
iphLen uint8
|
||
|
gsoSize uint16
|
||
|
item udpGROItem
|
||
|
bufs [][]byte
|
||
|
bufsOffset int
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
args args
|
||
|
want canCoalesce
|
||
|
}{
|
||
|
{
|
||
|
"coalesceAppend equal gso",
|
||
|
args{
|
||
|
pkt: udp4a[offset:],
|
||
|
iphLen: 20,
|
||
|
gsoSize: 100,
|
||
|
item: udpGROItem{
|
||
|
gsoSize: 100,
|
||
|
iphLen: 20,
|
||
|
},
|
||
|
bufs: [][]byte{
|
||
|
udp4a,
|
||
|
udp4b,
|
||
|
},
|
||
|
bufsOffset: offset,
|
||
|
},
|
||
|
coalesceAppend,
|
||
|
},
|
||
|
{
|
||
|
"coalesceAppend smaller gso",
|
||
|
args{
|
||
|
pkt: udp4a[offset : len(udp4a)-90],
|
||
|
iphLen: 20,
|
||
|
gsoSize: 10,
|
||
|
item: udpGROItem{
|
||
|
gsoSize: 100,
|
||
|
iphLen: 20,
|
||
|
},
|
||
|
bufs: [][]byte{
|
||
|
udp4a,
|
||
|
udp4b,
|
||
|
},
|
||
|
bufsOffset: offset,
|
||
|
},
|
||
|
coalesceAppend,
|
||
|
},
|
||
|
{
|
||
|
"coalesceUnavailable smaller gso previously appended",
|
||
|
args{
|
||
|
pkt: udp4a[offset:],
|
||
|
iphLen: 20,
|
||
|
gsoSize: 100,
|
||
|
item: udpGROItem{
|
||
|
gsoSize: 100,
|
||
|
iphLen: 20,
|
||
|
},
|
||
|
bufs: [][]byte{
|
||
|
udp4c,
|
||
|
udp4b,
|
||
|
},
|
||
|
bufsOffset: offset,
|
||
|
},
|
||
|
coalesceUnavailable,
|
||
|
},
|
||
|
{
|
||
|
"coalesceUnavailable larger following smaller",
|
||
|
args{
|
||
|
pkt: udp4c[offset:],
|
||
|
iphLen: 20,
|
||
|
gsoSize: 110,
|
||
|
item: udpGROItem{
|
||
|
gsoSize: 100,
|
||
|
iphLen: 20,
|
||
|
},
|
||
|
bufs: [][]byte{
|
||
|
udp4a,
|
||
|
udp4c,
|
||
|
},
|
||
|
bufsOffset: offset,
|
||
|
},
|
||
|
coalesceUnavailable,
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
|
||
|
t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|