mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
4201e08f1d
After reducing UDP stack traversal overhead via GSO and GRO, runtime.chanrecv() began to account for a high percentage (20% in one environment) of perf samples during a throughput benchmark. The individual packet channel ops with the crypto goroutines was the primary contributor to this overhead. Updating these channels to pass vectors, which the device package already handles at its ends, reduced this overhead substantially, and improved throughput. The iperf3 results below demonstrate the effect of this commit between two Linux computers with i5-12400 CPUs. There is roughly ~13us of round trip latency between them. The first result is with UDP GSO and GRO, and with single element channels. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 3.15 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 sender [ 5] 0.00-10.04 sec 12.3 GBytes 10.6 Gbits/sec receiver The second result is with channels updated to pass a slice of elements. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 13.2 GBytes 11.3 Gbits/sec 182 3.15 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 13.2 GBytes 11.3 Gbits/sec 182 sender [ 5] 0.00-10.04 sec 13.2 GBytes 11.3 Gbits/sec receiver Reviewed-by: Adrian Dewhurst <adrian@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
526 lines
13 KiB
Go
526 lines
13 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package device
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
"golang.org/x/net/ipv4"
|
|
"golang.org/x/net/ipv6"
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
)
|
|
|
|
type QueueHandshakeElement struct {
|
|
msgType uint32
|
|
packet []byte
|
|
endpoint conn.Endpoint
|
|
buffer *[MaxMessageSize]byte
|
|
}
|
|
|
|
type QueueInboundElement struct {
|
|
sync.Mutex
|
|
buffer *[MaxMessageSize]byte
|
|
packet []byte
|
|
counter uint64
|
|
keypair *Keypair
|
|
endpoint conn.Endpoint
|
|
}
|
|
|
|
// clearPointers clears elem fields that contain pointers.
|
|
// This makes the garbage collector's life easier and
|
|
// avoids accidentally keeping other objects around unnecessarily.
|
|
// It also reduces the possible collateral damage from use-after-free bugs.
|
|
func (elem *QueueInboundElement) clearPointers() {
|
|
elem.buffer = nil
|
|
elem.packet = nil
|
|
elem.keypair = nil
|
|
elem.endpoint = nil
|
|
}
|
|
|
|
/* Called when a new authenticated message has been received
|
|
*
|
|
* NOTE: Not thread safe, but called by sequential receiver!
|
|
*/
|
|
func (peer *Peer) keepKeyFreshReceiving() {
|
|
if peer.timers.sentLastMinuteHandshake.Load() {
|
|
return
|
|
}
|
|
keypair := peer.keypairs.Current()
|
|
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
|
peer.timers.sentLastMinuteHandshake.Store(true)
|
|
peer.SendHandshakeInitiation(false)
|
|
}
|
|
}
|
|
|
|
/* Receives incoming datagrams for the device
|
|
*
|
|
* Every time the bind is updated a new routine is started for
|
|
* IPv4 and IPv6 (separately)
|
|
*/
|
|
func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
|
|
recvName := recv.PrettyName()
|
|
defer func() {
|
|
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
|
|
device.queue.decryption.wg.Done()
|
|
device.queue.handshake.wg.Done()
|
|
device.net.stopping.Done()
|
|
}()
|
|
|
|
device.log.Verbosef("Routine: receive incoming %s - started", recvName)
|
|
|
|
// receive datagrams until conn is closed
|
|
|
|
var (
|
|
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
|
|
bufs = make([][]byte, maxBatchSize)
|
|
err error
|
|
sizes = make([]int, maxBatchSize)
|
|
count int
|
|
endpoints = make([]conn.Endpoint, maxBatchSize)
|
|
deathSpiral int
|
|
elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
|
|
)
|
|
|
|
for i := range bufsArrs {
|
|
bufsArrs[i] = device.GetMessageBuffer()
|
|
bufs[i] = bufsArrs[i][:]
|
|
}
|
|
|
|
defer func() {
|
|
for i := 0; i < maxBatchSize; i++ {
|
|
if bufsArrs[i] != nil {
|
|
device.PutMessageBuffer(bufsArrs[i])
|
|
}
|
|
}
|
|
}()
|
|
|
|
for {
|
|
count, err = recv(bufs, sizes, endpoints)
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
return
|
|
}
|
|
device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
|
|
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
|
|
return
|
|
}
|
|
if deathSpiral < 10 {
|
|
deathSpiral++
|
|
time.Sleep(time.Second / 3)
|
|
continue
|
|
}
|
|
return
|
|
}
|
|
deathSpiral = 0
|
|
|
|
// handle each packet in the batch
|
|
for i, size := range sizes[:count] {
|
|
if size < MinMessageSize {
|
|
continue
|
|
}
|
|
|
|
// check size of packet
|
|
|
|
packet := bufsArrs[i][:size]
|
|
msgType := binary.LittleEndian.Uint32(packet[:4])
|
|
|
|
switch msgType {
|
|
|
|
// check if transport
|
|
|
|
case MessageTransportType:
|
|
|
|
// check size
|
|
|
|
if len(packet) < MessageTransportSize {
|
|
continue
|
|
}
|
|
|
|
// lookup key pair
|
|
|
|
receiver := binary.LittleEndian.Uint32(
|
|
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
|
)
|
|
value := device.indexTable.Lookup(receiver)
|
|
keypair := value.keypair
|
|
if keypair == nil {
|
|
continue
|
|
}
|
|
|
|
// check keypair expiry
|
|
|
|
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
|
continue
|
|
}
|
|
|
|
// create work element
|
|
peer := value.peer
|
|
elem := device.GetInboundElement()
|
|
elem.packet = packet
|
|
elem.buffer = bufsArrs[i]
|
|
elem.keypair = keypair
|
|
elem.endpoint = endpoints[i]
|
|
elem.counter = 0
|
|
elem.Mutex = sync.Mutex{}
|
|
elem.Lock()
|
|
|
|
elemsForPeer, ok := elemsByPeer[peer]
|
|
if !ok {
|
|
elemsForPeer = device.GetInboundElementsSlice()
|
|
elemsByPeer[peer] = elemsForPeer
|
|
}
|
|
*elemsForPeer = append(*elemsForPeer, elem)
|
|
bufsArrs[i] = device.GetMessageBuffer()
|
|
bufs[i] = bufsArrs[i][:]
|
|
continue
|
|
|
|
// otherwise it is a fixed size & handshake related packet
|
|
|
|
case MessageInitiationType:
|
|
if len(packet) != MessageInitiationSize {
|
|
continue
|
|
}
|
|
|
|
case MessageResponseType:
|
|
if len(packet) != MessageResponseSize {
|
|
continue
|
|
}
|
|
|
|
case MessageCookieReplyType:
|
|
if len(packet) != MessageCookieReplySize {
|
|
continue
|
|
}
|
|
|
|
default:
|
|
device.log.Verbosef("Received message with unknown type")
|
|
continue
|
|
}
|
|
|
|
select {
|
|
case device.queue.handshake.c <- QueueHandshakeElement{
|
|
msgType: msgType,
|
|
buffer: bufsArrs[i],
|
|
packet: packet,
|
|
endpoint: endpoints[i],
|
|
}:
|
|
bufsArrs[i] = device.GetMessageBuffer()
|
|
bufs[i] = bufsArrs[i][:]
|
|
default:
|
|
}
|
|
}
|
|
for peer, elems := range elemsByPeer {
|
|
if peer.isRunning.Load() {
|
|
peer.queue.inbound.c <- elems
|
|
device.queue.decryption.c <- elems
|
|
} else {
|
|
for _, elem := range *elems {
|
|
device.PutMessageBuffer(elem.buffer)
|
|
device.PutInboundElement(elem)
|
|
}
|
|
device.PutInboundElementsSlice(elems)
|
|
}
|
|
delete(elemsByPeer, peer)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (device *Device) RoutineDecryption(id int) {
|
|
var nonce [chacha20poly1305.NonceSize]byte
|
|
|
|
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
|
|
device.log.Verbosef("Routine: decryption worker %d - started", id)
|
|
|
|
for elems := range device.queue.decryption.c {
|
|
for _, elem := range *elems {
|
|
// split message into fields
|
|
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
|
content := elem.packet[MessageTransportOffsetContent:]
|
|
|
|
// decrypt and release to consumer
|
|
var err error
|
|
elem.counter = binary.LittleEndian.Uint64(counter)
|
|
// copy counter to nonce
|
|
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
|
elem.packet, err = elem.keypair.receive.Open(
|
|
content[:0],
|
|
nonce[:],
|
|
content,
|
|
nil,
|
|
)
|
|
if err != nil {
|
|
elem.packet = nil
|
|
}
|
|
elem.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
/* Handles incoming packets related to handshake
|
|
*/
|
|
func (device *Device) RoutineHandshake(id int) {
|
|
defer func() {
|
|
device.log.Verbosef("Routine: handshake worker %d - stopped", id)
|
|
device.queue.encryption.wg.Done()
|
|
}()
|
|
device.log.Verbosef("Routine: handshake worker %d - started", id)
|
|
|
|
for elem := range device.queue.handshake.c {
|
|
|
|
// handle cookie fields and ratelimiting
|
|
|
|
switch elem.msgType {
|
|
|
|
case MessageCookieReplyType:
|
|
|
|
// unmarshal packet
|
|
|
|
var reply MessageCookieReply
|
|
reader := bytes.NewReader(elem.packet)
|
|
err := binary.Read(reader, binary.LittleEndian, &reply)
|
|
if err != nil {
|
|
device.log.Verbosef("Failed to decode cookie reply")
|
|
goto skip
|
|
}
|
|
|
|
// lookup peer from index
|
|
|
|
entry := device.indexTable.Lookup(reply.Receiver)
|
|
|
|
if entry.peer == nil {
|
|
goto skip
|
|
}
|
|
|
|
// consume reply
|
|
|
|
if peer := entry.peer; peer.isRunning.Load() {
|
|
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
|
|
if !peer.cookieGenerator.ConsumeReply(&reply) {
|
|
device.log.Verbosef("Could not decrypt invalid cookie response")
|
|
}
|
|
}
|
|
|
|
goto skip
|
|
|
|
case MessageInitiationType, MessageResponseType:
|
|
|
|
// check mac fields and maybe ratelimit
|
|
|
|
if !device.cookieChecker.CheckMAC1(elem.packet) {
|
|
device.log.Verbosef("Received packet with invalid mac1")
|
|
goto skip
|
|
}
|
|
|
|
// endpoints destination address is the source of the datagram
|
|
|
|
if device.IsUnderLoad() {
|
|
|
|
// verify MAC2 field
|
|
|
|
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
|
|
device.SendHandshakeCookie(&elem)
|
|
goto skip
|
|
}
|
|
|
|
// check ratelimiter
|
|
|
|
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
|
|
goto skip
|
|
}
|
|
}
|
|
|
|
default:
|
|
device.log.Errorf("Invalid packet ended up in the handshake queue")
|
|
goto skip
|
|
}
|
|
|
|
// handle handshake initiation/response content
|
|
|
|
switch elem.msgType {
|
|
case MessageInitiationType:
|
|
|
|
// unmarshal
|
|
|
|
var msg MessageInitiation
|
|
reader := bytes.NewReader(elem.packet)
|
|
err := binary.Read(reader, binary.LittleEndian, &msg)
|
|
if err != nil {
|
|
device.log.Errorf("Failed to decode initiation message")
|
|
goto skip
|
|
}
|
|
|
|
// consume initiation
|
|
|
|
peer := device.ConsumeMessageInitiation(&msg)
|
|
if peer == nil {
|
|
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
|
|
goto skip
|
|
}
|
|
|
|
// update timers
|
|
|
|
peer.timersAnyAuthenticatedPacketTraversal()
|
|
peer.timersAnyAuthenticatedPacketReceived()
|
|
|
|
// update endpoint
|
|
peer.SetEndpointFromPacket(elem.endpoint)
|
|
|
|
device.log.Verbosef("%v - Received handshake initiation", peer)
|
|
peer.rxBytes.Add(uint64(len(elem.packet)))
|
|
|
|
peer.SendHandshakeResponse()
|
|
|
|
case MessageResponseType:
|
|
|
|
// unmarshal
|
|
|
|
var msg MessageResponse
|
|
reader := bytes.NewReader(elem.packet)
|
|
err := binary.Read(reader, binary.LittleEndian, &msg)
|
|
if err != nil {
|
|
device.log.Errorf("Failed to decode response message")
|
|
goto skip
|
|
}
|
|
|
|
// consume response
|
|
|
|
peer := device.ConsumeMessageResponse(&msg)
|
|
if peer == nil {
|
|
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
|
|
goto skip
|
|
}
|
|
|
|
// update endpoint
|
|
peer.SetEndpointFromPacket(elem.endpoint)
|
|
|
|
device.log.Verbosef("%v - Received handshake response", peer)
|
|
peer.rxBytes.Add(uint64(len(elem.packet)))
|
|
|
|
// update timers
|
|
|
|
peer.timersAnyAuthenticatedPacketTraversal()
|
|
peer.timersAnyAuthenticatedPacketReceived()
|
|
|
|
// derive keypair
|
|
|
|
err = peer.BeginSymmetricSession()
|
|
|
|
if err != nil {
|
|
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
|
goto skip
|
|
}
|
|
|
|
peer.timersSessionDerived()
|
|
peer.timersHandshakeComplete()
|
|
peer.SendKeepalive()
|
|
}
|
|
skip:
|
|
device.PutMessageBuffer(elem.buffer)
|
|
}
|
|
}
|
|
|
|
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
|
device := peer.device
|
|
defer func() {
|
|
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
|
|
peer.stopping.Done()
|
|
}()
|
|
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
|
|
|
|
bufs := make([][]byte, 0, maxBatchSize)
|
|
|
|
for elems := range peer.queue.inbound.c {
|
|
if elems == nil {
|
|
return
|
|
}
|
|
for _, elem := range *elems {
|
|
elem.Lock()
|
|
if elem.packet == nil {
|
|
// decryption failed
|
|
continue
|
|
}
|
|
|
|
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
|
continue
|
|
}
|
|
|
|
peer.SetEndpointFromPacket(elem.endpoint)
|
|
if peer.ReceivedWithKeypair(elem.keypair) {
|
|
peer.timersHandshakeComplete()
|
|
peer.SendStagedPackets()
|
|
}
|
|
peer.keepKeyFreshReceiving()
|
|
peer.timersAnyAuthenticatedPacketTraversal()
|
|
peer.timersAnyAuthenticatedPacketReceived()
|
|
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
|
|
|
|
if len(elem.packet) == 0 {
|
|
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
|
continue
|
|
}
|
|
peer.timersDataReceived()
|
|
|
|
switch elem.packet[0] >> 4 {
|
|
case 4:
|
|
if len(elem.packet) < ipv4.HeaderLen {
|
|
continue
|
|
}
|
|
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
|
length := binary.BigEndian.Uint16(field)
|
|
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
|
continue
|
|
}
|
|
elem.packet = elem.packet[:length]
|
|
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
|
if device.allowedips.Lookup(src) != peer {
|
|
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
|
continue
|
|
}
|
|
|
|
case 6:
|
|
if len(elem.packet) < ipv6.HeaderLen {
|
|
continue
|
|
}
|
|
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
|
length := binary.BigEndian.Uint16(field)
|
|
length += ipv6.HeaderLen
|
|
if int(length) > len(elem.packet) {
|
|
continue
|
|
}
|
|
elem.packet = elem.packet[:length]
|
|
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
|
if device.allowedips.Lookup(src) != peer {
|
|
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
|
continue
|
|
}
|
|
|
|
default:
|
|
device.log.Verbosef("Packet with invalid IP version from %v", peer)
|
|
continue
|
|
}
|
|
|
|
bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
|
|
}
|
|
if len(bufs) > 0 {
|
|
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
|
|
if err != nil && !device.isClosed() {
|
|
device.log.Errorf("Failed to write packets to TUN device: %v", err)
|
|
}
|
|
}
|
|
for _, elem := range *elems {
|
|
device.PutMessageBuffer(elem.buffer)
|
|
device.PutInboundElement(elem)
|
|
}
|
|
bufs = bufs[:0]
|
|
device.PutInboundElementsSlice(elems)
|
|
}
|
|
}
|