mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
9e2f386022
Implement TCP offloading via TSO and GRO for the Linux tun.Device, which is made possible by virtio extensions in the kernel's TUN driver. Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind. conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All platforms now fall under conn.StdNetBind, except for Windows, which remains in conn.WinRingBind, which still needs to be adjusted to handle multiple packets. Also refactor sticky sockets support to eventually be applicable on platforms other than just Linux. However Linux remains the sole platform that fully implements it for now. Co-authored-by: James Tucker <james@tailscale.com> Signed-off-by: James Tucker <james@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
274 lines
8.4 KiB
Go
274 lines
8.4 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package tun
|
|
|
|
import (
|
|
"net/netip"
|
|
"testing"
|
|
|
|
"golang.org/x/sys/unix"
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
"gvisor.dev/gvisor/pkg/tcpip"
|
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
)
|
|
|
|
const (
|
|
offset = virtioNetHdrLen
|
|
)
|
|
|
|
var (
|
|
ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
|
|
ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
|
|
ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
|
|
ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
|
|
ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
|
|
ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
|
|
)
|
|
|
|
func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
|
|
totalLen := 40 + segmentSize
|
|
b := make([]byte, offset+int(totalLen), 65535)
|
|
ipv4H := header.IPv4(b[offset:])
|
|
srcAs4 := srcIPPort.Addr().As4()
|
|
dstAs4 := dstIPPort.Addr().As4()
|
|
ipv4H.Encode(&header.IPv4Fields{
|
|
SrcAddr: tcpip.Address(srcAs4[:]),
|
|
DstAddr: tcpip.Address(dstAs4[:]),
|
|
Protocol: unix.IPPROTO_TCP,
|
|
TTL: 64,
|
|
TotalLength: uint16(totalLen),
|
|
})
|
|
tcpH := header.TCP(b[offset+20:])
|
|
tcpH.Encode(&header.TCPFields{
|
|
SrcPort: srcIPPort.Port(),
|
|
DstPort: dstIPPort.Port(),
|
|
SeqNum: seq,
|
|
AckNum: 1,
|
|
DataOffset: 20,
|
|
Flags: flags,
|
|
WindowSize: 3000,
|
|
})
|
|
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
|
|
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
|
|
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
|
return b
|
|
}
|
|
|
|
func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
|
|
totalLen := 60 + segmentSize
|
|
b := make([]byte, offset+int(totalLen), 65535)
|
|
ipv6H := header.IPv6(b[offset:])
|
|
srcAs16 := srcIPPort.Addr().As16()
|
|
dstAs16 := dstIPPort.Addr().As16()
|
|
ipv6H.Encode(&header.IPv6Fields{
|
|
SrcAddr: tcpip.Address(srcAs16[:]),
|
|
DstAddr: tcpip.Address(dstAs16[:]),
|
|
TransportProtocol: unix.IPPROTO_TCP,
|
|
HopLimit: 64,
|
|
PayloadLength: uint16(segmentSize + 20),
|
|
})
|
|
tcpH := header.TCP(b[offset+40:])
|
|
tcpH.Encode(&header.TCPFields{
|
|
SrcPort: srcIPPort.Port(),
|
|
DstPort: dstIPPort.Port(),
|
|
SeqNum: seq,
|
|
AckNum: 1,
|
|
DataOffset: 20,
|
|
Flags: flags,
|
|
WindowSize: 3000,
|
|
})
|
|
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
|
|
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
|
return b
|
|
}
|
|
|
|
func Test_handleVirtioRead(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
hdr virtioNetHdr
|
|
pktIn []byte
|
|
wantLens []int
|
|
wantErr bool
|
|
}{
|
|
{
|
|
"tcp4",
|
|
virtioNetHdr{
|
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
|
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
|
gsoSize: 100,
|
|
hdrLen: 40,
|
|
csumStart: 20,
|
|
csumOffset: 16,
|
|
},
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
|
|
[]int{140, 140},
|
|
false,
|
|
},
|
|
{
|
|
"tcp6",
|
|
virtioNetHdr{
|
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
|
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
|
|
gsoSize: 100,
|
|
hdrLen: 60,
|
|
csumStart: 40,
|
|
csumOffset: 16,
|
|
},
|
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
|
|
[]int{160, 160},
|
|
false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
out := make([][]byte, conn.DefaultBatchSize)
|
|
sizes := make([]int, conn.DefaultBatchSize)
|
|
for i := range out {
|
|
out[i] = make([]byte, 65535)
|
|
}
|
|
tt.hdr.encode(tt.pktIn)
|
|
n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
|
|
if err != nil {
|
|
if tt.wantErr {
|
|
return
|
|
}
|
|
t.Fatalf("got err: %v", err)
|
|
}
|
|
if n != len(tt.wantLens) {
|
|
t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
|
|
}
|
|
for i := range tt.wantLens {
|
|
if tt.wantLens[i] != sizes[i] {
|
|
t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func flipTCP4Checksum(b []byte) []byte {
|
|
at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
|
|
b[at] ^= 0xFF
|
|
b[at+1] ^= 0xFF
|
|
return b
|
|
}
|
|
|
|
func Fuzz_handleGRO(f *testing.F) {
|
|
pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
|
|
pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
|
|
pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
|
|
pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
|
|
pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
|
|
pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
|
|
f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset)
|
|
f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) {
|
|
pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5}
|
|
toWrite := make([]int, 0, len(pkts))
|
|
handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
|
|
if len(toWrite) > len(pkts) {
|
|
t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
|
|
}
|
|
seenWriteI := make(map[int]bool)
|
|
for _, writeI := range toWrite {
|
|
if writeI < 0 || writeI > len(pkts)-1 {
|
|
t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
|
|
}
|
|
if seenWriteI[writeI] {
|
|
t.Errorf("duplicate toWrite value: %d", writeI)
|
|
}
|
|
seenWriteI[writeI] = true
|
|
}
|
|
})
|
|
}
|
|
|
|
func Test_handleGRO(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
pktsIn [][]byte
|
|
wantToWrite []int
|
|
wantLens []int
|
|
wantErr bool
|
|
}{
|
|
{
|
|
"multiple flows",
|
|
[][]byte{
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1
|
|
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2
|
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
|
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1
|
|
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2
|
|
},
|
|
[]int{0, 2, 3, 5},
|
|
[]int{240, 140, 260, 160},
|
|
false,
|
|
},
|
|
{
|
|
"PSH interleaved",
|
|
[][]byte{
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
|
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
|
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
|
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
|
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
|
|
},
|
|
[]int{0, 2, 4, 6},
|
|
[]int{240, 240, 260, 260},
|
|
false,
|
|
},
|
|
{
|
|
"coalesceItemInvalidCSum",
|
|
[][]byte{
|
|
flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
|
|
},
|
|
[]int{0, 1},
|
|
[]int{140, 240},
|
|
false,
|
|
},
|
|
{
|
|
"out of order",
|
|
[][]byte{
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
|
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
|
|
},
|
|
[]int{0},
|
|
[]int{340},
|
|
false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
toWrite := make([]int, 0, len(tt.pktsIn))
|
|
err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
|
|
if err != nil {
|
|
if tt.wantErr {
|
|
return
|
|
}
|
|
t.Fatalf("got err: %v", err)
|
|
}
|
|
if len(toWrite) != len(tt.wantToWrite) {
|
|
t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
|
|
}
|
|
for i, pktI := range tt.wantToWrite {
|
|
if tt.wantToWrite[i] != toWrite[i] {
|
|
t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
|
|
}
|
|
if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
|
|
t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|