1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 09:15:14 +01:00

device: use linked list for per-peer allowed-ip traversal

This makes the IpcGet method much faster.

We also refactor the traversal API to use a callback so that we don't
need to allocate at all. Avoiding allocations we do self-masking on
insertion, which in turn means that split intermediate nodes require a
copy of the bits.

benchmark               old ns/op     new ns/op     delta
BenchmarkUAPIGet-16     3243          2659          -18.01%

benchmark               old allocs     new allocs     delta
BenchmarkUAPIGet-16     35             30             -14.29%

benchmark               old bytes     new bytes     delta
BenchmarkUAPIGet-16     1218          737           -39.49%

This benchmark is good, though it's only for a pair of peers, each with
only one allowedips. As this grows, the delta expands considerably.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-01-26 23:44:37 +01:00
parent d669c78c43
commit 8cc99631d0
4 changed files with 62 additions and 44 deletions

View File

@ -14,15 +14,14 @@ import (
) )
type trieEntry struct { type trieEntry struct {
cidr uint
child [2]*trieEntry child [2]*trieEntry
bits net.IP
peer *Peer peer *Peer
bits net.IP
// index of "branching" bit cidr uint
bit_at_byte uint bit_at_byte uint
bit_at_shift uint bit_at_shift uint
nextEntryForPeer *trieEntry
pprevEntryForPeer **trieEntry
} }
func isLittleEndian() bool { func isLittleEndian() bool {
@ -69,6 +68,31 @@ func commonBits(ip1 net.IP, ip2 net.IP) uint {
} }
} }
func (node *trieEntry) addToPeerEntries() {
p := node.peer
first := p.firstTrieEntry
node.nextEntryForPeer = first
if first != nil {
first.pprevEntryForPeer = &node.nextEntryForPeer
}
p.firstTrieEntry = node
node.pprevEntryForPeer = &p.firstTrieEntry
}
func (node *trieEntry) removeFromPeerEntries() {
if node.pprevEntryForPeer == nil {
return
}
next := node.nextEntryForPeer
pprev := node.pprevEntryForPeer
*pprev = next
if next != nil {
next.pprevEntryForPeer = pprev
}
node.nextEntryForPeer = nil
node.pprevEntryForPeer = nil
}
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
if node == nil { if node == nil {
return node return node
@ -85,6 +109,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
// remove peer & merge // remove peer & merge
node.removeFromPeerEntries()
node.peer = nil node.peer = nil
if node.child[0] == nil { if node.child[0] == nil {
return node.child[1] return node.child[1]
@ -96,18 +121,28 @@ func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
} }
func (node *trieEntry) maskSelf() {
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
for i := 0; i < len(mask); i++ {
node.bits[i] &= mask[i]
}
}
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// at leaf // at leaf
if node == nil { if node == nil {
return &trieEntry{ node := &trieEntry{
bits: ip, bits: ip,
peer: peer, peer: peer,
cidr: cidr, cidr: cidr,
bit_at_byte: cidr / 8, bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8), bit_at_shift: 7 - (cidr % 8),
} }
node.maskSelf()
node.addToPeerEntries()
return node
} }
// traverse deeper // traverse deeper
@ -115,7 +150,9 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
common := commonBits(node.bits, ip) common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr { if node.cidr <= cidr && common >= node.cidr {
if node.cidr == cidr { if node.cidr == cidr {
node.removeFromPeerEntries()
node.peer = peer node.peer = peer
node.addToPeerEntries()
return node return node
} }
bit := node.choose(ip) bit := node.choose(ip)
@ -132,6 +169,8 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
bit_at_byte: cidr / 8, bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8), bit_at_shift: 7 - (cidr % 8),
} }
newNode.maskSelf()
newNode.addToPeerEntries()
cidr = min(cidr, common) cidr = min(cidr, common)
@ -146,12 +185,13 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// create new parent for node & newNode // create new parent for node & newNode
parent := &trieEntry{ parent := &trieEntry{
bits: ip, bits: append([]byte{}, ip...),
peer: nil, peer: nil,
cidr: cidr, cidr: cidr,
bit_at_byte: cidr / 8, bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8), bit_at_shift: 7 - (cidr % 8),
} }
parent.maskSelf()
bit := parent.choose(ip) bit := parent.choose(ip)
parent.child[bit] = newNode parent.child[bit] = newNode
@ -176,44 +216,21 @@ func (node *trieEntry) lookup(ip net.IP) *Peer {
return found return found
} }
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
if node == nil {
return results
}
if node.peer == p {
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
results = append(results, net.IPNet{
Mask: mask,
IP: node.bits.Mask(mask),
})
}
results = node.child[0].entriesForPeer(p, results)
results = node.child[1].entriesForPeer(p, results)
return results
}
type AllowedIPs struct { type AllowedIPs struct {
IPv4 *trieEntry IPv4 *trieEntry
IPv6 *trieEntry IPv6 *trieEntry
mutex sync.RWMutex mutex sync.RWMutex
} }
func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
allowed := make([]net.IPNet, 0, 10) for node := peer.firstTrieEntry; node != nil; node = node.nextEntryForPeer {
allowed = table.IPv4.entriesForPeer(peer, allowed) if !cb(node.bits, node.cidr) {
allowed = table.IPv6.entriesForPeer(peer, allowed) return
return allowed }
} }
func (table *AllowedIPs) Reset() {
table.mutex.Lock()
defer table.mutex.Unlock()
table.IPv4 = nil
table.IPv6 = nil
} }
func (table *AllowedIPs) RemoveByPeer(peer *Peer) { func (table *AllowedIPs) RemoveByPeer(peer *Peer) {

View File

@ -314,7 +314,6 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device.rate.underLoadUntil.Store(time.Time{}) device.rate.underLoadUntil.Store(time.Time{})
device.indexTable.Init() device.indexTable.Init()
device.allowedips.Reset()
device.PopulatePools() device.PopulatePools()

View File

@ -28,6 +28,7 @@ type Peer struct {
device *Device device *Device
endpoint conn.Endpoint endpoint conn.Endpoint
persistentKeepaliveInterval uint32 // accessed atomically persistentKeepaliveInterval uint32 // accessed atomically
firstTrieEntry *trieEntry
// These fields are accessed with atomic operations, which must be // These fields are accessed with atomic operations, which must be
// 64-bit aligned even on 32-bit platforms. Go guarantees that an // 64-bit aligned even on 32-bit platforms. Go guarantees that an

View File

@ -108,9 +108,10 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
for _, ip := range device.allowedips.EntriesForPeer(peer) { device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool {
sendf("allowed_ip=%s", ip.String()) sendf("allowed_ip=%s/%d", ip.String(), cidr)
} return true
})
} }
}() }()