mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
device: use linked list for per-peer allowed-ip traversal
This makes the IpcGet method much faster. We also refactor the traversal API to use a callback so that we don't need to allocate at all. Avoiding allocations we do self-masking on insertion, which in turn means that split intermediate nodes require a copy of the bits. benchmark old ns/op new ns/op delta BenchmarkUAPIGet-16 3243 2659 -18.01% benchmark old allocs new allocs delta BenchmarkUAPIGet-16 35 30 -14.29% benchmark old bytes new bytes delta BenchmarkUAPIGet-16 1218 737 -39.49% This benchmark is good, though it's only for a pair of peers, each with only one allowedips. As this grows, the delta expands considerably. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
d669c78c43
commit
8cc99631d0
@ -14,15 +14,14 @@ import (
|
||||
)
|
||||
|
||||
type trieEntry struct {
|
||||
cidr uint
|
||||
child [2]*trieEntry
|
||||
bits net.IP
|
||||
peer *Peer
|
||||
|
||||
// index of "branching" bit
|
||||
|
||||
bit_at_byte uint
|
||||
bit_at_shift uint
|
||||
child [2]*trieEntry
|
||||
peer *Peer
|
||||
bits net.IP
|
||||
cidr uint
|
||||
bit_at_byte uint
|
||||
bit_at_shift uint
|
||||
nextEntryForPeer *trieEntry
|
||||
pprevEntryForPeer **trieEntry
|
||||
}
|
||||
|
||||
func isLittleEndian() bool {
|
||||
@ -69,6 +68,31 @@ func commonBits(ip1 net.IP, ip2 net.IP) uint {
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) addToPeerEntries() {
|
||||
p := node.peer
|
||||
first := p.firstTrieEntry
|
||||
node.nextEntryForPeer = first
|
||||
if first != nil {
|
||||
first.pprevEntryForPeer = &node.nextEntryForPeer
|
||||
}
|
||||
p.firstTrieEntry = node
|
||||
node.pprevEntryForPeer = &p.firstTrieEntry
|
||||
}
|
||||
|
||||
func (node *trieEntry) removeFromPeerEntries() {
|
||||
if node.pprevEntryForPeer == nil {
|
||||
return
|
||||
}
|
||||
next := node.nextEntryForPeer
|
||||
pprev := node.pprevEntryForPeer
|
||||
*pprev = next
|
||||
if next != nil {
|
||||
next.pprevEntryForPeer = pprev
|
||||
}
|
||||
node.nextEntryForPeer = nil
|
||||
node.pprevEntryForPeer = nil
|
||||
}
|
||||
|
||||
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
||||
if node == nil {
|
||||
return node
|
||||
@ -85,6 +109,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
||||
|
||||
// remove peer & merge
|
||||
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = nil
|
||||
if node.child[0] == nil {
|
||||
return node.child[1]
|
||||
@ -96,18 +121,28 @@ func (node *trieEntry) choose(ip net.IP) byte {
|
||||
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
|
||||
}
|
||||
|
||||
func (node *trieEntry) maskSelf() {
|
||||
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
||||
for i := 0; i < len(mask); i++ {
|
||||
node.bits[i] &= mask[i]
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
||||
|
||||
// at leaf
|
||||
|
||||
if node == nil {
|
||||
return &trieEntry{
|
||||
node := &trieEntry{
|
||||
bits: ip,
|
||||
peer: peer,
|
||||
cidr: cidr,
|
||||
bit_at_byte: cidr / 8,
|
||||
bit_at_shift: 7 - (cidr % 8),
|
||||
}
|
||||
node.maskSelf()
|
||||
node.addToPeerEntries()
|
||||
return node
|
||||
}
|
||||
|
||||
// traverse deeper
|
||||
@ -115,7 +150,9 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
||||
common := commonBits(node.bits, ip)
|
||||
if node.cidr <= cidr && common >= node.cidr {
|
||||
if node.cidr == cidr {
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = peer
|
||||
node.addToPeerEntries()
|
||||
return node
|
||||
}
|
||||
bit := node.choose(ip)
|
||||
@ -132,6 +169,8 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
||||
bit_at_byte: cidr / 8,
|
||||
bit_at_shift: 7 - (cidr % 8),
|
||||
}
|
||||
newNode.maskSelf()
|
||||
newNode.addToPeerEntries()
|
||||
|
||||
cidr = min(cidr, common)
|
||||
|
||||
@ -146,12 +185,13 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
||||
// create new parent for node & newNode
|
||||
|
||||
parent := &trieEntry{
|
||||
bits: ip,
|
||||
bits: append([]byte{}, ip...),
|
||||
peer: nil,
|
||||
cidr: cidr,
|
||||
bit_at_byte: cidr / 8,
|
||||
bit_at_shift: 7 - (cidr % 8),
|
||||
}
|
||||
parent.maskSelf()
|
||||
|
||||
bit := parent.choose(ip)
|
||||
parent.child[bit] = newNode
|
||||
@ -176,44 +216,21 @@ func (node *trieEntry) lookup(ip net.IP) *Peer {
|
||||
return found
|
||||
}
|
||||
|
||||
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
|
||||
if node == nil {
|
||||
return results
|
||||
}
|
||||
if node.peer == p {
|
||||
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
||||
results = append(results, net.IPNet{
|
||||
Mask: mask,
|
||||
IP: node.bits.Mask(mask),
|
||||
})
|
||||
}
|
||||
results = node.child[0].entriesForPeer(p, results)
|
||||
results = node.child[1].entriesForPeer(p, results)
|
||||
return results
|
||||
}
|
||||
|
||||
type AllowedIPs struct {
|
||||
IPv4 *trieEntry
|
||||
IPv6 *trieEntry
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
|
||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) {
|
||||
table.mutex.RLock()
|
||||
defer table.mutex.RUnlock()
|
||||
|
||||
allowed := make([]net.IPNet, 0, 10)
|
||||
allowed = table.IPv4.entriesForPeer(peer, allowed)
|
||||
allowed = table.IPv6.entriesForPeer(peer, allowed)
|
||||
return allowed
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Reset() {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
table.IPv4 = nil
|
||||
table.IPv6 = nil
|
||||
for node := peer.firstTrieEntry; node != nil; node = node.nextEntryForPeer {
|
||||
if !cb(node.bits, node.cidr) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
|
@ -314,7 +314,6 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
||||
device.rate.underLoadUntil.Store(time.Time{})
|
||||
|
||||
device.indexTable.Init()
|
||||
device.allowedips.Reset()
|
||||
|
||||
device.PopulatePools()
|
||||
|
||||
|
@ -28,6 +28,7 @@ type Peer struct {
|
||||
device *Device
|
||||
endpoint conn.Endpoint
|
||||
persistentKeepaliveInterval uint32 // accessed atomically
|
||||
firstTrieEntry *trieEntry
|
||||
|
||||
// These fields are accessed with atomic operations, which must be
|
||||
// 64-bit aligned even on 32-bit platforms. Go guarantees that an
|
||||
|
@ -108,9 +108,10 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
|
||||
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
|
||||
|
||||
for _, ip := range device.allowedips.EntriesForPeer(peer) {
|
||||
sendf("allowed_ip=%s", ip.String())
|
||||
}
|
||||
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool {
|
||||
sendf("allowed_ip=%s/%d", ip.String(), cidr)
|
||||
return true
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user