diff --git a/device/allowedips.go b/device/allowedips.go index 3cac694..c36ef3a 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -16,68 +16,86 @@ import ( "unsafe" ) -type parentIndirection struct { - parentBit **trieEntry +type ipArray interface { + [4]byte | [16]byte +} + +type parentIndirection[B ipArray] struct { + parentBit **trieEntry[B] parentBitType uint8 } -type trieEntry struct { +type trieEntry[B ipArray] struct { peer *Peer - child [2]*trieEntry - parent parentIndirection + child [2]*trieEntry[B] + parent parentIndirection[B] cidr uint8 bitAtByte uint8 bitAtShift uint8 - bits []byte + bits B perPeerElem *list.Element } -func commonBits(ip1, ip2 []byte) uint8 { - size := len(ip1) - if size == net.IPv4len { - a := binary.BigEndian.Uint32(ip1) - b := binary.BigEndian.Uint32(ip2) - x := a ^ b - return uint8(bits.LeadingZeros32(x)) - } else if size == net.IPv6len { - a := binary.BigEndian.Uint64(ip1) - b := binary.BigEndian.Uint64(ip2) - x := a ^ b - if x != 0 { - return uint8(bits.LeadingZeros64(x)) - } - a = binary.BigEndian.Uint64(ip1[8:]) - b = binary.BigEndian.Uint64(ip2[8:]) - x = a ^ b - return 64 + uint8(bits.LeadingZeros64(x)) - } else { - panic("Wrong size bit string") - } +func commonBits4(ip1, ip2 [4]byte) uint8 { + a := binary.BigEndian.Uint32(ip1[:]) + b := binary.BigEndian.Uint32(ip2[:]) + x := a ^ b + return uint8(bits.LeadingZeros32(x)) } -func (node *trieEntry) addToPeerEntries() { +func commonBits16(ip1, ip2 [16]byte) uint8 { + a := binary.BigEndian.Uint64(ip1[:8]) + b := binary.BigEndian.Uint64(ip2[:8]) + x := a ^ b + if x != 0 { + return uint8(bits.LeadingZeros64(x)) + } + a = binary.BigEndian.Uint64(ip1[8:]) + b = binary.BigEndian.Uint64(ip2[8:]) + x = a ^ b + return 64 + uint8(bits.LeadingZeros64(x)) +} + +func giveMeA4[B ipArray](b B) [4]byte { + return *(*[4]byte)(unsafe.Slice(&b[0], 4)) +} + +func giveMeA16[B ipArray](b B) [16]byte { + return *(*[16]byte)(unsafe.Slice(&b[0], 16)) +} + +func commonBits[B ipArray](ip1, ip2 B) uint8 { + if len(ip1) == 4 { + return commonBits4(giveMeA4(ip1), giveMeA4(ip2)) + } else if len(ip1) == 16 { + return commonBits16(giveMeA16(ip1), giveMeA16(ip2)) + } + panic("Wrong size bit string") +} + +func (node *trieEntry[B]) addToPeerEntries() { node.perPeerElem = node.peer.trieEntries.PushBack(node) } -func (node *trieEntry) removeFromPeerEntries() { +func (node *trieEntry[B]) removeFromPeerEntries() { if node.perPeerElem != nil { node.peer.trieEntries.Remove(node.perPeerElem) node.perPeerElem = nil } } -func (node *trieEntry) choose(ip []byte) byte { +func (node *trieEntry[B]) choose(ip B) byte { return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } -func (node *trieEntry) maskSelf() { +func (node *trieEntry[B]) maskSelf() { mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) for i := 0; i < len(mask); i++ { node.bits[i] &= mask[i] } } -func (node *trieEntry) zeroizePointers() { +func (node *trieEntry[B]) zeroizePointers() { // Make the garbage collector's life slightly easier node.peer = nil node.child[0] = nil @@ -85,7 +103,7 @@ func (node *trieEntry) zeroizePointers() { node.parent.parentBit = nil } -func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { +func (node *trieEntry[B]) nodePlacement(ip B, cidr uint8) (parent *trieEntry[B], exact bool) { for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { parent = node if parent.cidr == cidr { @@ -98,9 +116,9 @@ func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, return } -func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { +func (trie parentIndirection[B]) insert(ip B, cidr uint8, peer *Peer) { if *trie.parentBit == nil { - node := &trieEntry{ + node := &trieEntry[B]{ peer: peer, parent: trie, bits: ip, @@ -121,7 +139,7 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { return } - newNode := &trieEntry{ + newNode := &trieEntry[B]{ peer: peer, bits: ip, cidr: cidr, @@ -131,14 +149,14 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { newNode.maskSelf() newNode.addToPeerEntries() - var down *trieEntry + var down *trieEntry[B] if node == nil { down = *trie.parentBit } else { bit := node.choose(ip) down = node.child[bit] if down == nil { - newNode.parent = parentIndirection{&node.child[bit], bit} + newNode.parent = parentIndirection[B]{&node.child[bit], bit} node.child[bit] = newNode return } @@ -151,21 +169,21 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { if newNode.cidr == cidr { bit := newNode.choose(down.bits) - down.parent = parentIndirection{&newNode.child[bit], bit} + down.parent = parentIndirection[B]{&newNode.child[bit], bit} newNode.child[bit] = down if parent == nil { newNode.parent = trie *trie.parentBit = newNode } else { bit := parent.choose(newNode.bits) - newNode.parent = parentIndirection{&parent.child[bit], bit} + newNode.parent = parentIndirection[B]{&parent.child[bit], bit} parent.child[bit] = newNode } return } - node = &trieEntry{ - bits: append([]byte{}, newNode.bits...), + node = &trieEntry[B]{ + bits: newNode.bits, cidr: cidr, bitAtByte: cidr / 8, bitAtShift: 7 - (cidr % 8), @@ -173,22 +191,22 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { node.maskSelf() bit := node.choose(down.bits) - down.parent = parentIndirection{&node.child[bit], bit} + down.parent = parentIndirection[B]{&node.child[bit], bit} node.child[bit] = down bit = node.choose(newNode.bits) - newNode.parent = parentIndirection{&node.child[bit], bit} + newNode.parent = parentIndirection[B]{&node.child[bit], bit} node.child[bit] = newNode if parent == nil { node.parent = trie *trie.parentBit = node } else { bit := parent.choose(node.bits) - node.parent = parentIndirection{&parent.child[bit], bit} + node.parent = parentIndirection[B]{&parent.child[bit], bit} parent.child[bit] = node } } -func (node *trieEntry) lookup(ip []byte) *Peer { +func (node *trieEntry[B]) lookup(ip B) *Peer { var found *Peer size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { @@ -205,8 +223,8 @@ func (node *trieEntry) lookup(ip []byte) *Peer { } type AllowedIPs struct { - IPv4 *trieEntry - IPv6 *trieEntry + IPv4 *trieEntry[[4]byte] + IPv6 *trieEntry[[16]byte] mutex sync.RWMutex } @@ -215,14 +233,51 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) defer table.mutex.RUnlock() for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { - node := elem.Value.(*trieEntry) - a, _ := netip.AddrFromSlice(node.bits) - if !cb(netip.PrefixFrom(a, int(node.cidr))) { - return + if node, ok := elem.Value.(*trieEntry[[4]byte]); ok { + if !cb(netip.PrefixFrom(netip.AddrFrom4(node.bits), int(node.cidr))) { + return + } + } else if node, ok := elem.Value.(*trieEntry[[16]byte]); ok { + if !cb(netip.PrefixFrom(netip.AddrFrom16(node.bits), int(node.cidr))) { + return + } } } } +func (node *trieEntry[B]) remove() { + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + return + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + node.zeroizePointers() + return + } + parent := (*trieEntry[B])(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + node.zeroizePointers() + return + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + node.zeroizePointers() + parent.zeroizePointers() +} + func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() @@ -230,38 +285,11 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { var next *list.Element for elem := peer.trieEntries.Front(); elem != nil; elem = next { next = elem.Next() - node := elem.Value.(*trieEntry) - - node.removeFromPeerEntries() - node.peer = nil - if node.child[0] != nil && node.child[1] != nil { - continue + if node, ok := elem.Value.(*trieEntry[[4]byte]); ok { + node.remove() + } else if node, ok := elem.Value.(*trieEntry[[16]byte]); ok { + node.remove() } - bit := 0 - if node.child[0] == nil { - bit = 1 - } - child := node.child[bit] - if child != nil { - child.parent = node.parent - } - *node.parent.parentBit = child - if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { - node.zeroizePointers() - continue - } - parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) - if parent.peer != nil { - node.zeroizePointers() - continue - } - child = parent.child[node.parent.parentBitType^1] - if child != nil { - child.parent = parent.parent - } - *parent.parent.parentBit = child - node.zeroizePointers() - parent.zeroizePointers() } } @@ -270,11 +298,9 @@ func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { defer table.mutex.Unlock() if prefix.Addr().Is6() { - ip := prefix.Addr().As16() - parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + parentIndirection[[16]byte]{&table.IPv6, 2}.insert(prefix.Addr().As16(), uint8(prefix.Bits()), peer) } else if prefix.Addr().Is4() { - ip := prefix.Addr().As4() - parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + parentIndirection[[4]byte]{&table.IPv4, 2}.insert(prefix.Addr().As4(), uint8(prefix.Bits()), peer) } else { panic(errors.New("inserting unknown address type")) } @@ -285,9 +311,9 @@ func (table *AllowedIPs) Lookup(ip []byte) *Peer { defer table.mutex.RUnlock() switch len(ip) { case net.IPv6len: - return table.IPv6.lookup(ip) + return table.IPv6.lookup(*(*[16]byte)(ip)) case net.IPv4len: - return table.IPv4.lookup(ip) + return table.IPv4.lookup(*(*[4]byte)(ip)) default: panic(errors.New("looking up unknown address type")) } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 0d3eecb..8c17d02 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -40,9 +40,18 @@ func (r SlowRouter) Swap(i, j int) { r[i], r[j] = r[j], r[i] } +func commonBitsSlice(addr1, addr2 []byte) uint8 { + if len(addr1) == 4 { + return commonBits4(*(*[4]byte)(addr1), *(*[4]byte)(addr2)) + } else if len(addr1) == 16 { + return commonBits16(*(*[16]byte)(addr1), *(*[16]byte)(addr2)) + } + return 0 +} + func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { for _, t := range r { - if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { + if t.cidr == cidr && commonBitsSlice(t.bits, addr) >= cidr { t.peer = peer t.bits = addr return r @@ -59,7 +68,7 @@ func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { func (r SlowRouter) Lookup(addr []byte) *Peer { for _, t := range r { - common := commonBits(t.bits, addr) + common := commonBitsSlice(t.bits, addr) if common >= t.cidr { return t.peer } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 225c788..a0d286f 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -7,28 +7,28 @@ package device import ( "math/rand" - "net" "net/netip" "testing" + "unsafe" ) -type testPairCommonBits struct { - s1 []byte - s2 []byte +type testPairCommonBits4 struct { + s1 [4]byte + s2 [4]byte match uint8 } -func TestCommonBits(t *testing.T) { - tests := []testPairCommonBits{ - {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, - {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, - {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, - {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, - {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, +func TestCommonBits4(t *testing.T) { + tests := []testPairCommonBits4{ + {s1: [4]byte{1, 4, 53, 128}, s2: [4]byte{0, 0, 0, 0}, match: 7}, + {s1: [4]byte{0, 4, 53, 128}, s2: [4]byte{0, 0, 0, 0}, match: 13}, + {s1: [4]byte{0, 4, 53, 253}, s2: [4]byte{0, 4, 53, 252}, match: 31}, + {s1: [4]byte{192, 168, 1, 1}, s2: [4]byte{192, 169, 1, 1}, match: 15}, + {s1: [4]byte{65, 168, 1, 1}, s2: [4]byte{192, 169, 1, 1}, match: 0}, } for _, p := range tests { - v := commonBits(p.s1, p.s2) + v := commonBits4(p.s1, p.s2) if v != p.match { t.Error( "For slice", p.s1, p.s2, @@ -39,48 +39,46 @@ func TestCommonBits(t *testing.T) { } } -func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { - var trie *trieEntry +func benchmarkTrie[B ipArray](peerNumber, addressNumber int, b *testing.B) { + var trie *trieEntry[B] var peers []*Peer - root := parentIndirection{&trie, 2} + root := parentIndirection[B]{&trie, 2} rand.Seed(1) - const AddressLength = 4 - for n := 0; n < peerNumber; n++ { peers = append(peers, &Peer{}) } for n := 0; n < addressNumber; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint8(rand.Uint32() % (AddressLength * 8)) + var addr B + rand.Read(unsafe.Slice(&addr[0], len(addr))) + cidr := uint8(rand.Uint32() % uint32(len(addr)*8)) index := rand.Int() % peerNumber - root.insert(addr[:], cidr, peers[index]) + root.insert(addr, cidr, peers[index]) } for n := 0; n < b.N; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - trie.lookup(addr[:]) + var addr B + rand.Read(unsafe.Slice(&addr[0], len(addr))) + trie.lookup(addr) } } func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv4len, b) + benchmarkTrie[[4]byte](100, 1000, b) } func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv4len, b) + benchmarkTrie[[4]byte](10, 10, b) } func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv6len, b) + benchmarkTrie[[16]byte](100, 1000, b) } func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv6len, b) + benchmarkTrie[[16]byte](10, 10, b) } /* Test ported from kernel implementation: