1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Completed initial version of outbound flow

This commit is contained in:
Mathias Hall-Andersen 2017-06-30 14:41:08 +02:00
parent 7e185db141
commit ba3e486667
17 changed files with 491 additions and 287 deletions

View File

@ -8,7 +8,6 @@ import (
"net"
"strconv"
"strings"
"time"
)
// #include <errno.h>
@ -51,9 +50,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
send("private_key=" + device.privateKey.ToHex())
}
if device.address != nil {
send(fmt.Sprintf("listen_port=%d", device.address.Port))
}
send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
for _, peer := range device.peers {
func() {
@ -106,7 +103,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
key := parts[0]
value := parts[1]
logger.Println("Key-value pair: (", key, ",", value, ")") // TODO: Remove, leaks private key to log
switch key {
@ -118,13 +114,13 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.privateKey = NoisePrivateKey{}
device.mutex.Unlock()
} else {
device.mutex.Lock()
err := device.privateKey.FromHex(value)
device.mutex.Unlock()
var sk NoisePrivateKey
err := sk.FromHex(value)
if err != nil {
logger.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
device.SetPrivateKey(sk)
}
case "listen_port":
@ -134,12 +130,10 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logger.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
device.mutex.Lock()
if device.address == nil {
device.address = &net.UDPAddr{}
}
device.address.Port = port
device.mutex.Unlock()
device.net.mutex.Lock()
device.net.addr.Port = port
device.net.conn, err = net.ListenUDP("udp", device.net.addr)
device.net.mutex.Unlock()
case "fwmark":
logger.Println("FWMark not handled yet")
@ -200,13 +194,13 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
case "endpoint":
ip := net.ParseIP(value)
if ip == nil {
addr, err := net.ResolveUDPAddr("udp", value)
if err != nil {
logger.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalidValue}
}
peer.mutex.Lock()
// peer.endpoint = ip FIX
peer.endpoint = addr
peer.mutex.Unlock()
case "persistent_keepalive_interval":
@ -216,7 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalidValue}
}
peer.mutex.Lock()
peer.persistentKeepaliveInterval = time.Duration(secs) * time.Second
peer.persistentKeepaliveInterval = uint64(secs)
peer.mutex.Unlock()
case "replace_allowed_ips":

View File

@ -5,15 +5,15 @@ import (
)
const (
RekeyAfterMessage = (1 << 64) - (1 << 16) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5 // TODO: Exponential backoff
RejectAfterTime = time.Second * 180
RejectAfterMessage = (1 << 64) - (1 << 4) - 1
KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 2
MaxHandshakeAttempTime = time.Second * 90
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5 // TODO: Exponential backoff
RejectAfterTime = time.Second * 180
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 2
MaxHandshakeAttemptTime = time.Second * 90
)
const (

View File

@ -7,16 +7,21 @@ import (
)
type Device struct {
mtu int
fwMark uint32
address *net.UDPAddr // UDP source address
conn *net.UDPConn // UDP "connection"
mtu int
log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
fwMark uint32
net struct {
// seperate for performance reasons
mutex sync.RWMutex
addr *net.UDPAddr // UDP source address
conn *net.UDPConn // UDP "connection"
}
mutex sync.RWMutex
privateKey NoisePrivateKey
publicKey NoisePublicKey
routingTable RoutingTable
indices IndexTable
log *Logger
queue struct {
encryption chan *QueueOutboundElement // parallel work queue
}
@ -44,17 +49,29 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
}
}
func NewDevice(tun TUNDevice) *Device {
func NewDevice(tun TUNDevice, logLevel int) *Device {
device := new(Device)
device.mutex.Lock()
defer device.mutex.Unlock()
device.log = NewLogger()
device.log = NewLogger(logLevel)
device.peers = make(map[NoisePublicKey]*Peer)
device.indices.Init()
device.routingTable.Reset()
// listen
device.net.mutex.Lock()
device.net.conn, _ = net.ListenUDP("udp", device.net.addr)
addr := device.net.conn.LocalAddr()
device.net.addr, _ = net.ResolveUDPAddr(addr.Network(), addr.String())
device.net.mutex.Unlock()
// create queues
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
// start workers
for i := 0; i < runtime.NumCPU(); i += 1 {
@ -92,5 +109,11 @@ func (device *Device) RemoveAllPeers() {
peer.mutex.Lock()
delete(device.peers, key)
peer.Close()
peer.mutex.Unlock()
}
}
func (device *Device) Close() {
device.RemoveAllPeers()
close(device.queue.encryption)
}

View File

@ -24,91 +24,163 @@ func (peer *Peer) SendKeepAlive() bool {
return true
}
func (peer *Peer) RoutineHandshakeInitiator() {
var ongoing bool
var begun time.Time
var attempts uint
var timeout time.Timer
device := peer.device
work := new(QueueOutboundElement)
buffer := make([]byte, 0, 1024)
queueHandshakeInitiation := func() error {
work.mutex.Lock()
defer work.mutex.Unlock()
// create initiation
msg, err := device.CreateMessageInitiation(peer)
if err != nil {
return err
}
// create "work" element
writer := bytes.NewBuffer(buffer[:0])
binary.Write(writer, binary.LittleEndian, &msg)
work.packet = writer.Bytes()
peer.mac.AddMacs(work.packet)
peer.InsertOutbound(work)
return nil
func StoppedTimer() *time.Timer {
timer := time.NewTimer(time.Hour)
if !timer.Stop() {
<-timer.C
}
return timer
}
for {
select {
case <-peer.signal.stopInitiator:
return
/* Called when a new authenticated message has been send
*
* TODO: This might be done in a faster way
*/
func (peer *Peer) KeepKeyFreshSending() {
send := func() bool {
peer.keyPairs.mutex.RLock()
defer peer.keyPairs.mutex.RUnlock()
case <-peer.signal.newHandshake:
if ongoing {
continue
}
// create handshake
err := queueHandshakeInitiation()
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
}
// log when we began
begun = time.Now()
ongoing = true
attempts = 0
timeout.Reset(RekeyTimeout)
case <-peer.timer.sendKeepalive.C:
// active keep-alives
peer.SendKeepAlive()
case <-peer.timer.handshakeTimeout.C:
// check if we can stop trying
if time.Now().Sub(begun) > MaxHandshakeAttempTime {
peer.signal.flushNonceQueue <- true
peer.timer.sendKeepalive.Stop()
ongoing = false
continue
}
// otherwise, try again (exponental backoff)
attempts += 1
err := queueHandshakeInitiation()
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
}
peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout)
kp := peer.keyPairs.current
if kp == nil {
return false
}
if !kp.isInitiator {
return false
}
nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessages {
return true
}
return time.Now().Sub(kp.created) > RekeyAfterTime
}()
if send {
sendSignal(peer.signal.handshakeBegin)
}
}
/* Handles packets related to handshake
/* This is the state machine for handshake initiation
*
* Associated with this routine is the signal "handshakeBegin"
* The routine will read from the "handshakeBegin" channel
* at most every RekeyTimeout or with exponential backoff
*
* Implements exponential backoff for retries
*/
func (peer *Peer) RoutineHandshakeInitiator() {
work := new(QueueOutboundElement)
device := peer.device
buffer := make([]byte, 1024)
logger := device.log.Debug
timeout := time.NewTimer(time.Hour)
logger.Println("Routine, handshake initator, started for peer", peer.id)
func() {
for {
var attempts uint
var deadline time.Time
select {
case <-peer.signal.handshakeBegin:
case <-peer.signal.stop:
return
}
HandshakeLoop:
for run := true; run; {
// clear completed signal
select {
case <-peer.signal.handshakeCompleted:
case <-peer.signal.stop:
return
default:
}
// queue handshake
err := func() error {
work.mutex.Lock()
defer work.mutex.Unlock()
// create initiation
msg, err := device.CreateMessageInitiation(peer)
if err != nil {
return err
}
// marshal
writer := bytes.NewBuffer(buffer[:0])
binary.Write(writer, binary.LittleEndian, msg)
work.packet = writer.Bytes()
peer.mac.AddMacs(work.packet)
peer.InsertOutbound(work)
return nil
}()
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
break
}
if attempts == 0 {
deadline = time.Now().Add(MaxHandshakeAttemptTime)
}
// set timeout
if !timeout.Stop() {
select {
case <-timeout.C:
default:
}
}
timeout.Reset((1 << attempts) * RekeyTimeout)
attempts += 1
device.log.Debug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id)
time.Sleep(RekeyTimeout)
// wait for handshake or timeout
select {
case <-peer.signal.stop:
return
case <-peer.signal.handshakeCompleted:
break HandshakeLoop
default:
select {
case <-peer.signal.stop:
return
case <-peer.signal.handshakeCompleted:
break HandshakeLoop
case <-timeout.C:
nextTimeout := (1 << attempts) * RekeyTimeout
if deadline.Before(time.Now().Add(nextTimeout)) {
// we do not have time for another attempt
peer.signal.flushNonceQueue <- struct{}{}
if !peer.timer.sendKeepalive.Stop() {
<-peer.timer.sendKeepalive.C
}
break HandshakeLoop
}
}
}
}
}
}()
logger.Println("Routine, handshake initator, stopped for peer", peer.id)
}
/* Handles incomming packets related to handshake
*
*
*/
@ -140,33 +212,12 @@ func (device *Device) HandshakeWorker(queue chan struct {
// check for cookie
case MessageCookieReplyType:
if len(elem.msg) != MessageCookieReplySize {
continue
}
case MessageTransportType:
default:
device.log.Error.Println("Invalid message type in handshake queue")
}
}
}
func (device *Device) KeepKeyFresh(peer *Peer) {
send := func() bool {
peer.keyPairs.mutex.RLock()
defer peer.keyPairs.mutex.RUnlock()
kp := peer.keyPairs.current
if kp == nil {
return false
}
nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessage {
return true
}
return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime
}()
if send {
}
}

View File

@ -35,7 +35,7 @@ func (tun *DummyTUN) Read(d []byte) (int, error) {
func CreateDummyTUN(name string) (TUNDevice, error) {
var dummy DummyTUN
dummy.mtu = 1024
dummy.mtu = 0
dummy.packets = make(chan []byte, 100)
return &dummy, nil
}
@ -58,7 +58,7 @@ func randDevice(t *testing.T) *Device {
t.Fatal(err)
}
tun, _ := CreateDummyTUN("dummy")
device := NewDevice(tun)
device := NewDevice(tun, LogLevelError)
device.SetPrivateKey(sk)
return device
}

View File

@ -41,7 +41,7 @@ func (table *IndexTable) Init() {
table.mutex.Unlock()
}
func (table *IndexTable) ClearIndex(index uint32) {
func (table *IndexTable) Delete(index uint32) {
if index == 0 {
return
}

View File

@ -13,20 +13,27 @@ type KeyPair struct {
sendNonce uint64
isInitiator bool
created time.Time
id uint32
}
type KeyPairs struct {
mutex sync.RWMutex
current *KeyPair
previous *KeyPair
next *KeyPair // not yet "confirmed by transport"
newKeyPair chan bool // signals when "current" has been updated
mutex sync.RWMutex
current *KeyPair
previous *KeyPair
next *KeyPair // not yet "confirmed by transport"
}
func (kp *KeyPairs) Init() {
kp.mutex.Lock()
kp.newKeyPair = make(chan bool, 5)
kp.mutex.Unlock()
/* Called during recieving to confirm the handshake
* was completed correctly
*/
func (kp *KeyPairs) Used(key *KeyPair) {
if key == kp.next {
kp.mutex.Lock()
kp.previous = kp.current
kp.current = key
kp.next = nil
kp.mutex.Unlock()
}
}
func (kp *KeyPairs) Current() *KeyPair {

View File

@ -1,6 +1,8 @@
package main
import (
"io"
"io/ioutil"
"log"
"os"
)
@ -17,17 +19,30 @@ type Logger struct {
Error *log.Logger
}
func NewLogger() *Logger {
func NewLogger(level int) *Logger {
output := os.Stdout
logger := new(Logger)
logger.Debug = log.New(os.Stdout,
logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
if level >= LogLevelDebug {
return output, output, output
}
if level >= LogLevelInfo {
return output, output, ioutil.Discard
}
return output, ioutil.Discard, ioutil.Discard
}()
logger.Debug = log.New(logDebug,
"DEBUG: ",
log.Ldate|log.Ltime|log.Lshortfile,
)
logger.Info = log.New(os.Stdout,
logger.Info = log.New(logInfo,
"INFO: ",
log.Ldate|log.Ltime|log.Lshortfile,
)
logger.Error = log.New(os.Stdout,
logger.Error = log.New(logErr,
"ERROR: ",
log.Ldate|log.Ltime|log.Lshortfile,
)

View File

@ -11,6 +11,9 @@ func TestMAC1(t *testing.T) {
dev1 := randDevice(t)
dev2 := randDevice(t)
defer dev1.Close()
defer dev2.Close()
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
@ -40,6 +43,9 @@ func TestMACs(t *testing.T) {
device2 := randDevice(t)
device2.SetPrivateKey(sk2)
defer device1.Close()
defer device2.Close()
peer1 := device2.NewPeer(device1.privateKey.publicKey())
peer2 := device1.NewPeer(device2.privateKey.publicKey())

View File

@ -28,7 +28,7 @@ func main() {
return
}
device := NewDevice(tun)
device := NewDevice(tun, LogLevelDebug)
// Start configuration lister

View File

@ -6,3 +6,10 @@ func min(a uint, b uint) uint {
}
return a
}
func sendSignal(c chan struct{}) {
select {
case c <- struct{}{}:
default:
}
}

View File

@ -33,6 +33,7 @@ func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
HMAC(&prk, key, input)
HMAC(&t0, prk[:], []byte{0x1})
HMAC(&t1, prk[:], append(t0[:], 0x2))
prk = [blake2s.Size]byte{}
return
}
@ -42,6 +43,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
HMAC(&t0, prk[:], []byte{0x1})
HMAC(&t1, prk[:], append(t0[:], 0x2))
HMAC(&t2, prk[:], append(t1[:], 0x3))
prk = [blake2s.Size]byte{}
return
}

View File

@ -31,8 +31,9 @@ const (
)
const (
MessageInitiationSize = 148
MessageResponseSize = 92
MessageInitiationSize = 148
MessageResponseSize = 92
MessageCookieReplySize = 64
)
/* Type is an 8-bit field, followed by 3 nul bytes,
@ -91,16 +92,11 @@ type Handshake struct {
}
var (
InitalChainKey [blake2s.Size]byte
InitalHash [blake2s.Size]byte
ZeroNonce [chacha20poly1305.NonceSize]byte
InitialChainKey [blake2s.Size]byte
InitialHash [blake2s.Size]byte
ZeroNonce [chacha20poly1305.NonceSize]byte
)
func init() {
InitalChainKey = blake2s.Sum256([]byte(NoiseConstruction))
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
}
func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return KDF1(c[:], data)
}
@ -117,6 +113,13 @@ func (h *Handshake) mixKey(data []byte) {
h.chainKey = mixKey(h.chainKey, data)
}
/* Do basic precomputations
*/
func init() {
InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
InitialHash = mixHash(InitialChainKey, []byte(WGIdentifier))
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
@ -125,28 +128,30 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// create ephemeral key
var err error
handshake.chainKey = InitalChainKey
handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:])
handshake.hash = InitialHash
handshake.chainKey = InitialChainKey
handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
device.indices.ClearIndex(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
// assign index
var msg MessageInitiation
msg.Type = MessageInitiationType
msg.Ephemeral = handshake.localEphemeral.publicKey()
device.indices.Delete(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil {
return nil, err
}
msg.Sender = handshake.localIndex
handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
}
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@ -185,9 +190,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
return nil
}
hash := mixHash(InitalHash, device.publicKey[:])
hash := mixHash(InitialHash, device.publicKey[:])
hash = mixHash(hash, msg.Ephemeral[:])
chainKey := mixKey(InitalChainKey, msg.Ephemeral[:])
chainKey := mixKey(InitialChainKey, msg.Ephemeral[:])
// decrypt static key
@ -278,7 +283,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
// assign index
var err error
device.indices.ClearIndex(handshake.localIndex)
device.indices.Delete(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil {
return nil, err
@ -420,10 +425,15 @@ func (peer *Peer) NewKeyPair() *KeyPair {
return nil
}
// zero handshake
handshake.chainKey = [blake2s.Size]byte{}
handshake.localEphemeral = NoisePrivateKey{}
peer.handshake.state = HandshakeZeroed
// create AEAD instances
var keyPair KeyPair
keyPair := new(KeyPair)
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0
@ -433,30 +443,32 @@ func (peer *Peer) NewKeyPair() *KeyPair {
peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
peer: peer,
keyPair: &keyPair,
keyPair: keyPair,
handshake: nil,
})
handshake.localIndex = 0
// start timer for keypair
// rotate key pairs
kp := &peer.keyPairs
func() {
kp := &peer.keyPairs
kp.mutex.Lock()
defer kp.mutex.Unlock()
if isInitiator {
kp.previous = peer.keyPairs.current
kp.current = &keyPair
kp.newKeyPair <- true
if kp.previous != nil {
kp.previous.send = nil
kp.previous.recv = nil
peer.device.indices.Delete(kp.previous.id)
}
kp.previous = kp.current
kp.current = keyPair
sendSignal(peer.signal.newKeyPair)
} else {
kp.next = &keyPair
kp.next = keyPair
}
}()
// zero handshake
handshake.chainKey = [blake2s.Size]byte{}
handshake.localEphemeral = NoisePrivateKey{}
peer.handshake.state = HandshakeZeroed
return &keyPair
return keyPair
}

View File

@ -25,10 +25,12 @@ func TestCurveWrappers(t *testing.T) {
}
func TestNoiseHandshake(t *testing.T) {
dev1 := randDevice(t)
dev2 := randDevice(t)
defer dev1.Close()
defer dev2.Close()
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())

View File

@ -10,26 +10,29 @@ import (
const ()
type Peer struct {
id uint
mutex sync.RWMutex
endpoint *net.UDPAddr
persistentKeepaliveInterval time.Duration // 0 = disabled
persistentKeepaliveInterval uint64
keyPairs KeyPairs
handshake Handshake
device *Device
tx_bytes uint64
rx_bytes uint64
time struct {
lastSend time.Time // last send message
lastSend time.Time // last send message
lastHandshake time.Time // last completed handshake
}
signal struct {
newHandshake chan bool
flushNonceQueue chan bool // empty queued packets
stopSending chan bool // stop sending pipeline
stopInitiator chan bool // stop initiator timer
newKeyPair chan struct{} // (size 1) : a new key pair was generated
handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake")
handshakeCompleted chan struct{} // (size 1) : handshake completed
flushNonceQueue chan struct{} // (size 1) : empty queued packets
stop chan struct{} // (size 0) : close to stop all goroutines for peer
}
timer struct {
sendKeepalive time.Timer
handshakeTimeout time.Timer
sendKeepalive *time.Timer
handshakeTimeout *time.Timer
}
queue struct {
nonce chan []byte // nonce / pre-handshake queue
@ -39,25 +42,30 @@ type Peer struct {
}
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
var peer Peer
// create peer
peer := new(Peer)
peer.mutex.Lock()
defer peer.mutex.Unlock()
peer.device = device
peer.keyPairs.Init()
peer.mac.Init(pk)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.nonce = make(chan []byte, QueueOutboundSize)
peer.timer.sendKeepalive = StoppedTimer()
// assign id for debugging
device.mutex.Lock()
peer.id = device.idCounter
device.idCounter += 1
// map public key
device.mutex.Lock()
_, ok := device.peers[pk]
if ok {
panic(errors.New("bug: adding existing peer"))
}
device.peers[pk] = &peer
device.peers[pk] = peer
device.mutex.Unlock()
// precompute DH
@ -67,22 +75,24 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
handshake.remoteStatic = pk
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock()
peer.mutex.Unlock()
// start workers
// prepare signaling
peer.signal.stopSending = make(chan bool, 1)
peer.signal.stopInitiator = make(chan bool, 1)
peer.signal.newHandshake = make(chan bool, 1)
peer.signal.flushNonceQueue = make(chan bool, 1)
peer.signal.stop = make(chan struct{})
peer.signal.newKeyPair = make(chan struct{}, 1)
peer.signal.handshakeBegin = make(chan struct{}, 1)
peer.signal.handshakeCompleted = make(chan struct{}, 1)
peer.signal.flushNonceQueue = make(chan struct{}, 1)
// outbound pipeline
go peer.RoutineNonce()
go peer.RoutineHandshakeInitiator()
go peer.RoutineSequentialSender()
return &peer
return peer
}
func (peer *Peer) Close() {
peer.signal.stopSending <- true
peer.signal.stopInitiator <- true
close(peer.signal.stop)
}

View File

@ -5,6 +5,8 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"net"
"sync"
"sync/atomic"
"time"
)
/* Handles outbound flow
@ -29,6 +31,7 @@ type QueueOutboundElement struct {
packet []byte
nonce uint64
keyPair *KeyPair
peer *Peer
}
func (peer *Peer) FlushNonceQueue() {
@ -46,6 +49,7 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
for {
select {
case peer.queue.outbound <- elem:
return
default:
select {
case <-peer.queue.outbound:
@ -61,11 +65,15 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
if tun.MTU() == 0 {
// Dummy
return
}
device.log.Debug.Println("Routine, TUN Reader: started")
for {
// read packet
device.log.Debug.Println("Read")
packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
size, err := tun.Read(packet)
if err != nil {
@ -94,13 +102,16 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
default:
device.log.Debug.Println("Receieved packet with unknown IP version")
return
}
if peer == nil {
device.log.Debug.Println("No peer configured for IP")
continue
}
if peer.endpoint == nil {
device.log.Debug.Println("No known endpoint for peer", peer.id)
continue
}
// insert into nonce/pre-handshake queue
@ -131,69 +142,95 @@ func (peer *Peer) RoutineNonce() {
var packet []byte
var keyPair *KeyPair
for {
device := peer.device
logger := device.log.Debug
// wait for packet
logger.Println("Routine, nonce worker, started for peer", peer.id)
if packet == nil {
select {
case packet = <-peer.queue.nonce:
case <-peer.signal.stopSending:
close(peer.queue.outbound)
return
func() {
for {
NextPacket:
// wait for packet
if packet == nil {
select {
case packet = <-peer.queue.nonce:
case <-peer.signal.stop:
return
}
}
}
// wait for key pair
// wait for key pair
for {
select {
case <-peer.signal.newKeyPair:
default:
}
for keyPair == nil {
peer.signal.newHandshake <- true
select {
case <-peer.keyPairs.newKeyPair:
keyPair = peer.keyPairs.Current()
continue
case <-peer.signal.flushNonceQueue:
peer.FlushNonceQueue()
packet = nil
continue
case <-peer.signal.stopSending:
close(peer.queue.outbound)
return
}
}
// process current packet
if packet != nil {
// create work element
work := new(QueueOutboundElement) // TODO: profile, maybe use pool
work.keyPair = keyPair
work.packet = packet
work.nonce = keyPair.sendNonce
work.mutex.Lock()
packet = nil
keyPair.sendNonce += 1
// drop packets until there is space
func() {
for {
select {
case peer.device.queue.encryption <- work:
return
default:
drop := <-peer.device.queue.encryption
drop.packet = nil
drop.mutex.Unlock()
if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
if time.Now().Sub(keyPair.created) < RejectAfterTime {
break
}
}
}()
peer.queue.outbound <- work
sendSignal(peer.signal.handshakeBegin)
logger.Println("Waiting for key-pair, peer", peer.id)
select {
case <-peer.signal.newKeyPair:
logger.Println("Key-pair negotiated for peer", peer.id)
goto NextPacket
case <-peer.signal.flushNonceQueue:
logger.Println("Clearing queue for peer", peer.id)
peer.FlushNonceQueue()
packet = nil
goto NextPacket
case <-peer.signal.stop:
return
}
}
// process current packet
if packet != nil {
// create work element
work := new(QueueOutboundElement) // TODO: profile, maybe use pool
work.keyPair = keyPair
work.packet = packet
work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1)
work.peer = peer
work.mutex.Lock()
packet = nil
// drop packets until there is space
func() {
for {
select {
case peer.device.queue.encryption <- work:
return
default:
drop := <-peer.device.queue.encryption
drop.packet = nil
drop.mutex.Unlock()
}
}
}()
peer.queue.outbound <- work
}
}
}
}()
logger.Println("Routine, nonce worker, stopped for peer", peer.id)
}
/* Encrypts the elements in the queue
@ -227,6 +264,10 @@ func (device *Device) RoutineEncryption() {
nil,
)
work.mutex.Unlock()
// initiate new handshake
work.peer.KeepKeyFreshSending()
}
}
@ -235,21 +276,54 @@ func (device *Device) RoutineEncryption() {
* Obs. Single instance per peer.
* The routine terminates then the outbound queue is closed.
*/
func (peer *Peer) RoutineSequential() {
for work := range peer.queue.outbound {
work.mutex.Lock()
func() {
peer.mutex.RLock()
defer peer.mutex.RUnlock()
if work.packet == nil {
return
}
if peer.endpoint == nil {
return
}
peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval)
}()
work.mutex.Unlock()
func (peer *Peer) RoutineSequentialSender() {
logger := peer.device.log.Debug
logger.Println("Routine, sequential sender, started for peer", peer.id)
device := peer.device
for {
select {
case <-peer.signal.stop:
logger.Println("Routine, sequential sender, stopped for peer", peer.id)
return
case work := <-peer.queue.outbound:
work.mutex.Lock()
func() {
if work.packet == nil {
return
}
peer.mutex.RLock()
defer peer.mutex.RUnlock()
if peer.endpoint == nil {
logger.Println("No endpoint for peer:", peer.id)
return
}
device.net.mutex.RLock()
defer device.net.mutex.RUnlock()
if device.net.conn == nil {
logger.Println("No source for device")
return
}
logger.Println("Sending packet for peer", peer.id, work.packet)
_, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint)
logger.Println("SEND:", peer.endpoint, err)
atomic.AddUint64(&peer.tx_bytes, uint64(len(work.packet)))
// shift keep-alive timer
if peer.persistentKeepaliveInterval != 0 {
interval := time.Duration(peer.persistentKeepaliveInterval) * time.Second
peer.timer.sendKeepalive.Reset(interval)
}
}()
work.mutex.Unlock()
}
}
}

View File

@ -74,5 +74,6 @@ func CreateTUN(name string) (TUNDevice, error) {
return &NativeTun{
fd: fd,
name: newName,
mtu: 1500, // TODO: FIX
}, nil
}