From 168ef61a638e4875b260edbc51551bae0dc34ac3 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Tue, 1 May 2018 16:59:13 +0200 Subject: [PATCH] Add missing locks and fix debug output, and try to flush queues Flushing queues on exit is sort of a partial solution, but this could be better. Really what we want is for no more packets to be enqueued after isUp is set to false. --- device.go | 15 ++++++++++++-- peer.go | 4 ++-- receive.go | 58 ++++++++++++++++++++++++++++++++++++++++++++---------- send.go | 53 ++++++++++++++++++++++++++++++++++++------------- timers.go | 28 +++++++++++++------------- 5 files changed, 116 insertions(+), 42 deletions(-) diff --git a/device.go b/device.go index 3ad53c9..dddb547 100644 --- a/device.go +++ b/device.go @@ -339,6 +339,8 @@ func (device *Device) RemovePeer(key NoisePublicKey) { } func (device *Device) RemoveAllPeers() { + device.noise.mutex.Lock() + defer device.noise.mutex.Unlock() device.routing.mutex.Lock() defer device.routing.mutex.Unlock() @@ -354,16 +356,25 @@ func (device *Device) RemoveAllPeers() { } func (device *Device) Close() { - device.log.Info.Println("Device closing") if device.isClosed.Swap(true) { return } - device.signal.stop.Broadcast() + device.log.Info.Println("Device closing") + device.state.changing.Set(true) + device.state.mutex.Lock() + defer device.state.mutex.Unlock() + device.tun.device.Close() device.BindClose() + device.isUp.Set(false) + + device.signal.stop.Broadcast() + device.RemoveAllPeers() device.rate.limiter.Close() + + device.state.changing.Set(false) device.log.Info.Println("Interface closed") } diff --git a/peer.go b/peer.go index f10bfbb..ec411b2 100644 --- a/peer.go +++ b/peer.go @@ -195,7 +195,7 @@ func (peer *Peer) Start() { } device := peer.device - device.log.Debug.Println(peer.String(), ": Starting...") + device.log.Debug.Println(peer.String() + ": Starting...") // sanity check : these should be 0 @@ -242,7 +242,7 @@ func (peer *Peer) Stop() { } device := peer.device - device.log.Debug.Println(peer.String(), ": Stopping...") + device.log.Debug.Println(peer.String() + ": Stopping...") // stop & wait for ongoing peer routines diff --git a/receive.go b/receive.go index ca20900..7d35497 100644 --- a/receive.go +++ b/receive.go @@ -7,6 +7,7 @@ import ( "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "net" + "strconv" "sync" "sync/atomic" "time" @@ -101,7 +102,11 @@ func (device *Device) addToHandshakeQueue( func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { logDebug := device.log.Debug - logDebug.Println("Routine, receive incoming, IP version:", IP) + defer func() { + logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") + }() + + logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - starting") // receive datagrams until conn is closed @@ -224,15 +229,31 @@ func (device *Device) RoutineDecryption() { var nonce [chacha20poly1305.NonceSize]byte logDebug := device.log.Debug - logDebug.Println("Routine, decryption, started for device") + defer func() { + for { + select { + case elem, ok := <-device.queue.decryption: + if ok { + elem.Drop() + } + default: + break + } + } + logDebug.Println("Routine: decryption worker - stopped") + }() + logDebug.Println("Routine: decryption worker - started") for { select { case <-device.signal.stop.Wait(): - logDebug.Println("Routine, decryption worker, stopped") return - case elem := <-device.queue.decryption: + case elem, ok := <-device.queue.decryption: + + if !ok { + return + } // check if dropped @@ -282,18 +303,35 @@ func (device *Device) RoutineHandshake() { logInfo := device.log.Info logError := device.log.Error logDebug := device.log.Debug - logDebug.Println("Routine, handshake routine, started for device") + + defer func() { + for { + select { + case <-device.queue.handshake: + default: + return + } + } + logDebug.Println("Routine: handshake worker - stopped") + }() + + logDebug.Println("Routine: handshake worker - started") var temp [MessageHandshakeSize]byte var elem QueueHandshakeElement + var ok bool for { select { - case elem = <-device.queue.handshake: + case elem, ok = <-device.queue.handshake: case <-device.signal.stop.Wait(): return } + if !ok { + return + } + // handle cookie fields and ratelimiting switch elem.msgType { @@ -419,7 +457,7 @@ func (device *Device) RoutineHandshake() { peer.endpoint = elem.endpoint peer.mutex.Unlock() - logDebug.Println(peer, ": Received handshake initiation") + logDebug.Println(peer.String() + ": Received handshake initiation") // create response @@ -477,7 +515,7 @@ func (device *Device) RoutineHandshake() { peer.endpoint = elem.endpoint peer.mutex.Unlock() - logDebug.Println(peer, ": Received handshake response") + logDebug.Println(peer.String() + ": Received handshake response") peer.TimerEphemeralKeyCreated() @@ -504,10 +542,10 @@ func (peer *Peer) RoutineSequentialReceiver() { defer func() { peer.routines.stopping.Done() - logDebug.Println(peer.String(), ": Routine, Sequential Receiver, Stopped") + logDebug.Println(peer.String() + ": Routine: sequential receiver - stopped") }() - logDebug.Println(peer.String(), ": Routine, Sequential Receiver, Started") + logDebug.Println(peer.String() + ": Routine: sequential receiver - started") peer.routines.starting.Done() diff --git a/send.go b/send.go index df8efdb..5c6b350 100644 --- a/send.go +++ b/send.go @@ -121,7 +121,11 @@ func (device *Device) RoutineReadFromTUN() { logDebug := device.log.Debug logError := device.log.Error - logDebug.Println("Routine, TUN Reader started") + defer func() { + logDebug.Println("Routine: TUN reader - stopped") + }() + + logDebug.Println("Routine: TUN reader - started") for { @@ -192,11 +196,11 @@ func (peer *Peer) RoutineNonce() { defer func() { peer.routines.stopping.Done() - logDebug.Println(peer.String(), ": Routine, Nonce Worker, Stopped") + logDebug.Println(peer.String() + ": Routine: nonce worker - stopped") }() peer.routines.starting.Done() - logDebug.Println(peer.String(), ": Routine, Nonce Worker, Started") + logDebug.Println(peer.String() + ": Routine: nonce worker - started") for { NextPacket: @@ -204,7 +208,11 @@ func (peer *Peer) RoutineNonce() { case <-peer.routines.stop.Wait(): return - case elem := <-peer.queue.nonce: + case elem, ok := <-peer.queue.nonce: + + if !ok { + return + } // wait for key pair @@ -218,13 +226,13 @@ func (peer *Peer) RoutineNonce() { peer.signal.handshakeBegin.Send() - logDebug.Println(peer.String(), ": Awaiting key-pair") + logDebug.Println(peer.String() + ": Awaiting key-pair") select { case <-peer.signal.newKeyPair.Wait(): - logDebug.Println(peer.String(), ": Obtained awaited key-pair") + logDebug.Println(peer.String() + ": Obtained awaited key-pair") case <-peer.signal.flushNonceQueue.Wait(): - logDebug.Println(peer.String(), ": Flushing nonce queue") + logDebug.Println(peer.String() + ": Flushing nonce queue") peer.FlushNonceQueue() goto NextPacket case <-peer.routines.stop.Wait(): @@ -258,7 +266,22 @@ func (device *Device) RoutineEncryption() { var nonce [chacha20poly1305.NonceSize]byte logDebug := device.log.Debug - logDebug.Println("Routine, encryption worker, started") + + defer func() { + for { + select { + case elem, ok := <-device.queue.encryption: + if ok { + elem.Drop() + } + default: + break + } + } + logDebug.Println("Routine: encryption worker - stopped") + }() + + logDebug.Println("Routine: encryption worker - started") for { @@ -266,10 +289,13 @@ func (device *Device) RoutineEncryption() { select { case <-device.signal.stop.Wait(): - logDebug.Println("Routine, encryption worker, stopped") return - case elem := <-device.queue.encryption: + case elem, ok := <-device.queue.encryption: + + if !ok { + return + } // check if dropped @@ -323,21 +349,20 @@ func (peer *Peer) RoutineSequentialSender() { device := peer.device logDebug := device.log.Debug - logDebug.Println("Routine, sequential sender, started for", peer.String()) defer func() { peer.routines.stopping.Done() - logDebug.Println(peer.String(), ": Routine, Sequential sender, Stopped") + logDebug.Println(peer.String() + ": Routine: sequential sender - stopped") }() + logDebug.Println(peer.String() + ": Routine: sequential sender - started") + peer.routines.starting.Done() for { select { case <-peer.routines.stop.Wait(): - logDebug.Println( - "Routine, sequential sender, stopped for", peer.String()) return case elem, ok := <-peer.queue.outbound: diff --git a/timers.go b/timers.go index 8725570..ba0d0e5 100644 --- a/timers.go +++ b/timers.go @@ -120,7 +120,7 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { */ func (peer *Peer) TimerHandshakeComplete() { peer.signal.handshakeCompleted.Send() - peer.device.log.Info.Println(peer.String(), ": New handshake completed") + peer.device.log.Info.Println(peer.String() + ": New handshake completed") } /* Event: @@ -189,10 +189,12 @@ func (peer *Peer) RoutineTimerHandler() { logDebug := device.log.Debug defer func() { - logDebug.Println(peer.String(), ": Routine, Timer handler, Stopped") + logDebug.Println(peer.String() + ": Routine: timer handler - stopped") peer.routines.stopping.Done() }() + logDebug.Println(peer.String() + ": Routine: timer handler - started") + // reset all timers peer.timer.keepalivePassive.Stop() @@ -207,8 +209,6 @@ func (peer *Peer) RoutineTimerHandler() { peer.timer.keepalivePersistent.Reset(duration) } - logDebug.Println("Routine, timer handler, started for peer", peer.String()) - // signal synchronised setup complete peer.routines.starting.Done() @@ -231,14 +231,14 @@ func (peer *Peer) RoutineTimerHandler() { interval := peer.persistentKeepaliveInterval if interval > 0 { - logDebug.Println(peer.String(), ": Send keep-alive (persistent)") + logDebug.Println(peer.String() + ": Send keep-alive (persistent)") peer.timer.keepalivePassive.Stop() peer.SendKeepAlive() } case <-peer.timer.keepalivePassive.Wait(): - logDebug.Println(peer.String(), ": Send keep-alive (passive)") + logDebug.Println(peer.String() + ": Send keep-alive (passive)") peer.SendKeepAlive() @@ -250,7 +250,7 @@ func (peer *Peer) RoutineTimerHandler() { case <-peer.timer.zeroAllKeys.Wait(): - logDebug.Println(peer.String(), ": Clear all key-material (timer event)") + logDebug.Println(peer.String() + ": Clear all key-material (timer event)") hs := &peer.handshake hs.mutex.Lock() @@ -283,7 +283,7 @@ func (peer *Peer) RoutineTimerHandler() { // handshake timers case <-peer.timer.handshakeNew.Wait(): - logInfo.Println(peer.String(), ": Retrying handshake (timer event)") + logInfo.Println(peer.String() + ": Retrying handshake (timer event)") peer.signal.handshakeBegin.Send() case <-peer.timer.handshakeTimeout.Wait(): @@ -301,16 +301,16 @@ func (peer *Peer) RoutineTimerHandler() { err := peer.sendNewHandshake() if err != nil { - logInfo.Println(peer.String(), ": Failed to send handshake initiation", err) + logInfo.Println(peer.String()+": Failed to send handshake initiation", err) } else { - logDebug.Println(peer.String(), ": Send handshake initiation (subsequent)") + logDebug.Println(peer.String() + ": Send handshake initiation (subsequent)") } case <-peer.timer.handshakeDeadline.Wait(): // clear all queued packets and stop keep-alive - logInfo.Println(peer.String(), ": Handshake negotiation timed-out") + logInfo.Println(peer.String() + ": Handshake negotiation timed-out") peer.signal.flushNonceQueue.Send() peer.timer.keepalivePersistent.Stop() @@ -325,16 +325,16 @@ func (peer *Peer) RoutineTimerHandler() { err := peer.sendNewHandshake() if err != nil { - logInfo.Println(peer.String(), ": Failed to send handshake initiation", err) + logInfo.Println(peer.String()+": Failed to send handshake initiation", err) } else { - logDebug.Println(peer.String(), ": Send handshake initiation (initial)") + logDebug.Println(peer.String() + ": Send handshake initiation (initial)") } peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) case <-peer.signal.handshakeCompleted.Wait(): - logInfo.Println(peer.String(), ": Handshake completed") + logInfo.Println(peer.String() + ": Handshake completed") atomic.StoreInt64( &peer.stats.lastHandshakeNano,