1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 09:15:14 +01:00

More odds and ends

This commit is contained in:
Jason A. Donenfeld 2018-05-13 19:50:58 +02:00
parent 680a57faae
commit 729773fdf3
4 changed files with 44 additions and 36 deletions

View File

@ -319,6 +319,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
handshake.mutex.Unlock() handshake.mutex.Unlock()
setZero(hash[:])
setZero(chainKey[:])
return peer return peer
} }
@ -362,7 +365,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixKey(ss[:]) handshake.mixKey(ss[:])
}() }()
// add preshared key (psk) // add preshared key
var tau [blake2s.Size]byte var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
@ -457,7 +460,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil { if err != nil {
device.log.Debug.Println("failed to open")
return false return false
} }
mixHash(&hash, &hash, msg.Empty[:]) mixHash(&hash, &hash, msg.Empty[:])
@ -485,10 +487,10 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return lookup.peer return lookup.peer
} }
/* Derives a new key-pair from the current handshake state /* Derives a new keypair from the current handshake state
* *
*/ */
func (peer *Peer) NewKeypair() *Keypair { func (peer *Peer) DeriveNewKeypair() error {
device := peer.device device := peer.device
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
@ -517,12 +519,13 @@ func (peer *Peer) NewKeypair() *Keypair {
) )
isInitiator = false isInitiator = false
} else { } else {
return nil return errors.New("invalid state for keypair derivation")
} }
// zero handshake // zero handshake
setZero(handshake.chainKey[:]) setZero(handshake.chainKey[:])
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:]) setZero(handshake.localEphemeral[:])
peer.handshake.state = HandshakeZeroed peer.handshake.state = HandshakeZeroed
@ -576,5 +579,23 @@ func (peer *Peer) NewKeypair() *Keypair {
} }
kp.mutex.Unlock() kp.mutex.Unlock()
return keypair return nil
}
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
kp := &peer.keypairs
if kp.next != receivedKeypair {
return false
}
kp.mutex.Lock()
defer kp.mutex.Unlock()
if kp.next != receivedKeypair {
return false
}
old := kp.previous
kp.previous = kp.current
peer.device.DeleteKeypair(old)
kp.current = kp.next
kp.next = nil
return true
} }

View File

@ -102,15 +102,15 @@ func TestNoiseHandshake(t *testing.T) {
t.Log("deriving keys") t.Log("deriving keys")
key1 := peer1.NewKeypair() key1 := peer1.DeriveNewKeypair()
key2 := peer2.NewKeypair() key2 := peer2.DeriveNewKeypair()
if key1 == nil { if key1 == nil {
t.Fatal("failed to dervice key-pair for peer 1") t.Fatal("failed to dervice keypair for peer 1")
} }
if key2 == nil { if key2 == nil {
t.Fatal("failed to dervice key-pair for peer 2") t.Fatal("failed to dervice keypair for peer 2")
} }
// encrypting / decryption test // encrypting / decryption test

View File

@ -189,7 +189,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
continue continue
} }
// check key-pair expiry // check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) { if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue continue
@ -475,7 +475,7 @@ func (device *Device) RoutineHandshake() {
continue continue
} }
if peer.NewKeypair() == nil { if peer.DeriveNewKeypair() != nil {
continue continue
} }
@ -532,9 +532,9 @@ func (device *Device) RoutineHandshake() {
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived() peer.timersAnyAuthenticatedPacketReceived()
// derive key-pair // derive keypair
if peer.NewKeypair() == nil { if peer.DeriveNewKeypair() != nil {
continue continue
} }
@ -597,27 +597,14 @@ func (peer *Peer) RoutineSequentialReceiver() {
peer.endpoint = elem.endpoint peer.endpoint = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
// check if using new key-pair // check if using new keypair
if peer.ReceivedWithKeypair(elem.keypair) {
kp := &peer.keypairs
if kp.next == elem.keypair {
kp.mutex.Lock()
if kp.next != elem.keypair {
kp.mutex.Unlock()
} else {
old := kp.previous
kp.previous = kp.current
device.DeleteKeypair(old)
kp.current = kp.next
kp.next = nil
kp.mutex.Unlock()
peer.timersHandshakeComplete() peer.timersHandshakeComplete()
select { select {
case peer.signals.newKeypairArrived <- struct{}{}: case peer.signals.newKeypairArrived <- struct{}{}:
default: default:
} }
} }
}
peer.keepKeyFreshReceiving() peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()

View File

@ -47,7 +47,7 @@ type QueueOutboundElement struct {
buffer *[MaxMessageSize]byte // slice holding the packet data buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!) packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption nonce uint64 // nonce for encryption
keypair *Keypair // key-pair for encryption keypair *Keypair // keypair for encryption
peer *Peer // related peer peer *Peer // related peer
} }
@ -306,11 +306,11 @@ func (peer *Peer) RoutineNonce() {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
logDebug.Println(peer, ": Awaiting key-pair") logDebug.Println(peer, ": Awaiting keypair")
select { select {
case <-peer.signals.newKeypairArrived: case <-peer.signals.newKeypairArrived:
logDebug.Println(peer, ": Obtained awaited key-pair") logDebug.Println(peer, ": Obtained awaited keypair")
case <-peer.signals.flushNonceQueue: case <-peer.signals.flushNonceQueue:
for { for {
select { select {