mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
device: use atomic access for unlocked keypair.next
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
1ecbb3313c
commit
682401a177
@ -8,7 +8,9 @@ package device
|
|||||||
import (
|
import (
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/replay"
|
"golang.zx2c4.com/wireguard/replay"
|
||||||
)
|
)
|
||||||
@ -35,7 +37,15 @@ type Keypairs struct {
|
|||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
current *Keypair
|
current *Keypair
|
||||||
previous *Keypair
|
previous *Keypair
|
||||||
next *Keypair
|
next unsafe.Pointer // *Keypair, access via LoadNext/StoreNext
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kp *Keypairs) StoreNext(next *Keypair) {
|
||||||
|
atomic.StorePointer(&kp.next, (unsafe.Pointer)(next))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kp *Keypairs) LoadNext() *Keypair {
|
||||||
|
return (*Keypair)(atomic.LoadPointer(&kp.next))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kp *Keypairs) Current() *Keypair {
|
func (kp *Keypairs) Current() *Keypair {
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/crypto/poly1305"
|
"golang.org/x/crypto/poly1305"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tai64n"
|
"golang.zx2c4.com/wireguard/tai64n"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
|
|
||||||
previous := keypairs.previous
|
previous := keypairs.previous
|
||||||
next := keypairs.next
|
next := keypairs.LoadNext()
|
||||||
current := keypairs.current
|
current := keypairs.current
|
||||||
|
|
||||||
if isInitiator {
|
if isInitiator {
|
||||||
if next != nil {
|
if next != nil {
|
||||||
keypairs.next = nil
|
keypairs.StoreNext(nil)
|
||||||
keypairs.previous = next
|
keypairs.previous = next
|
||||||
device.DeleteKeypair(current)
|
device.DeleteKeypair(current)
|
||||||
} else {
|
} else {
|
||||||
@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
keypairs.current = keypair
|
keypairs.current = keypair
|
||||||
} else {
|
} else {
|
||||||
keypairs.next = keypair
|
keypairs.StoreNext(keypair)
|
||||||
device.DeleteKeypair(next)
|
device.DeleteKeypair(next)
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
@ -608,15 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
|
|
||||||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||||
keypairs := &peer.keypairs
|
keypairs := &peer.keypairs
|
||||||
|
|
||||||
|
if keypairs.LoadNext() != receivedKeypair {
|
||||||
|
return false
|
||||||
|
}
|
||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
if keypairs.next != receivedKeypair {
|
if keypairs.LoadNext() != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
old := keypairs.previous
|
old := keypairs.previous
|
||||||
keypairs.previous = keypairs.current
|
keypairs.previous = keypairs.current
|
||||||
peer.device.DeleteKeypair(old)
|
peer.device.DeleteKeypair(old)
|
||||||
keypairs.current = keypairs.next
|
keypairs.current = keypairs.LoadNext()
|
||||||
keypairs.next = nil
|
keypairs.StoreNext(nil)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
|
|||||||
t.Fatal("failed to derive keypair for peer 2", err)
|
t.Fatal("failed to derive keypair for peer 2", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key1 := peer1.keypairs.next
|
key1 := peer1.keypairs.LoadNext()
|
||||||
key2 := peer2.keypairs.current
|
key2 := peer2.keypairs.current
|
||||||
|
|
||||||
// encrypting / decryption test
|
// encrypting / decryption test
|
||||||
|
@ -226,10 +226,10 @@ func (peer *Peer) ZeroAndFlushAll() {
|
|||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
device.DeleteKeypair(keypairs.previous)
|
device.DeleteKeypair(keypairs.previous)
|
||||||
device.DeleteKeypair(keypairs.current)
|
device.DeleteKeypair(keypairs.current)
|
||||||
device.DeleteKeypair(keypairs.next)
|
device.DeleteKeypair(keypairs.LoadNext())
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
keypairs.current = nil
|
keypairs.current = nil
|
||||||
keypairs.next = nil
|
keypairs.StoreNext(nil)
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
|
|
||||||
// clear handshake state
|
// clear handshake state
|
||||||
@ -257,7 +257,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
|
|||||||
keypairs.current.sendNonce = RejectAfterMessages
|
keypairs.current.sendNonce = RejectAfterMessages
|
||||||
}
|
}
|
||||||
if keypairs.next != nil {
|
if keypairs.next != nil {
|
||||||
keypairs.next.sendNonce = RejectAfterMessages
|
keypairs.LoadNext().sendNonce = RejectAfterMessages
|
||||||
}
|
}
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user