diff --git a/src/constants.go b/src/constants.go index 6b0d414..09d33d8 100644 --- a/src/constants.go +++ b/src/constants.go @@ -20,6 +20,7 @@ const ( const ( RekeyAfterTimeReceiving = RekeyAfterTime - KeepaliveTimeout - RekeyTimeout + NewHandshakeTime = KeepaliveTimeout + RekeyTimeout // upon failure to acknowledge transport message ) /* Implementation specific constants */ diff --git a/src/noise_protocol.go b/src/noise_protocol.go index 5fe6fb2..e2ff573 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -37,6 +37,7 @@ const ( MessageCookieReplySize = 64 MessageTransportHeaderSize = 16 MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport + MessageKeepaliveSize = MessageTransportSize ) const ( @@ -253,8 +254,6 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { } hash = mixHash(hash, msg.Timestamp[:]) - // TODO: check for flood attack - // check for replay attack return timestamp.After(handshake.lastTimestamp) diff --git a/src/peer.go b/src/peer.go index 8eea929..9136959 100644 --- a/src/peer.go +++ b/src/peer.go @@ -40,21 +40,22 @@ type Peer struct { stop chan struct{} // (size 0) : close to stop all goroutines for peer } timer struct { - /* Both keep-alive timers acts as one (see timers.go) - * They are kept seperate to simplify the implementation. - */ keepalivePersistent *time.Timer // set for persistent keepalives keepalivePassive *time.Timer // set upon recieving messages - zeroAllKeys *time.Timer // zero all key material after RejectAfterTime*3 + newHandshake *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout) + zeroAllKeys *time.Timer // zero all key material (after RejectAfterTime*3) + + pendingKeepalivePassive bool + pendingNewHandshake bool + pendingZeroAllKeys bool + + needAnotherKeepalive bool } queue struct { nonce chan *QueueOutboundElement // nonce / pre-handshake queue outbound chan *QueueOutboundElement // sequential ordering of work inbound chan *QueueInboundElement // sequential ordering of work } - flags struct { - keepaliveWaiting int32 - } mac MACStatePeer } @@ -68,12 +69,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { peer.mac.Init(pk) peer.device = device - peer.timer.keepalivePassive = NewStoppedTimer() peer.timer.keepalivePersistent = NewStoppedTimer() + peer.timer.keepalivePassive = NewStoppedTimer() + peer.timer.newHandshake = NewStoppedTimer() peer.timer.zeroAllKeys = NewStoppedTimer() - peer.flags.keepaliveWaiting = AtomicFalse - // assign id for debugging device.mutex.Lock() diff --git a/src/receive.go b/src/receive.go index d97ca41..c74211b 100644 --- a/src/receive.go +++ b/src/receive.go @@ -288,6 +288,7 @@ func (device *Device) RoutineHandshake() { logDebug := device.log.Debug logDebug.Println("Routine, handshake routine, started for device") + var temp [256]byte var elem QueueHandshakeElement for { @@ -363,6 +364,7 @@ func (device *Device) RoutineHandshake() { ) return } + peer.TimerPacketReceived() // update endpoint @@ -378,17 +380,19 @@ func (device *Device) RoutineHandshake() { return } + peer.TimerEphemeralKeyCreated() + logDebug.Println("Creating response message for", peer.String()) - outElem := device.NewOutboundElement() - writer := bytes.NewBuffer(outElem.buffer[:0]) + writer := bytes.NewBuffer(temp[:0]) binary.Write(writer, binary.LittleEndian, response) - outElem.packet = writer.Bytes() - peer.mac.AddMacs(outElem.packet) - addToOutboundQueue(peer.queue.outbound, outElem) + packet := writer.Bytes() + peer.mac.AddMacs(packet) - // create new keypair + // send response + peer.SendBuffer(packet) + peer.TimerPacketSent() peer.NewKeyPair() case MessageResponseType: @@ -418,12 +422,11 @@ func (device *Device) RoutineHandshake() { ) return } - kp := peer.NewKeyPair() - if kp == nil { - logDebug.Println("Failed to derieve key-pair") - } + + peer.TimerPacketReceived() + peer.TimerHandshakeComplete() + peer.NewKeyPair() peer.SendKeepAlive() - peer.EventHandshakeComplete() default: logError.Println("Invalid message type in handshake queue") @@ -464,12 +467,8 @@ func (peer *Peer) RoutineSequentialReceiver() { return } - // time (passive) keep-alive - - peer.TimerStartKeepalive() - - // refresh key material (rekey) - + peer.TimerPacketReceived() + peer.TimerTransportReceived() peer.KeepKeyFreshReceiving() // check if using new key-pair @@ -477,7 +476,7 @@ func (peer *Peer) RoutineSequentialReceiver() { kp := &peer.keyPairs kp.mutex.Lock() if kp.next == elem.keyPair { - peer.EventHandshakeComplete() + peer.TimerHandshakeComplete() kp.previous = kp.current kp.current = kp.next kp.next = nil @@ -490,6 +489,7 @@ func (peer *Peer) RoutineSequentialReceiver() { logDebug.Println("Received keep-alive from", peer.String()) return } + peer.TimerDataReceived() // verify source and strip padding diff --git a/src/send.go b/src/send.go index 7cdb806..37078b9 100644 --- a/src/send.go +++ b/src/send.go @@ -2,6 +2,7 @@ package main import ( "encoding/binary" + "errors" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -51,6 +52,11 @@ func (peer *Peer) FlushNonceQueue() { } } +var ( + ErrorNoEndpoint = errors.New("No known endpoint for peer") + ErrorNoConnection = errors.New("No UDP socket for device") +) + func (device *Device) NewOutboundElement() *QueueOutboundElement { return &QueueOutboundElement{ dropped: AtomicFalse, @@ -103,6 +109,25 @@ func addToEncryptionQueue( } } +func (peer *Peer) SendBuffer(buffer []byte) (int, error) { + + peer.mutex.RLock() + endpoint := peer.endpoint + peer.mutex.RUnlock() + if endpoint == nil { + return 0, ErrorNoEndpoint + } + + peer.device.net.mutex.RLock() + conn := peer.device.net.conn + peer.device.net.mutex.RUnlock() + if conn == nil { + return 0, ErrorNoConnection + } + + return conn.WriteToUDP(buffer, endpoint) +} + /* Reads packets from the TUN and inserts * into nonce queue for peer * @@ -349,42 +374,27 @@ func (peer *Peer) RoutineSequentialSender() { case elem := <-peer.queue.outbound: elem.mutex.Lock() + if elem.IsDropped() { + continue + } - func() { - if elem.IsDropped() { - return - } - - // get endpoint and connection - - peer.mutex.RLock() - endpoint := peer.endpoint - peer.mutex.RUnlock() - if endpoint == nil { - logDebug.Println("No endpoint for", peer.String()) - return - } - - device.net.mutex.RLock() - conn := device.net.conn - device.net.mutex.RUnlock() - if conn == nil { - logDebug.Println("No source for device") - return - } - - // send message and refresh keys - - _, err := conn.WriteToUDP(elem.packet, endpoint) - if err != nil { - return - } - - atomic.AddUint64(&peer.stats.txBytes, uint64(len(elem.packet))) - peer.TimerResetKeepalive() - }() + // send message and return buffer to pool + length := uint64(len(elem.packet)) + _, err := peer.SendBuffer(elem.packet) device.PutMessageBuffer(elem.buffer) + if err != nil { + continue + } + atomic.AddUint64(&peer.stats.txBytes, length) + + // update timers + + peer.TimerPacketSent() + if len(elem.packet) != MessageKeepaliveSize { + peer.TimerDataSent() + } + peer.KeepKeyFreshSending() } } } diff --git a/src/timers.go b/src/timers.go index 2454414..5a16e9b 100644 --- a/src/timers.go +++ b/src/timers.go @@ -44,21 +44,6 @@ func (peer *Peer) KeepKeyFreshReceiving() { } } -/* Called after succesfully completing a handshake. - * i.e. after: - * - Valid handshake response - * - First transport message under the "next" key - */ -func (peer *Peer) EventHandshakeComplete() { - peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) - peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) - atomic.StoreInt64( - &peer.stats.lastHandshakeNano, - time.Now().UnixNano(), - ) - signalSend(peer.signal.handshakeCompleted) -} - /* Queues a keep-alive if no packets are queued for peer */ func (peer *Peer) SendKeepAlive() bool { @@ -75,69 +60,89 @@ func (peer *Peer) SendKeepAlive() bool { return true } -/* Starts the "keep-alive" timer - * (if not already running), - * in response to incomming messages +/* Authenticated data packet send + * Always called together with peer.EventPacketSend + * + * - Start new handshake timer */ -func (peer *Peer) TimerStartKeepalive() { +func (peer *Peer) TimerDataSent() { + timerStop(peer.timer.keepalivePassive) + if !peer.timer.pendingNewHandshake { + peer.timer.pendingNewHandshake = true + peer.timer.newHandshake.Reset(NewHandshakeTime) + } +} - // check if acknowledgement timer set yet - - var waiting int32 = AtomicTrue - waiting = atomic.SwapInt32(&peer.flags.keepaliveWaiting, waiting) - if waiting == AtomicTrue { +/* Event: + * Received non-empty (authenticated) transport message + * + * - Start passive keep-alive timer + */ +func (peer *Peer) TimerDataReceived() { + if peer.timer.pendingKeepalivePassive { + peer.timer.needAnotherKeepalive = true return } + peer.timer.pendingKeepalivePassive = false + peer.timer.keepalivePassive.Reset(KeepaliveTimeout) +} - // timer not yet set, start it +/* Event: + * Any (authenticated) transport message received + * (keep-alive or data) + */ +func (peer *Peer) TimerTransportReceived() { + timerStop(peer.timer.newHandshake) +} - wait := KeepaliveTimeout +/* Event: + * Any packet send to the peer. + */ +func (peer *Peer) TimerPacketSent() { interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) if interval > 0 { duration := time.Duration(interval) * time.Second - if duration < wait { - wait = duration - } + peer.timer.keepalivePersistent.Reset(duration) } } -/* Resets both keep-alive timers +/* Event: + * Any authenticated packet received from peer */ -func (peer *Peer) TimerResetKeepalive() { - - // reset persistent timer - - interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) - if interval > 0 { - peer.timer.keepalivePersistent.Reset( - time.Duration(interval) * time.Second, - ) - } - - // stop acknowledgement timer - - timerStop(peer.timer.keepalivePassive) - atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse) +func (peer *Peer) TimerPacketReceived() { + peer.TimerPacketSent() } -func (peer *Peer) BeginHandshakeInitiation() (*QueueOutboundElement, error) { +/* Called after succesfully completing a handshake. + * i.e. after: + * + * - Valid handshake response + * - First transport message under the "next" key + */ +func (peer *Peer) TimerHandshakeComplete() { + timerStop(peer.timer.zeroAllKeys) + atomic.StoreInt64( + &peer.stats.lastHandshakeNano, + time.Now().UnixNano(), + ) + signalSend(peer.signal.handshakeCompleted) + peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) +} - // create initiation - - elem := peer.device.NewOutboundElement() - msg, err := peer.device.CreateMessageInitiation(peer) - if err != nil { - return nil, err +/* Called whenever an ephemeral key is generated + * i.e after: + * + * CreateMessageInitiation + * CreateMessageResponse + * + * Schedules the deletion of all key material + * upon failure to complete a handshake + */ +func (peer *Peer) TimerEphemeralKeyCreated() { + if !peer.timer.pendingZeroAllKeys { + peer.timer.pendingZeroAllKeys = true + peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) } - - // marshal & schedule for sending - - writer := bytes.NewBuffer(elem.buffer[:0]) - binary.Write(writer, binary.LittleEndian, msg) - elem.packet = writer.Bytes() - peer.mac.AddMacs(elem.packet) - addToOutboundQueue(peer.queue.outbound, elem) - return elem, err } func (peer *Peer) RoutineTimerHandler() { @@ -157,17 +162,30 @@ func (peer *Peer) RoutineTimerHandler() { case <-peer.timer.keepalivePersistent.C: - logDebug.Println("Sending persistent keep-alive to", peer.String()) - - peer.SendKeepAlive() - peer.TimerResetKeepalive() + interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) + if interval > 0 { + logDebug.Println("Sending persistent keep-alive to", peer.String()) + peer.SendKeepAlive() + } case <-peer.timer.keepalivePassive.C: - logDebug.Println("Sending passive persistent keep-alive to", peer.String()) + logDebug.Println("Sending passive keep-alive to", peer.String()) peer.SendKeepAlive() - peer.TimerResetKeepalive() + + if peer.timer.needAnotherKeepalive { + peer.timer.keepalivePassive.Reset(KeepaliveTimeout) + peer.timer.needAnotherKeepalive = true + } + + // unresponsive session + + case <-peer.timer.newHandshake.C: + + logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply") + + signalSend(peer.signal.handshakeBegin) // clear key material @@ -175,13 +193,15 @@ func (peer *Peer) RoutineTimerHandler() { logDebug.Println("Clearing all key material for", peer.String()) - kp := &peer.keyPairs - kp.mutex.Lock() - hs := &peer.handshake hs.mutex.Lock() - // unmap local indecies + kp := &peer.keyPairs + kp.mutex.Lock() + + peer.timer.pendingZeroAllKeys = false + + // unmap indecies indices.mutex.Lock() if kp.previous != nil { @@ -224,80 +244,103 @@ func (peer *Peer) RoutineTimerHandler() { func (peer *Peer) RoutineHandshakeInitiator() { device := peer.device - var elem *QueueOutboundElement - logInfo := device.log.Info logError := device.log.Error logDebug := device.log.Debug logDebug.Println("Routine, handshake initator, started for", peer.String()) + var temp [256]byte + for { // wait for signal select { case <-peer.signal.handshakeBegin: + signalSend(peer.signal.handshakeBegin) case <-peer.signal.stop: return } // wait for handshake - func() { - var err error - var deadline time.Time - for attempts := uint(1); ; attempts++ { + deadline := time.Now().Add(MaxHandshakeAttemptTime) - // clear completed signal + Loop: + for attempts := uint(1); ; attempts++ { - select { - case <-peer.signal.handshakeCompleted: - case <-peer.signal.stop: - return - default: - } + // clear completed signal - // create initiation - - if elem != nil { - elem.Drop() - } - elem, err = peer.BeginHandshakeInitiation() - if err != nil { - logError.Println("Failed to create initiation message", err, "for", peer.String()) - return - } - - // set timeout - - if attempts == 1 { - deadline = time.Now().Add(MaxHandshakeAttemptTime) - } - timeout := time.NewTimer(RekeyTimeout) - logDebug.Println("Handshake initiation attempt", attempts, "queued for", peer.String()) - - // wait for handshake or timeout - - select { - - case <-peer.signal.stop: - return - - case <-peer.signal.handshakeCompleted: - <-timeout.C - return - - case <-timeout.C: - if deadline.Before(time.Now().Add(RekeyTimeout)) { - logInfo.Println("Handshake negotiation timed out for", peer.String()) - signalSend(peer.signal.flushNonceQueue) - timerStop(peer.timer.keepalivePersistent) - timerStop(peer.timer.keepalivePassive) - return - } - } + select { + case <-peer.signal.handshakeCompleted: + case <-peer.signal.stop: + return + default: } - }() + + // check if sufficient time for retry + + if deadline.Before(time.Now().Add(RekeyTimeout)) { + logInfo.Println("Handshake negotiation timed out for", peer.String()) + signalSend(peer.signal.flushNonceQueue) + timerStop(peer.timer.keepalivePersistent) + timerStop(peer.timer.keepalivePassive) + break Loop + } + + // create initiation message + + msg, err := peer.device.CreateMessageInitiation(peer) + if err != nil { + logError.Println("Failed to create handshake initiation message:", err) + break Loop + } + peer.TimerEphemeralKeyCreated() + + // marshal and send + + writer := bytes.NewBuffer(temp[:0]) + binary.Write(writer, binary.LittleEndian, msg) + packet := writer.Bytes() + peer.mac.AddMacs(packet) + peer.TimerPacketSent() + + _, err = peer.SendBuffer(packet) + if err != nil { + logError.Println( + "Failed to send handshake initiation message to", + peer.String(), ":", err, + ) + continue + } + + // set timeout + + timeout := time.NewTimer(RekeyTimeout) + logDebug.Println( + "Handshake initiation attempt", + attempts, "sent to", peer.String(), + ) + + // wait for handshake or timeout + + select { + + case <-peer.signal.stop: + return + + case <-peer.signal.handshakeCompleted: + <-timeout.C + break Loop + + case <-timeout.C: + continue + + } + + } + + // allow new signal to be set signalClear(peer.signal.handshakeBegin) }