mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 09:15:14 +01:00
device: use linked list for per-peer allowed-ip traversal
This makes the IpcGet method much faster. We also refactor the traversal API to use a callback so that we don't need to allocate at all. Avoiding allocations we do self-masking on insertion, which in turn means that split intermediate nodes require a copy of the bits. benchmark old ns/op new ns/op delta BenchmarkUAPIGet-16 3243 2659 -18.01% benchmark old allocs new allocs delta BenchmarkUAPIGet-16 35 30 -14.29% benchmark old bytes new bytes delta BenchmarkUAPIGet-16 1218 737 -39.49% This benchmark is good, though it's only for a pair of peers, each with only one allowedips. As this grows, the delta expands considerably. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
d669c78c43
commit
8cc99631d0
@ -14,15 +14,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type trieEntry struct {
|
type trieEntry struct {
|
||||||
cidr uint
|
|
||||||
child [2]*trieEntry
|
child [2]*trieEntry
|
||||||
bits net.IP
|
|
||||||
peer *Peer
|
peer *Peer
|
||||||
|
bits net.IP
|
||||||
// index of "branching" bit
|
cidr uint
|
||||||
|
|
||||||
bit_at_byte uint
|
bit_at_byte uint
|
||||||
bit_at_shift uint
|
bit_at_shift uint
|
||||||
|
nextEntryForPeer *trieEntry
|
||||||
|
pprevEntryForPeer **trieEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
func isLittleEndian() bool {
|
func isLittleEndian() bool {
|
||||||
@ -69,6 +68,31 @@ func commonBits(ip1 net.IP, ip2 net.IP) uint {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (node *trieEntry) addToPeerEntries() {
|
||||||
|
p := node.peer
|
||||||
|
first := p.firstTrieEntry
|
||||||
|
node.nextEntryForPeer = first
|
||||||
|
if first != nil {
|
||||||
|
first.pprevEntryForPeer = &node.nextEntryForPeer
|
||||||
|
}
|
||||||
|
p.firstTrieEntry = node
|
||||||
|
node.pprevEntryForPeer = &p.firstTrieEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *trieEntry) removeFromPeerEntries() {
|
||||||
|
if node.pprevEntryForPeer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next := node.nextEntryForPeer
|
||||||
|
pprev := node.pprevEntryForPeer
|
||||||
|
*pprev = next
|
||||||
|
if next != nil {
|
||||||
|
next.pprevEntryForPeer = pprev
|
||||||
|
}
|
||||||
|
node.nextEntryForPeer = nil
|
||||||
|
node.pprevEntryForPeer = nil
|
||||||
|
}
|
||||||
|
|
||||||
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return node
|
return node
|
||||||
@ -85,6 +109,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
|||||||
|
|
||||||
// remove peer & merge
|
// remove peer & merge
|
||||||
|
|
||||||
|
node.removeFromPeerEntries()
|
||||||
node.peer = nil
|
node.peer = nil
|
||||||
if node.child[0] == nil {
|
if node.child[0] == nil {
|
||||||
return node.child[1]
|
return node.child[1]
|
||||||
@ -96,18 +121,28 @@ func (node *trieEntry) choose(ip net.IP) byte {
|
|||||||
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
|
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (node *trieEntry) maskSelf() {
|
||||||
|
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
||||||
|
for i := 0; i < len(mask); i++ {
|
||||||
|
node.bits[i] &= mask[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
||||||
|
|
||||||
// at leaf
|
// at leaf
|
||||||
|
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return &trieEntry{
|
node := &trieEntry{
|
||||||
bits: ip,
|
bits: ip,
|
||||||
peer: peer,
|
peer: peer,
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
bit_at_byte: cidr / 8,
|
bit_at_byte: cidr / 8,
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
bit_at_shift: 7 - (cidr % 8),
|
||||||
}
|
}
|
||||||
|
node.maskSelf()
|
||||||
|
node.addToPeerEntries()
|
||||||
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
// traverse deeper
|
// traverse deeper
|
||||||
@ -115,7 +150,9 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
|||||||
common := commonBits(node.bits, ip)
|
common := commonBits(node.bits, ip)
|
||||||
if node.cidr <= cidr && common >= node.cidr {
|
if node.cidr <= cidr && common >= node.cidr {
|
||||||
if node.cidr == cidr {
|
if node.cidr == cidr {
|
||||||
|
node.removeFromPeerEntries()
|
||||||
node.peer = peer
|
node.peer = peer
|
||||||
|
node.addToPeerEntries()
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
bit := node.choose(ip)
|
bit := node.choose(ip)
|
||||||
@ -132,6 +169,8 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
|||||||
bit_at_byte: cidr / 8,
|
bit_at_byte: cidr / 8,
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
bit_at_shift: 7 - (cidr % 8),
|
||||||
}
|
}
|
||||||
|
newNode.maskSelf()
|
||||||
|
newNode.addToPeerEntries()
|
||||||
|
|
||||||
cidr = min(cidr, common)
|
cidr = min(cidr, common)
|
||||||
|
|
||||||
@ -146,12 +185,13 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
|||||||
// create new parent for node & newNode
|
// create new parent for node & newNode
|
||||||
|
|
||||||
parent := &trieEntry{
|
parent := &trieEntry{
|
||||||
bits: ip,
|
bits: append([]byte{}, ip...),
|
||||||
peer: nil,
|
peer: nil,
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
bit_at_byte: cidr / 8,
|
bit_at_byte: cidr / 8,
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
bit_at_shift: 7 - (cidr % 8),
|
||||||
}
|
}
|
||||||
|
parent.maskSelf()
|
||||||
|
|
||||||
bit := parent.choose(ip)
|
bit := parent.choose(ip)
|
||||||
parent.child[bit] = newNode
|
parent.child[bit] = newNode
|
||||||
@ -176,44 +216,21 @@ func (node *trieEntry) lookup(ip net.IP) *Peer {
|
|||||||
return found
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
|
|
||||||
if node == nil {
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
if node.peer == p {
|
|
||||||
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
|
||||||
results = append(results, net.IPNet{
|
|
||||||
Mask: mask,
|
|
||||||
IP: node.bits.Mask(mask),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
results = node.child[0].entriesForPeer(p, results)
|
|
||||||
results = node.child[1].entriesForPeer(p, results)
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
type AllowedIPs struct {
|
type AllowedIPs struct {
|
||||||
IPv4 *trieEntry
|
IPv4 *trieEntry
|
||||||
IPv6 *trieEntry
|
IPv6 *trieEntry
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
|
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) {
|
||||||
table.mutex.RLock()
|
table.mutex.RLock()
|
||||||
defer table.mutex.RUnlock()
|
defer table.mutex.RUnlock()
|
||||||
|
|
||||||
allowed := make([]net.IPNet, 0, 10)
|
for node := peer.firstTrieEntry; node != nil; node = node.nextEntryForPeer {
|
||||||
allowed = table.IPv4.entriesForPeer(peer, allowed)
|
if !cb(node.bits, node.cidr) {
|
||||||
allowed = table.IPv6.entriesForPeer(peer, allowed)
|
return
|
||||||
return allowed
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) Reset() {
|
|
||||||
table.mutex.Lock()
|
|
||||||
defer table.mutex.Unlock()
|
|
||||||
|
|
||||||
table.IPv4 = nil
|
|
||||||
table.IPv6 = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||||
|
@ -314,7 +314,6 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
|||||||
device.rate.underLoadUntil.Store(time.Time{})
|
device.rate.underLoadUntil.Store(time.Time{})
|
||||||
|
|
||||||
device.indexTable.Init()
|
device.indexTable.Init()
|
||||||
device.allowedips.Reset()
|
|
||||||
|
|
||||||
device.PopulatePools()
|
device.PopulatePools()
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ type Peer struct {
|
|||||||
device *Device
|
device *Device
|
||||||
endpoint conn.Endpoint
|
endpoint conn.Endpoint
|
||||||
persistentKeepaliveInterval uint32 // accessed atomically
|
persistentKeepaliveInterval uint32 // accessed atomically
|
||||||
|
firstTrieEntry *trieEntry
|
||||||
|
|
||||||
// These fields are accessed with atomic operations, which must be
|
// These fields are accessed with atomic operations, which must be
|
||||||
// 64-bit aligned even on 32-bit platforms. Go guarantees that an
|
// 64-bit aligned even on 32-bit platforms. Go guarantees that an
|
||||||
|
@ -108,9 +108,10 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||||||
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
|
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
|
||||||
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
|
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
|
||||||
|
|
||||||
for _, ip := range device.allowedips.EntriesForPeer(peer) {
|
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool {
|
||||||
sendf("allowed_ip=%s", ip.String())
|
sendf("allowed_ip=%s/%d", ip.String(), cidr)
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user