mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
device: reduce size of trie struct
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
64cb82f2b3
commit
4a57024b94
@ -15,12 +15,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type trieEntry struct {
|
type trieEntry struct {
|
||||||
child [2]*trieEntry
|
|
||||||
peer *Peer
|
peer *Peer
|
||||||
|
child [2]*trieEntry
|
||||||
|
cidr uint8
|
||||||
|
bitAtByte uint8
|
||||||
|
bitAtShift uint8
|
||||||
bits net.IP
|
bits net.IP
|
||||||
cidr uint
|
|
||||||
bit_at_byte uint
|
|
||||||
bit_at_shift uint
|
|
||||||
perPeerElem *list.Element
|
perPeerElem *list.Element
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,24 +45,24 @@ func swapU64(i uint64) uint64 {
|
|||||||
return bits.ReverseBytes64(i)
|
return bits.ReverseBytes64(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func commonBits(ip1 net.IP, ip2 net.IP) uint {
|
func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
|
||||||
size := len(ip1)
|
size := len(ip1)
|
||||||
if size == net.IPv4len {
|
if size == net.IPv4len {
|
||||||
a := (*uint32)(unsafe.Pointer(&ip1[0]))
|
a := (*uint32)(unsafe.Pointer(&ip1[0]))
|
||||||
b := (*uint32)(unsafe.Pointer(&ip2[0]))
|
b := (*uint32)(unsafe.Pointer(&ip2[0]))
|
||||||
x := *a ^ *b
|
x := *a ^ *b
|
||||||
return uint(bits.LeadingZeros32(swapU32(x)))
|
return uint8(bits.LeadingZeros32(swapU32(x)))
|
||||||
} else if size == net.IPv6len {
|
} else if size == net.IPv6len {
|
||||||
a := (*uint64)(unsafe.Pointer(&ip1[0]))
|
a := (*uint64)(unsafe.Pointer(&ip1[0]))
|
||||||
b := (*uint64)(unsafe.Pointer(&ip2[0]))
|
b := (*uint64)(unsafe.Pointer(&ip2[0]))
|
||||||
x := *a ^ *b
|
x := *a ^ *b
|
||||||
if x != 0 {
|
if x != 0 {
|
||||||
return uint(bits.LeadingZeros64(swapU64(x)))
|
return uint8(bits.LeadingZeros64(swapU64(x)))
|
||||||
}
|
}
|
||||||
a = (*uint64)(unsafe.Pointer(&ip1[8]))
|
a = (*uint64)(unsafe.Pointer(&ip1[8]))
|
||||||
b = (*uint64)(unsafe.Pointer(&ip2[8]))
|
b = (*uint64)(unsafe.Pointer(&ip2[8]))
|
||||||
x = *a ^ *b
|
x = *a ^ *b
|
||||||
return 64 + uint(bits.LeadingZeros64(swapU64(x)))
|
return 64 + uint8(bits.LeadingZeros64(swapU64(x)))
|
||||||
} else {
|
} else {
|
||||||
panic("Wrong size bit string")
|
panic("Wrong size bit string")
|
||||||
}
|
}
|
||||||
@ -104,7 +104,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) choose(ip net.IP) byte {
|
func (node *trieEntry) choose(ip net.IP) byte {
|
||||||
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
|
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) maskSelf() {
|
func (node *trieEntry) maskSelf() {
|
||||||
@ -114,7 +114,7 @@ func (node *trieEntry) maskSelf() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
|
||||||
|
|
||||||
// at leaf
|
// at leaf
|
||||||
|
|
||||||
@ -123,8 +123,8 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
|||||||
bits: ip,
|
bits: ip,
|
||||||
peer: peer,
|
peer: peer,
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
bit_at_byte: cidr / 8,
|
bitAtByte: cidr / 8,
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
bitAtShift: 7 - (cidr % 8),
|
||||||
}
|
}
|
||||||
node.maskSelf()
|
node.maskSelf()
|
||||||
node.addToPeerEntries()
|
node.addToPeerEntries()
|
||||||
@ -152,13 +152,15 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
|||||||
bits: ip,
|
bits: ip,
|
||||||
peer: peer,
|
peer: peer,
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
bit_at_byte: cidr / 8,
|
bitAtByte: cidr / 8,
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
bitAtShift: 7 - (cidr % 8),
|
||||||
}
|
}
|
||||||
newNode.maskSelf()
|
newNode.maskSelf()
|
||||||
newNode.addToPeerEntries()
|
newNode.addToPeerEntries()
|
||||||
|
|
||||||
cidr = min(cidr, common)
|
if common < cidr {
|
||||||
|
cidr = common
|
||||||
|
}
|
||||||
|
|
||||||
// check for shorter prefix
|
// check for shorter prefix
|
||||||
|
|
||||||
@ -174,8 +176,8 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
|||||||
bits: append([]byte{}, ip...),
|
bits: append([]byte{}, ip...),
|
||||||
peer: nil,
|
peer: nil,
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
bit_at_byte: cidr / 8,
|
bitAtByte: cidr / 8,
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
bitAtShift: 7 - (cidr % 8),
|
||||||
}
|
}
|
||||||
parent.maskSelf()
|
parent.maskSelf()
|
||||||
|
|
||||||
@ -188,12 +190,12 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
|||||||
|
|
||||||
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
||||||
var found *Peer
|
var found *Peer
|
||||||
size := uint(len(ip))
|
size := uint8(len(ip))
|
||||||
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
||||||
if node.peer != nil {
|
if node.peer != nil {
|
||||||
found = node.peer
|
found = node.peer
|
||||||
}
|
}
|
||||||
if node.bit_at_byte == size {
|
if node.bitAtByte == size {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
bit := node.choose(ip)
|
bit := node.choose(ip)
|
||||||
@ -208,7 +210,7 @@ type AllowedIPs struct {
|
|||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) {
|
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
|
||||||
table.mutex.RLock()
|
table.mutex.RLock()
|
||||||
defer table.mutex.RUnlock()
|
defer table.mutex.RUnlock()
|
||||||
|
|
||||||
@ -228,7 +230,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
|||||||
table.IPv6 = table.IPv6.removeByPeer(peer)
|
table.IPv6 = table.IPv6.removeByPeer(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
|
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
defer table.mutex.Unlock()
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ const (
|
|||||||
|
|
||||||
type SlowNode struct {
|
type SlowNode struct {
|
||||||
peer *Peer
|
peer *Peer
|
||||||
cidr uint
|
cidr uint8
|
||||||
bits []byte
|
bits []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ func (r SlowRouter) Swap(i, j int) {
|
|||||||
r[i], r[j] = r[j], r[i]
|
r[i], r[j] = r[j], r[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
|
func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
|
||||||
for _, t := range r {
|
for _, t := range r {
|
||||||
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
|
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
|
||||||
t.peer = peer
|
t.peer = peer
|
||||||
@ -80,7 +80,7 @@ func TestTrieRandomIPv4(t *testing.T) {
|
|||||||
for n := 0; n < NumberOfAddresses; n++ {
|
for n := 0; n < NumberOfAddresses; n++ {
|
||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % NumberOfPeers
|
index := rand.Int() % NumberOfPeers
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
trie = trie.insert(addr[:], cidr, peers[index])
|
||||||
slow = slow.Insert(addr[:], cidr, peers[index])
|
slow = slow.Insert(addr[:], cidr, peers[index])
|
||||||
@ -113,7 +113,7 @@ func TestTrieRandomIPv6(t *testing.T) {
|
|||||||
for n := 0; n < NumberOfAddresses; n++ {
|
for n := 0; n < NumberOfAddresses; n++ {
|
||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % NumberOfPeers
|
index := rand.Int() % NumberOfPeers
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
trie = trie.insert(addr[:], cidr, peers[index])
|
||||||
slow = slow.Insert(addr[:], cidr, peers[index])
|
slow = slow.Insert(addr[:], cidr, peers[index])
|
||||||
|
@ -11,13 +11,10 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Todo: More comprehensive
|
|
||||||
*/
|
|
||||||
|
|
||||||
type testPairCommonBits struct {
|
type testPairCommonBits struct {
|
||||||
s1 []byte
|
s1 []byte
|
||||||
s2 []byte
|
s2 []byte
|
||||||
match uint
|
match uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCommonBits(t *testing.T) {
|
func TestCommonBits(t *testing.T) {
|
||||||
@ -57,7 +54,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
|
|||||||
for n := 0; n < addressNumber; n++ {
|
for n := 0; n < addressNumber; n++ {
|
||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % peerNumber
|
index := rand.Int() % peerNumber
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
trie = trie.insert(addr[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
@ -99,7 +96,7 @@ func TestTrieIPv4(t *testing.T) {
|
|||||||
|
|
||||||
var trie *trieEntry
|
var trie *trieEntry
|
||||||
|
|
||||||
insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
|
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
||||||
trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
|
trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,7 +192,7 @@ func TestTrieIPv6(t *testing.T) {
|
|||||||
return out[:]
|
return out[:]
|
||||||
}
|
}
|
||||||
|
|
||||||
insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
|
insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
|
||||||
var addr []byte
|
var addr []byte
|
||||||
addr = append(addr, expand(a)...)
|
addr = append(addr, expand(a)...)
|
||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
|
@ -39,10 +39,3 @@ func (a *AtomicBool) Set(val bool) {
|
|||||||
}
|
}
|
||||||
atomic.StoreInt32(&a.int32, flag)
|
atomic.StoreInt32(&a.int32, flag)
|
||||||
}
|
}
|
||||||
|
|
||||||
func min(a, b uint) uint {
|
|
||||||
if a > b {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
@ -121,7 +121,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||||||
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
|
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
|
||||||
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
|
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
|
||||||
|
|
||||||
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool {
|
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
|
||||||
sendf("allowed_ip=%s/%d", ip.String(), cidr)
|
sendf("allowed_ip=%s/%d", ip.String(), cidr)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
@ -379,7 +379,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
ones, _ := network.Mask.Size()
|
ones, _ := network.Mask.Size()
|
||||||
device.allowedips.Insert(network.IP, uint(ones), peer.Peer)
|
device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
|
||||||
|
|
||||||
case "protocol_version":
|
case "protocol_version":
|
||||||
if value != "1" {
|
if value != "1" {
|
||||||
|
Loading…
Reference in New Issue
Block a user