1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

device: reduce size of trie struct

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-06-03 13:51:03 +02:00
parent 64cb82f2b3
commit 4a57024b94
5 changed files with 45 additions and 53 deletions

View File

@ -15,13 +15,13 @@ import (
) )
type trieEntry struct { type trieEntry struct {
child [2]*trieEntry peer *Peer
peer *Peer child [2]*trieEntry
bits net.IP cidr uint8
cidr uint bitAtByte uint8
bit_at_byte uint bitAtShift uint8
bit_at_shift uint bits net.IP
perPeerElem *list.Element perPeerElem *list.Element
} }
func isLittleEndian() bool { func isLittleEndian() bool {
@ -45,24 +45,24 @@ func swapU64(i uint64) uint64 {
return bits.ReverseBytes64(i) return bits.ReverseBytes64(i)
} }
func commonBits(ip1 net.IP, ip2 net.IP) uint { func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
size := len(ip1) size := len(ip1)
if size == net.IPv4len { if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0])) a := (*uint32)(unsafe.Pointer(&ip1[0]))
b := (*uint32)(unsafe.Pointer(&ip2[0])) b := (*uint32)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b x := *a ^ *b
return uint(bits.LeadingZeros32(swapU32(x))) return uint8(bits.LeadingZeros32(swapU32(x)))
} else if size == net.IPv6len { } else if size == net.IPv6len {
a := (*uint64)(unsafe.Pointer(&ip1[0])) a := (*uint64)(unsafe.Pointer(&ip1[0]))
b := (*uint64)(unsafe.Pointer(&ip2[0])) b := (*uint64)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b x := *a ^ *b
if x != 0 { if x != 0 {
return uint(bits.LeadingZeros64(swapU64(x))) return uint8(bits.LeadingZeros64(swapU64(x)))
} }
a = (*uint64)(unsafe.Pointer(&ip1[8])) a = (*uint64)(unsafe.Pointer(&ip1[8]))
b = (*uint64)(unsafe.Pointer(&ip2[8])) b = (*uint64)(unsafe.Pointer(&ip2[8]))
x = *a ^ *b x = *a ^ *b
return 64 + uint(bits.LeadingZeros64(swapU64(x))) return 64 + uint8(bits.LeadingZeros64(swapU64(x)))
} else { } else {
panic("Wrong size bit string") panic("Wrong size bit string")
} }
@ -104,7 +104,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
} }
func (node *trieEntry) choose(ip net.IP) byte { func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 return (ip[node.bitAtByte] >> node.bitAtShift) & 1
} }
func (node *trieEntry) maskSelf() { func (node *trieEntry) maskSelf() {
@ -114,17 +114,17 @@ func (node *trieEntry) maskSelf() {
} }
} }
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
// at leaf // at leaf
if node == nil { if node == nil {
node := &trieEntry{ node := &trieEntry{
bits: ip, bits: ip,
peer: peer, peer: peer,
cidr: cidr, cidr: cidr,
bit_at_byte: cidr / 8, bitAtByte: cidr / 8,
bit_at_shift: 7 - (cidr % 8), bitAtShift: 7 - (cidr % 8),
} }
node.maskSelf() node.maskSelf()
node.addToPeerEntries() node.addToPeerEntries()
@ -149,16 +149,18 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// split node // split node
newNode := &trieEntry{ newNode := &trieEntry{
bits: ip, bits: ip,
peer: peer, peer: peer,
cidr: cidr, cidr: cidr,
bit_at_byte: cidr / 8, bitAtByte: cidr / 8,
bit_at_shift: 7 - (cidr % 8), bitAtShift: 7 - (cidr % 8),
} }
newNode.maskSelf() newNode.maskSelf()
newNode.addToPeerEntries() newNode.addToPeerEntries()
cidr = min(cidr, common) if common < cidr {
cidr = common
}
// check for shorter prefix // check for shorter prefix
@ -171,11 +173,11 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// create new parent for node & newNode // create new parent for node & newNode
parent := &trieEntry{ parent := &trieEntry{
bits: append([]byte{}, ip...), bits: append([]byte{}, ip...),
peer: nil, peer: nil,
cidr: cidr, cidr: cidr,
bit_at_byte: cidr / 8, bitAtByte: cidr / 8,
bit_at_shift: 7 - (cidr % 8), bitAtShift: 7 - (cidr % 8),
} }
parent.maskSelf() parent.maskSelf()
@ -188,12 +190,12 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
func (node *trieEntry) lookup(ip net.IP) *Peer { func (node *trieEntry) lookup(ip net.IP) *Peer {
var found *Peer var found *Peer
size := uint(len(ip)) size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr { for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil { if node.peer != nil {
found = node.peer found = node.peer
} }
if node.bit_at_byte == size { if node.bitAtByte == size {
break break
} }
bit := node.choose(ip) bit := node.choose(ip)
@ -208,7 +210,7 @@ type AllowedIPs struct {
mutex sync.RWMutex mutex sync.RWMutex
} }
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) { func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
@ -228,7 +230,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.IPv6 = table.IPv6.removeByPeer(peer) table.IPv6 = table.IPv6.removeByPeer(peer)
} }
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()

View File

@ -19,7 +19,7 @@ const (
type SlowNode struct { type SlowNode struct {
peer *Peer peer *Peer
cidr uint cidr uint8
bits []byte bits []byte
} }
@ -37,7 +37,7 @@ func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i] r[i], r[j] = r[j], r[i]
} }
func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r { for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer t.peer = peer
@ -80,7 +80,7 @@ func TestTrieRandomIPv4(t *testing.T) {
for n := 0; n < NumberOfAddresses; n++ { for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index]) trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index]) slow = slow.Insert(addr[:], cidr, peers[index])
@ -113,7 +113,7 @@ func TestTrieRandomIPv6(t *testing.T) {
for n := 0; n < NumberOfAddresses; n++ { for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index]) trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index]) slow = slow.Insert(addr[:], cidr, peers[index])

View File

@ -11,13 +11,10 @@ import (
"testing" "testing"
) )
/* Todo: More comprehensive
*/
type testPairCommonBits struct { type testPairCommonBits struct {
s1 []byte s1 []byte
s2 []byte s2 []byte
match uint match uint8
} }
func TestCommonBits(t *testing.T) { func TestCommonBits(t *testing.T) {
@ -57,7 +54,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
for n := 0; n < addressNumber; n++ { for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber index := rand.Int() % peerNumber
trie = trie.insert(addr[:], cidr, peers[index]) trie = trie.insert(addr[:], cidr, peers[index])
} }
@ -99,7 +96,7 @@ func TestTrieIPv4(t *testing.T) {
var trie *trieEntry var trie *trieEntry
insert := func(peer *Peer, a, b, c, d byte, cidr uint) { insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
trie = trie.insert([]byte{a, b, c, d}, cidr, peer) trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
} }
@ -195,7 +192,7 @@ func TestTrieIPv6(t *testing.T) {
return out[:] return out[:]
} }
insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte var addr []byte
addr = append(addr, expand(a)...) addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)

View File

@ -39,10 +39,3 @@ func (a *AtomicBool) Set(val bool) {
} }
atomic.StoreInt32(&a.int32, flag) atomic.StoreInt32(&a.int32, flag)
} }
func min(a, b uint) uint {
if a > b {
return b
}
return a
}

View File

@ -121,7 +121,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool { device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
sendf("allowed_ip=%s/%d", ip.String(), cidr) sendf("allowed_ip=%s/%d", ip.String(), cidr)
return true return true
}) })
@ -379,7 +379,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
return nil return nil
} }
ones, _ := network.Mask.Size() ones, _ := network.Mask.Size()
device.allowedips.Insert(network.IP, uint(ones), peer.Peer) device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
case "protocol_version": case "protocol_version":
if value != "1" { if value != "1" {