diff --git a/src/keypair.go b/src/keypair.go index 644d040..7e5297b 100644 --- a/src/keypair.go +++ b/src/keypair.go @@ -2,38 +2,20 @@ package main import ( "crypto/cipher" - "golang.org/x/crypto/chacha20poly1305" - "reflect" "sync" "time" ) -type safeAEAD struct { - mutex sync.RWMutex - aead cipher.AEAD -} - -func (con *safeAEAD) clear() { - // TODO: improve handling of key material - con.mutex.Lock() - if con.aead != nil { - val := reflect.ValueOf(con.aead) - elm := val.Elem() - typ := elm.Type() - elm.Set(reflect.Zero(typ)) - con.aead = nil - } - con.mutex.Unlock() -} - -func (con *safeAEAD) setKey(key *[chacha20poly1305.KeySize]byte) { - // TODO: improve handling of key material - con.aead, _ = chacha20poly1305.New(key[:]) -} +/* Due to limitations in Go and /x/crypto there is currently + * no way to ensure that key material is securely ereased in memory. + * + * Since this may harm the forward secrecy property, + * we plan to resolve this issue; whenever Go allows us to do so. + */ type KeyPair struct { - send safeAEAD - receive safeAEAD + send cipher.AEAD + receive cipher.AEAD replayFilter ReplayFilter sendNonce uint64 isInitiator bool @@ -56,7 +38,5 @@ func (kp *KeyPairs) Current() *KeyPair { } func (device *Device) DeleteKeyPair(key *KeyPair) { - key.send.clear() - key.receive.clear() device.indices.Delete(key.localIndex) } diff --git a/src/noise_protocol.go b/src/noise_protocol.go index a50e3dc..9e5fdd8 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -502,8 +502,8 @@ func (peer *Peer) NewKeyPair() *KeyPair { // create AEAD instances keyPair := new(KeyPair) - keyPair.send.setKey(&sendKey) - keyPair.receive.setKey(&recvKey) + keyPair.send, _ = chacha20poly1305.New(sendKey[:]) + keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) setZero(sendKey[:]) setZero(recvKey[:]) @@ -530,30 +530,29 @@ func (peer *Peer) NewKeyPair() *KeyPair { // rotate key pairs kp := &peer.keyPairs - func() { - kp.mutex.Lock() - defer kp.mutex.Unlock() - // TODO: Adapt kernel behavior noise.c:161 - if isInitiator { - if kp.previous != nil { - device.DeleteKeyPair(kp.previous) - kp.previous = nil - } + kp.mutex.Lock() - if kp.next != nil { - kp.previous = kp.next - kp.next = keyPair - } else { - kp.previous = kp.current - kp.current = keyPair - signalSend(peer.signal.newKeyPair) // TODO: This more places (after confirming the key) - } - - } else { - kp.next = keyPair - kp.previous = nil // TODO: Discuss why + // TODO: Adapt kernel behavior noise.c:161 + if isInitiator { + if kp.previous != nil { + device.DeleteKeyPair(kp.previous) + kp.previous = nil } - }() + + if kp.next != nil { + kp.previous = kp.next + kp.next = keyPair + } else { + kp.previous = kp.current + kp.current = keyPair + signalSend(peer.signal.newKeyPair) // TODO: This more places (after confirming the key) + } + + } else { + kp.next = keyPair + kp.previous = nil + } + kp.mutex.Unlock() return keyPair } diff --git a/src/peer.go b/src/peer.go index a4feb2f..6fea829 100644 --- a/src/peer.go +++ b/src/peer.go @@ -39,6 +39,8 @@ type Peer struct { stop chan struct{} // (size 0) : close to stop all goroutines for peer } timer struct { + // state related to WireGuard timers + keepalivePersistent *time.Timer // set for persistent keepalives keepalivePassive *time.Timer // set upon recieving messages newHandshake *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout) @@ -49,7 +51,8 @@ type Peer struct { pendingNewHandshake bool pendingZeroAllKeys bool - needAnotherKeepalive bool + needAnotherKeepalive bool + sendLastMinuteHandshake bool } queue struct { nonce chan *QueueOutboundElement // nonce / pre-handshake queue diff --git a/src/receive.go b/src/receive.go index 09fca77..52c2718 100644 --- a/src/receive.go +++ b/src/receive.go @@ -247,28 +247,20 @@ func (device *Device) RoutineDecryption() { counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] content := elem.packet[MessageTransportOffsetContent:] - // decrypt with key-pair + // decrypt and release to consumer + var err error copy(nonce[4:], counter) elem.counter = binary.LittleEndian.Uint64(counter) - elem.keyPair.receive.mutex.RLock() - if elem.keyPair.receive.aead == nil { - // very unlikely (the key was deleted during queuing) + elem.packet, err = elem.keyPair.receive.Open( + elem.buffer[:0], + nonce[:], + content, + nil, + ) + if err != nil { elem.Drop() - } else { - var err error - elem.packet, err = elem.keyPair.receive.aead.Open( - elem.buffer[:0], - nonce[:], - content, - nil, - ) - if err != nil { - elem.Drop() - } } - - elem.keyPair.receive.mutex.RUnlock() elem.mutex.Unlock() } } @@ -433,8 +425,6 @@ func (device *Device) RoutineHandshake() { case MessageResponseType: - logDebug.Println("Process response") - // unmarshal var msg MessageResponse @@ -457,6 +447,8 @@ func (device *Device) RoutineHandshake() { continue } + logDebug.Println("Received handshake initation from", peer) + peer.TimerEphemeralKeyCreated() // update timers diff --git a/src/send.go b/src/send.go index e9dfb54..5c88ead 100644 --- a/src/send.go +++ b/src/send.go @@ -303,27 +303,16 @@ func (device *Device) RoutineEncryption() { } } - // encrypt content (append to header) + // encrypt content and release to consumer binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.keyPair.send.mutex.RLock() - if elem.keyPair.send.aead == nil { - // very unlikely (the key was deleted during queuing) - elem.Drop() - } else { - elem.packet = elem.keyPair.send.aead.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - } + elem.packet = elem.keyPair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) elem.mutex.Unlock() - elem.keyPair.send.mutex.RUnlock() - - // refresh key if necessary - - elem.peer.KeepKeyFreshSending() } } } diff --git a/src/tests/netns.sh b/src/tests/netns.sh index 9c10d36..043da3e 100755 --- a/src/tests/netns.sh +++ b/src/tests/netns.sh @@ -28,7 +28,7 @@ netns0="wg-test-$$-0" netns1="wg-test-$$-1" netns2="wg-test-$$-2" program="../wireguard-go" -export LOG_LEVEL="debug" +export LOG_LEVEL="error" pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } pp() { pretty "" "$*"; "$@"; } diff --git a/src/timers.go b/src/timers.go index ad8866f..99695ba 100644 --- a/src/timers.go +++ b/src/timers.go @@ -27,9 +27,12 @@ func (peer *Peer) KeepKeyFreshSending() { /* Called when a new authenticated message has been recevied * + * NOTE: Not thread safe (called by sequential receiver) */ func (peer *Peer) KeepKeyFreshReceiving() { - // TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete) + if peer.timer.sendLastMinuteHandshake { + return + } kp := peer.keyPairs.Current() if kp == nil { return @@ -40,7 +43,9 @@ func (peer *Peer) KeepKeyFreshReceiving() { nonce := atomic.LoadUint64(&kp.sendNonce) send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving if send { + // do a last minute attempt at initiating a new handshake signalSend(peer.signal.handshakeBegin) + peer.timer.sendLastMinuteHandshake = true } } @@ -311,6 +316,7 @@ func (peer *Peer) RoutineHandshakeInitiator() { case <-peer.signal.handshakeCompleted: <-timeout.C + peer.timer.sendLastMinuteHandshake = false break AttemptHandshakes case <-peer.signal.handshakeReset: