diff --git a/device/keypair.go b/device/keypair.go index 9c78fa9..63fe506 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -8,7 +8,9 @@ package device import ( "crypto/cipher" "sync" + "sync/atomic" "time" + "unsafe" "golang.zx2c4.com/wireguard/replay" ) @@ -35,7 +37,15 @@ type Keypairs struct { sync.RWMutex current *Keypair previous *Keypair - next *Keypair + next unsafe.Pointer // *Keypair, access via LoadNext/StoreNext +} + +func (kp *Keypairs) StoreNext(next *Keypair) { + atomic.StorePointer(&kp.next, (unsafe.Pointer)(next)) +} + +func (kp *Keypairs) LoadNext() *Keypair { + return (*Keypair)(atomic.LoadPointer(&kp.next)) } func (kp *Keypairs) Current() *Keypair { diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 03b872b..c852ac6 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -14,6 +14,7 @@ import ( "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" + "golang.zx2c4.com/wireguard/tai64n" ) @@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error { defer keypairs.Unlock() previous := keypairs.previous - next := keypairs.next + next := keypairs.LoadNext() current := keypairs.current if isInitiator { if next != nil { - keypairs.next = nil + keypairs.StoreNext(nil) keypairs.previous = next device.DeleteKeypair(current) } else { @@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error { device.DeleteKeypair(previous) keypairs.current = keypair } else { - keypairs.next = keypair + keypairs.StoreNext(keypair) device.DeleteKeypair(next) keypairs.previous = nil device.DeleteKeypair(previous) @@ -608,15 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { keypairs := &peer.keypairs + + if keypairs.LoadNext() != receivedKeypair { + return false + } keypairs.Lock() defer keypairs.Unlock() - if keypairs.next != receivedKeypair { + if keypairs.LoadNext() != receivedKeypair { return false } old := keypairs.previous keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) - keypairs.current = keypairs.next - keypairs.next = nil + keypairs.current = keypairs.LoadNext() + keypairs.StoreNext(nil) return true } diff --git a/device/noise_test.go b/device/noise_test.go index 6ba3f2e..6ee5f7b 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) { t.Fatal("failed to derive keypair for peer 2", err) } - key1 := peer1.keypairs.next + key1 := peer1.keypairs.LoadNext() key2 := peer2.keypairs.current // encrypting / decryption test diff --git a/device/peer.go b/device/peer.go index cb348d5..94182e7 100644 --- a/device/peer.go +++ b/device/peer.go @@ -226,10 +226,10 @@ func (peer *Peer) ZeroAndFlushAll() { keypairs.Lock() device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.current) - device.DeleteKeypair(keypairs.next) + device.DeleteKeypair(keypairs.LoadNext()) keypairs.previous = nil keypairs.current = nil - keypairs.next = nil + keypairs.StoreNext(nil) keypairs.Unlock() // clear handshake state @@ -257,7 +257,7 @@ func (peer *Peer) ExpireCurrentKeypairs() { keypairs.current.sendNonce = RejectAfterMessages } if keypairs.next != nil { - keypairs.next.sendNonce = RejectAfterMessages + keypairs.LoadNext().sendNonce = RejectAfterMessages } keypairs.Unlock() }