1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

replay: clean up internals and better documentation

Signed-off-by: Riobard Zhan <me@riobard.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Riobard Zhan 2020-09-10 01:55:24 +08:00 committed by Jason A. Donenfeld
parent c8fe925020
commit 22af3890f6
2 changed files with 50 additions and 71 deletions

View File

@ -3,81 +3,60 @@
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/ */
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
package replay package replay
/* Implementation of RFC6479 type block uint64
* https://tools.ietf.org/html/rfc6479
*
* The implementation is not safe for concurrent use!
*/
const ( const (
// See: https://golang.org/src/math/big/arith.go blockBitLog = 6 // 1<<6 == 64 bits
_Wordm = ^uintptr(0) blockBits = 1 << blockBitLog // must be power of 2
_WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1 ringBlocks = 1 << 7 // must be power of 2
_WordSize = 1 << _WordLogSize windowSize = (ringBlocks - 1) * blockBits
blockMask = ringBlocks - 1
bitMask = blockBits - 1
) )
const ( // A ReplayFilter rejects replayed messages by checking if message counter value is
CounterRedundantBitsLog = _WordLogSize + 3 // within a sliding window of previously received messages.
CounterRedundantBits = _WordSize * 8 // The zero value for ReplayFilter is an empty filter ready to use.
CounterBitsTotal = 8192 // Filters are unsafe for concurrent use.
CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
)
const (
BacktrackWords = CounterBitsTotal / 8 / _WordSize
)
func minUint64(a uint64, b uint64) uint64 {
if a > b {
return b
}
return a
}
type ReplayFilter struct { type ReplayFilter struct {
counter uint64 last uint64
backtrack [BacktrackWords]uintptr ring [ringBlocks]block
} }
func (filter *ReplayFilter) Init() { // Init resets the filter to empty state.
filter.counter = 0 func (f *ReplayFilter) Init() {
filter.backtrack[0] = 0 f.last = 0
f.ring[0] = 0
} }
func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool { // ValidateCounter checks if the counter should be accepted.
// Overlimit counters (>= limit) are always rejected.
func (f *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
if counter >= limit { if counter >= limit {
return false return false
} }
indexBlock := counter >> blockBitLog
indexWord := counter >> CounterRedundantBitsLog if counter > f.last { // move window forward
current := f.last >> blockBitLog
if counter > filter.counter { diff := indexBlock - current
if diff > ringBlocks {
// move window forward diff = ringBlocks // cap diff to clear the whole ring
current := filter.counter >> CounterRedundantBitsLog
diff := minUint64(indexWord-current, BacktrackWords)
for i := uint64(1); i <= diff; i++ {
filter.backtrack[(current+i)%BacktrackWords] = 0
} }
filter.counter = counter for i := current + 1; i <= current+diff; i++ {
f.ring[i&blockMask] = 0
} else if filter.counter-counter > CounterWindowSize { }
f.last = counter
// behind current window } else if f.last-counter > windowSize { // behind current window
return false return false
} }
indexWord %= BacktrackWords
indexBit := counter & uint64(CounterRedundantBits-1)
// check and set bit // check and set bit
indexBlock &= blockMask
oldValue := filter.backtrack[indexWord] indexBit := counter & bitMask
newValue := oldValue | (1 << indexBit) old := f.ring[indexBlock]
filter.backtrack[indexWord] = newValue new := old | 1<<indexBit
return oldValue != newValue f.ring[indexBlock] = new
return old != new
} }

View File

@ -19,13 +19,13 @@ const RejectAfterMessages = (1 << 64) - (1 << 4) - 1
func TestReplay(t *testing.T) { func TestReplay(t *testing.T) {
var filter ReplayFilter var filter ReplayFilter
T_LIM := CounterWindowSize + 1 const T_LIM = windowSize + 1
testNumber := 0 testNumber := 0
T := func(n uint64, v bool) { T := func(n uint64, expected bool) {
testNumber++ testNumber++
if filter.ValidateCounter(n, RejectAfterMessages) != v { if filter.ValidateCounter(n, RejectAfterMessages) != expected {
t.Fatal("Test", testNumber, "failed", n, v) t.Fatal("Test", testNumber, "failed", n, expected)
} }
} }
@ -69,7 +69,7 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 1") t.Log("Bulk test 1")
filter.Init() filter.Init()
testNumber = 0 testNumber = 0
for i := uint64(1); i <= CounterWindowSize; i++ { for i := uint64(1); i <= windowSize; i++ {
T(i, true) T(i, true)
} }
T(0, true) T(0, true)
@ -78,7 +78,7 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 2") t.Log("Bulk test 2")
filter.Init() filter.Init()
testNumber = 0 testNumber = 0
for i := uint64(2); i <= CounterWindowSize+1; i++ { for i := uint64(2); i <= windowSize+1; i++ {
T(i, true) T(i, true)
} }
T(1, true) T(1, true)
@ -87,14 +87,14 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 3") t.Log("Bulk test 3")
filter.Init() filter.Init()
testNumber = 0 testNumber = 0
for i := CounterWindowSize + 1; i > 0; i-- { for i := uint64(windowSize + 1); i > 0; i-- {
T(i, true) T(i, true)
} }
t.Log("Bulk test 4") t.Log("Bulk test 4")
filter.Init() filter.Init()
testNumber = 0 testNumber = 0
for i := CounterWindowSize + 2; i > 1; i-- { for i := uint64(windowSize + 2); i > 1; i-- {
T(i, true) T(i, true)
} }
T(0, false) T(0, false)
@ -102,18 +102,18 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 5") t.Log("Bulk test 5")
filter.Init() filter.Init()
testNumber = 0 testNumber = 0
for i := CounterWindowSize; i > 0; i-- { for i := uint64(windowSize); i > 0; i-- {
T(i, true) T(i, true)
} }
T(CounterWindowSize+1, true) T(windowSize+1, true)
T(0, false) T(0, false)
t.Log("Bulk test 6") t.Log("Bulk test 6")
filter.Init() filter.Init()
testNumber = 0 testNumber = 0
for i := CounterWindowSize; i > 0; i-- { for i := uint64(windowSize); i > 0; i-- {
T(i, true) T(i, true)
} }
T(0, true) T(0, true)
T(CounterWindowSize+1, true) T(windowSize+1, true)
} }