mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 09:15:14 +01:00
device: fix persistent_keepalive_interval data races
Co-authored-by: David Anderson <danderson@tailscale.com> Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
parent
71ef21087e
commit
e739ff71a5
@ -163,7 +163,7 @@ func deviceUpdateState(device *Device) {
|
|||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.Start()
|
peer.Start()
|
||||||
if peer.persistentKeepaliveInterval > 0 {
|
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -215,7 +215,20 @@ func TestConcurrencySafety(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
warmup.Wait()
|
warmup.Wait()
|
||||||
|
|
||||||
// coming soon: more things here...
|
// Change persistent_keepalive_interval concurrently with tunnel use.
|
||||||
|
t.Run("persistentKeepaliveInterval", func(t *testing.T) {
|
||||||
|
cfg := uapiCfg(
|
||||||
|
"public_key", "f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725",
|
||||||
|
"persistent_keepalive_interval", "1",
|
||||||
|
)
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
cfg.Seek(0, io.SeekStart)
|
||||||
|
err := pair[0].dev.IpcSetOperation(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
close(done)
|
close(done)
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ type Peer struct {
|
|||||||
handshake Handshake
|
handshake Handshake
|
||||||
device *Device
|
device *Device
|
||||||
endpoint conn.Endpoint
|
endpoint conn.Endpoint
|
||||||
persistentKeepaliveInterval uint16
|
persistentKeepaliveInterval uint32 // accessed atomically
|
||||||
disableRoaming bool
|
disableRoaming bool
|
||||||
|
|
||||||
// These fields are accessed with atomic operations, which must be
|
// These fields are accessed with atomic operations, which must be
|
||||||
|
@ -138,7 +138,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func expiredPersistentKeepalive(peer *Peer) {
|
func expiredPersistentKeepalive(peer *Peer) {
|
||||||
if peer.persistentKeepaliveInterval > 0 {
|
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -201,8 +201,9 @@ func (peer *Peer) timersSessionDerived() {
|
|||||||
|
|
||||||
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
|
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
|
||||||
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
|
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
|
||||||
if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
|
keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval)
|
||||||
peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
|
if keepalive > 0 && peer.timersActive() {
|
||||||
|
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
|
|||||||
send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
|
send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
|
||||||
send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
|
send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
|
||||||
send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
|
send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
|
||||||
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
|
send(fmt.Sprintf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)))
|
||||||
|
|
||||||
for _, ip := range device.allowedips.EntriesForPeer(peer) {
|
for _, ip := range device.allowedips.EntriesForPeer(peer) {
|
||||||
send("allowed_ip=" + ip.String())
|
send("allowed_ip=" + ip.String())
|
||||||
@ -333,8 +333,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
|||||||
return &IPCError{ipc.IpcErrorInvalid}
|
return &IPCError{ipc.IpcErrorInvalid}
|
||||||
}
|
}
|
||||||
|
|
||||||
old := peer.persistentKeepaliveInterval
|
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
|
||||||
peer.persistentKeepaliveInterval = uint16(secs)
|
|
||||||
|
|
||||||
// send immediate keepalive if we're turning it on and before it wasn't on
|
// send immediate keepalive if we're turning it on and before it wasn't on
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user