diff --git a/device/device.go b/device/device.go index ab5e4b0..569c5a8 100644 --- a/device/device.go +++ b/device/device.go @@ -201,7 +201,6 @@ func (device *Device) IsUnderLoad() bool { } func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { - // lock required resources device.staticIdentity.Lock() @@ -214,9 +213,10 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { device.peers.Lock() defer device.peers.Unlock() + lockedPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { peer.handshake.mutex.RLock() - defer peer.handshake.mutex.RUnlock() + lockedPeers = append(lockedPeers, peer) } // remove peers with matching public keys @@ -238,8 +238,8 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { rmKey := device.staticIdentity.privateKey.IsZero() + expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) for key, peer := range device.peers.keyMap { - handshake := &peer.handshake if rmKey { @@ -251,10 +251,17 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { if isZero(handshake.precomputedStaticStatic[:]) { unsafeRemovePeer(device, peer, key) } else { - peer.ExpireCurrentKeypairs() + expiredPeers = append(expiredPeers, peer) } } + for _, peer := range lockedPeers { + peer.handshake.mutex.RUnlock() + } + for _, peer := range expiredPeers { + peer.ExpireCurrentKeypairs() + } + return nil }