1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

device: change Peer.endpoint locking to reduce contention

Access to Peer.endpoint was previously synchronized by Peer.RWMutex.
This has now moved to Peer.endpoint.Mutex. Peer.SendBuffers() is now the
sole caller of Endpoint.ClearSrc(), which is signaled via a new bool,
Peer.endpoint.clearSrcOnTx. Previous Callers of Endpoint.ClearSrc() now
set this bool, primarily via peer.markEndpointSrcForClearing().
Peer.SetEndpointFromPacket() clears Peer.endpoint.clearSrcOnTx when an
updated conn.Endpoint is stored. This maintains the same event order as
before, i.e. a conn.Endpoint received after peer.endpoint.clearSrcOnTx
is set, but before the next Peer.SendBuffers() call results in the
latest conn.Endpoint source being used for the next packet transmission.

These changes result in throughput improvements for single flow,
parallel (-P n) flow, and bidirectional (--bidir) flow iperf3 TCP/UDP
tests as measured on both Linux and Windows. Latency under load improves
especially for high throughput Linux scenarios. These improvements are
likely realized on all platforms to some degree, as the changes are not
platform-specific.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jordan Whited 2023-11-20 16:49:06 -08:00 committed by Jason A. Donenfeld
parent d0bc03c707
commit 4ffa9c2032
6 changed files with 83 additions and 81 deletions

View File

@ -461,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error {
// clear cached source addresses // clear cached source addresses
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Lock() peer.markEndpointSrcForClearing()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
} }
device.peers.RUnlock() device.peers.RUnlock()
@ -515,11 +511,7 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses // clear cached source addresses
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Lock() peer.markEndpointSrcForClearing()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
} }
device.peers.RUnlock() device.peers.RUnlock()

View File

@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.net.brokenRoaming = true device.net.brokenRoaming = true
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Lock() peer.endpoint.Lock()
peer.disableRoaming = peer.endpoint != nil peer.endpoint.disableRoaming = peer.endpoint.val != nil
peer.Unlock() peer.endpoint.Unlock()
} }
device.peers.RUnlock() device.peers.RUnlock()
} }

View File

@ -17,17 +17,20 @@ import (
type Peer struct { type Peer struct {
isRunning atomic.Bool isRunning atomic.Bool
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs keypairs Keypairs
handshake Handshake handshake Handshake
device *Device device *Device
endpoint conn.Endpoint
stopping sync.WaitGroup // routines pending stop stopping sync.WaitGroup // routines pending stop
txBytes atomic.Uint64 // bytes send to peer (endpoint) txBytes atomic.Uint64 // bytes send to peer (endpoint)
rxBytes atomic.Uint64 // bytes received from peer rxBytes atomic.Uint64 // bytes received from peer
lastHandshakeNano atomic.Int64 // nano seconds since epoch lastHandshakeNano atomic.Int64 // nano seconds since epoch
disableRoaming bool endpoint struct {
sync.Mutex
val conn.Endpoint
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
disableRoaming bool
}
timers struct { timers struct {
retransmitHandshake *Timer retransmitHandshake *Timer
@ -74,8 +77,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer // create peer
peer := new(Peer) peer := new(Peer)
peer.Lock()
defer peer.Unlock()
peer.cookieGenerator.Init(pk) peer.cookieGenerator.Init(pk)
peer.device = device peer.device = device
@ -97,7 +98,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake.mutex.Unlock() handshake.mutex.Unlock()
// reset endpoint // reset endpoint
peer.endpoint = nil peer.endpoint.Lock()
peer.endpoint.val = nil
peer.endpoint.disableRoaming = false
peer.endpoint.clearSrcOnTx = false
peer.endpoint.Unlock()
// init timers // init timers
peer.timersInit() peer.timersInit()
@ -116,14 +121,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
return nil return nil
} }
peer.RLock() peer.endpoint.Lock()
defer peer.RUnlock() endpoint := peer.endpoint.val
if endpoint == nil {
if peer.endpoint == nil { peer.endpoint.Unlock()
return errors.New("no known endpoint for peer") return errors.New("no known endpoint for peer")
} }
if peer.endpoint.clearSrcOnTx {
endpoint.ClearSrc()
peer.endpoint.clearSrcOnTx = false
}
peer.endpoint.Unlock()
err := peer.device.net.bind.Send(buffers, peer.endpoint) err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil { if err == nil {
var totalLen uint64 var totalLen uint64
for _, b := range buffers { for _, b := range buffers {
@ -267,10 +277,20 @@ func (peer *Peer) Stop() {
} }
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
if peer.disableRoaming { peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.disableRoaming {
return return
} }
peer.Lock() peer.endpoint.clearSrcOnTx = false
peer.endpoint = endpoint peer.endpoint.val = endpoint
peer.Unlock() }
func (peer *Peer) markEndpointSrcForClearing() {
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.val == nil {
return
}
peer.endpoint.clearSrcOnTx = true
} }

View File

@ -110,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
if !ok { if !ok {
break break
} }
pePtr.peer.Lock() pePtr.peer.endpoint.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint { if &pePtr.peer.endpoint.val != pePtr.endpoint {
pePtr.peer.Unlock() pePtr.peer.endpoint.Unlock()
break break
} }
if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.Unlock() pePtr.peer.endpoint.Unlock()
break break
} }
pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc() pePtr.peer.endpoint.clearSrcOnTx = true
pePtr.peer.Unlock() pePtr.peer.endpoint.Unlock()
} }
attr = attr[attrhdr.Len:] attr = attr[attrhdr.Len:]
} }
@ -134,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
device.peers.RLock() device.peers.RLock()
i := uint32(1) i := uint32(1)
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.RLock() peer.endpoint.Lock()
if peer.endpoint == nil { if peer.endpoint.val == nil {
peer.RUnlock() peer.endpoint.Unlock()
continue continue
} }
nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint) nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil { if nativeEP == nil {
peer.RUnlock() peer.endpoint.Unlock()
continue continue
} }
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
peer.RUnlock() peer.endpoint.Unlock()
break break
} }
nlmsg := struct { nlmsg := struct {
@ -188,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
reqPeerLock.Lock() reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{ reqPeer[i] = peerEndpointPtr{
peer: peer, peer: peer,
endpoint: &peer.endpoint, endpoint: &peer.endpoint.val,
} }
reqPeerLock.Unlock() reqPeerLock.Unlock()
peer.RUnlock() peer.endpoint.Unlock()
i++ i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil { if err != nil {

View File

@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock() peer.markEndpointSrcForClearing()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.SendHandshakeInitiation(true) peer.SendHandshakeInitiation(true)
} }
@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) {
func expiredNewHandshake(peer *Peer) { func expiredNewHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock() peer.markEndpointSrcForClearing()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }

View File

@ -99,33 +99,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
// Serialize peer state. // Serialize peer state.
// Do the work in an anonymous function so that we can use defer. peer.handshake.mutex.RLock()
func() { keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
peer.RLock() keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
defer peer.RUnlock() peer.handshake.mutex.RUnlock()
sendf("protocol_version=1")
peer.endpoint.Lock()
if peer.endpoint.val != nil {
sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
peer.endpoint.Unlock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) nano := peer.lastHandshakeNano.Load()
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) secs := nano / time.Second.Nanoseconds()
sendf("protocol_version=1") nano %= time.Second.Nanoseconds()
if peer.endpoint != nil {
sendf("endpoint=%s", peer.endpoint.DstToString())
}
nano := peer.lastHandshakeNano.Load() sendf("last_handshake_time_sec=%d", secs)
secs := nano / time.Second.Nanoseconds() sendf("last_handshake_time_nsec=%d", nano)
nano %= time.Second.Nanoseconds() sendf("tx_bytes=%d", peer.txBytes.Load())
sendf("rx_bytes=%d", peer.rxBytes.Load())
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
sendf("last_handshake_time_sec=%d", secs) device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("last_handshake_time_nsec=%d", nano) sendf("allowed_ip=%s", prefix.String())
sendf("tx_bytes=%d", peer.txBytes.Load()) return true
sendf("rx_bytes=%d", peer.rxBytes.Load()) })
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String())
return true
})
}()
} }
}() }()
@ -262,7 +260,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
return return
} }
if peer.created { if peer.created {
peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
} }
if peer.device.isUp() { if peer.device.isUp() {
peer.Start() peer.Start()
@ -345,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
} }
peer.Lock() peer.endpoint.Lock()
defer peer.Unlock() defer peer.endpoint.Unlock()
peer.endpoint = endpoint peer.endpoint.val = endpoint
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)