mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
device: use channel close to shut down and drain decryption channel
This is similar to commit e1fa1cc556
,
but for the decryption channel.
It is an alternative fix to f9f655567930a4cd78d40fa4ba0d58503335ae6a.
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
parent
675955de5d
commit
48c3b87eb8
@ -76,7 +76,7 @@ type Device struct {
|
|||||||
|
|
||||||
queue struct {
|
queue struct {
|
||||||
encryption *encryptionQueue
|
encryption *encryptionQueue
|
||||||
decryption chan *QueueInboundElement
|
decryption *decryptionQueue
|
||||||
handshake chan QueueHandshakeElement
|
handshake chan QueueHandshakeElement
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,6 +115,24 @@ func newEncryptionQueue() *encryptionQueue {
|
|||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A decryptionQueue is similar to an encryptionQueue; see those docs.
|
||||||
|
type decryptionQueue struct {
|
||||||
|
c chan *QueueInboundElement
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDecryptionQueue() *decryptionQueue {
|
||||||
|
q := &decryptionQueue{
|
||||||
|
c: make(chan *QueueInboundElement, QueueInboundSize),
|
||||||
|
}
|
||||||
|
q.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
q.wg.Wait()
|
||||||
|
close(q.c)
|
||||||
|
}()
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
/* Converts the peer into a "zombie", which remains in the peer map,
|
/* Converts the peer into a "zombie", which remains in the peer map,
|
||||||
* but processes no packets and does not exists in the routing table.
|
* but processes no packets and does not exists in the routing table.
|
||||||
*
|
*
|
||||||
@ -308,7 +326,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
|||||||
|
|
||||||
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
|
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
|
||||||
device.queue.encryption = newEncryptionQueue()
|
device.queue.encryption = newEncryptionQueue()
|
||||||
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
|
device.queue.decryption = newDecryptionQueue()
|
||||||
|
|
||||||
// prepare signals
|
// prepare signals
|
||||||
|
|
||||||
@ -369,13 +387,6 @@ func (device *Device) RemoveAllPeers() {
|
|||||||
func (device *Device) FlushPacketQueues() {
|
func (device *Device) FlushPacketQueues() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case elem, ok := <-device.queue.decryption:
|
|
||||||
if ok {
|
|
||||||
if !elem.IsDropped() {
|
|
||||||
elem.Drop()
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case <-device.queue.handshake:
|
case <-device.queue.handshake:
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
@ -399,10 +410,11 @@ func (device *Device) Close() {
|
|||||||
|
|
||||||
device.isUp.Set(false)
|
device.isUp.Set(false)
|
||||||
|
|
||||||
// We kept a reference to the encryption queue,
|
// We kept a reference to the encryption and decryption queues,
|
||||||
// in case we started any new peers that might write to it.
|
// in case we started any new peers that might write to them.
|
||||||
// No new peers are coming; we are done with the encryption queue.
|
// No new peers are coming; we are done with these queues.
|
||||||
device.queue.encryption.wg.Done()
|
device.queue.encryption.wg.Done()
|
||||||
|
device.queue.decryption.wg.Done()
|
||||||
close(device.signals.stop)
|
close(device.signals.stop)
|
||||||
device.state.stopping.Wait()
|
device.state.stopping.Wait()
|
||||||
|
|
||||||
@ -549,6 +561,7 @@ func (device *Device) BindUpdate() error {
|
|||||||
// start receiving routines
|
// start receiving routines
|
||||||
|
|
||||||
device.net.stopping.Add(2)
|
device.net.stopping.Add(2)
|
||||||
|
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||||
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
||||||
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
||||||
|
|
||||||
|
@ -109,6 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
|||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
defer func() {
|
defer func() {
|
||||||
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
|
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
|
||||||
|
device.queue.decryption.wg.Done()
|
||||||
device.net.stopping.Done()
|
device.net.stopping.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -206,7 +207,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
|||||||
|
|
||||||
peer.queue.RLock()
|
peer.queue.RLock()
|
||||||
if peer.isRunning.Get() {
|
if peer.isRunning.Get() {
|
||||||
if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
|
if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption.c, elem) {
|
||||||
buffer = device.GetMessageBuffer()
|
buffer = device.GetMessageBuffer()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -258,59 +259,35 @@ func (device *Device) RoutineDecryption() {
|
|||||||
}()
|
}()
|
||||||
logDebug.Println("Routine: decryption worker - started")
|
logDebug.Println("Routine: decryption worker - started")
|
||||||
|
|
||||||
for {
|
for elem := range device.queue.decryption.c {
|
||||||
select {
|
// check if dropped
|
||||||
case <-device.signals.stop:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case elem, ok := <-device.queue.decryption:
|
|
||||||
if ok {
|
|
||||||
if !elem.IsDropped() {
|
|
||||||
elem.Drop()
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
}
|
|
||||||
elem.Unlock()
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case elem, ok := <-device.queue.decryption:
|
if elem.IsDropped() {
|
||||||
|
continue
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if dropped
|
|
||||||
|
|
||||||
if elem.IsDropped() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// split message into fields
|
|
||||||
|
|
||||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
|
||||||
content := elem.packet[MessageTransportOffsetContent:]
|
|
||||||
|
|
||||||
// decrypt and release to consumer
|
|
||||||
|
|
||||||
var err error
|
|
||||||
elem.counter = binary.LittleEndian.Uint64(counter)
|
|
||||||
// copy counter to nonce
|
|
||||||
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
|
||||||
elem.packet, err = elem.keypair.receive.Open(
|
|
||||||
content[:0],
|
|
||||||
nonce[:],
|
|
||||||
content,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
elem.Drop()
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
}
|
|
||||||
elem.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// split message into fields
|
||||||
|
|
||||||
|
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
||||||
|
content := elem.packet[MessageTransportOffsetContent:]
|
||||||
|
|
||||||
|
// decrypt and release to consumer
|
||||||
|
|
||||||
|
var err error
|
||||||
|
elem.counter = binary.LittleEndian.Uint64(counter)
|
||||||
|
// copy counter to nonce
|
||||||
|
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
||||||
|
elem.packet, err = elem.keypair.receive.Open(
|
||||||
|
content[:0],
|
||||||
|
nonce[:],
|
||||||
|
content,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
elem.Drop()
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
}
|
||||||
|
elem.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user