diff --git a/device/device.go b/device/device.go index d37fe6f..9a9b1b3 100644 --- a/device/device.go +++ b/device/device.go @@ -76,7 +76,7 @@ type Device struct { queue struct { encryption *encryptionQueue - decryption chan *QueueInboundElement + decryption *decryptionQueue handshake chan QueueHandshakeElement } @@ -115,6 +115,24 @@ func newEncryptionQueue() *encryptionQueue { return q } +// A decryptionQueue is similar to an encryptionQueue; see those docs. +type decryptionQueue struct { + c chan *QueueInboundElement + wg sync.WaitGroup +} + +func newDecryptionQueue() *decryptionQueue { + q := &decryptionQueue{ + c: make(chan *QueueInboundElement, QueueInboundSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + /* Converts the peer into a "zombie", which remains in the peer map, * but processes no packets and does not exists in the routing table. * @@ -308,7 +326,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) device.queue.encryption = newEncryptionQueue() - device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) + device.queue.decryption = newDecryptionQueue() // prepare signals @@ -369,13 +387,6 @@ func (device *Device) RemoveAllPeers() { func (device *Device) FlushPacketQueues() { for { select { - case elem, ok := <-device.queue.decryption: - if ok { - if !elem.IsDropped() { - elem.Drop() - device.PutMessageBuffer(elem.buffer) - } - } case <-device.queue.handshake: default: return @@ -399,10 +410,11 @@ func (device *Device) Close() { device.isUp.Set(false) - // We kept a reference to the encryption queue, - // in case we started any new peers that might write to it. - // No new peers are coming; we are done with the encryption queue. + // We kept a reference to the encryption and decryption queues, + // in case we started any new peers that might write to them. + // No new peers are coming; we are done with these queues. device.queue.encryption.wg.Done() + device.queue.decryption.wg.Done() close(device.signals.stop) device.state.stopping.Wait() @@ -549,6 +561,7 @@ func (device *Device) BindUpdate() error { // start receiving routines device.net.stopping.Add(2) + device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) diff --git a/device/receive.go b/device/receive.go index fa31a1a..20e0c8f 100644 --- a/device/receive.go +++ b/device/receive.go @@ -109,6 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { logDebug := device.log.Debug defer func() { logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") + device.queue.decryption.wg.Done() device.net.stopping.Done() }() @@ -206,7 +207,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { peer.queue.RLock() if peer.isRunning.Get() { - if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) { + if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption.c, elem) { buffer = device.GetMessageBuffer() } } else { @@ -258,59 +259,35 @@ func (device *Device) RoutineDecryption() { }() logDebug.Println("Routine: decryption worker - started") - for { - select { - case <-device.signals.stop: - for { - select { - case elem, ok := <-device.queue.decryption: - if ok { - if !elem.IsDropped() { - elem.Drop() - device.PutMessageBuffer(elem.buffer) - } - elem.Unlock() - } - default: - return - } - } + for elem := range device.queue.decryption.c { + // check if dropped - case elem, ok := <-device.queue.decryption: - - if !ok { - return - } - - // check if dropped - - if elem.IsDropped() { - continue - } - - // split message into fields - - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] - content := elem.packet[MessageTransportOffsetContent:] - - // decrypt and release to consumer - - var err error - elem.counter = binary.LittleEndian.Uint64(counter) - // copy counter to nonce - binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) - elem.packet, err = elem.keypair.receive.Open( - content[:0], - nonce[:], - content, - nil, - ) - if err != nil { - elem.Drop() - device.PutMessageBuffer(elem.buffer) - } - elem.Unlock() + if elem.IsDropped() { + continue } + + // split message into fields + + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] + content := elem.packet[MessageTransportOffsetContent:] + + // decrypt and release to consumer + + var err error + elem.counter = binary.LittleEndian.Uint64(counter) + // copy counter to nonce + binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) + elem.packet, err = elem.keypair.receive.Open( + content[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.Drop() + device.PutMessageBuffer(elem.buffer) + } + elem.Unlock() } }