diff --git a/src/keypair.go b/src/keypair.go index b24dbe4..b5f46df 100644 --- a/src/keypair.go +++ b/src/keypair.go @@ -7,13 +7,14 @@ import ( ) type KeyPair struct { - receive cipher.AEAD - send cipher.AEAD - sendNonce uint64 - isInitiator bool - created time.Time - localIndex uint32 - remoteIndex uint32 + receive cipher.AEAD + replayFilter ReplayFilter + send cipher.AEAD + sendNonce uint64 + isInitiator bool + created time.Time + localIndex uint32 + remoteIndex uint32 } type KeyPairs struct { diff --git a/src/misc.go b/src/misc.go index 75561b2..fc75c0d 100644 --- a/src/misc.go +++ b/src/misc.go @@ -19,6 +19,13 @@ func min(a uint, b uint) uint { return a } +func minUint64(a uint64, b uint64) uint64 { + if a > b { + return b + } + return a +} + func signalSend(c chan struct{}) { select { case c <- struct{}{}: diff --git a/src/noise_protocol.go b/src/noise_protocol.go index a90fe4c..bfa3797 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -415,6 +415,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { return lookup.peer } +/* Derives a new key-pair from the current handshake state + * + */ func (peer *Peer) NewKeyPair() *KeyPair { handshake := &peer.handshake handshake.mutex.Lock() @@ -445,10 +448,11 @@ func (peer *Peer) NewKeyPair() *KeyPair { // create AEAD instances keyPair := new(KeyPair) + keyPair.created = time.Now() keyPair.send, _ = chacha20poly1305.New(sendKey[:]) keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) keyPair.sendNonce = 0 - keyPair.created = time.Now() + keyPair.replayFilter.Init() keyPair.isInitiator = isInitiator keyPair.localIndex = peer.handshake.localIndex keyPair.remoteIndex = peer.handshake.remoteIndex @@ -462,8 +466,6 @@ func (peer *Peer) NewKeyPair() *KeyPair { }) handshake.localIndex = 0 - // TODO: start timer for keypair (clearing) - // rotate key pairs kp := &peer.keyPairs diff --git a/src/receive.go b/src/receive.go index e780c66..6530c47 100644 --- a/src/receive.go +++ b/src/receive.go @@ -432,6 +432,10 @@ func (peer *Peer) RoutineSequentialReceiver() { // check for replay + if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { + return + } + // time (passive) keep-alive peer.TimerStartKeepalive() diff --git a/src/replay.go b/src/replay.go new file mode 100644 index 0000000..49c7e08 --- /dev/null +++ b/src/replay.go @@ -0,0 +1,71 @@ +package main + +/* Implementation of RFC6479 + * https://tools.ietf.org/html/rfc6479 + * + * The implementation is not safe for concurrent use! + */ + +const ( + // See: https://golang.org/src/math/big/arith.go + _Wordm = ^uintptr(0) + _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1 + _WordSize = 1 << _WordLogSize +) + +const ( + CounterRedundantBitsLog = _WordLogSize + 3 + CounterRedundantBits = _WordSize * 8 + CounterBitsTotal = 2048 + CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits) +) + +const ( + BacktrackWords = CounterBitsTotal / _WordSize +) + +type ReplayFilter struct { + counter uint64 + backtrack [BacktrackWords]uintptr +} + +func (filter *ReplayFilter) Init() { + filter.counter = 0 + filter.backtrack[0] = 0 +} + +func (filter *ReplayFilter) ValidateCounter(counter uint64) bool { + if counter >= RejectAfterMessages { + return false + } + + indexWord := counter >> CounterRedundantBitsLog + + if counter > filter.counter { + + // move window forward + + current := filter.counter >> CounterRedundantBitsLog + diff := minUint64(indexWord-current, BacktrackWords) + for i := uint64(1); i <= diff; i++ { + filter.backtrack[(current+i)%BacktrackWords] = 0 + } + filter.counter = counter + + } else if filter.counter-counter > CounterWindowSize { + + // behind current window + + return false + } + + indexWord %= BacktrackWords + indexBit := counter & uint64(CounterRedundantBits-1) + + // check and set bit + + oldValue := filter.backtrack[indexWord] + newValue := oldValue | (1 << indexBit) + filter.backtrack[indexWord] = newValue + return oldValue != newValue +} diff --git a/src/replay_test.go b/src/replay_test.go new file mode 100644 index 0000000..e75c5c1 --- /dev/null +++ b/src/replay_test.go @@ -0,0 +1,114 @@ +package main + +import ( + "testing" +) + +/* Ported from the linux kernel implementation + * + * + */ + +/* Copyright (C) 2015-2017 Jason A. Donenfeld . All Rights Reserved. */ + +func TestReplay(t *testing.T) { + var filter ReplayFilter + + T_LIM := CounterWindowSize + 1 + + testNumber := 0 + T := func(n uint64, v bool) { + testNumber++ + if filter.ValidateCounter(n) != v { + t.Fatal("Test", testNumber, "failed", n, v) + } + } + + filter.Init() + + /* 1 */ T(0, true) + /* 2 */ T(1, true) + /* 3 */ T(1, false) + /* 4 */ T(9, true) + /* 5 */ T(8, true) + /* 6 */ T(7, true) + /* 7 */ T(7, false) + /* 8 */ T(T_LIM, true) + /* 9 */ T(T_LIM-1, true) + /* 10 */ T(T_LIM-1, false) + /* 11 */ T(T_LIM-2, true) + /* 12 */ T(2, true) + /* 13 */ T(2, false) + /* 14 */ T(T_LIM+16, true) + /* 15 */ T(3, false) + /* 16 */ T(T_LIM+16, false) + /* 17 */ T(T_LIM*4, true) + /* 18 */ T(T_LIM*4-(T_LIM-1), true) + /* 19 */ T(10, false) + /* 20 */ T(T_LIM*4-T_LIM, false) + /* 21 */ T(T_LIM*4-(T_LIM+1), false) + /* 22 */ T(T_LIM*4-(T_LIM-2), true) + /* 23 */ T(T_LIM*4+1-T_LIM, false) + /* 24 */ T(0, false) + /* 25 */ T(RejectAfterMessages, false) + /* 26 */ T(RejectAfterMessages-1, true) + /* 27 */ T(RejectAfterMessages, false) + /* 28 */ T(RejectAfterMessages-1, false) + /* 29 */ T(RejectAfterMessages-2, true) + /* 30 */ T(RejectAfterMessages+1, false) + /* 31 */ T(RejectAfterMessages+2, false) + /* 32 */ T(RejectAfterMessages-2, false) + /* 33 */ T(RejectAfterMessages-3, true) + /* 34 */ T(0, false) + + t.Log("Bulk test 1") + filter.Init() + testNumber = 0 + for i := uint64(1); i <= CounterWindowSize; i++ { + T(i, true) + } + T(0, true) + T(0, false) + + t.Log("Bulk test 2") + filter.Init() + testNumber = 0 + for i := uint64(2); i <= CounterWindowSize+1; i++ { + T(i, true) + } + T(1, true) + T(0, false) + + t.Log("Bulk test 3") + filter.Init() + testNumber = 0 + for i := CounterWindowSize + 1; i > 0; i-- { + T(i, true) + } + + t.Log("Bulk test 4") + filter.Init() + testNumber = 0 + for i := CounterWindowSize + 2; i > 1; i-- { + T(i, true) + } + T(0, false) + + t.Log("Bulk test 5") + filter.Init() + testNumber = 0 + for i := CounterWindowSize; i > 0; i-- { + T(i, true) + } + T(CounterWindowSize+1, true) + T(0, false) + + t.Log("Bulk test 6") + filter.Init() + testNumber = 0 + for i := CounterWindowSize; i > 0; i-- { + T(i, true) + } + T(0, true) + T(CounterWindowSize+1, true) +} diff --git a/src/timers.go b/src/timers.go index 26926c2..70e0766 100644 --- a/src/timers.go +++ b/src/timers.go @@ -12,22 +12,15 @@ import ( * */ func (peer *Peer) KeepKeyFreshSending() { - send := func() bool { - peer.keyPairs.mutex.RLock() - defer peer.keyPairs.mutex.RUnlock() - - kp := peer.keyPairs.current - if kp == nil { - return false - } - - if !kp.isInitiator { - return false - } - - nonce := atomic.LoadUint64(&kp.sendNonce) - return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime - }() + kp := peer.keyPairs.Current() + if kp == nil { + return + } + if !kp.isInitiator { + return + } + nonce := atomic.LoadUint64(&kp.sendNonce) + send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime if send { signalSend(peer.signal.handshakeBegin) } @@ -37,22 +30,15 @@ func (peer *Peer) KeepKeyFreshSending() { * */ func (peer *Peer) KeepKeyFreshReceiving() { - send := func() bool { - peer.keyPairs.mutex.RLock() - defer peer.keyPairs.mutex.RUnlock() - - kp := peer.keyPairs.current - if kp == nil { - return false - } - - if !kp.isInitiator { - return false - } - - nonce := atomic.LoadUint64(&kp.sendNonce) - return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving - }() + kp := peer.keyPairs.Current() + if kp == nil { + return + } + if !kp.isInitiator { + return + } + nonce := atomic.LoadUint64(&kp.sendNonce) + send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving if send { signalSend(peer.signal.handshakeBegin) }