mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
Inital implementation of trie
This commit is contained in:
parent
8ce921987f
commit
ec3d656beb
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* todo : use real error code
|
/* todo : use real error code
|
||||||
@ -18,6 +19,7 @@ const (
|
|||||||
ipcErrorInvalidPrivateKey = 3
|
ipcErrorInvalidPrivateKey = 3
|
||||||
ipcErrorInvalidPublicKey = 4
|
ipcErrorInvalidPublicKey = 4
|
||||||
ipcErrorInvalidPort = 5
|
ipcErrorInvalidPort = 5
|
||||||
|
ipcErrorInvalidIPAddress = 6
|
||||||
)
|
)
|
||||||
|
|
||||||
type IPCError struct {
|
type IPCError struct {
|
||||||
@ -104,6 +106,10 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case "replace_peers":
|
case "replace_peers":
|
||||||
|
if key == "true" {
|
||||||
|
dev.RemoveAllPeers()
|
||||||
|
}
|
||||||
|
// todo: else fail
|
||||||
|
|
||||||
default:
|
default:
|
||||||
/* Peer configuration */
|
/* Peer configuration */
|
||||||
@ -116,20 +122,27 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
|
|
||||||
case "remove":
|
case "remove":
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
|
dev.RemovePeer(peer.publicKey)
|
||||||
peer = nil
|
peer = nil
|
||||||
|
|
||||||
case "preshared_key":
|
case "preshared_key":
|
||||||
func() {
|
err := func() error {
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
defer peer.mutex.Unlock()
|
defer peer.mutex.Unlock()
|
||||||
|
return peer.presharedKey.FromHex(value)
|
||||||
}()
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return &IPCError{Code: ipcErrorInvalidPublicKey}
|
||||||
|
}
|
||||||
|
|
||||||
case "endpoint":
|
case "endpoint":
|
||||||
func() {
|
ip := net.ParseIP(value)
|
||||||
peer.mutex.Lock()
|
if ip == nil {
|
||||||
defer peer.mutex.Unlock()
|
return &IPCError{Code: ipcErrorInvalidIPAddress}
|
||||||
}()
|
}
|
||||||
|
peer.mutex.Lock()
|
||||||
|
peer.endpoint = ip
|
||||||
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
case "persistent_keepalive_interval":
|
case "persistent_keepalive_interval":
|
||||||
func() {
|
func() {
|
||||||
|
@ -5,10 +5,39 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
peers map[NoisePublicKey]*Peer
|
peers map[NoisePublicKey]*Peer
|
||||||
privateKey NoisePrivateKey
|
privateKey NoisePrivateKey
|
||||||
publicKey NoisePublicKey
|
publicKey NoisePublicKey
|
||||||
fwMark uint32
|
fwMark uint32
|
||||||
listenPort uint16
|
listenPort uint16
|
||||||
|
routingTable RoutingTable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dev *Device) RemovePeer(key NoisePublicKey) {
|
||||||
|
dev.mutex.Lock()
|
||||||
|
defer dev.mutex.Unlock()
|
||||||
|
peer, ok := dev.peers[key]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
peer.mutex.Lock()
|
||||||
|
dev.routingTable.RemovePeer(peer)
|
||||||
|
delete(dev.peers, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dev *Device) RemoveAllAllowedIps(peer *Peer) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dev *Device) RemoveAllPeers() {
|
||||||
|
dev.mutex.Lock()
|
||||||
|
defer dev.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, peer := range dev.peers {
|
||||||
|
peer.mutex.Lock()
|
||||||
|
dev.routingTable.RemovePeer(peer)
|
||||||
|
delete(dev.peers, key)
|
||||||
|
peer.mutex.Unlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
38
src/noise.go
38
src/noise.go
@ -18,34 +18,38 @@ type (
|
|||||||
NoiseNonce uint64 // padded to 12-bytes
|
NoiseNonce uint64 // padded to 12-bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
func (key *NoisePrivateKey) FromHex(s string) error {
|
func loadExactHex(dst []byte, src string) error {
|
||||||
slice, err := hex.DecodeString(s)
|
slice, err := hex.DecodeString(src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(slice) != NoisePrivateKeySize {
|
if len(slice) != len(dst) {
|
||||||
return errors.New("Invalid length of hex string for curve25519 point")
|
return errors.New("Hex string does not fit the slice")
|
||||||
}
|
}
|
||||||
copy(key[:], slice)
|
copy(dst, slice)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *NoisePrivateKey) ToHex() string {
|
func (key *NoisePrivateKey) FromHex(src string) error {
|
||||||
|
return loadExactHex(key[:], src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key NoisePrivateKey) ToHex() string {
|
||||||
return hex.EncodeToString(key[:])
|
return hex.EncodeToString(key[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *NoisePublicKey) FromHex(s string) error {
|
func (key *NoisePublicKey) FromHex(src string) error {
|
||||||
slice, err := hex.DecodeString(s)
|
return loadExactHex(key[:], src)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if len(slice) != NoisePublicKeySize {
|
|
||||||
return errors.New("Invalid length of hex string for curve25519 scalar")
|
|
||||||
}
|
|
||||||
copy(key[:], slice)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *NoisePublicKey) ToHex() string {
|
func (key NoisePublicKey) ToHex() string {
|
||||||
|
return hex.EncodeToString(key[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *NoiseSymmetricKey) FromHex(src string) error {
|
||||||
|
return loadExactHex(key[:], src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key NoiseSymmetricKey) ToHex() string {
|
||||||
return hex.EncodeToString(key[:])
|
return hex.EncodeToString(key[:])
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,4 +16,5 @@ type Peer struct {
|
|||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
publicKey NoisePublicKey
|
publicKey NoisePublicKey
|
||||||
presharedKey NoiseSymmetricKey
|
presharedKey NoiseSymmetricKey
|
||||||
|
endpoint net.IP
|
||||||
}
|
}
|
||||||
|
175
src/ping-test.go
175
src/ping-test.go
@ -1,175 +0,0 @@
|
|||||||
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/binary"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/dchest/blake2s"
|
|
||||||
"github.com/titanous/noise"
|
|
||||||
"golang.org/x/net/icmp"
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ipChecksum(buf []byte) uint16 {
|
|
||||||
sum := uint32(0)
|
|
||||||
for ; len(buf) >= 2; buf = buf[2:] {
|
|
||||||
sum += uint32(buf[0])<<8 | uint32(buf[1])
|
|
||||||
}
|
|
||||||
if len(buf) > 0 {
|
|
||||||
sum += uint32(buf[0]) << 8
|
|
||||||
}
|
|
||||||
for sum > 0xffff {
|
|
||||||
sum = (sum >> 16) + (sum & 0xffff)
|
|
||||||
}
|
|
||||||
csum := ^uint16(sum)
|
|
||||||
if csum == 0 {
|
|
||||||
csum = 0xffff
|
|
||||||
}
|
|
||||||
return csum
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
ourPrivate, _ := base64.StdEncoding.DecodeString("WAmgVYXkbT2bCtdcDwolI88/iVi/aV3/PHcUBTQSYmo=")
|
|
||||||
ourPublic, _ := base64.StdEncoding.DecodeString("K5sF9yESrSBsOXPd6TcpKNgqoy1Ik3ZFKl4FolzrRyI=")
|
|
||||||
theirPublic, _ := base64.StdEncoding.DecodeString("qRCwZSKInrMAq5sepfCdaCsRJaoLe5jhtzfiw7CjbwM=")
|
|
||||||
preshared, _ := base64.StdEncoding.DecodeString("FpCyhws9cxwWoV4xELtfJvjJN+zQVRPISllRWgeopVE=")
|
|
||||||
cs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s)
|
|
||||||
hs := noise.NewHandshakeState(noise.Config{
|
|
||||||
CipherSuite: cs,
|
|
||||||
Random: rand.Reader,
|
|
||||||
Pattern: noise.HandshakeIK,
|
|
||||||
Initiator: true,
|
|
||||||
Prologue: []byte("WireGuard v1 zx2c4 Jason@zx2c4.com"),
|
|
||||||
PresharedKey: preshared,
|
|
||||||
PresharedKeyPlacement: 2,
|
|
||||||
StaticKeypair: noise.DHKey{Private: ourPrivate, Public: ourPublic},
|
|
||||||
PeerStatic: theirPublic,
|
|
||||||
})
|
|
||||||
conn, err := net.Dial("udp", "demo.wireguard.io:12913")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("error dialing udp socket: %s", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
// write handshake initiation packet
|
|
||||||
now := time.Now()
|
|
||||||
tai64n := make([]byte, 12)
|
|
||||||
binary.BigEndian.PutUint64(tai64n[:], 4611686018427387914+uint64(now.Unix()))
|
|
||||||
binary.BigEndian.PutUint32(tai64n[8:], uint32(now.UnixNano()))
|
|
||||||
initiationPacket := make([]byte, 8)
|
|
||||||
initiationPacket[0] = 1 // Type: Initiation
|
|
||||||
initiationPacket[1] = 0 // Reserved
|
|
||||||
initiationPacket[2] = 0 // Reserved
|
|
||||||
initiationPacket[3] = 0 // Reserved
|
|
||||||
binary.LittleEndian.PutUint32(initiationPacket[4:], 28) // Sender index: 28 (arbitrary)
|
|
||||||
initiationPacket, _, _ = hs.WriteMessage(initiationPacket, tai64n)
|
|
||||||
hasher, _ := blake2s.New(&blake2s.Config{Size: 32})
|
|
||||||
hasher.Write([]byte("mac1----"))
|
|
||||||
hasher.Write(theirPublic)
|
|
||||||
hasher, _ = blake2s.New(&blake2s.Config{Size: 16, Key: hasher.Sum(nil)})
|
|
||||||
hasher.Write(initiationPacket)
|
|
||||||
initiationPacket = append(initiationPacket, hasher.Sum(nil)[:16]...)
|
|
||||||
initiationPacket = append(initiationPacket, make([]byte, 16)...)
|
|
||||||
if _, err := conn.Write(initiationPacket); err != nil {
|
|
||||||
log.Fatalf("error writing initiation packet: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// read handshake response packet
|
|
||||||
responsePacket := make([]byte, 92)
|
|
||||||
n, err := conn.Read(responsePacket)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("error reading response packet: %s", err)
|
|
||||||
}
|
|
||||||
if n != len(responsePacket) {
|
|
||||||
log.Fatalf("response packet too short: want %d, got %d", len(responsePacket), n)
|
|
||||||
}
|
|
||||||
if responsePacket[0] != 2 { // Type: Response
|
|
||||||
log.Fatalf("response packet type wrong: want %d, got %d", 2, responsePacket[0])
|
|
||||||
}
|
|
||||||
if responsePacket[1] != 0 || responsePacket[2] != 0 || responsePacket[3] != 0 {
|
|
||||||
log.Fatalf("response packet has non-zero reserved fields")
|
|
||||||
}
|
|
||||||
theirIndex := binary.LittleEndian.Uint32(responsePacket[4:])
|
|
||||||
ourIndex := binary.LittleEndian.Uint32(responsePacket[8:])
|
|
||||||
if ourIndex != 28 {
|
|
||||||
log.Fatalf("response packet index wrong: want %d, got %d", 28, ourIndex)
|
|
||||||
}
|
|
||||||
payload, sendCipher, receiveCipher, err := hs.ReadMessage(nil, responsePacket[12:60])
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("error reading handshake message: %s", err)
|
|
||||||
}
|
|
||||||
if len(payload) > 0 {
|
|
||||||
log.Fatalf("unexpected payload: %x", payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
// write ICMP Echo packet
|
|
||||||
pingMessage, _ := (&icmp.Message{
|
|
||||||
Type: ipv4.ICMPTypeEcho,
|
|
||||||
Body: &icmp.Echo{
|
|
||||||
ID: 921,
|
|
||||||
Seq: 438,
|
|
||||||
Data: []byte("WireGuard"),
|
|
||||||
},
|
|
||||||
}).Marshal(nil)
|
|
||||||
pingHeader, err := (&ipv4.Header{
|
|
||||||
Version: ipv4.Version,
|
|
||||||
Len: ipv4.HeaderLen,
|
|
||||||
TotalLen: ipv4.HeaderLen + len(pingMessage),
|
|
||||||
Protocol: 1, // ICMP
|
|
||||||
TTL: 20,
|
|
||||||
Src: net.IPv4(10, 189, 129, 2),
|
|
||||||
Dst: net.IPv4(10, 189, 129, 1),
|
|
||||||
}).Marshal()
|
|
||||||
binary.BigEndian.PutUint16(pingHeader[2:], uint16(ipv4.HeaderLen+len(pingMessage))) // fix the length endianness on BSDs
|
|
||||||
pingData := append(pingHeader, pingMessage...)
|
|
||||||
binary.BigEndian.PutUint16(pingData[10:], ipChecksum(pingData))
|
|
||||||
pingPacket := make([]byte, 16)
|
|
||||||
pingPacket[0] = 4 // Type: Data
|
|
||||||
pingPacket[1] = 0 // Reserved
|
|
||||||
pingPacket[2] = 0 // Reserved
|
|
||||||
pingPacket[3] = 0 // Reserved
|
|
||||||
binary.LittleEndian.PutUint32(pingPacket[4:], theirIndex)
|
|
||||||
binary.LittleEndian.PutUint64(pingPacket[8:], 0) // Nonce
|
|
||||||
pingPacket = sendCipher.Encrypt(pingPacket, nil, pingData)
|
|
||||||
if _, err := conn.Write(pingPacket); err != nil {
|
|
||||||
log.Fatalf("error writing ping message: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// read ICMP Echo Reply packet
|
|
||||||
replyPacket := make([]byte, 128)
|
|
||||||
n, err = conn.Read(replyPacket)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("error reading ping reply message: %s", err)
|
|
||||||
}
|
|
||||||
replyPacket = replyPacket[:n]
|
|
||||||
if replyPacket[0] != 4 { // Type: Data
|
|
||||||
log.Fatalf("unexpected reply packet type: %d", replyPacket[0])
|
|
||||||
}
|
|
||||||
if replyPacket[1] != 0 || replyPacket[2] != 0 || replyPacket[3] != 0 {
|
|
||||||
log.Fatalf("reply packet has non-zero reserved fields")
|
|
||||||
}
|
|
||||||
replyPacket, err = receiveCipher.Decrypt(nil, nil, replyPacket[16:])
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("error decrypting reply packet: %s", err)
|
|
||||||
}
|
|
||||||
replyHeaderLen := int(replyPacket[0]&0x0f) << 2
|
|
||||||
replyLen := binary.BigEndian.Uint16(replyPacket[2:])
|
|
||||||
replyMessage, err := icmp.ParseMessage(1, replyPacket[replyHeaderLen:replyLen])
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("error parsing echo: %s", err)
|
|
||||||
}
|
|
||||||
echo, ok := replyMessage.Body.(*icmp.Echo)
|
|
||||||
if !ok {
|
|
||||||
log.Fatalf("unexpected reply body type %T", replyMessage.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
if echo.ID != 921 || echo.Seq != 438 || string(echo.Data) != "WireGuard" {
|
|
||||||
log.Fatalf("incorrect echo response: %#v", echo)
|
|
||||||
}
|
|
||||||
}
|
|
22
src/routing.go
Normal file
22
src/routing.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Thread-safe high level functions for cryptkey routing.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
type RoutingTable struct {
|
||||||
|
IPv4 *Trie
|
||||||
|
IPv6 *Trie
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *RoutingTable) RemovePeer(peer *Peer) {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
table.IPv4 = table.IPv4.RemovePeer(peer)
|
||||||
|
table.IPv6 = table.IPv6.RemovePeer(peer)
|
||||||
|
}
|
81
src/trie.go
81
src/trie.go
@ -1,9 +1,11 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import "fmt"
|
/* Binary trie
|
||||||
|
|
||||||
/* Syncronization must be done seperatly
|
|
||||||
*
|
*
|
||||||
|
* Syncronization done seperatly
|
||||||
|
* See: routing.go
|
||||||
|
*
|
||||||
|
* Todo: Better commenting
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type Trie struct {
|
type Trie struct {
|
||||||
@ -13,7 +15,6 @@ type Trie struct {
|
|||||||
peer *Peer
|
peer *Peer
|
||||||
|
|
||||||
// Index of "branching" bit
|
// Index of "branching" bit
|
||||||
// bit_at_shift
|
|
||||||
bit_at_byte uint
|
bit_at_byte uint
|
||||||
bit_at_shift uint
|
bit_at_shift uint
|
||||||
}
|
}
|
||||||
@ -92,7 +93,14 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
|
|||||||
return node.child[0]
|
return node.child[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (node *Trie) choose(key []byte) byte {
|
||||||
|
return (key[node.bit_at_byte] >> node.bit_at_shift) & 1
|
||||||
|
}
|
||||||
|
|
||||||
func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
|
func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
|
||||||
|
|
||||||
|
// At leaf
|
||||||
|
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return &Trie{
|
return &Trie{
|
||||||
bits: key,
|
bits: key,
|
||||||
@ -107,22 +115,17 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
|
|||||||
|
|
||||||
common := commonBits(node.bits, key)
|
common := commonBits(node.bits, key)
|
||||||
if node.cidr <= cidr && common >= node.cidr {
|
if node.cidr <= cidr && common >= node.cidr {
|
||||||
// Check if match the t.bits[:t.cidr] exactly
|
|
||||||
if node.cidr == cidr {
|
if node.cidr == cidr {
|
||||||
node.peer = peer
|
node.peer = peer
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
bit := node.choose(key)
|
||||||
// Go to child
|
|
||||||
bit := (key[node.bit_at_byte] >> node.bit_at_shift) & 1
|
|
||||||
node.child[bit] = node.child[bit].Insert(key, cidr, peer)
|
node.child[bit] = node.child[bit].Insert(key, cidr, peer)
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split node
|
// Split node
|
||||||
|
|
||||||
fmt.Println("new", common)
|
|
||||||
|
|
||||||
newNode := &Trie{
|
newNode := &Trie{
|
||||||
bits: key,
|
bits: key,
|
||||||
peer: peer,
|
peer: peer,
|
||||||
@ -132,23 +135,53 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cidr = min(cidr, common)
|
cidr = min(cidr, common)
|
||||||
node.cidr = cidr
|
|
||||||
node.bit_at_byte = cidr / 8
|
|
||||||
node.bit_at_shift = 7 - (cidr % 8)
|
|
||||||
|
|
||||||
// bval := node.bits[node.bit_at_byte] >> node.bit_at_shift // todo : remember index
|
// Check for shorter prefix
|
||||||
// Work in progress
|
|
||||||
node.child[0] = newNode
|
|
||||||
node.child[1] = newNode
|
|
||||||
|
|
||||||
return node
|
if newNode.cidr == cidr {
|
||||||
}
|
bit := newNode.choose(node.bits)
|
||||||
|
newNode.child[bit] = node
|
||||||
func (t *Trie) Lookup(key []byte) *Peer {
|
return newNode
|
||||||
if t == nil {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Create new parent for node & newNode
|
||||||
|
|
||||||
|
parent := &Trie{
|
||||||
|
bits: key,
|
||||||
|
peer: nil,
|
||||||
|
cidr: cidr,
|
||||||
|
bit_at_byte: cidr / 8,
|
||||||
|
bit_at_shift: 7 - (cidr % 8),
|
||||||
|
}
|
||||||
|
|
||||||
|
bit := parent.choose(key)
|
||||||
|
parent.child[bit] = newNode
|
||||||
|
parent.child[bit^1] = node
|
||||||
|
|
||||||
|
return parent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *Trie) Lookup(key []byte) *Peer {
|
||||||
|
var found *Peer
|
||||||
|
size := uint(len(key))
|
||||||
|
for node != nil && commonBits(node.bits, key) >= node.cidr {
|
||||||
|
if node.peer != nil {
|
||||||
|
found = node.peer
|
||||||
|
}
|
||||||
|
if node.bit_at_byte == size {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
bit := node.choose(key)
|
||||||
|
node = node.child[bit]
|
||||||
|
}
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *Trie) Count() uint {
|
||||||
|
if node == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
l := node.child[0].Count()
|
||||||
|
r := node.child[1].Count()
|
||||||
|
return l + r
|
||||||
}
|
}
|
||||||
|
180
src/trie_test.go
180
src/trie_test.go
@ -4,6 +4,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/* Todo: More comprehensive
|
||||||
|
*/
|
||||||
|
|
||||||
type testPairCommonBits struct {
|
type testPairCommonBits struct {
|
||||||
s1 []byte
|
s1 []byte
|
||||||
s2 []byte
|
s2 []byte
|
||||||
@ -16,6 +19,11 @@ type testPairTrieInsert struct {
|
|||||||
peer *Peer
|
peer *Peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type testPairTrieLookup struct {
|
||||||
|
key []byte
|
||||||
|
peer *Peer
|
||||||
|
}
|
||||||
|
|
||||||
func printTrie(t *testing.T, p *Trie) {
|
func printTrie(t *testing.T, p *Trie) {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return
|
return
|
||||||
@ -41,26 +49,176 @@ func TestCommonBits(t *testing.T) {
|
|||||||
t.Error(
|
t.Error(
|
||||||
"For slice", p.s1, p.s2,
|
"For slice", p.s1, p.s2,
|
||||||
"expected match", p.match,
|
"expected match", p.match,
|
||||||
"got", v,
|
",but got", v,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrieInsertV4(t *testing.T) {
|
/* Test ported from kernel implementation:
|
||||||
|
* selftest/routingtable.h
|
||||||
|
*/
|
||||||
|
func TestTrieIPv4(t *testing.T) {
|
||||||
|
a := &Peer{}
|
||||||
|
b := &Peer{}
|
||||||
|
c := &Peer{}
|
||||||
|
d := &Peer{}
|
||||||
|
e := &Peer{}
|
||||||
|
g := &Peer{}
|
||||||
|
h := &Peer{}
|
||||||
|
|
||||||
var trie *Trie
|
var trie *Trie
|
||||||
|
|
||||||
peer1 := Peer{}
|
insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
|
||||||
peer2 := Peer{}
|
trie = trie.Insert([]byte{a, b, c, d}, cidr, peer)
|
||||||
|
|
||||||
tests := []testPairTrieInsert{
|
|
||||||
{key: []byte{192, 168, 1, 1}, cidr: 24, peer: &peer1},
|
|
||||||
{key: []byte{192, 169, 1, 1}, cidr: 24, peer: &peer2},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range tests {
|
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
trie = trie.Insert(p.key, p.cidr, p.peer)
|
p := trie.Lookup([]byte{a, b, c, d})
|
||||||
printTrie(t, trie)
|
if p != peer {
|
||||||
|
t.Error("Assert EQ failed")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assertNEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
|
p := trie.Lookup([]byte{a, b, c, d})
|
||||||
|
if p == peer {
|
||||||
|
t.Error("Assert NEQ failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
insert(a, 192, 168, 4, 0, 24)
|
||||||
|
insert(b, 192, 168, 4, 4, 32)
|
||||||
|
insert(c, 192, 168, 0, 0, 16)
|
||||||
|
insert(d, 192, 95, 5, 64, 27)
|
||||||
|
insert(c, 192, 95, 5, 65, 27) /* replaces previous entry, and maskself is required */
|
||||||
|
insert(e, 0, 0, 0, 0, 0)
|
||||||
|
insert(g, 64, 15, 112, 0, 20)
|
||||||
|
insert(h, 64, 15, 123, 211, 25) /* maskself is required */
|
||||||
|
insert(a, 10, 0, 0, 0, 25)
|
||||||
|
insert(b, 10, 0, 0, 128, 25)
|
||||||
|
insert(a, 10, 1, 0, 0, 30)
|
||||||
|
insert(b, 10, 1, 0, 4, 30)
|
||||||
|
insert(c, 10, 1, 0, 8, 29)
|
||||||
|
insert(d, 10, 1, 0, 16, 29)
|
||||||
|
|
||||||
|
assertEQ(a, 192, 168, 4, 20)
|
||||||
|
assertEQ(a, 192, 168, 4, 0)
|
||||||
|
assertEQ(b, 192, 168, 4, 4)
|
||||||
|
assertEQ(c, 192, 168, 200, 182)
|
||||||
|
assertEQ(c, 192, 95, 5, 68)
|
||||||
|
assertEQ(e, 192, 95, 5, 96)
|
||||||
|
assertEQ(g, 64, 15, 116, 26)
|
||||||
|
assertEQ(g, 64, 15, 127, 3)
|
||||||
|
|
||||||
|
insert(a, 1, 0, 0, 0, 32)
|
||||||
|
insert(a, 64, 0, 0, 0, 32)
|
||||||
|
insert(a, 128, 0, 0, 0, 32)
|
||||||
|
insert(a, 192, 0, 0, 0, 32)
|
||||||
|
insert(a, 255, 0, 0, 0, 32)
|
||||||
|
|
||||||
|
assertEQ(a, 1, 0, 0, 0)
|
||||||
|
assertEQ(a, 64, 0, 0, 0)
|
||||||
|
assertEQ(a, 128, 0, 0, 0)
|
||||||
|
assertEQ(a, 192, 0, 0, 0)
|
||||||
|
assertEQ(a, 255, 0, 0, 0)
|
||||||
|
|
||||||
|
trie = trie.RemovePeer(a)
|
||||||
|
|
||||||
|
assertNEQ(a, 1, 0, 0, 0)
|
||||||
|
assertNEQ(a, 64, 0, 0, 0)
|
||||||
|
assertNEQ(a, 128, 0, 0, 0)
|
||||||
|
assertNEQ(a, 192, 0, 0, 0)
|
||||||
|
assertNEQ(a, 255, 0, 0, 0)
|
||||||
|
|
||||||
|
trie = nil
|
||||||
|
|
||||||
|
insert(a, 192, 168, 0, 0, 16)
|
||||||
|
insert(a, 192, 168, 0, 0, 24)
|
||||||
|
|
||||||
|
trie = trie.RemovePeer(a)
|
||||||
|
|
||||||
|
assertNEQ(a, 192, 168, 0, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Test ported from kernel implementation:
|
||||||
|
* selftest/routingtable.h
|
||||||
|
*/
|
||||||
|
func TestTrieIPv6(t *testing.T) {
|
||||||
|
a := &Peer{}
|
||||||
|
b := &Peer{}
|
||||||
|
c := &Peer{}
|
||||||
|
d := &Peer{}
|
||||||
|
e := &Peer{}
|
||||||
|
f := &Peer{}
|
||||||
|
g := &Peer{}
|
||||||
|
h := &Peer{}
|
||||||
|
|
||||||
|
var trie *Trie
|
||||||
|
|
||||||
|
expand := func(a uint32) []byte {
|
||||||
|
var out [4]byte
|
||||||
|
out[0] = byte(a >> 24 & 0xff)
|
||||||
|
out[1] = byte(a >> 16 & 0xff)
|
||||||
|
out[2] = byte(a >> 8 & 0xff)
|
||||||
|
out[3] = byte(a & 0xff)
|
||||||
|
return out[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
|
||||||
|
var addr []byte
|
||||||
|
addr = append(addr, expand(a)...)
|
||||||
|
addr = append(addr, expand(b)...)
|
||||||
|
addr = append(addr, expand(c)...)
|
||||||
|
addr = append(addr, expand(d)...)
|
||||||
|
trie = trie.Insert(addr, cidr, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||||
|
var addr []byte
|
||||||
|
addr = append(addr, expand(a)...)
|
||||||
|
addr = append(addr, expand(b)...)
|
||||||
|
addr = append(addr, expand(c)...)
|
||||||
|
addr = append(addr, expand(d)...)
|
||||||
|
p := trie.Lookup(addr)
|
||||||
|
if p != peer {
|
||||||
|
t.Error("Assert EQ failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
assertNEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||||
|
var addr []byte
|
||||||
|
addr = append(addr, expand(a)...)
|
||||||
|
addr = append(addr, expand(b)...)
|
||||||
|
addr = append(addr, expand(c)...)
|
||||||
|
addr = append(addr, expand(d)...)
|
||||||
|
p := trie.Lookup(addr)
|
||||||
|
if p == peer {
|
||||||
|
t.Error("Assert NEQ failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
|
||||||
|
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
|
||||||
|
insert(e, 0, 0, 0, 0, 0)
|
||||||
|
insert(f, 0, 0, 0, 0, 0)
|
||||||
|
insert(g, 0x24046800, 0, 0, 0, 32)
|
||||||
|
insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64)
|
||||||
|
insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128)
|
||||||
|
insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||||
|
insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||||
|
|
||||||
|
assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543)
|
||||||
|
assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee)
|
||||||
|
assertEQ(f, 0x26075300, 0x60006b01, 0, 0)
|
||||||
|
assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006)
|
||||||
|
assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678)
|
||||||
|
assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678)
|
||||||
|
assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678)
|
||||||
|
assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678)
|
||||||
|
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
|
||||||
|
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
|
||||||
|
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user