1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Change queueing drop order and fix memory leaks

If the queues are full, we drop the present packet, which is better for
network traffic flow. Also, we try to fix up the memory leaks with not
putting buffers from our shared pool.
This commit is contained in:
Jason A. Donenfeld 2018-09-16 21:50:58 +02:00
parent 1c02557013
commit 39d6e4f2f1
2 changed files with 47 additions and 75 deletions

View File

@ -43,59 +43,28 @@ func (elem *QueueInboundElement) IsDropped() bool {
return atomic.LoadInt32(&elem.dropped) == AtomicTrue return atomic.LoadInt32(&elem.dropped) == AtomicTrue
} }
func (device *Device) addToInboundQueue( func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool {
queue chan *QueueInboundElement, select {
element *QueueInboundElement, case inboundQueue <- element:
) {
for {
select { select {
case queue <- element: case decryptionQueue <- element:
return return true
default: default:
select { element.Drop()
case old := <-queue: element.mutex.Unlock()
old.Drop() return false
default:
}
} }
default:
return false
} }
} }
func (device *Device) addToDecryptionQueue( func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool {
queue chan *QueueInboundElement, select {
element *QueueInboundElement, case queue <- element:
) { return true
for { default:
select { return false
case queue <- element:
return
default:
select {
case old := <-queue:
// drop & release to potential consumer
old.Drop()
old.mutex.Unlock()
default:
}
}
}
}
func (device *Device) addToHandshakeQueue(
queue chan QueueHandshakeElement,
element QueueHandshakeElement,
) {
for {
select {
case queue <- element:
return
default:
select {
case elem := <-queue:
device.PutMessageBuffer(elem.buffer)
default:
}
}
} }
} }
@ -154,6 +123,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
} }
if err != nil { if err != nil {
device.PutMessageBuffer(buffer)
return return
} }
@ -212,9 +182,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
// add to decryption queues // add to decryption queues
if peer.isRunning.Get() { if peer.isRunning.Get() {
device.addToDecryptionQueue(device.queue.decryption, elem) if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
device.addToInboundQueue(peer.queue.inbound, elem) buffer = device.GetMessageBuffer()
buffer = device.GetMessageBuffer() }
} }
continue continue
@ -235,7 +205,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
} }
if okay { if okay {
device.addToHandshakeQueue( if (device.addToHandshakeQueue(
device.queue.handshake, device.queue.handshake,
QueueHandshakeElement{ QueueHandshakeElement{
msgType: msgType, msgType: msgType,
@ -243,8 +213,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
packet: packet, packet: packet,
endpoint: endpoint, endpoint: endpoint,
}, },
) )) {
buffer = device.GetMessageBuffer() buffer = device.GetMessageBuffer()
}
} }
} }
} }
@ -307,6 +278,8 @@ func (device *Device) RoutineDecryption() {
) )
if err != nil { if err != nil {
elem.Drop() elem.Drop()
device.PutMessageBuffer(elem.buffer)
elem.mutex.Unlock()
} }
elem.mutex.Unlock() elem.mutex.Unlock()
} }

43
send.go
View File

@ -66,10 +66,7 @@ func (elem *QueueOutboundElement) IsDropped() bool {
return atomic.LoadInt32(&elem.dropped) == AtomicTrue return atomic.LoadInt32(&elem.dropped) == AtomicTrue
} }
func addToOutboundQueue( func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) {
queue chan *QueueOutboundElement,
element *QueueOutboundElement,
) {
for { for {
select { select {
case queue <- element: case queue <- element:
@ -78,32 +75,30 @@ func addToOutboundQueue(
select { select {
case old := <-queue: case old := <-queue:
old.Drop() old.Drop()
device.PutMessageBuffer(element.buffer)
default: default:
} }
} }
} }
} }
func addToEncryptionQueue( func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) {
queue chan *QueueOutboundElement, select {
element *QueueOutboundElement, case outboundQueue <- element:
) {
for {
select { select {
case queue <- element: case encryptionQueue <- element:
return return
default: default:
select { element.Drop()
case old := <-queue: element.peer.device.PutMessageBuffer(element.buffer)
// drop & release to potential consumer element.mutex.Unlock()
old.Drop()
old.mutex.Unlock()
default:
}
} }
default:
element.peer.device.PutMessageBuffer(element.buffer)
} }
} }
/* Queues a keepalive if no packets are queued for peer /* Queues a keepalive if no packets are queued for peer
*/ */
func (peer *Peer) SendKeepalive() bool { func (peer *Peer) SendKeepalive() bool {
@ -117,6 +112,7 @@ func (peer *Peer) SendKeepalive() bool {
peer.device.log.Debug.Println(peer, "- Sending keepalive packet") peer.device.log.Debug.Println(peer, "- Sending keepalive packet")
return true return true
default: default:
peer.device.PutMessageBuffer(elem.buffer)
return false return false
} }
} }
@ -267,6 +263,7 @@ func (device *Device) RoutineReadFromTUN() {
logError.Println("Failed to read packet from TUN device:", err) logError.Println("Failed to read packet from TUN device:", err)
device.Close() device.Close()
} }
device.PutMessageBuffer(elem.buffer)
return return
} }
@ -308,7 +305,7 @@ func (device *Device) RoutineReadFromTUN() {
if peer.queue.packetInNonceQueueIsAwaitingKey.Get() { if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
addToOutboundQueue(peer.queue.nonce, elem) addToNonceQueue(peer.queue.nonce, elem, device)
elem = device.NewOutboundElement() elem = device.NewOutboundElement()
} }
} }
@ -342,7 +339,8 @@ func (peer *Peer) RoutineNonce() {
flush := func() { flush := func() {
for { for {
select { select {
case <-peer.queue.nonce: case elem := <-peer.queue.nonce:
device.PutMessageBuffer(elem.buffer)
default: default:
return return
} }
@ -402,10 +400,12 @@ func (peer *Peer) RoutineNonce() {
logDebug.Println(peer, "- Obtained awaited keypair") logDebug.Println(peer, "- Obtained awaited keypair")
case <-peer.signals.flushNonceQueue: case <-peer.signals.flushNonceQueue:
device.PutMessageBuffer(elem.buffer)
flush() flush()
goto NextPacket goto NextPacket
case <-peer.routines.stop: case <-peer.routines.stop:
device.PutMessageBuffer(elem.buffer)
return return
} }
} }
@ -420,6 +420,7 @@ func (peer *Peer) RoutineNonce() {
if elem.nonce >= RejectAfterMessages { if elem.nonce >= RejectAfterMessages {
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
device.PutMessageBuffer(elem.buffer)
goto NextPacket goto NextPacket
} }
@ -428,9 +429,7 @@ func (peer *Peer) RoutineNonce() {
elem.mutex.Lock() elem.mutex.Lock()
// add to parallel and sequential queue // add to parallel and sequential queue
addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem)
addToEncryptionQueue(device.queue.encryption, elem)
addToOutboundQueue(peer.queue.outbound, elem)
} }
} }
} }