From ba3e486667987f16290ac85dc35b53cb9702d662 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 30 Jun 2017 14:41:08 +0200 Subject: [PATCH] Completed initial version of outbound flow --- src/config.go | 30 ++--- src/constants.go | 18 +-- src/device.go | 37 ++++-- src/handshake.go | 259 +++++++++++++++++++++++++----------------- src/helper_test.go | 4 +- src/index.go | 2 +- src/keypair.go | 25 ++-- src/logger.go | 23 +++- src/macs_test.go | 6 + src/main.go | 2 +- src/misc.go | 7 ++ src/noise_helpers.go | 2 + src/noise_protocol.go | 86 ++++++++------ src/noise_test.go | 4 +- src/peer.go | 54 +++++---- src/send.go | 218 +++++++++++++++++++++++------------ src/tun_linux.go | 1 + 17 files changed, 491 insertions(+), 287 deletions(-) diff --git a/src/config.go b/src/config.go index 2f8dc76..8281581 100644 --- a/src/config.go +++ b/src/config.go @@ -8,7 +8,6 @@ import ( "net" "strconv" "strings" - "time" ) // #include @@ -51,9 +50,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { send("private_key=" + device.privateKey.ToHex()) } - if device.address != nil { - send(fmt.Sprintf("listen_port=%d", device.address.Port)) - } + send(fmt.Sprintf("listen_port=%d", device.net.addr.Port)) for _, peer := range device.peers { func() { @@ -106,7 +103,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } key := parts[0] value := parts[1] - logger.Println("Key-value pair: (", key, ",", value, ")") // TODO: Remove, leaks private key to log switch key { @@ -118,13 +114,13 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { device.privateKey = NoisePrivateKey{} device.mutex.Unlock() } else { - device.mutex.Lock() - err := device.privateKey.FromHex(value) - device.mutex.Unlock() + var sk NoisePrivateKey + err := sk.FromHex(value) if err != nil { logger.Println("Failed to set private_key:", err) return &IPCError{Code: ipcErrorInvalidValue} } + device.SetPrivateKey(sk) } case "listen_port": @@ -134,12 +130,10 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logger.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorInvalidValue} } - device.mutex.Lock() - if device.address == nil { - device.address = &net.UDPAddr{} - } - device.address.Port = port - device.mutex.Unlock() + device.net.mutex.Lock() + device.net.addr.Port = port + device.net.conn, err = net.ListenUDP("udp", device.net.addr) + device.net.mutex.Unlock() case "fwmark": logger.Println("FWMark not handled yet") @@ -200,13 +194,13 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } case "endpoint": - ip := net.ParseIP(value) - if ip == nil { + addr, err := net.ResolveUDPAddr("udp", value) + if err != nil { logger.Println("Failed to set endpoint:", value) return &IPCError{Code: ipcErrorInvalidValue} } peer.mutex.Lock() - // peer.endpoint = ip FIX + peer.endpoint = addr peer.mutex.Unlock() case "persistent_keepalive_interval": @@ -216,7 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorInvalidValue} } peer.mutex.Lock() - peer.persistentKeepaliveInterval = time.Duration(secs) * time.Second + peer.persistentKeepaliveInterval = uint64(secs) peer.mutex.Unlock() case "replace_allowed_ips": diff --git a/src/constants.go b/src/constants.go index e8cdd63..34217d2 100644 --- a/src/constants.go +++ b/src/constants.go @@ -5,15 +5,15 @@ import ( ) const ( - RekeyAfterMessage = (1 << 64) - (1 << 16) - 1 - RekeyAfterTime = time.Second * 120 - RekeyAttemptTime = time.Second * 90 - RekeyTimeout = time.Second * 5 // TODO: Exponential backoff - RejectAfterTime = time.Second * 180 - RejectAfterMessage = (1 << 64) - (1 << 4) - 1 - KeepaliveTimeout = time.Second * 10 - CookieRefreshTime = time.Second * 2 - MaxHandshakeAttempTime = time.Second * 90 + RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 + RekeyAfterTime = time.Second * 120 + RekeyAttemptTime = time.Second * 90 + RekeyTimeout = time.Second * 5 // TODO: Exponential backoff + RejectAfterTime = time.Second * 180 + RejectAfterMessages = (1 << 64) - (1 << 4) - 1 + KeepaliveTimeout = time.Second * 10 + CookieRefreshTime = time.Second * 2 + MaxHandshakeAttemptTime = time.Second * 90 ) const ( diff --git a/src/device.go b/src/device.go index 52ac6a4..a33e923 100644 --- a/src/device.go +++ b/src/device.go @@ -7,16 +7,21 @@ import ( ) type Device struct { - mtu int - fwMark uint32 - address *net.UDPAddr // UDP source address - conn *net.UDPConn // UDP "connection" + mtu int + log *Logger // collection of loggers for levels + idCounter uint // for assigning debug ids to peers + fwMark uint32 + net struct { + // seperate for performance reasons + mutex sync.RWMutex + addr *net.UDPAddr // UDP source address + conn *net.UDPConn // UDP "connection" + } mutex sync.RWMutex privateKey NoisePrivateKey publicKey NoisePublicKey routingTable RoutingTable indices IndexTable - log *Logger queue struct { encryption chan *QueueOutboundElement // parallel work queue } @@ -44,17 +49,29 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) { } } -func NewDevice(tun TUNDevice) *Device { +func NewDevice(tun TUNDevice, logLevel int) *Device { device := new(Device) device.mutex.Lock() defer device.mutex.Unlock() - device.log = NewLogger() + device.log = NewLogger(logLevel) device.peers = make(map[NoisePublicKey]*Peer) device.indices.Init() device.routingTable.Reset() + // listen + + device.net.mutex.Lock() + device.net.conn, _ = net.ListenUDP("udp", device.net.addr) + addr := device.net.conn.LocalAddr() + device.net.addr, _ = net.ResolveUDPAddr(addr.Network(), addr.String()) + device.net.mutex.Unlock() + + // create queues + + device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) + // start workers for i := 0; i < runtime.NumCPU(); i += 1 { @@ -92,5 +109,11 @@ func (device *Device) RemoveAllPeers() { peer.mutex.Lock() delete(device.peers, key) peer.Close() + peer.mutex.Unlock() } } + +func (device *Device) Close() { + device.RemoveAllPeers() + close(device.queue.encryption) +} diff --git a/src/handshake.go b/src/handshake.go index 238c339..8f8e2f9 100644 --- a/src/handshake.go +++ b/src/handshake.go @@ -24,91 +24,163 @@ func (peer *Peer) SendKeepAlive() bool { return true } -func (peer *Peer) RoutineHandshakeInitiator() { - var ongoing bool - var begun time.Time - var attempts uint - var timeout time.Timer - - device := peer.device - work := new(QueueOutboundElement) - buffer := make([]byte, 0, 1024) - - queueHandshakeInitiation := func() error { - work.mutex.Lock() - defer work.mutex.Unlock() - - // create initiation - - msg, err := device.CreateMessageInitiation(peer) - if err != nil { - return err - } - - // create "work" element - - writer := bytes.NewBuffer(buffer[:0]) - binary.Write(writer, binary.LittleEndian, &msg) - work.packet = writer.Bytes() - peer.mac.AddMacs(work.packet) - peer.InsertOutbound(work) - return nil +func StoppedTimer() *time.Timer { + timer := time.NewTimer(time.Hour) + if !timer.Stop() { + <-timer.C } + return timer +} - for { - select { - case <-peer.signal.stopInitiator: - return +/* Called when a new authenticated message has been send + * + * TODO: This might be done in a faster way + */ +func (peer *Peer) KeepKeyFreshSending() { + send := func() bool { + peer.keyPairs.mutex.RLock() + defer peer.keyPairs.mutex.RUnlock() - case <-peer.signal.newHandshake: - if ongoing { - continue - } - - // create handshake - - err := queueHandshakeInitiation() - if err != nil { - device.log.Error.Println("Failed to create initiation message:", err) - } - - // log when we began - - begun = time.Now() - ongoing = true - attempts = 0 - timeout.Reset(RekeyTimeout) - - case <-peer.timer.sendKeepalive.C: - - // active keep-alives - - peer.SendKeepAlive() - - case <-peer.timer.handshakeTimeout.C: - - // check if we can stop trying - - if time.Now().Sub(begun) > MaxHandshakeAttempTime { - peer.signal.flushNonceQueue <- true - peer.timer.sendKeepalive.Stop() - ongoing = false - continue - } - - // otherwise, try again (exponental backoff) - - attempts += 1 - err := queueHandshakeInitiation() - if err != nil { - device.log.Error.Println("Failed to create initiation message:", err) - } - peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout) + kp := peer.keyPairs.current + if kp == nil { + return false } + + if !kp.isInitiator { + return false + } + + nonce := atomic.LoadUint64(&kp.sendNonce) + if nonce > RekeyAfterMessages { + return true + } + return time.Now().Sub(kp.created) > RekeyAfterTime + }() + if send { + sendSignal(peer.signal.handshakeBegin) } } -/* Handles packets related to handshake +/* This is the state machine for handshake initiation + * + * Associated with this routine is the signal "handshakeBegin" + * The routine will read from the "handshakeBegin" channel + * at most every RekeyTimeout or with exponential backoff + * + * Implements exponential backoff for retries + */ +func (peer *Peer) RoutineHandshakeInitiator() { + work := new(QueueOutboundElement) + device := peer.device + buffer := make([]byte, 1024) + logger := device.log.Debug + timeout := time.NewTimer(time.Hour) + + logger.Println("Routine, handshake initator, started for peer", peer.id) + + func() { + for { + var attempts uint + var deadline time.Time + + select { + case <-peer.signal.handshakeBegin: + case <-peer.signal.stop: + return + } + + HandshakeLoop: + for run := true; run; { + // clear completed signal + + select { + case <-peer.signal.handshakeCompleted: + case <-peer.signal.stop: + return + default: + } + + // queue handshake + + err := func() error { + work.mutex.Lock() + defer work.mutex.Unlock() + + // create initiation + + msg, err := device.CreateMessageInitiation(peer) + if err != nil { + return err + } + + // marshal + + writer := bytes.NewBuffer(buffer[:0]) + binary.Write(writer, binary.LittleEndian, msg) + work.packet = writer.Bytes() + peer.mac.AddMacs(work.packet) + peer.InsertOutbound(work) + return nil + }() + if err != nil { + device.log.Error.Println("Failed to create initiation message:", err) + break + } + if attempts == 0 { + deadline = time.Now().Add(MaxHandshakeAttemptTime) + } + + // set timeout + + if !timeout.Stop() { + select { + case <-timeout.C: + default: + } + } + timeout.Reset((1 << attempts) * RekeyTimeout) + attempts += 1 + device.log.Debug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id) + time.Sleep(RekeyTimeout) + + // wait for handshake or timeout + + select { + case <-peer.signal.stop: + return + + case <-peer.signal.handshakeCompleted: + break HandshakeLoop + + default: + select { + + case <-peer.signal.stop: + return + + case <-peer.signal.handshakeCompleted: + break HandshakeLoop + + case <-timeout.C: + nextTimeout := (1 << attempts) * RekeyTimeout + if deadline.Before(time.Now().Add(nextTimeout)) { + // we do not have time for another attempt + peer.signal.flushNonceQueue <- struct{}{} + if !peer.timer.sendKeepalive.Stop() { + <-peer.timer.sendKeepalive.C + } + break HandshakeLoop + } + } + } + } + } + }() + + logger.Println("Routine, handshake initator, stopped for peer", peer.id) +} + +/* Handles incomming packets related to handshake * * */ @@ -140,33 +212,12 @@ func (device *Device) HandshakeWorker(queue chan struct { // check for cookie case MessageCookieReplyType: + if len(elem.msg) != MessageCookieReplySize { + continue + } - case MessageTransportType: + default: + device.log.Error.Println("Invalid message type in handshake queue") } - - } -} - -func (device *Device) KeepKeyFresh(peer *Peer) { - - send := func() bool { - peer.keyPairs.mutex.RLock() - defer peer.keyPairs.mutex.RUnlock() - - kp := peer.keyPairs.current - if kp == nil { - return false - } - - nonce := atomic.LoadUint64(&kp.sendNonce) - if nonce > RekeyAfterMessage { - return true - } - - return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime - }() - - if send { - } } diff --git a/src/helper_test.go b/src/helper_test.go index 3a5c331..464292f 100644 --- a/src/helper_test.go +++ b/src/helper_test.go @@ -35,7 +35,7 @@ func (tun *DummyTUN) Read(d []byte) (int, error) { func CreateDummyTUN(name string) (TUNDevice, error) { var dummy DummyTUN - dummy.mtu = 1024 + dummy.mtu = 0 dummy.packets = make(chan []byte, 100) return &dummy, nil } @@ -58,7 +58,7 @@ func randDevice(t *testing.T) *Device { t.Fatal(err) } tun, _ := CreateDummyTUN("dummy") - device := NewDevice(tun) + device := NewDevice(tun, LogLevelError) device.SetPrivateKey(sk) return device } diff --git a/src/index.go b/src/index.go index 9178510..59e2079 100644 --- a/src/index.go +++ b/src/index.go @@ -41,7 +41,7 @@ func (table *IndexTable) Init() { table.mutex.Unlock() } -func (table *IndexTable) ClearIndex(index uint32) { +func (table *IndexTable) Delete(index uint32) { if index == 0 { return } diff --git a/src/keypair.go b/src/keypair.go index 0b029ce..0e845f7 100644 --- a/src/keypair.go +++ b/src/keypair.go @@ -13,20 +13,27 @@ type KeyPair struct { sendNonce uint64 isInitiator bool created time.Time + id uint32 } type KeyPairs struct { - mutex sync.RWMutex - current *KeyPair - previous *KeyPair - next *KeyPair // not yet "confirmed by transport" - newKeyPair chan bool // signals when "current" has been updated + mutex sync.RWMutex + current *KeyPair + previous *KeyPair + next *KeyPair // not yet "confirmed by transport" } -func (kp *KeyPairs) Init() { - kp.mutex.Lock() - kp.newKeyPair = make(chan bool, 5) - kp.mutex.Unlock() +/* Called during recieving to confirm the handshake + * was completed correctly + */ +func (kp *KeyPairs) Used(key *KeyPair) { + if key == kp.next { + kp.mutex.Lock() + kp.previous = kp.current + kp.current = key + kp.next = nil + kp.mutex.Unlock() + } } func (kp *KeyPairs) Current() *KeyPair { diff --git a/src/logger.go b/src/logger.go index 117fe5b..827f9e9 100644 --- a/src/logger.go +++ b/src/logger.go @@ -1,6 +1,8 @@ package main import ( + "io" + "io/ioutil" "log" "os" ) @@ -17,17 +19,30 @@ type Logger struct { Error *log.Logger } -func NewLogger() *Logger { +func NewLogger(level int) *Logger { + output := os.Stdout logger := new(Logger) - logger.Debug = log.New(os.Stdout, + + logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) { + if level >= LogLevelDebug { + return output, output, output + } + if level >= LogLevelInfo { + return output, output, ioutil.Discard + } + return output, ioutil.Discard, ioutil.Discard + }() + + logger.Debug = log.New(logDebug, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile, ) - logger.Info = log.New(os.Stdout, + + logger.Info = log.New(logInfo, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile, ) - logger.Error = log.New(os.Stdout, + logger.Error = log.New(logErr, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile, ) diff --git a/src/macs_test.go b/src/macs_test.go index fcb64ea..a2a6503 100644 --- a/src/macs_test.go +++ b/src/macs_test.go @@ -11,6 +11,9 @@ func TestMAC1(t *testing.T) { dev1 := randDevice(t) dev2 := randDevice(t) + defer dev1.Close() + defer dev2.Close() + peer1 := dev2.NewPeer(dev1.privateKey.publicKey()) peer2 := dev1.NewPeer(dev2.privateKey.publicKey()) @@ -40,6 +43,9 @@ func TestMACs(t *testing.T) { device2 := randDevice(t) device2.SetPrivateKey(sk2) + defer device1.Close() + defer device2.Close() + peer1 := device2.NewPeer(device1.privateKey.publicKey()) peer2 := device1.NewPeer(device2.privateKey.publicKey()) diff --git a/src/main.go b/src/main.go index 9c76ff4..b89af17 100644 --- a/src/main.go +++ b/src/main.go @@ -28,7 +28,7 @@ func main() { return } - device := NewDevice(tun) + device := NewDevice(tun, LogLevelDebug) // Start configuration lister diff --git a/src/misc.go b/src/misc.go index e1244d6..2bcb148 100644 --- a/src/misc.go +++ b/src/misc.go @@ -6,3 +6,10 @@ func min(a uint, b uint) uint { } return a } + +func sendSignal(c chan struct{}) { + select { + case c <- struct{}{}: + default: + } +} diff --git a/src/noise_helpers.go b/src/noise_helpers.go index e163ace..1e622a5 100644 --- a/src/noise_helpers.go +++ b/src/noise_helpers.go @@ -33,6 +33,7 @@ func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt HMAC(&prk, key, input) HMAC(&t0, prk[:], []byte{0x1}) HMAC(&t1, prk[:], append(t0[:], 0x2)) + prk = [blake2s.Size]byte{} return } @@ -42,6 +43,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt HMAC(&t0, prk[:], []byte{0x1}) HMAC(&t1, prk[:], append(t0[:], 0x2)) HMAC(&t2, prk[:], append(t1[:], 0x3)) + prk = [blake2s.Size]byte{} return } diff --git a/src/noise_protocol.go b/src/noise_protocol.go index 46ceeda..a1a1c7b 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -31,8 +31,9 @@ const ( ) const ( - MessageInitiationSize = 148 - MessageResponseSize = 92 + MessageInitiationSize = 148 + MessageResponseSize = 92 + MessageCookieReplySize = 64 ) /* Type is an 8-bit field, followed by 3 nul bytes, @@ -91,16 +92,11 @@ type Handshake struct { } var ( - InitalChainKey [blake2s.Size]byte - InitalHash [blake2s.Size]byte - ZeroNonce [chacha20poly1305.NonceSize]byte + InitialChainKey [blake2s.Size]byte + InitialHash [blake2s.Size]byte + ZeroNonce [chacha20poly1305.NonceSize]byte ) -func init() { - InitalChainKey = blake2s.Sum256([]byte(NoiseConstruction)) - InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...)) -} - func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { return KDF1(c[:], data) } @@ -117,6 +113,13 @@ func (h *Handshake) mixKey(data []byte) { h.chainKey = mixKey(h.chainKey, data) } +/* Do basic precomputations + */ +func init() { + InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) + InitialHash = mixHash(InitialChainKey, []byte(WGIdentifier)) +} + func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { handshake := &peer.handshake handshake.mutex.Lock() @@ -125,28 +128,30 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e // create ephemeral key var err error - handshake.chainKey = InitalChainKey - handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:]) + handshake.hash = InitialHash + handshake.chainKey = InitialChainKey handshake.localEphemeral, err = newPrivateKey() if err != nil { return nil, err } - device.indices.ClearIndex(handshake.localIndex) - handshake.localIndex, err = device.indices.NewIndex(peer) - // assign index - var msg MessageInitiation - - msg.Type = MessageInitiationType - msg.Ephemeral = handshake.localEphemeral.publicKey() + device.indices.Delete(handshake.localIndex) + handshake.localIndex, err = device.indices.NewIndex(peer) if err != nil { return nil, err } - msg.Sender = handshake.localIndex + handshake.mixHash(handshake.remoteStatic[:]) + + msg := MessageInitiation{ + Type: MessageInitiationType, + Ephemeral: handshake.localEphemeral.publicKey(), + Sender: handshake.localIndex, + } + handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -185,9 +190,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { return nil } - hash := mixHash(InitalHash, device.publicKey[:]) + hash := mixHash(InitialHash, device.publicKey[:]) hash = mixHash(hash, msg.Ephemeral[:]) - chainKey := mixKey(InitalChainKey, msg.Ephemeral[:]) + chainKey := mixKey(InitialChainKey, msg.Ephemeral[:]) // decrypt static key @@ -278,7 +283,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error // assign index var err error - device.indices.ClearIndex(handshake.localIndex) + device.indices.Delete(handshake.localIndex) handshake.localIndex, err = device.indices.NewIndex(peer) if err != nil { return nil, err @@ -420,10 +425,15 @@ func (peer *Peer) NewKeyPair() *KeyPair { return nil } + // zero handshake + + handshake.chainKey = [blake2s.Size]byte{} + handshake.localEphemeral = NoisePrivateKey{} + peer.handshake.state = HandshakeZeroed + // create AEAD instances - var keyPair KeyPair - + keyPair := new(KeyPair) keyPair.send, _ = chacha20poly1305.New(sendKey[:]) keyPair.recv, _ = chacha20poly1305.New(recvKey[:]) keyPair.sendNonce = 0 @@ -433,30 +443,32 @@ func (peer *Peer) NewKeyPair() *KeyPair { peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{ peer: peer, - keyPair: &keyPair, + keyPair: keyPair, handshake: nil, }) handshake.localIndex = 0 + // start timer for keypair + // rotate key pairs + kp := &peer.keyPairs func() { - kp := &peer.keyPairs kp.mutex.Lock() defer kp.mutex.Unlock() if isInitiator { - kp.previous = peer.keyPairs.current - kp.current = &keyPair - kp.newKeyPair <- true + if kp.previous != nil { + kp.previous.send = nil + kp.previous.recv = nil + peer.device.indices.Delete(kp.previous.id) + } + kp.previous = kp.current + kp.current = keyPair + sendSignal(peer.signal.newKeyPair) } else { - kp.next = &keyPair + kp.next = keyPair } }() - // zero handshake - - handshake.chainKey = [blake2s.Size]byte{} - handshake.localEphemeral = NoisePrivateKey{} - peer.handshake.state = HandshakeZeroed - return &keyPair + return keyPair } diff --git a/src/noise_test.go b/src/noise_test.go index 02f6bf3..9b50ff3 100644 --- a/src/noise_test.go +++ b/src/noise_test.go @@ -25,10 +25,12 @@ func TestCurveWrappers(t *testing.T) { } func TestNoiseHandshake(t *testing.T) { - dev1 := randDevice(t) dev2 := randDevice(t) + defer dev1.Close() + defer dev2.Close() + peer1 := dev2.NewPeer(dev1.privateKey.publicKey()) peer2 := dev1.NewPeer(dev2.privateKey.publicKey()) diff --git a/src/peer.go b/src/peer.go index 21cad9d..e885cee 100644 --- a/src/peer.go +++ b/src/peer.go @@ -10,26 +10,29 @@ import ( const () type Peer struct { + id uint mutex sync.RWMutex endpoint *net.UDPAddr - persistentKeepaliveInterval time.Duration // 0 = disabled + persistentKeepaliveInterval uint64 keyPairs KeyPairs handshake Handshake device *Device tx_bytes uint64 rx_bytes uint64 time struct { - lastSend time.Time // last send message + lastSend time.Time // last send message + lastHandshake time.Time // last completed handshake } signal struct { - newHandshake chan bool - flushNonceQueue chan bool // empty queued packets - stopSending chan bool // stop sending pipeline - stopInitiator chan bool // stop initiator timer + newKeyPair chan struct{} // (size 1) : a new key pair was generated + handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake") + handshakeCompleted chan struct{} // (size 1) : handshake completed + flushNonceQueue chan struct{} // (size 1) : empty queued packets + stop chan struct{} // (size 0) : close to stop all goroutines for peer } timer struct { - sendKeepalive time.Timer - handshakeTimeout time.Timer + sendKeepalive *time.Timer + handshakeTimeout *time.Timer } queue struct { nonce chan []byte // nonce / pre-handshake queue @@ -39,25 +42,30 @@ type Peer struct { } func (device *Device) NewPeer(pk NoisePublicKey) *Peer { - var peer Peer - // create peer + peer := new(Peer) peer.mutex.Lock() + defer peer.mutex.Unlock() peer.device = device - peer.keyPairs.Init() peer.mac.Init(pk) peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.nonce = make(chan []byte, QueueOutboundSize) + peer.timer.sendKeepalive = StoppedTimer() + + // assign id for debugging + + device.mutex.Lock() + peer.id = device.idCounter + device.idCounter += 1 // map public key - device.mutex.Lock() _, ok := device.peers[pk] if ok { panic(errors.New("bug: adding existing peer")) } - device.peers[pk] = &peer + device.peers[pk] = peer device.mutex.Unlock() // precompute DH @@ -67,22 +75,24 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { handshake.remoteStatic = pk handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) handshake.mutex.Unlock() - peer.mutex.Unlock() - // start workers + // prepare signaling - peer.signal.stopSending = make(chan bool, 1) - peer.signal.stopInitiator = make(chan bool, 1) - peer.signal.newHandshake = make(chan bool, 1) - peer.signal.flushNonceQueue = make(chan bool, 1) + peer.signal.stop = make(chan struct{}) + peer.signal.newKeyPair = make(chan struct{}, 1) + peer.signal.handshakeBegin = make(chan struct{}, 1) + peer.signal.handshakeCompleted = make(chan struct{}, 1) + peer.signal.flushNonceQueue = make(chan struct{}, 1) + + // outbound pipeline go peer.RoutineNonce() go peer.RoutineHandshakeInitiator() + go peer.RoutineSequentialSender() - return &peer + return peer } func (peer *Peer) Close() { - peer.signal.stopSending <- true - peer.signal.stopInitiator <- true + close(peer.signal.stop) } diff --git a/src/send.go b/src/send.go index ab75750..d4f9342 100644 --- a/src/send.go +++ b/src/send.go @@ -5,6 +5,8 @@ import ( "golang.org/x/crypto/chacha20poly1305" "net" "sync" + "sync/atomic" + "time" ) /* Handles outbound flow @@ -29,6 +31,7 @@ type QueueOutboundElement struct { packet []byte nonce uint64 keyPair *KeyPair + peer *Peer } func (peer *Peer) FlushNonceQueue() { @@ -46,6 +49,7 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) { for { select { case peer.queue.outbound <- elem: + return default: select { case <-peer.queue.outbound: @@ -61,11 +65,15 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) { * Obs. Single instance per TUN device */ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { + if tun.MTU() == 0 { + // Dummy + return + } + device.log.Debug.Println("Routine, TUN Reader: started") for { // read packet - device.log.Debug.Println("Read") packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation size, err := tun.Read(packet) if err != nil { @@ -94,13 +102,16 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { default: device.log.Debug.Println("Receieved packet with unknown IP version") - return } if peer == nil { device.log.Debug.Println("No peer configured for IP") continue } + if peer.endpoint == nil { + device.log.Debug.Println("No known endpoint for peer", peer.id) + continue + } // insert into nonce/pre-handshake queue @@ -131,69 +142,95 @@ func (peer *Peer) RoutineNonce() { var packet []byte var keyPair *KeyPair - for { + device := peer.device + logger := device.log.Debug - // wait for packet + logger.Println("Routine, nonce worker, started for peer", peer.id) - if packet == nil { - select { - case packet = <-peer.queue.nonce: - case <-peer.signal.stopSending: - close(peer.queue.outbound) - return + func() { + + for { + NextPacket: + + // wait for packet + + if packet == nil { + select { + case packet = <-peer.queue.nonce: + case <-peer.signal.stop: + return + } } - } - // wait for key pair + // wait for key pair + + for { + select { + case <-peer.signal.newKeyPair: + default: + } - for keyPair == nil { - peer.signal.newHandshake <- true - select { - case <-peer.keyPairs.newKeyPair: keyPair = peer.keyPairs.Current() - continue - case <-peer.signal.flushNonceQueue: - peer.FlushNonceQueue() - packet = nil - continue - case <-peer.signal.stopSending: - close(peer.queue.outbound) - return - } - } - - // process current packet - - if packet != nil { - - // create work element - - work := new(QueueOutboundElement) // TODO: profile, maybe use pool - work.keyPair = keyPair - work.packet = packet - work.nonce = keyPair.sendNonce - work.mutex.Lock() - - packet = nil - keyPair.sendNonce += 1 - - // drop packets until there is space - - func() { - for { - select { - case peer.device.queue.encryption <- work: - return - default: - drop := <-peer.device.queue.encryption - drop.packet = nil - drop.mutex.Unlock() + if keyPair != nil && keyPair.sendNonce < RejectAfterMessages { + if time.Now().Sub(keyPair.created) < RejectAfterTime { + break } } - }() - peer.queue.outbound <- work + + sendSignal(peer.signal.handshakeBegin) + logger.Println("Waiting for key-pair, peer", peer.id) + + select { + case <-peer.signal.newKeyPair: + logger.Println("Key-pair negotiated for peer", peer.id) + goto NextPacket + + case <-peer.signal.flushNonceQueue: + logger.Println("Clearing queue for peer", peer.id) + peer.FlushNonceQueue() + packet = nil + goto NextPacket + + case <-peer.signal.stop: + return + } + } + + // process current packet + + if packet != nil { + + // create work element + + work := new(QueueOutboundElement) // TODO: profile, maybe use pool + work.keyPair = keyPair + work.packet = packet + work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) + work.peer = peer + work.mutex.Lock() + + packet = nil + + // drop packets until there is space + + func() { + for { + select { + case peer.device.queue.encryption <- work: + return + default: + drop := <-peer.device.queue.encryption + drop.packet = nil + drop.mutex.Unlock() + } + } + }() + peer.queue.outbound <- work + } } - } + }() + + logger.Println("Routine, nonce worker, stopped for peer", peer.id) } /* Encrypts the elements in the queue @@ -227,6 +264,10 @@ func (device *Device) RoutineEncryption() { nil, ) work.mutex.Unlock() + + // initiate new handshake + + work.peer.KeepKeyFreshSending() } } @@ -235,21 +276,54 @@ func (device *Device) RoutineEncryption() { * Obs. Single instance per peer. * The routine terminates then the outbound queue is closed. */ -func (peer *Peer) RoutineSequential() { - for work := range peer.queue.outbound { - work.mutex.Lock() - func() { - peer.mutex.RLock() - defer peer.mutex.RUnlock() - if work.packet == nil { - return - } - if peer.endpoint == nil { - return - } - peer.device.conn.WriteToUDP(work.packet, peer.endpoint) - peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval) - }() - work.mutex.Unlock() +func (peer *Peer) RoutineSequentialSender() { + logger := peer.device.log.Debug + logger.Println("Routine, sequential sender, started for peer", peer.id) + + device := peer.device + + for { + select { + case <-peer.signal.stop: + logger.Println("Routine, sequential sender, stopped for peer", peer.id) + return + case work := <-peer.queue.outbound: + work.mutex.Lock() + func() { + if work.packet == nil { + return + } + + peer.mutex.RLock() + defer peer.mutex.RUnlock() + + if peer.endpoint == nil { + logger.Println("No endpoint for peer:", peer.id) + return + } + + device.net.mutex.RLock() + defer device.net.mutex.RUnlock() + + if device.net.conn == nil { + logger.Println("No source for device") + return + } + + logger.Println("Sending packet for peer", peer.id, work.packet) + + _, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint) + logger.Println("SEND:", peer.endpoint, err) + atomic.AddUint64(&peer.tx_bytes, uint64(len(work.packet))) + + // shift keep-alive timer + + if peer.persistentKeepaliveInterval != 0 { + interval := time.Duration(peer.persistentKeepaliveInterval) * time.Second + peer.timer.sendKeepalive.Reset(interval) + } + }() + work.mutex.Unlock() + } } } diff --git a/src/tun_linux.go b/src/tun_linux.go index cbbcb70..db13fb0 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -74,5 +74,6 @@ func CreateTUN(name string) (TUNDevice, error) { return &NativeTun{ fd: fd, name: newName, + mtu: 1500, // TODO: FIX }, nil }