1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Terminate on interface deletion

Program now terminates when the interface is removed
Increases the number of os threads (relevant for Go <1.5, not tested)
More consistent commenting
Improved logging (additional peer information)
This commit is contained in:
Mathias Hall-Andersen 2017-07-13 14:32:40 +02:00
parent 8393cbff52
commit 93e3848ea7
9 changed files with 132 additions and 97 deletions

View File

@ -29,6 +29,6 @@ const (
QueueInboundSize = 1024 QueueInboundSize = 1024
QueueHandshakeSize = 1024 QueueHandshakeSize = 1024
QueueHandshakeBusySize = QueueHandshakeSize / 8 QueueHandshakeBusySize = QueueHandshakeSize / 8
MinMessageSize = MessageTransportSize // keep-alive MinMessageSize = MessageTransportSize // size of keep-alive
MaxMessageSize = 4096 // TODO: make depend on the MTU? MaxMessageSize = (1 << 16) - 1
) )

View File

@ -98,9 +98,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
} }
go device.RoutineBusyMonitor() go device.RoutineBusyMonitor()
go device.RoutineWriteToTUN(tun)
go device.RoutineReadFromTUN(tun) go device.RoutineReadFromTUN(tun)
go device.RoutineReceiveIncomming() go device.RoutineReceiveIncomming()
go device.RoutineWriteToTUN(tun)
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
return device return device
@ -141,5 +141,8 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() { func (device *Device) Close() {
device.RemoveAllPeers() device.RemoveAllPeers()
close(device.signal.stop) close(device.signal.stop)
close(device.queue.encryption) }
func (device *Device) Wait() {
<-device.signal.stop
} }

View File

@ -5,17 +5,13 @@ import (
) )
const ( const (
IPv4version = 4
IPv4offsetTotalLength = 2 IPv4offsetTotalLength = 2
IPv4offsetSrc = 12 IPv4offsetSrc = 12
IPv4offsetDst = IPv4offsetSrc + net.IPv4len IPv4offsetDst = IPv4offsetSrc + net.IPv4len
IPv4headerSize = 20
) )
const ( const (
IPv6version = 6
IPv6offsetPayloadLength = 4 IPv6offsetPayloadLength = 4
IPv6offsetSrc = 8 IPv6offsetSrc = 8
IPv6offsetDst = IPv6offsetSrc + net.IPv6len IPv6offsetDst = IPv6offsetSrc + net.IPv6len
IPv6headerSize = 40
) )

View File

@ -5,6 +5,7 @@ import (
"log" "log"
"net" "net"
"os" "os"
"runtime"
) )
/* TODO: Fix logging /* TODO: Fix logging
@ -18,6 +19,10 @@ func main() {
} }
deviceName := os.Args[1] deviceName := os.Args[1]
// increase number of go workers (for Go <1.5)
runtime.GOMAXPROCS(runtime.NumCPU())
// open TUN device // open TUN device
tun, err := CreateTUN(deviceName) tun, err := CreateTUN(deviceName)
@ -31,17 +36,21 @@ func main() {
// start configuration lister // start configuration lister
socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName) go func() {
l, err := net.Listen("unix", socketPath) socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
if err != nil { l, err := net.Listen("unix", socketPath)
log.Fatal("listen error:", err)
}
for {
conn, err := l.Accept()
if err != nil { if err != nil {
log.Fatal("accept error:", err) log.Fatal("listen error:", err)
} }
go ipcHandle(device, conn)
} for {
conn, err := l.Accept()
if err != nil {
log.Fatal("accept error:", err)
}
go ipcHandle(device, conn)
}
}()
device.Wait()
} }

View File

@ -1,7 +1,9 @@
package main package main
import ( import (
"encoding/base64"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"time" "time"
@ -38,9 +40,9 @@ type Peer struct {
/* Both keep-alive timers acts as one (see timers.go) /* Both keep-alive timers acts as one (see timers.go)
* They are kept seperate to simplify the implementation. * They are kept seperate to simplify the implementation.
*/ */
keepalivePersistent *time.Timer // set for persistent keepalives keepalivePersistent *time.Timer // set for persistent keepalives
keepaliveAcknowledgement *time.Timer // set upon recieving messages keepalivePassive *time.Timer // set upon recieving messages
zeroAllKeys *time.Timer // zero all key material after RejectAfterTime*3 zeroAllKeys *time.Timer // zero all key material after RejectAfterTime*3
} }
queue struct { queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue nonce chan *QueueOutboundElement // nonce / pre-handshake queue
@ -63,8 +65,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
peer.mac.Init(pk) peer.mac.Init(pk)
peer.device = device peer.device = device
peer.timer.keepalivePassive = NewStoppedTimer()
peer.timer.keepalivePersistent = NewStoppedTimer() peer.timer.keepalivePersistent = NewStoppedTimer()
peer.timer.keepaliveAcknowledgement = NewStoppedTimer()
peer.timer.zeroAllKeys = NewStoppedTimer() peer.timer.zeroAllKeys = NewStoppedTimer()
peer.flags.keepaliveWaiting = AtomicFalse peer.flags.keepaliveWaiting = AtomicFalse
@ -115,6 +117,15 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
return peer return peer
} }
func (peer *Peer) String() string {
return fmt.Sprintf(
"peer(%d %s %s)",
peer.id,
peer.endpoint.String(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}
func (peer *Peer) Close() { func (peer *Peer) Close() {
close(peer.signal.stop) close(peer.signal.stop)
} }

View File

@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -362,7 +364,7 @@ func (device *Device) RoutineHandshake() {
return return
} }
logDebug.Println("Creating response...") logDebug.Println("Creating response message for", peer.String())
outElem := device.NewOutboundElement() outElem := device.NewOutboundElement()
writer := bytes.NewBuffer(outElem.data[:0]) writer := bytes.NewBuffer(outElem.data[:0])
@ -416,6 +418,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
var elem *QueueInboundElement var elem *QueueInboundElement
device := peer.device device := peer.device
logInfo := device.log.Info
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, sequential receiver, started for peer", peer.id) logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
@ -450,7 +454,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
peer.KeepKeyFreshReceiving() peer.KeepKeyFreshReceiving()
// check if confirming handshake // check if using new key-pair
kp := &peer.keyPairs kp := &peer.keyPairs
kp.mutex.Lock() kp.mutex.Lock()
@ -465,17 +469,18 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check for keep-alive // check for keep-alive
if len(elem.packet) == 0 { if len(elem.packet) == 0 {
logDebug.Println("Received keep-alive from", peer.String())
return return
} }
// verify source and strip padding // verify source and strip padding
switch elem.packet[0] >> 4 { switch elem.packet[0] >> 4 {
case IPv4version: case ipv4.Version:
// strip padding // strip padding
if len(elem.packet) < IPv4headerSize { if len(elem.packet) < ipv4.HeaderLen {
return return
} }
@ -487,31 +492,33 @@ func (peer *Peer) RoutineSequentialReceiver() {
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
if device.routingTable.LookupIPv4(dst) != peer { if device.routingTable.LookupIPv4(dst) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
return return
} }
case IPv6version: case ipv6.Version:
// strip padding // strip padding
if len(elem.packet) < IPv6headerSize { if len(elem.packet) < ipv6.HeaderLen {
return return
} }
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field) length := binary.BigEndian.Uint16(field)
length += IPv6headerSize length += ipv6.HeaderLen
elem.packet = elem.packet[:length] elem.packet = elem.packet[:length]
// verify IPv6 source // verify IPv6 source
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
if device.routingTable.LookupIPv6(dst) != peer { if device.routingTable.LookupIPv6(dst) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
return return
} }
default: default:
logDebug.Println("Receieved packet with unknown IP version") logInfo.Println("Packet with invalid IP version from", peer.String())
return return
} }
@ -522,6 +529,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
} }
func (device *Device) RoutineWriteToTUN(tun TUNDevice) { func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
logError := device.log.Error logError := device.log.Error
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, sequential tun writer, started") logDebug.Println("Routine, sequential tun writer, started")

View File

@ -3,6 +3,8 @@ package main
import ( import (
"encoding/binary" "encoding/binary"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -21,28 +23,26 @@ import (
* The functions in this file occure (roughly) in the order packets are processed. * The functions in this file occure (roughly) in the order packets are processed.
*/ */
/* A work unit /* The sequential consumers will attempt to take the lock,
* * workers release lock when they have completed work (encryption) on the packet.
* The sequential consumers will attempt to take the lock,
* workers release lock when they have completed work on the packet.
* *
* If the element is inserted into the "encryption queue", * If the element is inserted into the "encryption queue",
* the content is preceeded by enough "junk" to contain the header * the content is preceeded by enough "junk" to contain the transport header
* (to allow the construction of transport messages in-place) * (to allow the construction of transport messages in-place)
*/ */
type QueueOutboundElement struct { type QueueOutboundElement struct {
dropped int32 dropped int32
mutex sync.Mutex mutex sync.Mutex
data [MaxMessageSize]byte data [MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "data" (always!) packet []byte // slice of "data" (always!)
nonce uint64 // nonce for encryption nonce uint64 // nonce for encryption
keyPair *KeyPair // key-pair for encryption keyPair *KeyPair // key-pair for encryption
peer *Peer // related peer peer *Peer // related peer
} }
func (peer *Peer) FlushNonceQueue() { func (peer *Peer) FlushNonceQueue() {
elems := len(peer.queue.nonce) elems := len(peer.queue.nonce)
for i := 0; i < elems; i += 1 { for i := 0; i < elems; i++ {
select { select {
case <-peer.queue.nonce: case <-peer.queue.nonce:
default: default:
@ -111,14 +111,18 @@ func addToEncryptionQueue(
* Obs. Single instance per TUN device * Obs. Single instance per TUN device
*/ */
func (device *Device) RoutineReadFromTUN(tun TUNDevice) { func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
if tun == nil { if tun == nil {
// dummy
return return
} }
elem := device.NewOutboundElement() elem := device.NewOutboundElement()
device.log.Debug.Println("Routine, TUN Reader: started") logDebug := device.log.Debug
logError := device.log.Error
logDebug.Println("Routine, TUN Reader: started")
for { for {
// read packet // read packet
@ -129,12 +133,17 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
elem.packet = elem.data[MessageTransportHeaderSize:] elem.packet = elem.data[MessageTransportHeaderSize:]
size, err := tun.Read(elem.packet) size, err := tun.Read(elem.packet)
if err != nil { if err != nil {
device.log.Error.Println("Failed to read packet from TUN device:", err)
continue // stop process
logError.Println("Failed to read packet from TUN device:", err)
device.Close()
return
} }
elem.packet = elem.packet[:size] elem.packet = elem.packet[:size]
if len(elem.packet) < IPv4headerSize { if len(elem.packet) < ipv4.HeaderLen {
device.log.Error.Println("Packet too short, length:", size) logError.Println("Packet too short, length:", size)
continue continue
} }
@ -142,23 +151,24 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
var peer *Peer var peer *Peer
switch elem.packet[0] >> 4 { switch elem.packet[0] >> 4 {
case IPv4version: case ipv4.Version:
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst) peer = device.routingTable.LookupIPv4(dst)
case IPv6version: case ipv6.Version:
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst) peer = device.routingTable.LookupIPv6(dst)
default: default:
device.log.Debug.Println("Receieved packet with unknown IP version") logDebug.Println("Receieved packet with unknown IP version")
} }
if peer == nil { if peer == nil {
continue continue
} }
if peer.endpoint == nil { if peer.endpoint == nil {
device.log.Debug.Println("No known endpoint for peer", peer.id) logDebug.Println("No known endpoint for peer", peer.String())
continue continue
} }
@ -184,7 +194,7 @@ func (peer *Peer) RoutineNonce() {
device := peer.device device := peer.device
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, nonce worker, started for peer", peer.id) logDebug.Println("Routine, nonce worker, started for peer", peer.String())
func() { func() {
@ -216,15 +226,15 @@ func (peer *Peer) RoutineNonce() {
} }
} }
signalSend(peer.signal.handshakeBegin) signalSend(peer.signal.handshakeBegin)
logDebug.Println("Waiting for key-pair, peer", peer.id) logDebug.Println("Awaiting key-pair for", peer.String())
select { select {
case <-peer.signal.newKeyPair: case <-peer.signal.newKeyPair:
logDebug.Println("Key-pair negotiated for peer", peer.id) logDebug.Println("Key-pair negotiated for", peer.String())
goto NextPacket goto NextPacket
case <-peer.signal.flushNonceQueue: case <-peer.signal.flushNonceQueue:
logDebug.Println("Clearing queue for peer", peer.id) logDebug.Println("Clearing queue for", peer.String())
peer.FlushNonceQueue() peer.FlushNonceQueue()
elem = nil elem = nil
goto NextPacket goto NextPacket
@ -313,13 +323,14 @@ func (peer *Peer) RoutineSequentialSender() {
device := peer.device device := peer.device
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, sequential sender, started for peer", peer.id) logDebug.Println("Routine, sequential sender, started for", peer.String())
for { for {
select { select {
case <-peer.signal.stop: case <-peer.signal.stop:
logDebug.Println("Routine, sequential sender, stopped for peer", peer.id) logDebug.Println("Routine, sequential sender, stopped for", peer.String())
return return
case work := <-peer.queue.outbound: case work := <-peer.queue.outbound:
work.mutex.Lock() work.mutex.Lock()
if work.IsDropped() { if work.IsDropped() {
@ -334,7 +345,7 @@ func (peer *Peer) RoutineSequentialSender() {
defer peer.mutex.RUnlock() defer peer.mutex.RUnlock()
if peer.endpoint == nil { if peer.endpoint == nil {
logDebug.Println("No endpoint for peer:", peer.id) logDebug.Println("No endpoint for", peer.String())
return return
} }
@ -352,7 +363,7 @@ func (peer *Peer) RoutineSequentialSender() {
} }
atomic.AddUint64(&peer.txBytes, uint64(len(work.packet))) atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
// reset keep-alive (passive keep-alives / acknowledgements) // reset keep-alive
peer.TimerResetKeepalive() peer.TimerResetKeepalive()
}() }()

View File

@ -50,7 +50,7 @@ func (peer *Peer) KeepKeyFreshReceiving() {
* - First transport message under the "next" key * - First transport message under the "next" key
*/ */
func (peer *Peer) EventHandshakeComplete() { func (peer *Peer) EventHandshakeComplete() {
peer.device.log.Debug.Println("Handshake completed") peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
signalSend(peer.signal.handshakeCompleted) signalSend(peer.signal.handshakeCompleted)
} }
@ -112,7 +112,7 @@ func (peer *Peer) TimerResetKeepalive() {
// stop acknowledgement timer // stop acknowledgement timer
timerStop(peer.timer.keepaliveAcknowledgement) timerStop(peer.timer.keepalivePassive)
atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse) atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse)
} }
@ -140,7 +140,7 @@ func (peer *Peer) RoutineTimerHandler() {
device := peer.device device := peer.device
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.id) logDebug.Println("Routine, timer handler, started for peer", peer.String())
for { for {
select { select {
@ -152,14 +152,14 @@ func (peer *Peer) RoutineTimerHandler() {
case <-peer.timer.keepalivePersistent.C: case <-peer.timer.keepalivePersistent.C:
logDebug.Println("Sending persistent keep-alive to peer", peer.id) logDebug.Println("Sending persistent keep-alive to", peer.String())
peer.SendKeepAlive() peer.SendKeepAlive()
peer.TimerResetKeepalive() peer.TimerResetKeepalive()
case <-peer.timer.keepaliveAcknowledgement.C: case <-peer.timer.keepalivePassive.C:
logDebug.Println("Sending passive persistent keep-alive to peer", peer.id) logDebug.Println("Sending passive persistent keep-alive to", peer.String())
peer.SendKeepAlive() peer.SendKeepAlive()
peer.TimerResetKeepalive() peer.TimerResetKeepalive()
@ -168,7 +168,7 @@ func (peer *Peer) RoutineTimerHandler() {
case <-peer.timer.zeroAllKeys.C: case <-peer.timer.zeroAllKeys.C:
logDebug.Println("Clearing all key material for peer", peer.id) logDebug.Println("Clearing all key material for", peer.String())
// zero out key pairs // zero out key pairs
@ -208,14 +208,12 @@ func (peer *Peer) RoutineHandshakeInitiator() {
var elem *QueueOutboundElement var elem *QueueOutboundElement
logInfo := device.log.Info
logError := device.log.Error logError := device.log.Error
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, handshake initator, started for peer", peer.id) logDebug.Println("Routine, handshake initator, started for", peer.String())
for run := true; run; { for {
var err error
var attempts uint
var deadline time.Time
// wait for signal // wait for signal
@ -227,15 +225,17 @@ func (peer *Peer) RoutineHandshakeInitiator() {
// wait for handshake // wait for handshake
run = func() bool { func() {
for { var err error
var deadline time.Time
for attempts := uint(1); ; attempts++ {
// clear completed signal // clear completed signal
select { select {
case <-peer.signal.handshakeCompleted: case <-peer.signal.handshakeCompleted:
case <-peer.signal.stop: case <-peer.signal.stop:
return false return
default: default:
} }
@ -246,43 +246,39 @@ func (peer *Peer) RoutineHandshakeInitiator() {
} }
elem, err = peer.BeginHandshakeInitiation() elem, err = peer.BeginHandshakeInitiation()
if err != nil { if err != nil {
logError.Println("Failed to create initiation message:", err) logError.Println("Failed to create initiation message", err, "for", peer.String())
break return
} }
// set timeout // set timeout
attempts += 1
if attempts == 1 { if attempts == 1 {
deadline = time.Now().Add(MaxHandshakeAttemptTime) deadline = time.Now().Add(MaxHandshakeAttemptTime)
} }
timeout := time.NewTimer(RekeyTimeout) timeout := time.NewTimer(RekeyTimeout)
logDebug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id) logDebug.Println("Handshake initiation attempt", attempts, "queued for", peer.String())
// wait for handshake or timeout // wait for handshake or timeout
select { select {
case <-peer.signal.stop: case <-peer.signal.stop:
return true return
case <-peer.signal.handshakeCompleted: case <-peer.signal.handshakeCompleted:
<-timeout.C <-timeout.C
return true return
case <-timeout.C: case <-timeout.C:
logDebug.Println("Timeout")
// check if sufficient time for retry
if deadline.Before(time.Now().Add(RekeyTimeout)) { if deadline.Before(time.Now().Add(RekeyTimeout)) {
logInfo.Println("Handshake negotiation timed out for", peer.String())
signalSend(peer.signal.flushNonceQueue) signalSend(peer.signal.flushNonceQueue)
timerStop(peer.timer.keepalivePersistent) timerStop(peer.timer.keepalivePersistent)
timerStop(peer.timer.keepaliveAcknowledgement) timerStop(peer.timer.keepalivePassive)
return true return
} }
} }
} }
return true
}() }()
signalClear(peer.signal.handshakeBegin) signalClear(peer.signal.handshakeBegin)

View File

@ -23,7 +23,8 @@ type Trie struct {
bits []byte bits []byte
peer *Peer peer *Peer
// Index of "branching" bit // index of "branching" bit
bit_at_byte uint bit_at_byte uint
bit_at_shift uint bit_at_shift uint
} }
@ -36,7 +37,7 @@ type Trie struct {
func commonBits(ip1 net.IP, ip2 net.IP) uint { func commonBits(ip1 net.IP, ip2 net.IP) uint {
var i uint var i uint
size := uint(len(ip1)) size := uint(len(ip1))
for i = 0; i < size; i += 1 { for i = 0; i < size; i++ {
v := ip1[i] ^ ip2[i] v := ip1[i] ^ ip2[i]
if v != 0 { if v != 0 {
v >>= 1 v >>= 1
@ -84,7 +85,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
return node return node
} }
// Walk recursivly // walk recursivly
node.child[0] = node.child[0].RemovePeer(p) node.child[0] = node.child[0].RemovePeer(p)
node.child[1] = node.child[1].RemovePeer(p) node.child[1] = node.child[1].RemovePeer(p)
@ -93,7 +94,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
return node return node
} }
// Remove peer & merge // remove peer & merge
node.peer = nil node.peer = nil
if node.child[0] == nil { if node.child[0] == nil {
@ -108,7 +109,7 @@ func (node *Trie) choose(ip net.IP) byte {
func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
// At leaf // at leaf
if node == nil { if node == nil {
return &Trie{ return &Trie{
@ -120,7 +121,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
} }
} }
// Traverse deeper // traverse deeper
common := commonBits(node.bits, ip) common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr { if node.cidr <= cidr && common >= node.cidr {
@ -133,7 +134,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
return node return node
} }
// Split node // split node
newNode := &Trie{ newNode := &Trie{
bits: ip, bits: ip,
@ -145,7 +146,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
cidr = min(cidr, common) cidr = min(cidr, common)
// Check for shorter prefix // check for shorter prefix
if newNode.cidr == cidr { if newNode.cidr == cidr {
bit := newNode.choose(node.bits) bit := newNode.choose(node.bits)
@ -153,7 +154,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
return newNode return newNode
} }
// Create new parent for node & newNode // create new parent for node & newNode
parent := &Trie{ parent := &Trie{
bits: ip, bits: ip,