1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Add missing locks and fix debug output, and try to flush queues

Flushing queues on exit is sort of a partial solution, but this could be
better. Really what we want is for no more packets to be enqueued after
isUp is set to false.
This commit is contained in:
Jason A. Donenfeld 2018-05-01 16:59:13 +02:00
parent b34604245e
commit 168ef61a63
5 changed files with 116 additions and 42 deletions

View File

@ -339,6 +339,8 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
} }
func (device *Device) RemoveAllPeers() { func (device *Device) RemoveAllPeers() {
device.noise.mutex.Lock()
defer device.noise.mutex.Unlock()
device.routing.mutex.Lock() device.routing.mutex.Lock()
defer device.routing.mutex.Unlock() defer device.routing.mutex.Unlock()
@ -354,16 +356,25 @@ func (device *Device) RemoveAllPeers() {
} }
func (device *Device) Close() { func (device *Device) Close() {
device.log.Info.Println("Device closing")
if device.isClosed.Swap(true) { if device.isClosed.Swap(true) {
return return
} }
device.signal.stop.Broadcast() device.log.Info.Println("Device closing")
device.state.changing.Set(true)
device.state.mutex.Lock()
defer device.state.mutex.Unlock()
device.tun.device.Close() device.tun.device.Close()
device.BindClose() device.BindClose()
device.isUp.Set(false) device.isUp.Set(false)
device.signal.stop.Broadcast()
device.RemoveAllPeers() device.RemoveAllPeers()
device.rate.limiter.Close() device.rate.limiter.Close()
device.state.changing.Set(false)
device.log.Info.Println("Interface closed") device.log.Info.Println("Interface closed")
} }

View File

@ -195,7 +195,7 @@ func (peer *Peer) Start() {
} }
device := peer.device device := peer.device
device.log.Debug.Println(peer.String(), ": Starting...") device.log.Debug.Println(peer.String() + ": Starting...")
// sanity check : these should be 0 // sanity check : these should be 0
@ -242,7 +242,7 @@ func (peer *Peer) Stop() {
} }
device := peer.device device := peer.device
device.log.Debug.Println(peer.String(), ": Stopping...") device.log.Debug.Println(peer.String() + ": Stopping...")
// stop & wait for ongoing peer routines // stop & wait for ongoing peer routines

View File

@ -7,6 +7,7 @@ import (
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"net" "net"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -101,7 +102,11 @@ func (device *Device) addToHandshakeQueue(
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, receive incoming, IP version:", IP) defer func() {
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
}()
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - starting")
// receive datagrams until conn is closed // receive datagrams until conn is closed
@ -224,15 +229,31 @@ func (device *Device) RoutineDecryption() {
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, decryption, started for device") defer func() {
for {
select {
case elem, ok := <-device.queue.decryption:
if ok {
elem.Drop()
}
default:
break
}
}
logDebug.Println("Routine: decryption worker - stopped")
}()
logDebug.Println("Routine: decryption worker - started")
for { for {
select { select {
case <-device.signal.stop.Wait(): case <-device.signal.stop.Wait():
logDebug.Println("Routine, decryption worker, stopped")
return return
case elem := <-device.queue.decryption: case elem, ok := <-device.queue.decryption:
if !ok {
return
}
// check if dropped // check if dropped
@ -282,18 +303,35 @@ func (device *Device) RoutineHandshake() {
logInfo := device.log.Info logInfo := device.log.Info
logError := device.log.Error logError := device.log.Error
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, handshake routine, started for device")
defer func() {
for {
select {
case <-device.queue.handshake:
default:
return
}
}
logDebug.Println("Routine: handshake worker - stopped")
}()
logDebug.Println("Routine: handshake worker - started")
var temp [MessageHandshakeSize]byte var temp [MessageHandshakeSize]byte
var elem QueueHandshakeElement var elem QueueHandshakeElement
var ok bool
for { for {
select { select {
case elem = <-device.queue.handshake: case elem, ok = <-device.queue.handshake:
case <-device.signal.stop.Wait(): case <-device.signal.stop.Wait():
return return
} }
if !ok {
return
}
// handle cookie fields and ratelimiting // handle cookie fields and ratelimiting
switch elem.msgType { switch elem.msgType {
@ -419,7 +457,7 @@ func (device *Device) RoutineHandshake() {
peer.endpoint = elem.endpoint peer.endpoint = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
logDebug.Println(peer, ": Received handshake initiation") logDebug.Println(peer.String() + ": Received handshake initiation")
// create response // create response
@ -477,7 +515,7 @@ func (device *Device) RoutineHandshake() {
peer.endpoint = elem.endpoint peer.endpoint = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
logDebug.Println(peer, ": Received handshake response") logDebug.Println(peer.String() + ": Received handshake response")
peer.TimerEphemeralKeyCreated() peer.TimerEphemeralKeyCreated()
@ -504,10 +542,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
defer func() { defer func() {
peer.routines.stopping.Done() peer.routines.stopping.Done()
logDebug.Println(peer.String(), ": Routine, Sequential Receiver, Stopped") logDebug.Println(peer.String() + ": Routine: sequential receiver - stopped")
}() }()
logDebug.Println(peer.String(), ": Routine, Sequential Receiver, Started") logDebug.Println(peer.String() + ": Routine: sequential receiver - started")
peer.routines.starting.Done() peer.routines.starting.Done()

53
send.go
View File

@ -121,7 +121,11 @@ func (device *Device) RoutineReadFromTUN() {
logDebug := device.log.Debug logDebug := device.log.Debug
logError := device.log.Error logError := device.log.Error
logDebug.Println("Routine, TUN Reader started") defer func() {
logDebug.Println("Routine: TUN reader - stopped")
}()
logDebug.Println("Routine: TUN reader - started")
for { for {
@ -192,11 +196,11 @@ func (peer *Peer) RoutineNonce() {
defer func() { defer func() {
peer.routines.stopping.Done() peer.routines.stopping.Done()
logDebug.Println(peer.String(), ": Routine, Nonce Worker, Stopped") logDebug.Println(peer.String() + ": Routine: nonce worker - stopped")
}() }()
peer.routines.starting.Done() peer.routines.starting.Done()
logDebug.Println(peer.String(), ": Routine, Nonce Worker, Started") logDebug.Println(peer.String() + ": Routine: nonce worker - started")
for { for {
NextPacket: NextPacket:
@ -204,7 +208,11 @@ func (peer *Peer) RoutineNonce() {
case <-peer.routines.stop.Wait(): case <-peer.routines.stop.Wait():
return return
case elem := <-peer.queue.nonce: case elem, ok := <-peer.queue.nonce:
if !ok {
return
}
// wait for key pair // wait for key pair
@ -218,13 +226,13 @@ func (peer *Peer) RoutineNonce() {
peer.signal.handshakeBegin.Send() peer.signal.handshakeBegin.Send()
logDebug.Println(peer.String(), ": Awaiting key-pair") logDebug.Println(peer.String() + ": Awaiting key-pair")
select { select {
case <-peer.signal.newKeyPair.Wait(): case <-peer.signal.newKeyPair.Wait():
logDebug.Println(peer.String(), ": Obtained awaited key-pair") logDebug.Println(peer.String() + ": Obtained awaited key-pair")
case <-peer.signal.flushNonceQueue.Wait(): case <-peer.signal.flushNonceQueue.Wait():
logDebug.Println(peer.String(), ": Flushing nonce queue") logDebug.Println(peer.String() + ": Flushing nonce queue")
peer.FlushNonceQueue() peer.FlushNonceQueue()
goto NextPacket goto NextPacket
case <-peer.routines.stop.Wait(): case <-peer.routines.stop.Wait():
@ -258,7 +266,22 @@ func (device *Device) RoutineEncryption() {
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, encryption worker, started")
defer func() {
for {
select {
case elem, ok := <-device.queue.encryption:
if ok {
elem.Drop()
}
default:
break
}
}
logDebug.Println("Routine: encryption worker - stopped")
}()
logDebug.Println("Routine: encryption worker - started")
for { for {
@ -266,10 +289,13 @@ func (device *Device) RoutineEncryption() {
select { select {
case <-device.signal.stop.Wait(): case <-device.signal.stop.Wait():
logDebug.Println("Routine, encryption worker, stopped")
return return
case elem := <-device.queue.encryption: case elem, ok := <-device.queue.encryption:
if !ok {
return
}
// check if dropped // check if dropped
@ -323,21 +349,20 @@ func (peer *Peer) RoutineSequentialSender() {
device := peer.device device := peer.device
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, sequential sender, started for", peer.String())
defer func() { defer func() {
peer.routines.stopping.Done() peer.routines.stopping.Done()
logDebug.Println(peer.String(), ": Routine, Sequential sender, Stopped") logDebug.Println(peer.String() + ": Routine: sequential sender - stopped")
}() }()
logDebug.Println(peer.String() + ": Routine: sequential sender - started")
peer.routines.starting.Done() peer.routines.starting.Done()
for { for {
select { select {
case <-peer.routines.stop.Wait(): case <-peer.routines.stop.Wait():
logDebug.Println(
"Routine, sequential sender, stopped for", peer.String())
return return
case elem, ok := <-peer.queue.outbound: case elem, ok := <-peer.queue.outbound:

View File

@ -120,7 +120,7 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
*/ */
func (peer *Peer) TimerHandshakeComplete() { func (peer *Peer) TimerHandshakeComplete() {
peer.signal.handshakeCompleted.Send() peer.signal.handshakeCompleted.Send()
peer.device.log.Info.Println(peer.String(), ": New handshake completed") peer.device.log.Info.Println(peer.String() + ": New handshake completed")
} }
/* Event: /* Event:
@ -189,10 +189,12 @@ func (peer *Peer) RoutineTimerHandler() {
logDebug := device.log.Debug logDebug := device.log.Debug
defer func() { defer func() {
logDebug.Println(peer.String(), ": Routine, Timer handler, Stopped") logDebug.Println(peer.String() + ": Routine: timer handler - stopped")
peer.routines.stopping.Done() peer.routines.stopping.Done()
}() }()
logDebug.Println(peer.String() + ": Routine: timer handler - started")
// reset all timers // reset all timers
peer.timer.keepalivePassive.Stop() peer.timer.keepalivePassive.Stop()
@ -207,8 +209,6 @@ func (peer *Peer) RoutineTimerHandler() {
peer.timer.keepalivePersistent.Reset(duration) peer.timer.keepalivePersistent.Reset(duration)
} }
logDebug.Println("Routine, timer handler, started for peer", peer.String())
// signal synchronised setup complete // signal synchronised setup complete
peer.routines.starting.Done() peer.routines.starting.Done()
@ -231,14 +231,14 @@ func (peer *Peer) RoutineTimerHandler() {
interval := peer.persistentKeepaliveInterval interval := peer.persistentKeepaliveInterval
if interval > 0 { if interval > 0 {
logDebug.Println(peer.String(), ": Send keep-alive (persistent)") logDebug.Println(peer.String() + ": Send keep-alive (persistent)")
peer.timer.keepalivePassive.Stop() peer.timer.keepalivePassive.Stop()
peer.SendKeepAlive() peer.SendKeepAlive()
} }
case <-peer.timer.keepalivePassive.Wait(): case <-peer.timer.keepalivePassive.Wait():
logDebug.Println(peer.String(), ": Send keep-alive (passive)") logDebug.Println(peer.String() + ": Send keep-alive (passive)")
peer.SendKeepAlive() peer.SendKeepAlive()
@ -250,7 +250,7 @@ func (peer *Peer) RoutineTimerHandler() {
case <-peer.timer.zeroAllKeys.Wait(): case <-peer.timer.zeroAllKeys.Wait():
logDebug.Println(peer.String(), ": Clear all key-material (timer event)") logDebug.Println(peer.String() + ": Clear all key-material (timer event)")
hs := &peer.handshake hs := &peer.handshake
hs.mutex.Lock() hs.mutex.Lock()
@ -283,7 +283,7 @@ func (peer *Peer) RoutineTimerHandler() {
// handshake timers // handshake timers
case <-peer.timer.handshakeNew.Wait(): case <-peer.timer.handshakeNew.Wait():
logInfo.Println(peer.String(), ": Retrying handshake (timer event)") logInfo.Println(peer.String() + ": Retrying handshake (timer event)")
peer.signal.handshakeBegin.Send() peer.signal.handshakeBegin.Send()
case <-peer.timer.handshakeTimeout.Wait(): case <-peer.timer.handshakeTimeout.Wait():
@ -301,16 +301,16 @@ func (peer *Peer) RoutineTimerHandler() {
err := peer.sendNewHandshake() err := peer.sendNewHandshake()
if err != nil { if err != nil {
logInfo.Println(peer.String(), ": Failed to send handshake initiation", err) logInfo.Println(peer.String()+": Failed to send handshake initiation", err)
} else { } else {
logDebug.Println(peer.String(), ": Send handshake initiation (subsequent)") logDebug.Println(peer.String() + ": Send handshake initiation (subsequent)")
} }
case <-peer.timer.handshakeDeadline.Wait(): case <-peer.timer.handshakeDeadline.Wait():
// clear all queued packets and stop keep-alive // clear all queued packets and stop keep-alive
logInfo.Println(peer.String(), ": Handshake negotiation timed-out") logInfo.Println(peer.String() + ": Handshake negotiation timed-out")
peer.signal.flushNonceQueue.Send() peer.signal.flushNonceQueue.Send()
peer.timer.keepalivePersistent.Stop() peer.timer.keepalivePersistent.Stop()
@ -325,16 +325,16 @@ func (peer *Peer) RoutineTimerHandler() {
err := peer.sendNewHandshake() err := peer.sendNewHandshake()
if err != nil { if err != nil {
logInfo.Println(peer.String(), ": Failed to send handshake initiation", err) logInfo.Println(peer.String()+": Failed to send handshake initiation", err)
} else { } else {
logDebug.Println(peer.String(), ": Send handshake initiation (initial)") logDebug.Println(peer.String() + ": Send handshake initiation (initial)")
} }
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
case <-peer.signal.handshakeCompleted.Wait(): case <-peer.signal.handshakeCompleted.Wait():
logInfo.Println(peer.String(), ": Handshake completed") logInfo.Println(peer.String() + ": Handshake completed")
atomic.StoreInt64( atomic.StoreInt64(
&peer.stats.lastHandshakeNano, &peer.stats.lastHandshakeNano,