1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00
wireguard-go/tun/tcp_offload_linux_test.go
Jordan Whited 9e2f386022 conn, device, tun: implement vectorized I/O on Linux
Implement TCP offloading via TSO and GRO for the Linux tun.Device, which
is made possible by virtio extensions in the kernel's TUN driver.

Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind.
conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All
platforms now fall under conn.StdNetBind, except for Windows, which
remains in conn.WinRingBind, which still needs to be adjusted to handle
multiple packets.

Also refactor sticky sockets support to eventually be applicable on
platforms other than just Linux. However Linux remains the sole platform
that fully implements it for now.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:17 +01:00

274 lines
8.4 KiB
Go

/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"net/netip"
"testing"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
offset = virtioNetHdrLen
)
var (
ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
)
func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
totalLen := 40 + segmentSize
b := make([]byte, offset+int(totalLen), 65535)
ipv4H := header.IPv4(b[offset:])
srcAs4 := srcIPPort.Addr().As4()
dstAs4 := dstIPPort.Addr().As4()
ipv4H.Encode(&header.IPv4Fields{
SrcAddr: tcpip.Address(srcAs4[:]),
DstAddr: tcpip.Address(dstAs4[:]),
Protocol: unix.IPPROTO_TCP,
TTL: 64,
TotalLength: uint16(totalLen),
})
tcpH := header.TCP(b[offset+20:])
tcpH.Encode(&header.TCPFields{
SrcPort: srcIPPort.Port(),
DstPort: dstIPPort.Port(),
SeqNum: seq,
AckNum: 1,
DataOffset: 20,
Flags: flags,
WindowSize: 3000,
})
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
return b
}
func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
totalLen := 60 + segmentSize
b := make([]byte, offset+int(totalLen), 65535)
ipv6H := header.IPv6(b[offset:])
srcAs16 := srcIPPort.Addr().As16()
dstAs16 := dstIPPort.Addr().As16()
ipv6H.Encode(&header.IPv6Fields{
SrcAddr: tcpip.Address(srcAs16[:]),
DstAddr: tcpip.Address(dstAs16[:]),
TransportProtocol: unix.IPPROTO_TCP,
HopLimit: 64,
PayloadLength: uint16(segmentSize + 20),
})
tcpH := header.TCP(b[offset+40:])
tcpH.Encode(&header.TCPFields{
SrcPort: srcIPPort.Port(),
DstPort: dstIPPort.Port(),
SeqNum: seq,
AckNum: 1,
DataOffset: 20,
Flags: flags,
WindowSize: 3000,
})
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
return b
}
func Test_handleVirtioRead(t *testing.T) {
tests := []struct {
name string
hdr virtioNetHdr
pktIn []byte
wantLens []int
wantErr bool
}{
{
"tcp4",
virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
gsoSize: 100,
hdrLen: 40,
csumStart: 20,
csumOffset: 16,
},
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
[]int{140, 140},
false,
},
{
"tcp6",
virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
gsoSize: 100,
hdrLen: 60,
csumStart: 40,
csumOffset: 16,
},
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
[]int{160, 160},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
out := make([][]byte, conn.DefaultBatchSize)
sizes := make([]int, conn.DefaultBatchSize)
for i := range out {
out[i] = make([]byte, 65535)
}
tt.hdr.encode(tt.pktIn)
n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
if err != nil {
if tt.wantErr {
return
}
t.Fatalf("got err: %v", err)
}
if n != len(tt.wantLens) {
t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
}
for i := range tt.wantLens {
if tt.wantLens[i] != sizes[i] {
t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
}
}
})
}
}
func flipTCP4Checksum(b []byte) []byte {
at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
b[at] ^= 0xFF
b[at+1] ^= 0xFF
return b
}
func Fuzz_handleGRO(f *testing.F) {
pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset)
f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) {
pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5}
toWrite := make([]int, 0, len(pkts))
handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
if len(toWrite) > len(pkts) {
t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
}
seenWriteI := make(map[int]bool)
for _, writeI := range toWrite {
if writeI < 0 || writeI > len(pkts)-1 {
t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
}
if seenWriteI[writeI] {
t.Errorf("duplicate toWrite value: %d", writeI)
}
seenWriteI[writeI] = true
}
})
}
func Test_handleGRO(t *testing.T) {
tests := []struct {
name string
pktsIn [][]byte
wantToWrite []int
wantLens []int
wantErr bool
}{
{
"multiple flows",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2
},
[]int{0, 2, 3, 5},
[]int{240, 140, 260, 160},
false,
},
{
"PSH interleaved",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
},
[]int{0, 2, 4, 6},
[]int{240, 240, 260, 260},
false,
},
{
"coalesceItemInvalidCSum",
[][]byte{
flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
},
[]int{0, 1},
[]int{140, 240},
false,
},
{
"out of order",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
},
[]int{0},
[]int{340},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toWrite := make([]int, 0, len(tt.pktsIn))
err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
if err != nil {
if tt.wantErr {
return
}
t.Fatalf("got err: %v", err)
}
if len(toWrite) != len(tt.wantToWrite) {
t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
}
for i, pktI := range tt.wantToWrite {
if tt.wantToWrite[i] != toWrite[i] {
t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
}
if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
}
}
})
}
}