1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00
wireguard-go/allowedips.go

262 lines
5.1 KiB
Go
Raw Normal View History

/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
package main
import (
"errors"
2018-05-14 15:49:20 +02:00
"math/bits"
"net"
2018-05-13 19:33:41 +02:00
"sync"
2018-05-14 15:49:20 +02:00
"unsafe"
)
2018-05-13 19:33:41 +02:00
type trieEntry struct {
cidr uint
2018-05-13 19:33:41 +02:00
child [2]*trieEntry
bits []byte
peer *Peer
// index of "branching" bit
bit_at_byte uint
bit_at_shift uint
}
2018-05-14 15:49:20 +02:00
func isLittleEndian() bool {
one := uint32(1)
return *(*byte)(unsafe.Pointer(&one)) != 0
}
func swapU32(i uint32) uint32 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes32(i)
}
func swapU64(i uint64) uint64 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes64(i)
}
func commonBits(ip1 net.IP, ip2 net.IP) uint {
size := len(ip1)
if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0]))
b := (*uint32)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b
return uint(bits.LeadingZeros32(swapU32(x)))
} else if size == net.IPv6len {
a := (*uint64)(unsafe.Pointer(&ip1[0]))
b := (*uint64)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b
if x != 0 {
return uint(bits.LeadingZeros64(swapU64(x)))
}
2018-05-14 15:49:20 +02:00
a = (*uint64)(unsafe.Pointer(&ip1[8]))
b = (*uint64)(unsafe.Pointer(&ip2[8]))
x = *a ^ *b
return 64 + uint(bits.LeadingZeros64(swapU64(x)))
} else {
panic("Wrong size bit string")
}
}
2018-05-13 19:33:41 +02:00
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
if node == nil {
return node
}
2017-12-01 23:37:26 +01:00
// walk recursively
2018-05-13 19:33:41 +02:00
node.child[0] = node.child[0].removeByPeer(p)
node.child[1] = node.child[1].removeByPeer(p)
if node.peer != p {
return node
}
// remove peer & merge
node.peer = nil
if node.child[0] == nil {
return node.child[1]
}
return node.child[0]
}
2018-05-13 19:33:41 +02:00
func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
2017-06-01 21:31:30 +02:00
}
2018-05-13 19:33:41 +02:00
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
2017-06-01 21:31:30 +02:00
// at leaf
2017-06-01 21:31:30 +02:00
if node == nil {
2018-05-13 19:33:41 +02:00
return &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
}
// traverse deeper
common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr {
if node.cidr == cidr {
node.peer = peer
return node
}
bit := node.choose(ip)
2018-05-13 19:33:41 +02:00
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
}
// split node
2018-05-13 19:33:41 +02:00
newNode := &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
cidr = min(cidr, common)
// check for shorter prefix
2017-06-01 21:31:30 +02:00
if newNode.cidr == cidr {
bit := newNode.choose(node.bits)
newNode.child[bit] = node
return newNode
}
// create new parent for node & newNode
2018-05-13 19:33:41 +02:00
parent := &trieEntry{
bits: ip,
2017-06-01 21:31:30 +02:00
peer: nil,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
bit := parent.choose(ip)
2017-06-01 21:31:30 +02:00
parent.child[bit] = newNode
parent.child[bit^1] = node
return parent
}
2018-05-13 19:33:41 +02:00
func (node *trieEntry) lookup(ip net.IP) *Peer {
2017-06-01 21:31:30 +02:00
var found *Peer
size := uint(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
2017-06-01 21:31:30 +02:00
if node.peer != nil {
found = node.peer
}
if node.bit_at_byte == size {
break
}
bit := node.choose(ip)
2017-06-01 21:31:30 +02:00
node = node.child[bit]
}
return found
}
2018-05-13 19:33:41 +02:00
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
if node == nil {
return results
}
if node.peer == p {
var mask net.IPNet
mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
if len(node.bits) == net.IPv4len {
mask.IP = net.IPv4(
node.bits[0],
node.bits[1],
node.bits[2],
node.bits[3],
)
} else if len(node.bits) == net.IPv6len {
mask.IP = node.bits
} else {
2018-05-13 19:33:41 +02:00
panic(errors.New("unexpected address length"))
}
results = append(results, mask)
}
2018-05-13 19:33:41 +02:00
results = node.child[0].entriesForPeer(p, results)
results = node.child[1].entriesForPeer(p, results)
return results
}
2018-05-13 19:33:41 +02:00
type AllowedIPs struct {
IPv4 *trieEntry
IPv6 *trieEntry
mutex sync.RWMutex
}
func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
table.mutex.RLock()
defer table.mutex.RUnlock()
allowed := make([]net.IPNet, 0, 10)
allowed = table.IPv4.entriesForPeer(peer, allowed)
allowed = table.IPv6.entriesForPeer(peer, allowed)
return allowed
}
func (table *AllowedIPs) Reset() {
table.mutex.Lock()
defer table.mutex.Unlock()
table.IPv4 = nil
table.IPv6 = nil
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
table.IPv4 = table.IPv4.removeByPeer(peer)
table.IPv6 = table.IPv6.removeByPeer(peer)
}
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
switch len(ip) {
case net.IPv6len:
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
case net.IPv4len:
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
default:
panic(errors.New("inserting unknown address type"))
}
}
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.IPv4.lookup(address)
}
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.IPv6.lookup(address)
}