From beb25cc4fd31da09590fed3200628baf4c701f8b Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Fri, 29 Jan 2021 18:24:45 +0100 Subject: [PATCH] device: use new model queues for handshakes Signed-off-by: Jason A. Donenfeld --- device/device.go | 55 +++++++++++++++++----------------- device/receive.go | 76 +++++++++++++++-------------------------------- 2 files changed, 52 insertions(+), 79 deletions(-) diff --git a/device/device.go b/device/device.go index 08db244..fd88855 100644 --- a/device/device.go +++ b/device/device.go @@ -13,6 +13,7 @@ import ( "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" @@ -77,11 +78,7 @@ type Device struct { queue struct { encryption *outboundQueue decryption *inboundQueue - handshake chan QueueHandshakeElement - } - - signals struct { - stop chan struct{} + handshake *handshakeQueue } tun struct { @@ -90,6 +87,7 @@ type Device struct { } ipcMutex sync.RWMutex + closed chan struct{} } // An outboundQueue is a channel of QueueOutboundElements awaiting encryption. @@ -135,6 +133,24 @@ func newInboundQueue() *inboundQueue { return q } +// A handshakeQueue is similar to an outboundQueue; see those docs. +type handshakeQueue struct { + c chan QueueHandshakeElement + wg sync.WaitGroup +} + +func newHandshakeQueue() *handshakeQueue { + q := &handshakeQueue{ + c: make(chan QueueHandshakeElement, QueueHandshakeSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + /* Converts the peer into a "zombie", which remains in the peer map, * but processes no packets and does not exists in the routing table. * @@ -233,7 +249,7 @@ func (device *Device) IsUnderLoad() bool { // check if currently under load now := time.Now() - underLoad := len(device.queue.handshake) >= UnderLoadQueueSize + underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize if underLoad { device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime)) return true @@ -302,6 +318,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { func NewDevice(tunDevice tun.Device, logger *Logger) *Device { device := new(Device) + device.closed = make(chan struct{}) device.log = logger device.tun.device = tunDevice mtu, err := device.tun.device.MTU() @@ -322,14 +339,10 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { // create queues - device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) + device.queue.handshake = newHandshakeQueue() device.queue.encryption = newOutboundQueue() device.queue.decryption = newInboundQueue() - // prepare signals - - device.signals.stop = make(chan struct{}) - // prepare net device.net.port = 0 @@ -382,18 +395,6 @@ func (device *Device) RemoveAllPeers() { device.peers.keyMap = make(map[NoisePublicKey]*Peer) } -func (device *Device) FlushPacketQueues() { - for { - select { - case elem := <-device.queue.handshake: - device.PutMessageBuffer(elem.buffer) - default: - return - } - } - -} - func (device *Device) Close() { if device.isClosed.Swap(true) { return @@ -414,21 +415,20 @@ func (device *Device) Close() { // No new peers are coming; we are done with these queues. device.queue.encryption.wg.Done() device.queue.decryption.wg.Done() - close(device.signals.stop) + device.queue.handshake.wg.Done() device.state.stopping.Wait() device.RemoveAllPeers() - device.FlushPacketQueues() - device.rate.limiter.Close() device.state.changing.Set(false) device.log.Verbosef("Interface closed") + close(device.closed) } func (device *Device) Wait() chan struct{} { - return device.signals.stop + return device.closed } func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { @@ -561,6 +561,7 @@ func (device *Device) BindUpdate() error { device.net.stopping.Add(2) device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption + device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) diff --git a/device/receive.go b/device/receive.go index abaf5af..0b70137 100644 --- a/device/receive.go +++ b/device/receive.go @@ -48,15 +48,6 @@ func (elem *QueueInboundElement) clearPointers() { elem.endpoint = nil } -func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem QueueHandshakeElement) bool { - select { - case queue <- elem: - return true - default: - return false - } -} - /* Called when a new authenticated message has been received * * NOTE: Not thread safe, but called by sequential receiver! @@ -81,6 +72,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { defer func() { device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP) device.queue.decryption.wg.Done() + device.queue.handshake.wg.Done() device.net.stopping.Done() }() @@ -202,16 +194,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { } if okay { - if (device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - endpoint: endpoint, - }, - )) { + select { + case device.queue.handshake.c <- QueueHandshakeElement{ + msgType: msgType, + buffer: buffer, + packet: packet, + endpoint: endpoint, + }: buffer = device.GetMessageBuffer() + default: } } } @@ -251,34 +242,13 @@ func (device *Device) RoutineDecryption() { /* Handles incoming packets related to handshake */ func (device *Device) RoutineHandshake() { - var elem QueueHandshakeElement - var ok bool - defer func() { device.log.Verbosef("Routine: handshake worker - stopped") device.state.stopping.Done() - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - } }() - device.log.Verbosef("Routine: handshake worker - started") - for { - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - elem.buffer = nil - } - - select { - case elem, ok = <-device.queue.handshake: - case <-device.signals.stop: - return - } - - if !ok { - return - } + for elem := range device.queue.handshake.c { // handle cookie fields and ratelimiting @@ -293,7 +263,7 @@ func (device *Device) RoutineHandshake() { err := binary.Read(reader, binary.LittleEndian, &reply) if err != nil { device.log.Verbosef("Failed to decode cookie reply") - return + goto skip } // lookup peer from index @@ -301,7 +271,7 @@ func (device *Device) RoutineHandshake() { entry := device.indexTable.Lookup(reply.Receiver) if entry.peer == nil { - continue + goto skip } // consume reply @@ -313,7 +283,7 @@ func (device *Device) RoutineHandshake() { } } - continue + goto skip case MessageInitiationType, MessageResponseType: @@ -321,7 +291,7 @@ func (device *Device) RoutineHandshake() { if !device.cookieChecker.CheckMAC1(elem.packet) { device.log.Verbosef("Received packet with invalid mac1") - continue + goto skip } // endpoints destination address is the source of the datagram @@ -332,19 +302,19 @@ func (device *Device) RoutineHandshake() { if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { device.SendHandshakeCookie(&elem) - continue + goto skip } // check ratelimiter if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { - continue + goto skip } } default: device.log.Errorf("Invalid packet ended up in the handshake queue") - continue + goto skip } // handle handshake initiation/response content @@ -359,7 +329,7 @@ func (device *Device) RoutineHandshake() { err := binary.Read(reader, binary.LittleEndian, &msg) if err != nil { device.log.Errorf("Failed to decode initiation message") - continue + goto skip } // consume initiation @@ -367,7 +337,7 @@ func (device *Device) RoutineHandshake() { peer := device.ConsumeMessageInitiation(&msg) if peer == nil { device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) - continue + goto skip } // update timers @@ -392,7 +362,7 @@ func (device *Device) RoutineHandshake() { err := binary.Read(reader, binary.LittleEndian, &msg) if err != nil { device.log.Errorf("Failed to decode response message") - continue + goto skip } // consume response @@ -400,7 +370,7 @@ func (device *Device) RoutineHandshake() { peer := device.ConsumeMessageResponse(&msg) if peer == nil { device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) - continue + goto skip } // update endpoint @@ -420,13 +390,15 @@ func (device *Device) RoutineHandshake() { if err != nil { device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) - continue + goto skip } peer.timersSessionDerived() peer.timersHandshakeComplete() peer.SendKeepalive() } + skip: + device.PutMessageBuffer(elem.buffer) } }