1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

device: use atomic access for unlocked keypair.next

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2020-04-05 18:51:15 -06:00 committed by David Crawshaw
parent 1ecbb3313c
commit 682401a177
4 changed files with 26 additions and 11 deletions

View File

@ -8,7 +8,9 @@ package device
import ( import (
"crypto/cipher" "crypto/cipher"
"sync" "sync"
"sync/atomic"
"time" "time"
"unsafe"
"golang.zx2c4.com/wireguard/replay" "golang.zx2c4.com/wireguard/replay"
) )
@ -35,7 +37,15 @@ type Keypairs struct {
sync.RWMutex sync.RWMutex
current *Keypair current *Keypair
previous *Keypair previous *Keypair
next *Keypair next unsafe.Pointer // *Keypair, access via LoadNext/StoreNext
}
func (kp *Keypairs) StoreNext(next *Keypair) {
atomic.StorePointer(&kp.next, (unsafe.Pointer)(next))
}
func (kp *Keypairs) LoadNext() *Keypair {
return (*Keypair)(atomic.LoadPointer(&kp.next))
} }
func (kp *Keypairs) Current() *Keypair { func (kp *Keypairs) Current() *Keypair {

View File

@ -14,6 +14,7 @@ import (
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n" "golang.zx2c4.com/wireguard/tai64n"
) )
@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock() defer keypairs.Unlock()
previous := keypairs.previous previous := keypairs.previous
next := keypairs.next next := keypairs.LoadNext()
current := keypairs.current current := keypairs.current
if isInitiator { if isInitiator {
if next != nil { if next != nil {
keypairs.next = nil keypairs.StoreNext(nil)
keypairs.previous = next keypairs.previous = next
device.DeleteKeypair(current) device.DeleteKeypair(current)
} else { } else {
@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
keypairs.current = keypair keypairs.current = keypair
} else { } else {
keypairs.next = keypair keypairs.StoreNext(keypair)
device.DeleteKeypair(next) device.DeleteKeypair(next)
keypairs.previous = nil keypairs.previous = nil
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
@ -608,15 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs keypairs := &peer.keypairs
if keypairs.LoadNext() != receivedKeypair {
return false
}
keypairs.Lock() keypairs.Lock()
defer keypairs.Unlock() defer keypairs.Unlock()
if keypairs.next != receivedKeypair { if keypairs.LoadNext() != receivedKeypair {
return false return false
} }
old := keypairs.previous old := keypairs.previous
keypairs.previous = keypairs.current keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old) peer.device.DeleteKeypair(old)
keypairs.current = keypairs.next keypairs.current = keypairs.LoadNext()
keypairs.next = nil keypairs.StoreNext(nil)
return true return true
} }

View File

@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err) t.Fatal("failed to derive keypair for peer 2", err)
} }
key1 := peer1.keypairs.next key1 := peer1.keypairs.LoadNext()
key2 := peer2.keypairs.current key2 := peer2.keypairs.current
// encrypting / decryption test // encrypting / decryption test

View File

@ -226,10 +226,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock() keypairs.Lock()
device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current) device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next) device.DeleteKeypair(keypairs.LoadNext())
keypairs.previous = nil keypairs.previous = nil
keypairs.current = nil keypairs.current = nil
keypairs.next = nil keypairs.StoreNext(nil)
keypairs.Unlock() keypairs.Unlock()
// clear handshake state // clear handshake state
@ -257,7 +257,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs.current.sendNonce = RejectAfterMessages keypairs.current.sendNonce = RejectAfterMessages
} }
if keypairs.next != nil { if keypairs.next != nil {
keypairs.next.sendNonce = RejectAfterMessages keypairs.LoadNext().sendNonce = RejectAfterMessages
} }
keypairs.Unlock() keypairs.Unlock()
} }