1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

tun: use correct IP header comparisons in tcpGRO() and tcpPacketsCanCoalesce()

tcpGRO() was using an incorrect IPv4 more fragments bit mask.

tcpPacketsCanCoalesce() was not distinguishing tcp6 from tcp4, and TTL
values were not compared. TTL values should be equal at the IP layer,
otherwise the packets should not coalesce. This tracks with the kernel.

Reviewed-by: Denton Gentry <dgentry@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jordan Whited 2023-03-24 16:23:42 -07:00 committed by Jason A. Donenfeld
parent aad7fca9c5
commit 052af4a807
2 changed files with 119 additions and 16 deletions

View File

@ -189,6 +189,16 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet
return coalesceUnavailable return coalesceUnavailable
} }
} }
if pkt[0]>>4 == 6 {
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
// cannot coalesce with unequal Traffic class values
return coalesceUnavailable
}
if pkt[7] != pktTarget[7] {
// cannot coalesce with unequal Hop limit values
return coalesceUnavailable
}
} else {
if pkt[1] != pktTarget[1] { if pkt[1] != pktTarget[1] {
// cannot coalesce with unequal ToS values // cannot coalesce with unequal ToS values
return coalesceUnavailable return coalesceUnavailable
@ -198,6 +208,11 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet
// further up the stack. // further up the stack.
return coalesceUnavailable return coalesceUnavailable
} }
if pkt[8] != pktTarget[8] {
// cannot coalesce with unequal TTL values
return coalesceUnavailable
}
}
// seq adjacency // seq adjacency
lhsLen := item.gsoSize lhsLen := item.gsoSize
lhsLen += item.numMerged * item.gsoSize lhsLen += item.numMerged * item.gsoSize
@ -366,7 +381,7 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize
} }
const ( const (
ipv4FlagMoreFragments = 0x80 ipv4FlagMoreFragments uint8 = 0x20
) )
const ( const (
@ -409,7 +424,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
return false return false
} }
if !isV6 { if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || (pkt[6]<<3 != 0 || pkt[7] != 0) { if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now // no GRO support for fragmented segments for now
return false return false
} }

View File

@ -28,19 +28,23 @@ var (
ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
) )
func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
totalLen := 40 + segmentSize totalLen := 40 + segmentSize
b := make([]byte, offset+int(totalLen), 65535) b := make([]byte, offset+int(totalLen), 65535)
ipv4H := header.IPv4(b[offset:]) ipv4H := header.IPv4(b[offset:])
srcAs4 := srcIPPort.Addr().As4() srcAs4 := srcIPPort.Addr().As4()
dstAs4 := dstIPPort.Addr().As4() dstAs4 := dstIPPort.Addr().As4()
ipv4H.Encode(&header.IPv4Fields{ ipFields := &header.IPv4Fields{
SrcAddr: tcpip.Address(srcAs4[:]), SrcAddr: tcpip.Address(srcAs4[:]),
DstAddr: tcpip.Address(dstAs4[:]), DstAddr: tcpip.Address(dstAs4[:]),
Protocol: unix.IPPROTO_TCP, Protocol: unix.IPPROTO_TCP,
TTL: 64, TTL: 64,
TotalLength: uint16(totalLen), TotalLength: uint16(totalLen),
}) }
if ipFn != nil {
ipFn(ipFields)
}
ipv4H.Encode(ipFields)
tcpH := header.TCP(b[offset+20:]) tcpH := header.TCP(b[offset+20:])
tcpH.Encode(&header.TCPFields{ tcpH.Encode(&header.TCPFields{
SrcPort: srcIPPort.Port(), SrcPort: srcIPPort.Port(),
@ -57,19 +61,27 @@ func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segm
return b return b
} }
func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
}
func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
totalLen := 60 + segmentSize totalLen := 60 + segmentSize
b := make([]byte, offset+int(totalLen), 65535) b := make([]byte, offset+int(totalLen), 65535)
ipv6H := header.IPv6(b[offset:]) ipv6H := header.IPv6(b[offset:])
srcAs16 := srcIPPort.Addr().As16() srcAs16 := srcIPPort.Addr().As16()
dstAs16 := dstIPPort.Addr().As16() dstAs16 := dstIPPort.Addr().As16()
ipv6H.Encode(&header.IPv6Fields{ ipFields := &header.IPv6Fields{
SrcAddr: tcpip.Address(srcAs16[:]), SrcAddr: tcpip.Address(srcAs16[:]),
DstAddr: tcpip.Address(dstAs16[:]), DstAddr: tcpip.Address(dstAs16[:]),
TransportProtocol: unix.IPPROTO_TCP, TransportProtocol: unix.IPPROTO_TCP,
HopLimit: 64, HopLimit: 64,
PayloadLength: uint16(segmentSize + 20), PayloadLength: uint16(segmentSize + 20),
}) }
if ipFn != nil {
ipFn(ipFields)
}
ipv6H.Encode(ipFields)
tcpH := header.TCP(b[offset+40:]) tcpH := header.TCP(b[offset+40:])
tcpH.Encode(&header.TCPFields{ tcpH.Encode(&header.TCPFields{
SrcPort: srcIPPort.Port(), SrcPort: srcIPPort.Port(),
@ -85,6 +97,10 @@ func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segm
return b return b
} }
func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
}
func Test_handleVirtioRead(t *testing.T) { func Test_handleVirtioRead(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -245,6 +261,78 @@ func Test_handleGRO(t *testing.T) {
[]int{340}, []int{340},
false, false,
}, },
{
"tcp4 unequal TTL",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.TTL++
}),
},
[]int{0, 1},
[]int{140, 140},
false,
},
{
"tcp4 unequal ToS",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.TOS++
}),
},
[]int{0, 1},
[]int{140, 140},
false,
},
{
"tcp4 unequal flags more fragments set",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.Flags = 1
}),
},
[]int{0, 1},
[]int{140, 140},
false,
},
{
"tcp4 unequal flags DF set",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
fields.Flags = 2
}),
},
[]int{0, 1},
[]int{140, 140},
false,
},
{
"tcp6 unequal hop limit",
[][]byte{
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
fields.HopLimit++
}),
},
[]int{0, 1},
[]int{160, 160},
false,
},
{
"tcp6 unequal traffic class",
[][]byte{
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
fields.TrafficClass++
}),
},
[]int{0, 1},
[]int{160, 160},
false,
},
} }
for _, tt := range tests { for _, tt := range tests {