1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2025-09-18 20:57:50 +02:00
wireguard-go/device/allowedips.go
Jason A. Donenfeld f3aff443a6 device: make allowedips generic
The implementation of commonBits uses a horrific unsafe.Slice trick.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-16 19:45:10 -06:00

321 lines
7.3 KiB
Go

/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"container/list"
"encoding/binary"
"errors"
"math/bits"
"net"
"net/netip"
"sync"
"unsafe"
)
type ipArray interface {
[4]byte | [16]byte
}
type parentIndirection[B ipArray] struct {
parentBit **trieEntry[B]
parentBitType uint8
}
type trieEntry[B ipArray] struct {
peer *Peer
child [2]*trieEntry[B]
parent parentIndirection[B]
cidr uint8
bitAtByte uint8
bitAtShift uint8
bits B
perPeerElem *list.Element
}
func commonBits4(ip1, ip2 [4]byte) uint8 {
a := binary.BigEndian.Uint32(ip1[:])
b := binary.BigEndian.Uint32(ip2[:])
x := a ^ b
return uint8(bits.LeadingZeros32(x))
}
func commonBits16(ip1, ip2 [16]byte) uint8 {
a := binary.BigEndian.Uint64(ip1[:8])
b := binary.BigEndian.Uint64(ip2[:8])
x := a ^ b
if x != 0 {
return uint8(bits.LeadingZeros64(x))
}
a = binary.BigEndian.Uint64(ip1[8:])
b = binary.BigEndian.Uint64(ip2[8:])
x = a ^ b
return 64 + uint8(bits.LeadingZeros64(x))
}
func giveMeA4[B ipArray](b B) [4]byte {
return *(*[4]byte)(unsafe.Slice(&b[0], 4))
}
func giveMeA16[B ipArray](b B) [16]byte {
return *(*[16]byte)(unsafe.Slice(&b[0], 16))
}
func commonBits[B ipArray](ip1, ip2 B) uint8 {
if len(ip1) == 4 {
return commonBits4(giveMeA4(ip1), giveMeA4(ip2))
} else if len(ip1) == 16 {
return commonBits16(giveMeA16(ip1), giveMeA16(ip2))
}
panic("Wrong size bit string")
}
func (node *trieEntry[B]) addToPeerEntries() {
node.perPeerElem = node.peer.trieEntries.PushBack(node)
}
func (node *trieEntry[B]) removeFromPeerEntries() {
if node.perPeerElem != nil {
node.peer.trieEntries.Remove(node.perPeerElem)
node.perPeerElem = nil
}
}
func (node *trieEntry[B]) choose(ip B) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
func (node *trieEntry[B]) maskSelf() {
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
for i := 0; i < len(mask); i++ {
node.bits[i] &= mask[i]
}
}
func (node *trieEntry[B]) zeroizePointers() {
// Make the garbage collector's life slightly easier
node.peer = nil
node.child[0] = nil
node.child[1] = nil
node.parent.parentBit = nil
}
func (node *trieEntry[B]) nodePlacement(ip B, cidr uint8) (parent *trieEntry[B], exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
exact = true
return
}
bit := node.choose(ip)
node = node.child[bit]
}
return
}
func (trie parentIndirection[B]) insert(ip B, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry[B]{
peer: peer,
parent: trie,
bits: ip,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
node.maskSelf()
node.addToPeerEntries()
*trie.parentBit = node
return
}
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
if exact {
node.removeFromPeerEntries()
node.peer = peer
node.addToPeerEntries()
return
}
newNode := &trieEntry[B]{
peer: peer,
bits: ip,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
newNode.maskSelf()
newNode.addToPeerEntries()
var down *trieEntry[B]
if node == nil {
down = *trie.parentBit
} else {
bit := node.choose(ip)
down = node.child[bit]
if down == nil {
newNode.parent = parentIndirection[B]{&node.child[bit], bit}
node.child[bit] = newNode
return
}
}
common := commonBits(down.bits, ip)
if common < cidr {
cidr = common
}
parent := node
if newNode.cidr == cidr {
bit := newNode.choose(down.bits)
down.parent = parentIndirection[B]{&newNode.child[bit], bit}
newNode.child[bit] = down
if parent == nil {
newNode.parent = trie
*trie.parentBit = newNode
} else {
bit := parent.choose(newNode.bits)
newNode.parent = parentIndirection[B]{&parent.child[bit], bit}
parent.child[bit] = newNode
}
return
}
node = &trieEntry[B]{
bits: newNode.bits,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
node.maskSelf()
bit := node.choose(down.bits)
down.parent = parentIndirection[B]{&node.child[bit], bit}
node.child[bit] = down
bit = node.choose(newNode.bits)
newNode.parent = parentIndirection[B]{&node.child[bit], bit}
node.child[bit] = newNode
if parent == nil {
node.parent = trie
*trie.parentBit = node
} else {
bit := parent.choose(node.bits)
node.parent = parentIndirection[B]{&parent.child[bit], bit}
parent.child[bit] = node
}
}
func (node *trieEntry[B]) lookup(ip B) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil {
found = node.peer
}
if node.bitAtByte == size {
break
}
bit := node.choose(ip)
node = node.child[bit]
}
return found
}
type AllowedIPs struct {
IPv4 *trieEntry[[4]byte]
IPv6 *trieEntry[[16]byte]
mutex sync.RWMutex
}
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
if node, ok := elem.Value.(*trieEntry[[4]byte]); ok {
if !cb(netip.PrefixFrom(netip.AddrFrom4(node.bits), int(node.cidr))) {
return
}
} else if node, ok := elem.Value.(*trieEntry[[16]byte]); ok {
if !cb(netip.PrefixFrom(netip.AddrFrom16(node.bits), int(node.cidr))) {
return
}
}
}
}
func (node *trieEntry[B]) remove() {
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
return
}
bit := 0
if node.child[0] == nil {
bit = 1
}
child := node.child[bit]
if child != nil {
child.parent = node.parent
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers()
return
}
parent := (*trieEntry[B])(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
node.zeroizePointers()
return
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
child.parent = parent.parent
}
*parent.parent.parentBit = child
node.zeroizePointers()
parent.zeroizePointers()
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
if node, ok := elem.Value.(*trieEntry[[4]byte]); ok {
node.remove()
} else if node, ok := elem.Value.(*trieEntry[[16]byte]); ok {
node.remove()
}
}
}
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
if prefix.Addr().Is6() {
parentIndirection[[16]byte]{&table.IPv6, 2}.insert(prefix.Addr().As16(), uint8(prefix.Bits()), peer)
} else if prefix.Addr().Is4() {
parentIndirection[[4]byte]{&table.IPv4, 2}.insert(prefix.Addr().As4(), uint8(prefix.Bits()), peer)
} else {
panic(errors.New("inserting unknown address type"))
}
}
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
switch len(ip) {
case net.IPv6len:
return table.IPv6.lookup(*(*[16]byte)(ip))
case net.IPv4len:
return table.IPv4.lookup(*(*[4]byte)(ip))
default:
panic(errors.New("looking up unknown address type"))
}
}