diff --git a/receive.go b/receive.go index b23c5e0..6b6543c 100644 --- a/receive.go +++ b/receive.go @@ -43,59 +43,28 @@ func (elem *QueueInboundElement) IsDropped() bool { return atomic.LoadInt32(&elem.dropped) == AtomicTrue } -func (device *Device) addToInboundQueue( - queue chan *QueueInboundElement, - element *QueueInboundElement, -) { - for { +func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool { + select { + case inboundQueue <- element: select { - case queue <- element: - return + case decryptionQueue <- element: + return true default: - select { - case old := <-queue: - old.Drop() - default: - } + element.Drop() + element.mutex.Unlock() + return false } + default: + return false } } -func (device *Device) addToDecryptionQueue( - queue chan *QueueInboundElement, - element *QueueInboundElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - // drop & release to potential consumer - old.Drop() - old.mutex.Unlock() - default: - } - } - } -} - -func (device *Device) addToHandshakeQueue( - queue chan QueueHandshakeElement, - element QueueHandshakeElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case elem := <-queue: - device.PutMessageBuffer(elem.buffer) - default: - } - } +func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool { + select { + case queue <- element: + return true + default: + return false } } @@ -154,6 +123,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { } if err != nil { + device.PutMessageBuffer(buffer) return } @@ -212,9 +182,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { // add to decryption queues if peer.isRunning.Get() { - device.addToDecryptionQueue(device.queue.decryption, elem) - device.addToInboundQueue(peer.queue.inbound, elem) - buffer = device.GetMessageBuffer() + if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) { + buffer = device.GetMessageBuffer() + } } continue @@ -235,7 +205,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { } if okay { - device.addToHandshakeQueue( + if (device.addToHandshakeQueue( device.queue.handshake, QueueHandshakeElement{ msgType: msgType, @@ -243,8 +213,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { packet: packet, endpoint: endpoint, }, - ) - buffer = device.GetMessageBuffer() + )) { + buffer = device.GetMessageBuffer() + } } } } @@ -307,6 +278,8 @@ func (device *Device) RoutineDecryption() { ) if err != nil { elem.Drop() + device.PutMessageBuffer(elem.buffer) + elem.mutex.Unlock() } elem.mutex.Unlock() } diff --git a/send.go b/send.go index 37ae738..3b6cfa3 100644 --- a/send.go +++ b/send.go @@ -66,10 +66,7 @@ func (elem *QueueOutboundElement) IsDropped() bool { return atomic.LoadInt32(&elem.dropped) == AtomicTrue } -func addToOutboundQueue( - queue chan *QueueOutboundElement, - element *QueueOutboundElement, -) { +func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) { for { select { case queue <- element: @@ -78,32 +75,30 @@ func addToOutboundQueue( select { case old := <-queue: old.Drop() + device.PutMessageBuffer(element.buffer) default: } } } } -func addToEncryptionQueue( - queue chan *QueueOutboundElement, - element *QueueOutboundElement, -) { - for { +func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) { + select { + case outboundQueue <- element: select { - case queue <- element: + case encryptionQueue <- element: return default: - select { - case old := <-queue: - // drop & release to potential consumer - old.Drop() - old.mutex.Unlock() - default: - } + element.Drop() + element.peer.device.PutMessageBuffer(element.buffer) + element.mutex.Unlock() } + default: + element.peer.device.PutMessageBuffer(element.buffer) } } + /* Queues a keepalive if no packets are queued for peer */ func (peer *Peer) SendKeepalive() bool { @@ -117,6 +112,7 @@ func (peer *Peer) SendKeepalive() bool { peer.device.log.Debug.Println(peer, "- Sending keepalive packet") return true default: + peer.device.PutMessageBuffer(elem.buffer) return false } } @@ -267,6 +263,7 @@ func (device *Device) RoutineReadFromTUN() { logError.Println("Failed to read packet from TUN device:", err) device.Close() } + device.PutMessageBuffer(elem.buffer) return } @@ -308,7 +305,7 @@ func (device *Device) RoutineReadFromTUN() { if peer.queue.packetInNonceQueueIsAwaitingKey.Get() { peer.SendHandshakeInitiation(false) } - addToOutboundQueue(peer.queue.nonce, elem) + addToNonceQueue(peer.queue.nonce, elem, device) elem = device.NewOutboundElement() } } @@ -342,7 +339,8 @@ func (peer *Peer) RoutineNonce() { flush := func() { for { select { - case <-peer.queue.nonce: + case elem := <-peer.queue.nonce: + device.PutMessageBuffer(elem.buffer) default: return } @@ -402,10 +400,12 @@ func (peer *Peer) RoutineNonce() { logDebug.Println(peer, "- Obtained awaited keypair") case <-peer.signals.flushNonceQueue: + device.PutMessageBuffer(elem.buffer) flush() goto NextPacket case <-peer.routines.stop: + device.PutMessageBuffer(elem.buffer) return } } @@ -420,6 +420,7 @@ func (peer *Peer) RoutineNonce() { if elem.nonce >= RejectAfterMessages { atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) + device.PutMessageBuffer(elem.buffer) goto NextPacket } @@ -428,9 +429,7 @@ func (peer *Peer) RoutineNonce() { elem.mutex.Lock() // add to parallel and sequential queue - - addToEncryptionQueue(device.queue.encryption, elem) - addToOutboundQueue(peer.queue.outbound, elem) + addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem) } } }