diff --git a/device/allowedips.go b/device/allowedips.go index 143bda3..b5e40e9 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -14,15 +14,14 @@ import ( ) type trieEntry struct { - cidr uint - child [2]*trieEntry - bits net.IP - peer *Peer - - // index of "branching" bit - - bit_at_byte uint - bit_at_shift uint + child [2]*trieEntry + peer *Peer + bits net.IP + cidr uint + bit_at_byte uint + bit_at_shift uint + nextEntryForPeer *trieEntry + pprevEntryForPeer **trieEntry } func isLittleEndian() bool { @@ -69,6 +68,31 @@ func commonBits(ip1 net.IP, ip2 net.IP) uint { } } +func (node *trieEntry) addToPeerEntries() { + p := node.peer + first := p.firstTrieEntry + node.nextEntryForPeer = first + if first != nil { + first.pprevEntryForPeer = &node.nextEntryForPeer + } + p.firstTrieEntry = node + node.pprevEntryForPeer = &p.firstTrieEntry +} + +func (node *trieEntry) removeFromPeerEntries() { + if node.pprevEntryForPeer == nil { + return + } + next := node.nextEntryForPeer + pprev := node.pprevEntryForPeer + *pprev = next + if next != nil { + next.pprevEntryForPeer = pprev + } + node.nextEntryForPeer = nil + node.pprevEntryForPeer = nil +} + func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { if node == nil { return node @@ -85,6 +109,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { // remove peer & merge + node.removeFromPeerEntries() node.peer = nil if node.child[0] == nil { return node.child[1] @@ -96,18 +121,28 @@ func (node *trieEntry) choose(ip net.IP) byte { return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 } +func (node *trieEntry) maskSelf() { + mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) + for i := 0; i < len(mask); i++ { + node.bits[i] &= mask[i] + } +} + func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { // at leaf if node == nil { - return &trieEntry{ + node := &trieEntry{ bits: ip, peer: peer, cidr: cidr, bit_at_byte: cidr / 8, bit_at_shift: 7 - (cidr % 8), } + node.maskSelf() + node.addToPeerEntries() + return node } // traverse deeper @@ -115,7 +150,9 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { common := commonBits(node.bits, ip) if node.cidr <= cidr && common >= node.cidr { if node.cidr == cidr { + node.removeFromPeerEntries() node.peer = peer + node.addToPeerEntries() return node } bit := node.choose(ip) @@ -132,6 +169,8 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { bit_at_byte: cidr / 8, bit_at_shift: 7 - (cidr % 8), } + newNode.maskSelf() + newNode.addToPeerEntries() cidr = min(cidr, common) @@ -146,12 +185,13 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { // create new parent for node & newNode parent := &trieEntry{ - bits: ip, + bits: append([]byte{}, ip...), peer: nil, cidr: cidr, bit_at_byte: cidr / 8, bit_at_shift: 7 - (cidr % 8), } + parent.maskSelf() bit := parent.choose(ip) parent.child[bit] = newNode @@ -176,44 +216,21 @@ func (node *trieEntry) lookup(ip net.IP) *Peer { return found } -func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { - if node == nil { - return results - } - if node.peer == p { - mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) - results = append(results, net.IPNet{ - Mask: mask, - IP: node.bits.Mask(mask), - }) - } - results = node.child[0].entriesForPeer(p, results) - results = node.child[1].entriesForPeer(p, results) - return results -} - type AllowedIPs struct { IPv4 *trieEntry IPv6 *trieEntry mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) { table.mutex.RLock() defer table.mutex.RUnlock() - allowed := make([]net.IPNet, 0, 10) - allowed = table.IPv4.entriesForPeer(peer, allowed) - allowed = table.IPv6.entriesForPeer(peer, allowed) - return allowed -} - -func (table *AllowedIPs) Reset() { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = nil - table.IPv6 = nil + for node := peer.firstTrieEntry; node != nil; node = node.nextEntryForPeer { + if !cb(node.bits, node.cidr) { + return + } + } } func (table *AllowedIPs) RemoveByPeer(peer *Peer) { diff --git a/device/device.go b/device/device.go index ebcbd9e..47c4944 100644 --- a/device/device.go +++ b/device/device.go @@ -314,7 +314,6 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { device.rate.underLoadUntil.Store(time.Time{}) device.indexTable.Init() - device.allowedips.Reset() device.PopulatePools() diff --git a/device/peer.go b/device/peer.go index 5324ae4..a103b5d 100644 --- a/device/peer.go +++ b/device/peer.go @@ -28,6 +28,7 @@ type Peer struct { device *Device endpoint conn.Endpoint persistentKeepaliveInterval uint32 // accessed atomically + firstTrieEntry *trieEntry // These fields are accessed with atomic operations, which must be // 64-bit aligned even on 32-bit platforms. Go guarantees that an diff --git a/device/uapi.go b/device/uapi.go index 148a7a2..cbfe25e 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -108,9 +108,10 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) - for _, ip := range device.allowedips.EntriesForPeer(peer) { - sendf("allowed_ip=%s", ip.String()) - } + device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool { + sendf("allowed_ip=%s/%d", ip.String(), cidr) + return true + }) } }()