mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 09:15:14 +01:00
More odds and ends
This commit is contained in:
parent
680a57faae
commit
729773fdf3
@ -319,6 +319,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
|
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
setZero(hash[:])
|
||||||
|
setZero(chainKey[:])
|
||||||
|
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -362,7 +365,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
handshake.mixKey(ss[:])
|
handshake.mixKey(ss[:])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// add preshared key (psk)
|
// add preshared key
|
||||||
|
|
||||||
var tau [blake2s.Size]byte
|
var tau [blake2s.Size]byte
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
@ -457,7 +460,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.log.Debug.Println("failed to open")
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
mixHash(&hash, &hash, msg.Empty[:])
|
mixHash(&hash, &hash, msg.Empty[:])
|
||||||
@ -485,10 +487,10 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
return lookup.peer
|
return lookup.peer
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Derives a new key-pair from the current handshake state
|
/* Derives a new keypair from the current handshake state
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) NewKeypair() *Keypair {
|
func (peer *Peer) DeriveNewKeypair() error {
|
||||||
device := peer.device
|
device := peer.device
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
@ -517,12 +519,13 @@ func (peer *Peer) NewKeypair() *Keypair {
|
|||||||
)
|
)
|
||||||
isInitiator = false
|
isInitiator = false
|
||||||
} else {
|
} else {
|
||||||
return nil
|
return errors.New("invalid state for keypair derivation")
|
||||||
}
|
}
|
||||||
|
|
||||||
// zero handshake
|
// zero handshake
|
||||||
|
|
||||||
setZero(handshake.chainKey[:])
|
setZero(handshake.chainKey[:])
|
||||||
|
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
||||||
setZero(handshake.localEphemeral[:])
|
setZero(handshake.localEphemeral[:])
|
||||||
peer.handshake.state = HandshakeZeroed
|
peer.handshake.state = HandshakeZeroed
|
||||||
|
|
||||||
@ -576,5 +579,23 @@ func (peer *Peer) NewKeypair() *Keypair {
|
|||||||
}
|
}
|
||||||
kp.mutex.Unlock()
|
kp.mutex.Unlock()
|
||||||
|
|
||||||
return keypair
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||||
|
kp := &peer.keypairs
|
||||||
|
if kp.next != receivedKeypair {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
kp.mutex.Lock()
|
||||||
|
defer kp.mutex.Unlock()
|
||||||
|
if kp.next != receivedKeypair {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
old := kp.previous
|
||||||
|
kp.previous = kp.current
|
||||||
|
peer.device.DeleteKeypair(old)
|
||||||
|
kp.current = kp.next
|
||||||
|
kp.next = nil
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
@ -102,15 +102,15 @@ func TestNoiseHandshake(t *testing.T) {
|
|||||||
|
|
||||||
t.Log("deriving keys")
|
t.Log("deriving keys")
|
||||||
|
|
||||||
key1 := peer1.NewKeypair()
|
key1 := peer1.DeriveNewKeypair()
|
||||||
key2 := peer2.NewKeypair()
|
key2 := peer2.DeriveNewKeypair()
|
||||||
|
|
||||||
if key1 == nil {
|
if key1 == nil {
|
||||||
t.Fatal("failed to dervice key-pair for peer 1")
|
t.Fatal("failed to dervice keypair for peer 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
if key2 == nil {
|
if key2 == nil {
|
||||||
t.Fatal("failed to dervice key-pair for peer 2")
|
t.Fatal("failed to dervice keypair for peer 2")
|
||||||
}
|
}
|
||||||
|
|
||||||
// encrypting / decryption test
|
// encrypting / decryption test
|
||||||
|
25
receive.go
25
receive.go
@ -189,7 +189,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// check key-pair expiry
|
// check keypair expiry
|
||||||
|
|
||||||
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||||
continue
|
continue
|
||||||
@ -475,7 +475,7 @@ func (device *Device) RoutineHandshake() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.NewKeypair() == nil {
|
if peer.DeriveNewKeypair() != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -532,9 +532,9 @@ func (device *Device) RoutineHandshake() {
|
|||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
peer.timersAnyAuthenticatedPacketReceived()
|
peer.timersAnyAuthenticatedPacketReceived()
|
||||||
|
|
||||||
// derive key-pair
|
// derive keypair
|
||||||
|
|
||||||
if peer.NewKeypair() == nil {
|
if peer.DeriveNewKeypair() != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -597,27 +597,14 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
peer.endpoint = elem.endpoint
|
peer.endpoint = elem.endpoint
|
||||||
peer.mutex.Unlock()
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
// check if using new key-pair
|
// check if using new keypair
|
||||||
|
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||||
kp := &peer.keypairs
|
|
||||||
if kp.next == elem.keypair {
|
|
||||||
kp.mutex.Lock()
|
|
||||||
if kp.next != elem.keypair {
|
|
||||||
kp.mutex.Unlock()
|
|
||||||
} else {
|
|
||||||
old := kp.previous
|
|
||||||
kp.previous = kp.current
|
|
||||||
device.DeleteKeypair(old)
|
|
||||||
kp.current = kp.next
|
|
||||||
kp.next = nil
|
|
||||||
kp.mutex.Unlock()
|
|
||||||
peer.timersHandshakeComplete()
|
peer.timersHandshakeComplete()
|
||||||
select {
|
select {
|
||||||
case peer.signals.newKeypairArrived <- struct{}{}:
|
case peer.signals.newKeypairArrived <- struct{}{}:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
peer.keepKeyFreshReceiving()
|
peer.keepKeyFreshReceiving()
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
|
6
send.go
6
send.go
@ -47,7 +47,7 @@ type QueueOutboundElement struct {
|
|||||||
buffer *[MaxMessageSize]byte // slice holding the packet data
|
buffer *[MaxMessageSize]byte // slice holding the packet data
|
||||||
packet []byte // slice of "buffer" (always!)
|
packet []byte // slice of "buffer" (always!)
|
||||||
nonce uint64 // nonce for encryption
|
nonce uint64 // nonce for encryption
|
||||||
keypair *Keypair // key-pair for encryption
|
keypair *Keypair // keypair for encryption
|
||||||
peer *Peer // related peer
|
peer *Peer // related peer
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -306,11 +306,11 @@ func (peer *Peer) RoutineNonce() {
|
|||||||
|
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
|
|
||||||
logDebug.Println(peer, ": Awaiting key-pair")
|
logDebug.Println(peer, ": Awaiting keypair")
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-peer.signals.newKeypairArrived:
|
case <-peer.signals.newKeypairArrived:
|
||||||
logDebug.Println(peer, ": Obtained awaited key-pair")
|
logDebug.Println(peer, ": Obtained awaited keypair")
|
||||||
case <-peer.signals.flushNonceQueue:
|
case <-peer.signals.flushNonceQueue:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
Loading…
Reference in New Issue
Block a user