1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00
wireguard-go/device/allowedips_rand_test.go

151 lines
3.1 KiB
Go
Raw Permalink Normal View History

2019-01-02 01:55:51 +01:00
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
2019-03-03 04:04:41 +01:00
package device
2017-06-02 17:58:46 +02:00
import (
"math/rand"
"net"
"net/netip"
2017-06-02 17:58:46 +02:00
"sort"
"testing"
)
const (
NumberOfPeers = 100
NumberOfPeerRemovals = 4
NumberOfAddresses = 250
NumberOfTests = 10000
2017-06-02 17:58:46 +02:00
)
type SlowNode struct {
peer *Peer
cidr uint8
bits []byte
}
type SlowRouter []*SlowNode
2017-06-02 17:58:46 +02:00
func (r SlowRouter) Len() int {
return len(r)
}
func (r SlowRouter) Less(i, j int) bool {
return r[i].cidr > r[j].cidr
}
func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i]
}
func commonBitsSlice(addr1, addr2 []byte) uint8 {
if len(addr1) == 4 {
return commonBits4(*(*[4]byte)(addr1), *(*[4]byte)(addr2))
} else if len(addr1) == 16 {
return commonBits16(*(*[16]byte)(addr1), *(*[16]byte)(addr2))
}
return 0
}
func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
2017-06-02 17:58:46 +02:00
for _, t := range r {
if t.cidr == cidr && commonBitsSlice(t.bits, addr) >= cidr {
2017-06-02 17:58:46 +02:00
t.peer = peer
t.bits = addr
return r
}
}
r = append(r, &SlowNode{
2017-06-02 17:58:46 +02:00
cidr: cidr,
bits: addr,
peer: peer,
})
sort.Sort(r)
return r
}
func (r SlowRouter) Lookup(addr []byte) *Peer {
for _, t := range r {
common := commonBitsSlice(t.bits, addr)
2017-06-02 17:58:46 +02:00
if common >= t.cidr {
return t.peer
}
}
return nil
}
func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
n := 0
for _, x := range r {
if x.peer != peer {
r[n] = x
n++
2017-06-02 17:58:46 +02:00
}
}
return r[:n]
2017-06-02 17:58:46 +02:00
}
func TestTrieRandom(t *testing.T) {
var slow4, slow6 SlowRouter
2017-06-02 17:58:46 +02:00
var peers []*Peer
var allowedIPs AllowedIPs
2017-06-02 17:58:46 +02:00
rand.Seed(1)
for n := 0; n < NumberOfPeers; n++ {
2017-06-02 17:58:46 +02:00
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n++ {
var addr4 [4]byte
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
2017-06-02 17:58:46 +02:00
}
var p int
for p = 0; ; p++ {
for n := 0; n < NumberOfTests; n++ {
var addr4 [4]byte
rand.Read(addr4[:])
peer1 := slow4.Lookup(addr4[:])
peer2 := allowedIPs.Lookup(addr4[:])
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
}
var addr6 [16]byte
rand.Read(addr6[:])
peer1 = slow6.Lookup(addr6[:])
peer2 = allowedIPs.Lookup(addr6[:])
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
}
}
if p >= len(peers) || p >= NumberOfPeerRemovals {
break
2017-06-02 17:58:46 +02:00
}
allowedIPs.RemoveByPeer(peers[p])
slow4 = slow4.RemoveByPeer(peers[p])
slow6 = slow6.RemoveByPeer(peers[p])
}
for ; p < len(peers); p++ {
allowedIPs.RemoveByPeer(peers[p])
}
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Failed to remove all nodes from trie by peer")
2017-06-02 17:58:46 +02:00
}
}