mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
Added replay protection
This commit is contained in:
parent
4ad62aaa6a
commit
44c9896883
@ -7,13 +7,14 @@ import (
|
||||
)
|
||||
|
||||
type KeyPair struct {
|
||||
receive cipher.AEAD
|
||||
send cipher.AEAD
|
||||
sendNonce uint64
|
||||
isInitiator bool
|
||||
created time.Time
|
||||
localIndex uint32
|
||||
remoteIndex uint32
|
||||
receive cipher.AEAD
|
||||
replayFilter ReplayFilter
|
||||
send cipher.AEAD
|
||||
sendNonce uint64
|
||||
isInitiator bool
|
||||
created time.Time
|
||||
localIndex uint32
|
||||
remoteIndex uint32
|
||||
}
|
||||
|
||||
type KeyPairs struct {
|
||||
|
@ -19,6 +19,13 @@ func min(a uint, b uint) uint {
|
||||
return a
|
||||
}
|
||||
|
||||
func minUint64(a uint64, b uint64) uint64 {
|
||||
if a > b {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func signalSend(c chan struct{}) {
|
||||
select {
|
||||
case c <- struct{}{}:
|
||||
|
@ -415,6 +415,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
return lookup.peer
|
||||
}
|
||||
|
||||
/* Derives a new key-pair from the current handshake state
|
||||
*
|
||||
*/
|
||||
func (peer *Peer) NewKeyPair() *KeyPair {
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
@ -445,10 +448,11 @@ func (peer *Peer) NewKeyPair() *KeyPair {
|
||||
// create AEAD instances
|
||||
|
||||
keyPair := new(KeyPair)
|
||||
keyPair.created = time.Now()
|
||||
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
|
||||
keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
|
||||
keyPair.sendNonce = 0
|
||||
keyPair.created = time.Now()
|
||||
keyPair.replayFilter.Init()
|
||||
keyPair.isInitiator = isInitiator
|
||||
keyPair.localIndex = peer.handshake.localIndex
|
||||
keyPair.remoteIndex = peer.handshake.remoteIndex
|
||||
@ -462,8 +466,6 @@ func (peer *Peer) NewKeyPair() *KeyPair {
|
||||
})
|
||||
handshake.localIndex = 0
|
||||
|
||||
// TODO: start timer for keypair (clearing)
|
||||
|
||||
// rotate key pairs
|
||||
|
||||
kp := &peer.keyPairs
|
||||
|
@ -432,6 +432,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
|
||||
// check for replay
|
||||
|
||||
if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
|
||||
return
|
||||
}
|
||||
|
||||
// time (passive) keep-alive
|
||||
|
||||
peer.TimerStartKeepalive()
|
||||
|
71
src/replay.go
Normal file
71
src/replay.go
Normal file
@ -0,0 +1,71 @@
|
||||
package main
|
||||
|
||||
/* Implementation of RFC6479
|
||||
* https://tools.ietf.org/html/rfc6479
|
||||
*
|
||||
* The implementation is not safe for concurrent use!
|
||||
*/
|
||||
|
||||
const (
|
||||
// See: https://golang.org/src/math/big/arith.go
|
||||
_Wordm = ^uintptr(0)
|
||||
_WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
|
||||
_WordSize = 1 << _WordLogSize
|
||||
)
|
||||
|
||||
const (
|
||||
CounterRedundantBitsLog = _WordLogSize + 3
|
||||
CounterRedundantBits = _WordSize * 8
|
||||
CounterBitsTotal = 2048
|
||||
CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
|
||||
)
|
||||
|
||||
const (
|
||||
BacktrackWords = CounterBitsTotal / _WordSize
|
||||
)
|
||||
|
||||
type ReplayFilter struct {
|
||||
counter uint64
|
||||
backtrack [BacktrackWords]uintptr
|
||||
}
|
||||
|
||||
func (filter *ReplayFilter) Init() {
|
||||
filter.counter = 0
|
||||
filter.backtrack[0] = 0
|
||||
}
|
||||
|
||||
func (filter *ReplayFilter) ValidateCounter(counter uint64) bool {
|
||||
if counter >= RejectAfterMessages {
|
||||
return false
|
||||
}
|
||||
|
||||
indexWord := counter >> CounterRedundantBitsLog
|
||||
|
||||
if counter > filter.counter {
|
||||
|
||||
// move window forward
|
||||
|
||||
current := filter.counter >> CounterRedundantBitsLog
|
||||
diff := minUint64(indexWord-current, BacktrackWords)
|
||||
for i := uint64(1); i <= diff; i++ {
|
||||
filter.backtrack[(current+i)%BacktrackWords] = 0
|
||||
}
|
||||
filter.counter = counter
|
||||
|
||||
} else if filter.counter-counter > CounterWindowSize {
|
||||
|
||||
// behind current window
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
indexWord %= BacktrackWords
|
||||
indexBit := counter & uint64(CounterRedundantBits-1)
|
||||
|
||||
// check and set bit
|
||||
|
||||
oldValue := filter.backtrack[indexWord]
|
||||
newValue := oldValue | (1 << indexBit)
|
||||
filter.backtrack[indexWord] = newValue
|
||||
return oldValue != newValue
|
||||
}
|
114
src/replay_test.go
Normal file
114
src/replay_test.go
Normal file
@ -0,0 +1,114 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
/* Ported from the linux kernel implementation
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
|
||||
|
||||
func TestReplay(t *testing.T) {
|
||||
var filter ReplayFilter
|
||||
|
||||
T_LIM := CounterWindowSize + 1
|
||||
|
||||
testNumber := 0
|
||||
T := func(n uint64, v bool) {
|
||||
testNumber++
|
||||
if filter.ValidateCounter(n) != v {
|
||||
t.Fatal("Test", testNumber, "failed", n, v)
|
||||
}
|
||||
}
|
||||
|
||||
filter.Init()
|
||||
|
||||
/* 1 */ T(0, true)
|
||||
/* 2 */ T(1, true)
|
||||
/* 3 */ T(1, false)
|
||||
/* 4 */ T(9, true)
|
||||
/* 5 */ T(8, true)
|
||||
/* 6 */ T(7, true)
|
||||
/* 7 */ T(7, false)
|
||||
/* 8 */ T(T_LIM, true)
|
||||
/* 9 */ T(T_LIM-1, true)
|
||||
/* 10 */ T(T_LIM-1, false)
|
||||
/* 11 */ T(T_LIM-2, true)
|
||||
/* 12 */ T(2, true)
|
||||
/* 13 */ T(2, false)
|
||||
/* 14 */ T(T_LIM+16, true)
|
||||
/* 15 */ T(3, false)
|
||||
/* 16 */ T(T_LIM+16, false)
|
||||
/* 17 */ T(T_LIM*4, true)
|
||||
/* 18 */ T(T_LIM*4-(T_LIM-1), true)
|
||||
/* 19 */ T(10, false)
|
||||
/* 20 */ T(T_LIM*4-T_LIM, false)
|
||||
/* 21 */ T(T_LIM*4-(T_LIM+1), false)
|
||||
/* 22 */ T(T_LIM*4-(T_LIM-2), true)
|
||||
/* 23 */ T(T_LIM*4+1-T_LIM, false)
|
||||
/* 24 */ T(0, false)
|
||||
/* 25 */ T(RejectAfterMessages, false)
|
||||
/* 26 */ T(RejectAfterMessages-1, true)
|
||||
/* 27 */ T(RejectAfterMessages, false)
|
||||
/* 28 */ T(RejectAfterMessages-1, false)
|
||||
/* 29 */ T(RejectAfterMessages-2, true)
|
||||
/* 30 */ T(RejectAfterMessages+1, false)
|
||||
/* 31 */ T(RejectAfterMessages+2, false)
|
||||
/* 32 */ T(RejectAfterMessages-2, false)
|
||||
/* 33 */ T(RejectAfterMessages-3, true)
|
||||
/* 34 */ T(0, false)
|
||||
|
||||
t.Log("Bulk test 1")
|
||||
filter.Init()
|
||||
testNumber = 0
|
||||
for i := uint64(1); i <= CounterWindowSize; i++ {
|
||||
T(i, true)
|
||||
}
|
||||
T(0, true)
|
||||
T(0, false)
|
||||
|
||||
t.Log("Bulk test 2")
|
||||
filter.Init()
|
||||
testNumber = 0
|
||||
for i := uint64(2); i <= CounterWindowSize+1; i++ {
|
||||
T(i, true)
|
||||
}
|
||||
T(1, true)
|
||||
T(0, false)
|
||||
|
||||
t.Log("Bulk test 3")
|
||||
filter.Init()
|
||||
testNumber = 0
|
||||
for i := CounterWindowSize + 1; i > 0; i-- {
|
||||
T(i, true)
|
||||
}
|
||||
|
||||
t.Log("Bulk test 4")
|
||||
filter.Init()
|
||||
testNumber = 0
|
||||
for i := CounterWindowSize + 2; i > 1; i-- {
|
||||
T(i, true)
|
||||
}
|
||||
T(0, false)
|
||||
|
||||
t.Log("Bulk test 5")
|
||||
filter.Init()
|
||||
testNumber = 0
|
||||
for i := CounterWindowSize; i > 0; i-- {
|
||||
T(i, true)
|
||||
}
|
||||
T(CounterWindowSize+1, true)
|
||||
T(0, false)
|
||||
|
||||
t.Log("Bulk test 6")
|
||||
filter.Init()
|
||||
testNumber = 0
|
||||
for i := CounterWindowSize; i > 0; i-- {
|
||||
T(i, true)
|
||||
}
|
||||
T(0, true)
|
||||
T(CounterWindowSize+1, true)
|
||||
}
|
@ -12,22 +12,15 @@ import (
|
||||
*
|
||||
*/
|
||||
func (peer *Peer) KeepKeyFreshSending() {
|
||||
send := func() bool {
|
||||
peer.keyPairs.mutex.RLock()
|
||||
defer peer.keyPairs.mutex.RUnlock()
|
||||
|
||||
kp := peer.keyPairs.current
|
||||
if kp == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !kp.isInitiator {
|
||||
return false
|
||||
}
|
||||
|
||||
nonce := atomic.LoadUint64(&kp.sendNonce)
|
||||
return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
|
||||
}()
|
||||
kp := peer.keyPairs.Current()
|
||||
if kp == nil {
|
||||
return
|
||||
}
|
||||
if !kp.isInitiator {
|
||||
return
|
||||
}
|
||||
nonce := atomic.LoadUint64(&kp.sendNonce)
|
||||
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
|
||||
if send {
|
||||
signalSend(peer.signal.handshakeBegin)
|
||||
}
|
||||
@ -37,22 +30,15 @@ func (peer *Peer) KeepKeyFreshSending() {
|
||||
*
|
||||
*/
|
||||
func (peer *Peer) KeepKeyFreshReceiving() {
|
||||
send := func() bool {
|
||||
peer.keyPairs.mutex.RLock()
|
||||
defer peer.keyPairs.mutex.RUnlock()
|
||||
|
||||
kp := peer.keyPairs.current
|
||||
if kp == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !kp.isInitiator {
|
||||
return false
|
||||
}
|
||||
|
||||
nonce := atomic.LoadUint64(&kp.sendNonce)
|
||||
return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
|
||||
}()
|
||||
kp := peer.keyPairs.Current()
|
||||
if kp == nil {
|
||||
return
|
||||
}
|
||||
if !kp.isInitiator {
|
||||
return
|
||||
}
|
||||
nonce := atomic.LoadUint64(&kp.sendNonce)
|
||||
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
|
||||
if send {
|
||||
signalSend(peer.signal.handshakeBegin)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user