1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

device: remove recursion from insertion and connect parent pointers

This makes the insertion algorithm a bit more efficient, while also now
taking on the additional task of connecting up parent pointers. This
will be handy in the following commit.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-06-03 14:50:28 +02:00
parent 4a57024b94
commit b41f4cc768
3 changed files with 94 additions and 58 deletions

View File

@ -14,9 +14,15 @@ import (
"unsafe" "unsafe"
) )
type parentIndirection struct {
parentBit **trieEntry
parentBitType uint8
}
type trieEntry struct { type trieEntry struct {
peer *Peer peer *Peer
child [2]*trieEntry child [2]*trieEntry
parent parentIndirection
cidr uint8 cidr uint8
bitAtByte uint8 bitAtByte uint8
bitAtShift uint8 bitAtShift uint8
@ -114,43 +120,45 @@ func (node *trieEntry) maskSelf() {
} }
} }
func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry { func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
exact = true
return
}
bit := node.choose(ip)
node = node.child[bit]
}
return
}
// at leaf func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
if node == nil {
node := &trieEntry{ node := &trieEntry{
bits: ip,
peer: peer, peer: peer,
parent: trie,
bits: ip,
cidr: cidr, cidr: cidr,
bitAtByte: cidr / 8, bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8), bitAtShift: 7 - (cidr % 8),
} }
node.maskSelf() node.maskSelf()
node.addToPeerEntries() node.addToPeerEntries()
return node *trie.parentBit = node
return
} }
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
// traverse deeper if exact {
common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr {
if node.cidr == cidr {
node.removeFromPeerEntries() node.removeFromPeerEntries()
node.peer = peer node.peer = peer
node.addToPeerEntries() node.addToPeerEntries()
return node return
} }
bit := node.choose(ip)
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
}
// split node
newNode := &trieEntry{ newNode := &trieEntry{
bits: ip,
peer: peer, peer: peer,
bits: ip,
cidr: cidr, cidr: cidr,
bitAtByte: cidr / 8, bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8), bitAtShift: 7 - (cidr % 8),
@ -158,34 +166,61 @@ func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
newNode.maskSelf() newNode.maskSelf()
newNode.addToPeerEntries() newNode.addToPeerEntries()
var down *trieEntry
if node == nil {
down = *trie.parentBit
} else {
bit := node.choose(ip)
down = node.child[bit]
if down == nil {
newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
return
}
}
common := commonBits(down.bits, ip)
if common < cidr { if common < cidr {
cidr = common cidr = common
} }
parent := node
// check for shorter prefix
if newNode.cidr == cidr { if newNode.cidr == cidr {
bit := newNode.choose(node.bits) bit := newNode.choose(down.bits)
newNode.child[bit] = node down.parent = parentIndirection{&newNode.child[bit], bit}
return newNode newNode.child[bit] = down
if parent == nil {
newNode.parent = trie
*trie.parentBit = newNode
} else {
bit := parent.choose(newNode.bits)
newNode.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = newNode
}
return
} }
// create new parent for node & newNode node = &trieEntry{
bits: append([]byte{}, newNode.bits...),
parent := &trieEntry{
bits: append([]byte{}, ip...),
peer: nil,
cidr: cidr, cidr: cidr,
bitAtByte: cidr / 8, bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8), bitAtShift: 7 - (cidr % 8),
} }
parent.maskSelf() node.maskSelf()
bit := parent.choose(ip) bit := node.choose(down.bits)
parent.child[bit] = newNode down.parent = parentIndirection{&node.child[bit], bit}
parent.child[bit^1] = node node.child[bit] = down
bit = node.choose(newNode.bits)
return parent newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
if parent == nil {
node.parent = trie
*trie.parentBit = node
} else {
bit := parent.choose(node.bits)
node.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = node
}
} }
func (node *trieEntry) lookup(ip net.IP) *Peer { func (node *trieEntry) lookup(ip net.IP) *Peer {
@ -236,9 +271,9 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
switch len(ip) { switch len(ip) {
case net.IPv6len: case net.IPv6len:
table.IPv6 = table.IPv6.insert(ip, cidr, peer) parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
case net.IPv4len: case net.IPv4len:
table.IPv4 = table.IPv4.insert(ip, cidr, peer) parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
default: default:
panic(errors.New("inserting unknown address type")) panic(errors.New("inserting unknown address type"))
} }

View File

@ -65,9 +65,9 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
} }
func TestTrieRandomIPv4(t *testing.T) { func TestTrieRandomIPv4(t *testing.T) {
var trie *trieEntry
var slow SlowRouter var slow SlowRouter
var peers []*Peer var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1) rand.Seed(1)
@ -82,7 +82,7 @@ func TestTrieRandomIPv4(t *testing.T) {
rand.Read(addr[:]) rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index]) allowedIPs.Insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index]) slow = slow.Insert(addr[:], cidr, peers[index])
} }
@ -90,7 +90,7 @@ func TestTrieRandomIPv4(t *testing.T) {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
peer1 := slow.Lookup(addr[:]) peer1 := slow.Lookup(addr[:])
peer2 := trie.lookup(addr[:]) peer2 := allowedIPs.LookupIPv4(addr[:])
if peer1 != peer2 { if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr) t.Error("Trie did not match naive implementation, for:", addr)
} }
@ -98,9 +98,9 @@ func TestTrieRandomIPv4(t *testing.T) {
} }
func TestTrieRandomIPv6(t *testing.T) { func TestTrieRandomIPv6(t *testing.T) {
var trie *trieEntry
var slow SlowRouter var slow SlowRouter
var peers []*Peer var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1) rand.Seed(1)
@ -115,7 +115,7 @@ func TestTrieRandomIPv6(t *testing.T) {
rand.Read(addr[:]) rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index]) allowedIPs.Insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index]) slow = slow.Insert(addr[:], cidr, peers[index])
} }
@ -123,7 +123,7 @@ func TestTrieRandomIPv6(t *testing.T) {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
peer1 := slow.Lookup(addr[:]) peer1 := slow.Lookup(addr[:])
peer2 := trie.lookup(addr[:]) peer2 := allowedIPs.LookupIPv6(addr[:])
if peer1 != peer2 { if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr) t.Error("Trie did not match naive implementation, for:", addr)
} }

View File

@ -42,6 +42,7 @@ func TestCommonBits(t *testing.T) {
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
var trie *trieEntry var trie *trieEntry
var peers []*Peer var peers []*Peer
root := parentIndirection{&trie, 2}
rand.Seed(1) rand.Seed(1)
@ -56,7 +57,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
rand.Read(addr[:]) rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber index := rand.Int() % peerNumber
trie = trie.insert(addr[:], cidr, peers[index]) root.insert(addr[:], cidr, peers[index])
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
@ -94,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var trie *trieEntry var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
trie = trie.insert([]byte{a, b, c, d}, cidr, peer) allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
} }
assertEQ := func(peer *Peer, a, b, c, d byte) { assertEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d}) p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }
} }
assertNEQ := func(peer *Peer, a, b, c, d byte) { assertNEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d}) p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
if p == peer { if p == peer {
t.Error("Assert NEQ failed") t.Error("Assert NEQ failed")
} }
@ -150,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0) assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0) assertEQ(a, 255, 0, 0, 0)
trie = trie.removeByPeer(a) allowedIPs.RemoveByPeer(a)
assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0)
@ -158,12 +159,12 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0)
trie = nil allowedIPs = AllowedIPs{}
insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24) insert(a, 192, 168, 0, 0, 24)
trie = trie.removeByPeer(a) allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1) assertNEQ(a, 192, 168, 0, 1)
} }
@ -181,7 +182,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var trie *trieEntry var allowedIPs AllowedIPs
expand := func(a uint32) []byte { expand := func(a uint32) []byte {
var out [4]byte var out [4]byte
@ -198,7 +199,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
trie = trie.insert(addr, cidr, peer) allowedIPs.Insert(addr, cidr, peer)
} }
assertEQ := func(peer *Peer, a, b, c, d uint32) { assertEQ := func(peer *Peer, a, b, c, d uint32) {
@ -207,7 +208,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
p := trie.lookup(addr) p := allowedIPs.LookupIPv6(addr)
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }