1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

device: use channel close to shut down and drain encryption channel

The new test introduced in this commit used to deadlock about 1% of the time.

I believe that the deadlock occurs as follows:

* The test completes, calling device.Close.
* device.Close closes device.signals.stop.
* RoutineEncryption stops.
* The deferred function in RoutineEncryption drains device.queue.encryption.
* RoutineEncryption exits.
* A peer's RoutineNonce processes an element queued in peer.queue.nonce.
* RoutineNonce puts that element into the outbound and encryption queues.
* RoutineSequentialSender reads that elements from the outbound queue.
* It waits for that element to get Unlocked by RoutineEncryption.
* RoutineEncryption has already exited, so RoutineSequentialSender blocks forever.
* device.RemoveAllPeers calls peer.Stop on all peers.
* peer.Stop waits for peer.routines.stopping, which blocks forever.

Rather than attempt to add even more ordering to the already complex
centralized shutdown orchestration, this commit moves towards a
data-flow-oriented shutdown.

The device.queue.encryption gets closed when there will be no more writes to it.
All device.queue.encryption readers always read until the channel is closed and then exit.
We thus guarantee that any element that enters the encryption queue also exits it.
This removes the need for central control of the lifetime of RoutineEncryption,
removes the need to drain the encryption queue on shutdown, and simplifies RoutineEncryption.

This commit also fixes a data race. When RoutineSequentialSender
drains its queue on shutdown, it needs to lock the elem before operating on it,
just as the main body does.

The new test in this commit passed 50k iterations with the race detector enabled
and 150k iterations with the race detector disabled, with no failures.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
Josh Bleecher Snyder 2020-12-14 15:07:23 -08:00 committed by Jason A. Donenfeld
parent 41cd68416c
commit e1fa1cc556
3 changed files with 178 additions and 106 deletions

View File

@ -74,7 +74,7 @@ type Device struct {
} }
queue struct { queue struct {
encryption chan *QueueOutboundElement encryption *encryptionQueue
decryption chan *QueueInboundElement decryption chan *QueueInboundElement
handshake chan QueueHandshakeElement handshake chan QueueHandshakeElement
} }
@ -89,6 +89,31 @@ type Device struct {
} }
} }
// An encryptionQueue is a channel of QueueOutboundElements awaiting encryption.
// An encryptionQueue is ref-counted using its wg field.
// An encryptionQueue created with newEncryptionQueue has one reference.
// Every additional writer must call wg.Add(1).
// Every completed writer must call wg.Done().
// When no further writers will be added,
// call wg.Done to remove the initial reference.
// When the refcount hits 0, the queue's channel is closed.
type encryptionQueue struct {
c chan *QueueOutboundElement
wg sync.WaitGroup
}
func newEncryptionQueue() *encryptionQueue {
q := &encryptionQueue{
c: make(chan *QueueOutboundElement, QueueOutboundSize),
}
q.wg.Add(1)
go func() {
q.wg.Wait()
close(q.c)
}()
return q
}
/* Converts the peer into a "zombie", which remains in the peer map, /* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table. * but processes no packets and does not exists in the routing table.
* *
@ -280,7 +305,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
// create queues // create queues
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) device.queue.encryption = newEncryptionQueue()
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signals // prepare signals
@ -297,7 +322,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
cpus := runtime.NumCPU() cpus := runtime.NumCPU()
device.state.stopping.Wait() device.state.stopping.Wait()
for i := 0; i < cpus; i += 1 { for i := 0; i < cpus; i += 1 {
device.state.stopping.Add(3) device.state.stopping.Add(2) // decryption and handshake
go device.RoutineEncryption() go device.RoutineEncryption()
go device.RoutineDecryption() go device.RoutineDecryption()
go device.RoutineHandshake() go device.RoutineHandshake()
@ -346,10 +371,6 @@ func (device *Device) FlushPacketQueues() {
if ok { if ok {
elem.Drop() elem.Drop()
} }
case elem, ok := <-device.queue.encryption:
if ok {
elem.Drop()
}
case <-device.queue.handshake: case <-device.queue.handshake:
default: default:
return return
@ -373,6 +394,10 @@ func (device *Device) Close() {
device.isUp.Set(false) device.isUp.Set(false)
// We kept a reference to the encryption queue,
// in case we started any new peers that might write to it.
// No new peers are coming; we are done with the encryption queue.
device.queue.encryption.wg.Done()
close(device.signals.stop) close(device.signals.stop)
device.state.stopping.Wait() device.state.stopping.Wait()

View File

@ -7,9 +7,11 @@ package device
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
@ -79,18 +81,74 @@ func genConfigs(t *testing.T) (cfgs [2]io.Reader) {
return return
} }
// genChannelTUNs creates a usable pair of ChannelTUNs for use in a test. // A testPair is a pair of testPeers.
func genChannelTUNs(t *testing.T) (tun [2]*tuntest.ChannelTUN) { type testPair [2]testPeer
// A testPeer is a peer used for testing.
type testPeer struct {
tun *tuntest.ChannelTUN
dev *Device
ip net.IP
}
type SendDirection bool
const (
Ping SendDirection = true
Pong SendDirection = false
)
func (pair *testPair) Send(t *testing.T, ping SendDirection, done chan struct{}) {
t.Helper()
p0, p1 := pair[0], pair[1]
if !ping {
// pong is the new ping
p0, p1 = p1, p0
}
msg := tuntest.Ping(p0.ip, p1.ip)
p1.tun.Outbound <- msg
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
var err error
select {
case msgRecv := <-p0.tun.Inbound:
if !bytes.Equal(msg, msgRecv) {
err = errors.New("ping did not transit correctly")
}
case <-timer.C:
err = errors.New("ping did not transit")
case <-done:
}
if err != nil {
// The error may have occurred because the test is done.
select {
case <-done:
return
default:
}
// Real error.
t.Error(err)
}
}
// genTestPair creates a testPair.
func genTestPair(t *testing.T) (pair testPair) {
const maxAttempts = 10 const maxAttempts = 10
NextAttempt: NextAttempt:
for i := 0; i < maxAttempts; i++ { for i := 0; i < maxAttempts; i++ {
cfg := genConfigs(t) cfg := genConfigs(t)
// Bring up a ChannelTun for each config. // Bring up a ChannelTun for each config.
for i := range tun { for i := range pair {
tun[i] = tuntest.NewChannelTUN() p := &pair[i]
dev := NewDevice(tun[i].TUN(), NewLogger(LogLevelDebug, fmt.Sprintf("dev%d: ", i))) p.tun = tuntest.NewChannelTUN()
dev.Up() if i == 0 {
if err := dev.IpcSetOperation(cfg[i]); err != nil { p.ip = net.ParseIP("1.0.0.1")
} else {
p.ip = net.ParseIP("1.0.0.2")
}
p.dev = NewDevice(p.tun.TUN(), NewLogger(LogLevelDebug, fmt.Sprintf("dev%d: ", i)))
p.dev.Up()
if err := p.dev.IpcSetOperation(cfg[i]); err != nil {
// genConfigs attempted to pick ports that were free. // genConfigs attempted to pick ports that were free.
// There's a tiny window between genConfigs closing the port // There's a tiny window between genConfigs closing the port
// and us opening it, during which another process could // and us opening it, during which another process could
@ -104,12 +162,12 @@ NextAttempt:
// The device might still not be up, e.g. due to an error // The device might still not be up, e.g. due to an error
// in RoutineTUNEventReader's call to dev.Up that got swallowed. // in RoutineTUNEventReader's call to dev.Up that got swallowed.
// Assume it's due to a transient error (port in use), and retry. // Assume it's due to a transient error (port in use), and retry.
if !dev.isUp.Get() { if !p.dev.isUp.Get() {
t.Logf("%v did not come up, trying again", dev) t.Logf("device %d did not come up, trying again", i)
continue NextAttempt continue NextAttempt
} }
// The device is up. Close it when the test completes. // The device is up. Close it when the test completes.
t.Cleanup(dev.Close) t.Cleanup(p.dev.Close)
} }
return // success return // success
} }
@ -119,35 +177,49 @@ NextAttempt:
} }
func TestTwoDevicePing(t *testing.T) { func TestTwoDevicePing(t *testing.T) {
tun := genChannelTUNs(t) pair := genTestPair(t)
t.Run("ping 1.0.0.1", func(t *testing.T) { t.Run("ping 1.0.0.1", func(t *testing.T) {
msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) pair.Send(t, Ping, nil)
tun[1].Outbound <- msg2to1
select {
case msgRecv := <-tun[0].Inbound:
if !bytes.Equal(msg2to1, msgRecv) {
t.Error("ping did not transit correctly")
}
case <-time.After(5 * time.Second):
t.Error("ping did not transit")
}
}) })
t.Run("ping 1.0.0.2", func(t *testing.T) { t.Run("ping 1.0.0.2", func(t *testing.T) {
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) pair.Send(t, Pong, nil)
tun[0].Outbound <- msg1to2
select {
case msgRecv := <-tun[1].Inbound:
if !bytes.Equal(msg1to2, msgRecv) {
t.Error("return ping did not transit correctly")
}
case <-time.After(5 * time.Second):
t.Error("return ping did not transit")
}
}) })
} }
// TestConcurrencySafety does other things concurrently with tunnel use.
// It is intended to be used with the race detector to catch data races.
func TestConcurrencySafety(t *testing.T) {
pair := genTestPair(t)
done := make(chan struct{})
const warmupIters = 10
var warmup sync.WaitGroup
warmup.Add(warmupIters)
go func() {
// Send data continuously back and forth until we're done.
// Note that we may continue to attempt to send data
// even after done is closed.
i := warmupIters
for ping := Ping; ; ping = !ping {
pair.Send(t, ping, done)
select {
case <-done:
return
default:
}
if i > 0 {
warmup.Done()
i--
}
}
}()
warmup.Wait()
// coming soon: more things here...
close(done)
}
func assertNil(t *testing.T, err error) { func assertNil(t *testing.T, err error) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -352,6 +352,9 @@ func (peer *Peer) RoutineNonce() {
device := peer.device device := peer.device
logDebug := device.log.Debug logDebug := device.log.Debug
// We write to the encryption queue; keep it alive until we are done.
device.queue.encryption.wg.Add(1)
flush := func() { flush := func() {
for { for {
select { select {
@ -368,6 +371,7 @@ func (peer *Peer) RoutineNonce() {
flush() flush()
logDebug.Println(peer, "- Routine: nonce worker - stopped") logDebug.Println(peer, "- Routine: nonce worker - stopped")
peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
device.queue.encryption.wg.Done() // no more writes from us
peer.routines.stopping.Done() peer.routines.stopping.Done()
}() }()
@ -455,7 +459,7 @@ NextPacket:
elem.Lock() elem.Lock()
// add to parallel and sequential queue // add to parallel and sequential queue
addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem) addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption.c, elem)
} }
} }
} }
@ -486,76 +490,46 @@ func (device *Device) RoutineEncryption() {
logDebug := device.log.Debug logDebug := device.log.Debug
defer func() { defer logDebug.Println("Routine: encryption worker - stopped")
for {
select {
case elem, ok := <-device.queue.encryption:
if ok && !elem.IsDropped() {
elem.Drop()
device.PutMessageBuffer(elem.buffer)
elem.Unlock()
}
default:
goto out
}
}
out:
logDebug.Println("Routine: encryption worker - stopped")
device.state.stopping.Done()
}()
logDebug.Println("Routine: encryption worker - started") logDebug.Println("Routine: encryption worker - started")
for { for elem := range device.queue.encryption.c {
// fetch next element // check if dropped
select { if elem.IsDropped() {
case <-device.signals.stop: continue
return
case elem, ok := <-device.queue.encryption:
if !ok {
return
}
// check if dropped
if elem.IsDropped() {
continue
}
// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
for i := 0; i < paddingSize; i++ {
elem.packet = append(elem.packet, 0)
}
// encrypt content and release to consumer
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keypair.send.Seal(
header,
nonce[:],
elem.packet,
nil,
)
elem.Unlock()
} }
// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
for i := 0; i < paddingSize; i++ {
elem.packet = append(elem.packet, 0)
}
// encrypt content and release to consumer
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keypair.send.Seal(
header,
nonce[:],
elem.packet,
nil,
)
elem.Unlock()
} }
} }
@ -576,6 +550,7 @@ func (peer *Peer) RoutineSequentialSender() {
select { select {
case elem, ok := <-peer.queue.outbound: case elem, ok := <-peer.queue.outbound:
if ok { if ok {
elem.Lock()
if !elem.IsDropped() { if !elem.IsDropped() {
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
elem.Drop() elem.Drop()