1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2025-09-18 20:57:50 +02:00

device: make allowedips generic

The implementation of commonBits uses a horrific unsafe.Slice trick.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2022-03-16 19:34:42 -06:00
parent 95b48cdb39
commit f3aff443a6
3 changed files with 152 additions and 119 deletions

View File

@ -16,32 +16,36 @@ import (
"unsafe" "unsafe"
) )
type parentIndirection struct { type ipArray interface {
parentBit **trieEntry [4]byte | [16]byte
}
type parentIndirection[B ipArray] struct {
parentBit **trieEntry[B]
parentBitType uint8 parentBitType uint8
} }
type trieEntry struct { type trieEntry[B ipArray] struct {
peer *Peer peer *Peer
child [2]*trieEntry child [2]*trieEntry[B]
parent parentIndirection parent parentIndirection[B]
cidr uint8 cidr uint8
bitAtByte uint8 bitAtByte uint8
bitAtShift uint8 bitAtShift uint8
bits []byte bits B
perPeerElem *list.Element perPeerElem *list.Element
} }
func commonBits(ip1, ip2 []byte) uint8 { func commonBits4(ip1, ip2 [4]byte) uint8 {
size := len(ip1) a := binary.BigEndian.Uint32(ip1[:])
if size == net.IPv4len { b := binary.BigEndian.Uint32(ip2[:])
a := binary.BigEndian.Uint32(ip1)
b := binary.BigEndian.Uint32(ip2)
x := a ^ b x := a ^ b
return uint8(bits.LeadingZeros32(x)) return uint8(bits.LeadingZeros32(x))
} else if size == net.IPv6len { }
a := binary.BigEndian.Uint64(ip1)
b := binary.BigEndian.Uint64(ip2) func commonBits16(ip1, ip2 [16]byte) uint8 {
a := binary.BigEndian.Uint64(ip1[:8])
b := binary.BigEndian.Uint64(ip2[:8])
x := a ^ b x := a ^ b
if x != 0 { if x != 0 {
return uint8(bits.LeadingZeros64(x)) return uint8(bits.LeadingZeros64(x))
@ -50,34 +54,48 @@ func commonBits(ip1, ip2 []byte) uint8 {
b = binary.BigEndian.Uint64(ip2[8:]) b = binary.BigEndian.Uint64(ip2[8:])
x = a ^ b x = a ^ b
return 64 + uint8(bits.LeadingZeros64(x)) return 64 + uint8(bits.LeadingZeros64(x))
} else {
panic("Wrong size bit string")
}
} }
func (node *trieEntry) addToPeerEntries() { func giveMeA4[B ipArray](b B) [4]byte {
return *(*[4]byte)(unsafe.Slice(&b[0], 4))
}
func giveMeA16[B ipArray](b B) [16]byte {
return *(*[16]byte)(unsafe.Slice(&b[0], 16))
}
func commonBits[B ipArray](ip1, ip2 B) uint8 {
if len(ip1) == 4 {
return commonBits4(giveMeA4(ip1), giveMeA4(ip2))
} else if len(ip1) == 16 {
return commonBits16(giveMeA16(ip1), giveMeA16(ip2))
}
panic("Wrong size bit string")
}
func (node *trieEntry[B]) addToPeerEntries() {
node.perPeerElem = node.peer.trieEntries.PushBack(node) node.perPeerElem = node.peer.trieEntries.PushBack(node)
} }
func (node *trieEntry) removeFromPeerEntries() { func (node *trieEntry[B]) removeFromPeerEntries() {
if node.perPeerElem != nil { if node.perPeerElem != nil {
node.peer.trieEntries.Remove(node.perPeerElem) node.peer.trieEntries.Remove(node.perPeerElem)
node.perPeerElem = nil node.perPeerElem = nil
} }
} }
func (node *trieEntry) choose(ip []byte) byte { func (node *trieEntry[B]) choose(ip B) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1 return (ip[node.bitAtByte] >> node.bitAtShift) & 1
} }
func (node *trieEntry) maskSelf() { func (node *trieEntry[B]) maskSelf() {
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
for i := 0; i < len(mask); i++ { for i := 0; i < len(mask); i++ {
node.bits[i] &= mask[i] node.bits[i] &= mask[i]
} }
} }
func (node *trieEntry) zeroizePointers() { func (node *trieEntry[B]) zeroizePointers() {
// Make the garbage collector's life slightly easier // Make the garbage collector's life slightly easier
node.peer = nil node.peer = nil
node.child[0] = nil node.child[0] = nil
@ -85,7 +103,7 @@ func (node *trieEntry) zeroizePointers() {
node.parent.parentBit = nil node.parent.parentBit = nil
} }
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { func (node *trieEntry[B]) nodePlacement(ip B, cidr uint8) (parent *trieEntry[B], exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node parent = node
if parent.cidr == cidr { if parent.cidr == cidr {
@ -98,9 +116,9 @@ func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry,
return return
} }
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { func (trie parentIndirection[B]) insert(ip B, cidr uint8, peer *Peer) {
if *trie.parentBit == nil { if *trie.parentBit == nil {
node := &trieEntry{ node := &trieEntry[B]{
peer: peer, peer: peer,
parent: trie, parent: trie,
bits: ip, bits: ip,
@ -121,7 +139,7 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
return return
} }
newNode := &trieEntry{ newNode := &trieEntry[B]{
peer: peer, peer: peer,
bits: ip, bits: ip,
cidr: cidr, cidr: cidr,
@ -131,14 +149,14 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
newNode.maskSelf() newNode.maskSelf()
newNode.addToPeerEntries() newNode.addToPeerEntries()
var down *trieEntry var down *trieEntry[B]
if node == nil { if node == nil {
down = *trie.parentBit down = *trie.parentBit
} else { } else {
bit := node.choose(ip) bit := node.choose(ip)
down = node.child[bit] down = node.child[bit]
if down == nil { if down == nil {
newNode.parent = parentIndirection{&node.child[bit], bit} newNode.parent = parentIndirection[B]{&node.child[bit], bit}
node.child[bit] = newNode node.child[bit] = newNode
return return
} }
@ -151,21 +169,21 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if newNode.cidr == cidr { if newNode.cidr == cidr {
bit := newNode.choose(down.bits) bit := newNode.choose(down.bits)
down.parent = parentIndirection{&newNode.child[bit], bit} down.parent = parentIndirection[B]{&newNode.child[bit], bit}
newNode.child[bit] = down newNode.child[bit] = down
if parent == nil { if parent == nil {
newNode.parent = trie newNode.parent = trie
*trie.parentBit = newNode *trie.parentBit = newNode
} else { } else {
bit := parent.choose(newNode.bits) bit := parent.choose(newNode.bits)
newNode.parent = parentIndirection{&parent.child[bit], bit} newNode.parent = parentIndirection[B]{&parent.child[bit], bit}
parent.child[bit] = newNode parent.child[bit] = newNode
} }
return return
} }
node = &trieEntry{ node = &trieEntry[B]{
bits: append([]byte{}, newNode.bits...), bits: newNode.bits,
cidr: cidr, cidr: cidr,
bitAtByte: cidr / 8, bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8), bitAtShift: 7 - (cidr % 8),
@ -173,22 +191,22 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
node.maskSelf() node.maskSelf()
bit := node.choose(down.bits) bit := node.choose(down.bits)
down.parent = parentIndirection{&node.child[bit], bit} down.parent = parentIndirection[B]{&node.child[bit], bit}
node.child[bit] = down node.child[bit] = down
bit = node.choose(newNode.bits) bit = node.choose(newNode.bits)
newNode.parent = parentIndirection{&node.child[bit], bit} newNode.parent = parentIndirection[B]{&node.child[bit], bit}
node.child[bit] = newNode node.child[bit] = newNode
if parent == nil { if parent == nil {
node.parent = trie node.parent = trie
*trie.parentBit = node *trie.parentBit = node
} else { } else {
bit := parent.choose(node.bits) bit := parent.choose(node.bits)
node.parent = parentIndirection{&parent.child[bit], bit} node.parent = parentIndirection[B]{&parent.child[bit], bit}
parent.child[bit] = node parent.child[bit] = node
} }
} }
func (node *trieEntry) lookup(ip []byte) *Peer { func (node *trieEntry[B]) lookup(ip B) *Peer {
var found *Peer var found *Peer
size := uint8(len(ip)) size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr { for node != nil && commonBits(node.bits, ip) >= node.cidr {
@ -205,8 +223,8 @@ func (node *trieEntry) lookup(ip []byte) *Peer {
} }
type AllowedIPs struct { type AllowedIPs struct {
IPv4 *trieEntry IPv4 *trieEntry[[4]byte]
IPv6 *trieEntry IPv6 *trieEntry[[16]byte]
mutex sync.RWMutex mutex sync.RWMutex
} }
@ -215,27 +233,23 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry) if node, ok := elem.Value.(*trieEntry[[4]byte]); ok {
a, _ := netip.AddrFromSlice(node.bits) if !cb(netip.PrefixFrom(netip.AddrFrom4(node.bits), int(node.cidr))) {
if !cb(netip.PrefixFrom(a, int(node.cidr))) { return
}
} else if node, ok := elem.Value.(*trieEntry[[16]byte]); ok {
if !cb(netip.PrefixFrom(netip.AddrFrom16(node.bits), int(node.cidr))) {
return return
} }
} }
} }
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) { func (node *trieEntry[B]) remove() {
table.mutex.Lock()
defer table.mutex.Unlock()
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
node := elem.Value.(*trieEntry)
node.removeFromPeerEntries() node.removeFromPeerEntries()
node.peer = nil node.peer = nil
if node.child[0] != nil && node.child[1] != nil { if node.child[0] != nil && node.child[1] != nil {
continue return
} }
bit := 0 bit := 0
if node.child[0] == nil { if node.child[0] == nil {
@ -248,12 +262,12 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
*node.parent.parentBit = child *node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers() node.zeroizePointers()
continue return
} }
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) parent := (*trieEntry[B])(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil { if parent.peer != nil {
node.zeroizePointers() node.zeroizePointers()
continue return
} }
child = parent.child[node.parent.parentBitType^1] child = parent.child[node.parent.parentBitType^1]
if child != nil { if child != nil {
@ -263,6 +277,20 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
node.zeroizePointers() node.zeroizePointers()
parent.zeroizePointers() parent.zeroizePointers()
} }
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
if node, ok := elem.Value.(*trieEntry[[4]byte]); ok {
node.remove()
} else if node, ok := elem.Value.(*trieEntry[[16]byte]); ok {
node.remove()
}
}
} }
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
@ -270,11 +298,9 @@ func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
defer table.mutex.Unlock() defer table.mutex.Unlock()
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
ip := prefix.Addr().As16() parentIndirection[[16]byte]{&table.IPv6, 2}.insert(prefix.Addr().As16(), uint8(prefix.Bits()), peer)
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else if prefix.Addr().Is4() { } else if prefix.Addr().Is4() {
ip := prefix.Addr().As4() parentIndirection[[4]byte]{&table.IPv4, 2}.insert(prefix.Addr().As4(), uint8(prefix.Bits()), peer)
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else { } else {
panic(errors.New("inserting unknown address type")) panic(errors.New("inserting unknown address type"))
} }
@ -285,9 +311,9 @@ func (table *AllowedIPs) Lookup(ip []byte) *Peer {
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
switch len(ip) { switch len(ip) {
case net.IPv6len: case net.IPv6len:
return table.IPv6.lookup(ip) return table.IPv6.lookup(*(*[16]byte)(ip))
case net.IPv4len: case net.IPv4len:
return table.IPv4.lookup(ip) return table.IPv4.lookup(*(*[4]byte)(ip))
default: default:
panic(errors.New("looking up unknown address type")) panic(errors.New("looking up unknown address type"))
} }

View File

@ -40,9 +40,18 @@ func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i] r[i], r[j] = r[j], r[i]
} }
func commonBitsSlice(addr1, addr2 []byte) uint8 {
if len(addr1) == 4 {
return commonBits4(*(*[4]byte)(addr1), *(*[4]byte)(addr2))
} else if len(addr1) == 16 {
return commonBits16(*(*[16]byte)(addr1), *(*[16]byte)(addr2))
}
return 0
}
func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r { for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { if t.cidr == cidr && commonBitsSlice(t.bits, addr) >= cidr {
t.peer = peer t.peer = peer
t.bits = addr t.bits = addr
return r return r
@ -59,7 +68,7 @@ func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
func (r SlowRouter) Lookup(addr []byte) *Peer { func (r SlowRouter) Lookup(addr []byte) *Peer {
for _, t := range r { for _, t := range r {
common := commonBits(t.bits, addr) common := commonBitsSlice(t.bits, addr)
if common >= t.cidr { if common >= t.cidr {
return t.peer return t.peer
} }

View File

@ -7,28 +7,28 @@ package device
import ( import (
"math/rand" "math/rand"
"net"
"net/netip" "net/netip"
"testing" "testing"
"unsafe"
) )
type testPairCommonBits struct { type testPairCommonBits4 struct {
s1 []byte s1 [4]byte
s2 []byte s2 [4]byte
match uint8 match uint8
} }
func TestCommonBits(t *testing.T) { func TestCommonBits4(t *testing.T) {
tests := []testPairCommonBits{ tests := []testPairCommonBits4{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, {s1: [4]byte{1, 4, 53, 128}, s2: [4]byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, {s1: [4]byte{0, 4, 53, 128}, s2: [4]byte{0, 0, 0, 0}, match: 13},
{s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, {s1: [4]byte{0, 4, 53, 253}, s2: [4]byte{0, 4, 53, 252}, match: 31},
{s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, {s1: [4]byte{192, 168, 1, 1}, s2: [4]byte{192, 169, 1, 1}, match: 15},
{s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, {s1: [4]byte{65, 168, 1, 1}, s2: [4]byte{192, 169, 1, 1}, match: 0},
} }
for _, p := range tests { for _, p := range tests {
v := commonBits(p.s1, p.s2) v := commonBits4(p.s1, p.s2)
if v != p.match { if v != p.match {
t.Error( t.Error(
"For slice", p.s1, p.s2, "For slice", p.s1, p.s2,
@ -39,48 +39,46 @@ func TestCommonBits(t *testing.T) {
} }
} }
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { func benchmarkTrie[B ipArray](peerNumber, addressNumber int, b *testing.B) {
var trie *trieEntry var trie *trieEntry[B]
var peers []*Peer var peers []*Peer
root := parentIndirection{&trie, 2} root := parentIndirection[B]{&trie, 2}
rand.Seed(1) rand.Seed(1)
const AddressLength = 4
for n := 0; n < peerNumber; n++ { for n := 0; n < peerNumber; n++ {
peers = append(peers, &Peer{}) peers = append(peers, &Peer{})
} }
for n := 0; n < addressNumber; n++ { for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte var addr B
rand.Read(addr[:]) rand.Read(unsafe.Slice(&addr[0], len(addr)))
cidr := uint8(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Uint32() % uint32(len(addr)*8))
index := rand.Int() % peerNumber index := rand.Int() % peerNumber
root.insert(addr[:], cidr, peers[index]) root.insert(addr, cidr, peers[index])
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
var addr [AddressLength]byte var addr B
rand.Read(addr[:]) rand.Read(unsafe.Slice(&addr[0], len(addr)))
trie.lookup(addr[:]) trie.lookup(addr)
} }
} }
func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
benchmarkTrie(100, 1000, net.IPv4len, b) benchmarkTrie[[4]byte](100, 1000, b)
} }
func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
benchmarkTrie(10, 10, net.IPv4len, b) benchmarkTrie[[4]byte](10, 10, b)
} }
func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
benchmarkTrie(100, 1000, net.IPv6len, b) benchmarkTrie[[16]byte](100, 1000, b)
} }
func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
benchmarkTrie(10, 10, net.IPv6len, b) benchmarkTrie[[16]byte](10, 10, b)
} }
/* Test ported from kernel implementation: /* Test ported from kernel implementation: