mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
device: remove recursion from insertion and connect parent pointers
This makes the insertion algorithm a bit more efficient, while also now taking on the additional task of connecting up parent pointers. This will be handy in the following commit. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
4a57024b94
commit
b41f4cc768
@ -14,9 +14,15 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type parentIndirection struct {
|
||||||
|
parentBit **trieEntry
|
||||||
|
parentBitType uint8
|
||||||
|
}
|
||||||
|
|
||||||
type trieEntry struct {
|
type trieEntry struct {
|
||||||
peer *Peer
|
peer *Peer
|
||||||
child [2]*trieEntry
|
child [2]*trieEntry
|
||||||
|
parent parentIndirection
|
||||||
cidr uint8
|
cidr uint8
|
||||||
bitAtByte uint8
|
bitAtByte uint8
|
||||||
bitAtShift uint8
|
bitAtShift uint8
|
||||||
@ -114,43 +120,45 @@ func (node *trieEntry) maskSelf() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
|
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
|
||||||
|
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
|
||||||
|
parent = node
|
||||||
|
if parent.cidr == cidr {
|
||||||
|
exact = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bit := node.choose(ip)
|
||||||
|
node = node.child[bit]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// at leaf
|
func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
|
||||||
|
if *trie.parentBit == nil {
|
||||||
if node == nil {
|
|
||||||
node := &trieEntry{
|
node := &trieEntry{
|
||||||
bits: ip,
|
|
||||||
peer: peer,
|
peer: peer,
|
||||||
|
parent: trie,
|
||||||
|
bits: ip,
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
bitAtByte: cidr / 8,
|
bitAtByte: cidr / 8,
|
||||||
bitAtShift: 7 - (cidr % 8),
|
bitAtShift: 7 - (cidr % 8),
|
||||||
}
|
}
|
||||||
node.maskSelf()
|
node.maskSelf()
|
||||||
node.addToPeerEntries()
|
node.addToPeerEntries()
|
||||||
return node
|
*trie.parentBit = node
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
|
||||||
// traverse deeper
|
if exact {
|
||||||
|
|
||||||
common := commonBits(node.bits, ip)
|
|
||||||
if node.cidr <= cidr && common >= node.cidr {
|
|
||||||
if node.cidr == cidr {
|
|
||||||
node.removeFromPeerEntries()
|
node.removeFromPeerEntries()
|
||||||
node.peer = peer
|
node.peer = peer
|
||||||
node.addToPeerEntries()
|
node.addToPeerEntries()
|
||||||
return node
|
return
|
||||||
}
|
}
|
||||||
bit := node.choose(ip)
|
|
||||||
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
|
|
||||||
return node
|
|
||||||
}
|
|
||||||
|
|
||||||
// split node
|
|
||||||
|
|
||||||
newNode := &trieEntry{
|
newNode := &trieEntry{
|
||||||
bits: ip,
|
|
||||||
peer: peer,
|
peer: peer,
|
||||||
|
bits: ip,
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
bitAtByte: cidr / 8,
|
bitAtByte: cidr / 8,
|
||||||
bitAtShift: 7 - (cidr % 8),
|
bitAtShift: 7 - (cidr % 8),
|
||||||
@ -158,34 +166,61 @@ func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
|
|||||||
newNode.maskSelf()
|
newNode.maskSelf()
|
||||||
newNode.addToPeerEntries()
|
newNode.addToPeerEntries()
|
||||||
|
|
||||||
|
var down *trieEntry
|
||||||
|
if node == nil {
|
||||||
|
down = *trie.parentBit
|
||||||
|
} else {
|
||||||
|
bit := node.choose(ip)
|
||||||
|
down = node.child[bit]
|
||||||
|
if down == nil {
|
||||||
|
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||||
|
node.child[bit] = newNode
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
common := commonBits(down.bits, ip)
|
||||||
if common < cidr {
|
if common < cidr {
|
||||||
cidr = common
|
cidr = common
|
||||||
}
|
}
|
||||||
|
parent := node
|
||||||
// check for shorter prefix
|
|
||||||
|
|
||||||
if newNode.cidr == cidr {
|
if newNode.cidr == cidr {
|
||||||
bit := newNode.choose(node.bits)
|
bit := newNode.choose(down.bits)
|
||||||
newNode.child[bit] = node
|
down.parent = parentIndirection{&newNode.child[bit], bit}
|
||||||
return newNode
|
newNode.child[bit] = down
|
||||||
|
if parent == nil {
|
||||||
|
newNode.parent = trie
|
||||||
|
*trie.parentBit = newNode
|
||||||
|
} else {
|
||||||
|
bit := parent.choose(newNode.bits)
|
||||||
|
newNode.parent = parentIndirection{&parent.child[bit], bit}
|
||||||
|
parent.child[bit] = newNode
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// create new parent for node & newNode
|
node = &trieEntry{
|
||||||
|
bits: append([]byte{}, newNode.bits...),
|
||||||
parent := &trieEntry{
|
|
||||||
bits: append([]byte{}, ip...),
|
|
||||||
peer: nil,
|
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
bitAtByte: cidr / 8,
|
bitAtByte: cidr / 8,
|
||||||
bitAtShift: 7 - (cidr % 8),
|
bitAtShift: 7 - (cidr % 8),
|
||||||
}
|
}
|
||||||
parent.maskSelf()
|
node.maskSelf()
|
||||||
|
|
||||||
bit := parent.choose(ip)
|
bit := node.choose(down.bits)
|
||||||
parent.child[bit] = newNode
|
down.parent = parentIndirection{&node.child[bit], bit}
|
||||||
parent.child[bit^1] = node
|
node.child[bit] = down
|
||||||
|
bit = node.choose(newNode.bits)
|
||||||
return parent
|
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||||
|
node.child[bit] = newNode
|
||||||
|
if parent == nil {
|
||||||
|
node.parent = trie
|
||||||
|
*trie.parentBit = node
|
||||||
|
} else {
|
||||||
|
bit := parent.choose(node.bits)
|
||||||
|
node.parent = parentIndirection{&parent.child[bit], bit}
|
||||||
|
parent.child[bit] = node
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
||||||
@ -236,9 +271,9 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
|
|||||||
|
|
||||||
switch len(ip) {
|
switch len(ip) {
|
||||||
case net.IPv6len:
|
case net.IPv6len:
|
||||||
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
|
parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
|
||||||
case net.IPv4len:
|
case net.IPv4len:
|
||||||
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
|
parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
|
||||||
default:
|
default:
|
||||||
panic(errors.New("inserting unknown address type"))
|
panic(errors.New("inserting unknown address type"))
|
||||||
}
|
}
|
||||||
|
@ -65,9 +65,9 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTrieRandomIPv4(t *testing.T) {
|
func TestTrieRandomIPv4(t *testing.T) {
|
||||||
var trie *trieEntry
|
|
||||||
var slow SlowRouter
|
var slow SlowRouter
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
|
var allowedIPs AllowedIPs
|
||||||
|
|
||||||
rand.Seed(1)
|
rand.Seed(1)
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ func TestTrieRandomIPv4(t *testing.T) {
|
|||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % NumberOfPeers
|
index := rand.Int() % NumberOfPeers
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
allowedIPs.Insert(addr[:], cidr, peers[index])
|
||||||
slow = slow.Insert(addr[:], cidr, peers[index])
|
slow = slow.Insert(addr[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ func TestTrieRandomIPv4(t *testing.T) {
|
|||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
peer1 := slow.Lookup(addr[:])
|
peer1 := slow.Lookup(addr[:])
|
||||||
peer2 := trie.lookup(addr[:])
|
peer2 := allowedIPs.LookupIPv4(addr[:])
|
||||||
if peer1 != peer2 {
|
if peer1 != peer2 {
|
||||||
t.Error("Trie did not match naive implementation, for:", addr)
|
t.Error("Trie did not match naive implementation, for:", addr)
|
||||||
}
|
}
|
||||||
@ -98,9 +98,9 @@ func TestTrieRandomIPv4(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTrieRandomIPv6(t *testing.T) {
|
func TestTrieRandomIPv6(t *testing.T) {
|
||||||
var trie *trieEntry
|
|
||||||
var slow SlowRouter
|
var slow SlowRouter
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
|
var allowedIPs AllowedIPs
|
||||||
|
|
||||||
rand.Seed(1)
|
rand.Seed(1)
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ func TestTrieRandomIPv6(t *testing.T) {
|
|||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % NumberOfPeers
|
index := rand.Int() % NumberOfPeers
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
allowedIPs.Insert(addr[:], cidr, peers[index])
|
||||||
slow = slow.Insert(addr[:], cidr, peers[index])
|
slow = slow.Insert(addr[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,7 +123,7 @@ func TestTrieRandomIPv6(t *testing.T) {
|
|||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
peer1 := slow.Lookup(addr[:])
|
peer1 := slow.Lookup(addr[:])
|
||||||
peer2 := trie.lookup(addr[:])
|
peer2 := allowedIPs.LookupIPv6(addr[:])
|
||||||
if peer1 != peer2 {
|
if peer1 != peer2 {
|
||||||
t.Error("Trie did not match naive implementation, for:", addr)
|
t.Error("Trie did not match naive implementation, for:", addr)
|
||||||
}
|
}
|
||||||
|
@ -42,6 +42,7 @@ func TestCommonBits(t *testing.T) {
|
|||||||
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
|
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
|
||||||
var trie *trieEntry
|
var trie *trieEntry
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
|
root := parentIndirection{&trie, 2}
|
||||||
|
|
||||||
rand.Seed(1)
|
rand.Seed(1)
|
||||||
|
|
||||||
@ -56,7 +57,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
|
|||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % peerNumber
|
index := rand.Int() % peerNumber
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
root.insert(addr[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
@ -94,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
|
|||||||
g := &Peer{}
|
g := &Peer{}
|
||||||
h := &Peer{}
|
h := &Peer{}
|
||||||
|
|
||||||
var trie *trieEntry
|
var allowedIPs AllowedIPs
|
||||||
|
|
||||||
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
||||||
trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
|
allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
p := trie.lookup([]byte{a, b, c, d})
|
p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
|
||||||
if p != peer {
|
if p != peer {
|
||||||
t.Error("Assert EQ failed")
|
t.Error("Assert EQ failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNEQ := func(peer *Peer, a, b, c, d byte) {
|
assertNEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
p := trie.lookup([]byte{a, b, c, d})
|
p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
|
||||||
if p == peer {
|
if p == peer {
|
||||||
t.Error("Assert NEQ failed")
|
t.Error("Assert NEQ failed")
|
||||||
}
|
}
|
||||||
@ -150,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
|
|||||||
assertEQ(a, 192, 0, 0, 0)
|
assertEQ(a, 192, 0, 0, 0)
|
||||||
assertEQ(a, 255, 0, 0, 0)
|
assertEQ(a, 255, 0, 0, 0)
|
||||||
|
|
||||||
trie = trie.removeByPeer(a)
|
allowedIPs.RemoveByPeer(a)
|
||||||
|
|
||||||
assertNEQ(a, 1, 0, 0, 0)
|
assertNEQ(a, 1, 0, 0, 0)
|
||||||
assertNEQ(a, 64, 0, 0, 0)
|
assertNEQ(a, 64, 0, 0, 0)
|
||||||
@ -158,12 +159,12 @@ func TestTrieIPv4(t *testing.T) {
|
|||||||
assertNEQ(a, 192, 0, 0, 0)
|
assertNEQ(a, 192, 0, 0, 0)
|
||||||
assertNEQ(a, 255, 0, 0, 0)
|
assertNEQ(a, 255, 0, 0, 0)
|
||||||
|
|
||||||
trie = nil
|
allowedIPs = AllowedIPs{}
|
||||||
|
|
||||||
insert(a, 192, 168, 0, 0, 16)
|
insert(a, 192, 168, 0, 0, 16)
|
||||||
insert(a, 192, 168, 0, 0, 24)
|
insert(a, 192, 168, 0, 0, 24)
|
||||||
|
|
||||||
trie = trie.removeByPeer(a)
|
allowedIPs.RemoveByPeer(a)
|
||||||
|
|
||||||
assertNEQ(a, 192, 168, 0, 1)
|
assertNEQ(a, 192, 168, 0, 1)
|
||||||
}
|
}
|
||||||
@ -181,7 +182,7 @@ func TestTrieIPv6(t *testing.T) {
|
|||||||
g := &Peer{}
|
g := &Peer{}
|
||||||
h := &Peer{}
|
h := &Peer{}
|
||||||
|
|
||||||
var trie *trieEntry
|
var allowedIPs AllowedIPs
|
||||||
|
|
||||||
expand := func(a uint32) []byte {
|
expand := func(a uint32) []byte {
|
||||||
var out [4]byte
|
var out [4]byte
|
||||||
@ -198,7 +199,7 @@ func TestTrieIPv6(t *testing.T) {
|
|||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
addr = append(addr, expand(c)...)
|
addr = append(addr, expand(c)...)
|
||||||
addr = append(addr, expand(d)...)
|
addr = append(addr, expand(d)...)
|
||||||
trie = trie.insert(addr, cidr, peer)
|
allowedIPs.Insert(addr, cidr, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||||
@ -207,7 +208,7 @@ func TestTrieIPv6(t *testing.T) {
|
|||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
addr = append(addr, expand(c)...)
|
addr = append(addr, expand(c)...)
|
||||||
addr = append(addr, expand(d)...)
|
addr = append(addr, expand(d)...)
|
||||||
p := trie.lookup(addr)
|
p := allowedIPs.LookupIPv6(addr)
|
||||||
if p != peer {
|
if p != peer {
|
||||||
t.Error("Assert EQ failed")
|
t.Error("Assert EQ failed")
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user