1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Added initial version of peer teardown

There is a double lock issue with device.Close which has yet to be
resolved.
This commit is contained in:
Mathias Hall-Andersen 2018-01-26 22:52:32 +01:00
parent 068d932f2c
commit f73d2fb2d9
6 changed files with 177 additions and 70 deletions

View File

@ -64,9 +64,13 @@ func unsafeCloseBind(device *Device) error {
return err return err
} }
/* Must hold device and net lock func (device *Device) BindUpdate() error {
*/ device.mutex.Lock()
func unsafeUpdateBind(device *Device) error { defer device.mutex.Unlock()
netc := &device.net
netc.mutex.Lock()
defer netc.mutex.Unlock()
// close existing sockets // close existing sockets
@ -74,18 +78,13 @@ func unsafeUpdateBind(device *Device) error {
return err return err
} }
// assumption: netc.update WaitGroup should be exactly 1
// open new sockets // open new sockets
if device.isUp.Get() { if device.isUp.Get() {
device.log.Debug.Println("UDP bind updating")
// bind to new port // bind to new port
var err error var err error
netc := &device.net
netc.bind, netc.port, err = CreateBind(netc.port) netc.bind, netc.port, err = CreateBind(netc.port)
if err != nil { if err != nil {
netc.bind = nil netc.bind = nil
@ -109,7 +108,7 @@ func unsafeUpdateBind(device *Device) error {
peer.mutex.Unlock() peer.mutex.Unlock()
} }
// decrease waitgroup to 0 // start receiving routines
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
@ -120,7 +119,7 @@ func unsafeUpdateBind(device *Device) error {
return nil return nil
} }
func closeBind(device *Device) error { func (device *Device) BindClose() error {
device.mutex.Lock() device.mutex.Lock()
device.net.mutex.Lock() device.net.mutex.Lock()
err := unsafeCloseBind(device) err := unsafeCloseBind(device)

View File

@ -9,7 +9,7 @@ import (
) )
type Device struct { type Device struct {
isUp AtomicBool // device is up (TUN interface up)? isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard) isClosed AtomicBool // device is closed? (acting as guard)
log *Logger // collection of loggers for levels log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers idCounter uint // for assigning debug ids to peers
@ -18,6 +18,11 @@ type Device struct {
device TUNDevice device TUNDevice
mtu int32 mtu int32
} }
state struct {
mutex deadlock.Mutex
changing AtomicBool
current bool
}
pool struct { pool struct {
messageBuffers sync.Pool messageBuffers sync.Pool
} }
@ -46,37 +51,86 @@ type Device struct {
mac CookieChecker mac CookieChecker
} }
func (device *Device) Up() { func deviceUpdateState(device *Device) {
device.mutex.Lock()
defer device.mutex.Unlock()
device.net.mutex.Lock() // check if state already being updated (guard)
defer device.net.mutex.Unlock()
if device.isUp.Swap(true) { if device.state.changing.Swap(true) {
return return
} }
unsafeUpdateBind(device) // compare to current state of device
for _, peer := range device.peers { device.state.mutex.Lock()
peer.Start()
newIsUp := device.isUp.Get()
if newIsUp == device.state.current {
device.state.mutex.Unlock()
device.state.changing.Set(false)
return
} }
device.state.mutex.Unlock()
// change state of device
switch newIsUp {
case true:
// start listener
if err := device.BindUpdate(); err != nil {
device.isUp.Set(false)
break
}
// start every peer
for _, peer := range device.peers {
peer.Start()
}
case false:
// stop listening
device.BindClose()
// stop every peer
for _, peer := range device.peers {
peer.Stop()
}
}
// update state variables
// and check for state change in the mean time
device.state.current = newIsUp
device.state.changing.Set(false)
deviceUpdateState(device)
}
func (device *Device) Up() {
// closed device cannot be brought up
if device.isClosed.Get() {
return
}
device.state.mutex.Lock()
device.isUp.Set(true)
device.state.mutex.Unlock()
deviceUpdateState(device)
} }
func (device *Device) Down() { func (device *Device) Down() {
device.mutex.Lock() device.state.mutex.Lock()
defer device.mutex.Unlock() device.isUp.Set(false)
device.state.mutex.Unlock()
if !device.isUp.Swap(false) { deviceUpdateState(device)
return
}
closeBind(device)
for _, peer := range device.peers {
peer.Stop()
}
} }
/* Warning: /* Warning:
@ -87,7 +141,6 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) {
if !ok { if !ok {
return return
} }
peer.Stop()
device.routingTable.RemovePeer(peer) device.routingTable.RemovePeer(peer)
delete(device.peers, key) delete(device.peers, key)
} }
@ -231,20 +284,30 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
func (device *Device) RemoveAllPeers() { func (device *Device) RemoveAllPeers() {
device.mutex.Lock() device.mutex.Lock()
defer device.mutex.Unlock() defer device.mutex.Unlock()
for key := range device.peers {
removePeerUnsafe(device, key) for key, peer := range device.peers {
peer.Stop()
peer, ok := device.peers[key]
if !ok {
return
}
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
} }
} }
func (device *Device) Close() { func (device *Device) Close() {
device.log.Info.Println("Device closing")
if device.isClosed.Swap(true) { if device.isClosed.Swap(true) {
return return
} }
device.log.Info.Println("Closing device")
device.RemoveAllPeers()
device.signal.stop.Broadcast() device.signal.stop.Broadcast()
device.tun.device.Close() device.tun.device.Close()
closeBind(device) device.BindClose()
device.isUp.Set(false)
println("remove")
device.RemoveAllPeers()
device.log.Info.Println("Interface closed")
} }
func (device *Device) Wait() chan struct{} { func (device *Device) Wait() chan struct{} {

View File

@ -4,6 +4,7 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"github.com/sasha-s/go-deadlock"
"sync" "sync"
"time" "time"
) )
@ -14,7 +15,8 @@ const (
type Peer struct { type Peer struct {
id uint id uint
mutex sync.RWMutex isRunning AtomicBool
mutex deadlock.RWMutex
persistentKeepaliveInterval uint64 persistentKeepaliveInterval uint64
keyPairs KeyPairs keyPairs KeyPairs
handshake Handshake handshake Handshake
@ -26,7 +28,7 @@ type Peer struct {
lastHandshakeNano int64 // nano seconds since epoch lastHandshakeNano int64 // nano seconds since epoch
} }
time struct { time struct {
mutex sync.RWMutex mutex deadlock.RWMutex
lastSend time.Time // last send message lastSend time.Time // last send message
lastHandshake time.Time // last completed handshake lastHandshake time.Time // last completed handshake
nextKeepalive time.Time nextKeepalive time.Time
@ -58,7 +60,7 @@ type Peer struct {
inbound chan *QueueInboundElement // sequential ordering of work inbound chan *QueueInboundElement // sequential ordering of work
} }
routines struct { routines struct {
mutex sync.Mutex // held when stopping / starting routines mutex deadlock.Mutex // held when stopping / starting routines
starting sync.WaitGroup // routines pending start starting sync.WaitGroup // routines pending start
stopping sync.WaitGroup // routines pending stop stopping sync.WaitGroup // routines pending stop
stop Signal // size 0, stop all goroutines in peer stop Signal // size 0, stop all goroutines in peer
@ -67,6 +69,14 @@ type Peer struct {
} }
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
if device.isClosed.Get() {
return nil, errors.New("Device closed")
}
device.mutex.Lock()
defer device.mutex.Unlock()
// create peer // create peer
peer := new(Peer) peer := new(Peer)
@ -75,17 +85,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.mac.Init(pk) peer.mac.Init(pk)
peer.device = device peer.device = device
peer.isRunning.Set(false)
peer.timer.zeroAllKeys = NewTimer()
peer.timer.keepalivePersistent = NewTimer() peer.timer.keepalivePersistent = NewTimer()
peer.timer.keepalivePassive = NewTimer() peer.timer.keepalivePassive = NewTimer()
peer.timer.zeroAllKeys = NewTimer()
peer.timer.handshakeNew = NewTimer() peer.timer.handshakeNew = NewTimer()
peer.timer.handshakeDeadline = NewTimer() peer.timer.handshakeDeadline = NewTimer()
peer.timer.handshakeTimeout = NewTimer() peer.timer.handshakeTimeout = NewTimer()
// assign id for debugging // assign id for debugging
device.mutex.Lock()
peer.id = device.idCounter peer.id = device.idCounter
device.idCounter += 1 device.idCounter += 1
@ -102,7 +112,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return nil, errors.New("Adding existing peer") return nil, errors.New("Adding existing peer")
} }
device.peers[pk] = peer device.peers[pk] = peer
device.mutex.Unlock()
// precompute DH // precompute DH
@ -117,23 +126,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.endpoint = nil peer.endpoint = nil
// prepare queuing
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signaling & routines // prepare signaling & routines
peer.signal.newKeyPair = NewSignal()
peer.signal.handshakeBegin = NewSignal()
peer.signal.handshakeCompleted = NewSignal()
peer.signal.flushNonceQueue = NewSignal()
peer.routines.mutex.Lock() peer.routines.mutex.Lock()
peer.routines.stop = NewSignal() peer.routines.stop = NewSignal()
peer.routines.mutex.Unlock() peer.routines.mutex.Unlock()
// start peer
peer.device.state.mutex.Lock()
if peer.device.isUp.Get() {
peer.Start()
}
peer.device.state.mutex.Unlock()
return peer, nil return peer, nil
} }
@ -148,6 +154,10 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
return errors.New("No known endpoint for peer") return errors.New("No known endpoint for peer")
} }
if peer.device.net.bind == nil {
return errors.New("No bind")
}
return peer.device.net.bind.Send(buffer, peer.endpoint) return peer.device.net.bind.Send(buffer, peer.endpoint)
} }
@ -174,12 +184,26 @@ func (peer *Peer) Start() {
peer.routines.mutex.Lock() peer.routines.mutex.Lock()
defer peer.routines.mutex.Lock() defer peer.routines.mutex.Lock()
peer.device.log.Debug.Println("Starting:", peer.String())
// stop & wait for ungoing routines (if any) // stop & wait for ungoing routines (if any)
peer.isRunning.Set(false)
peer.routines.stop.Broadcast() peer.routines.stop.Broadcast()
peer.routines.starting.Wait() peer.routines.starting.Wait()
peer.routines.stopping.Wait() peer.routines.stopping.Wait()
// prepare queues
peer.signal.newKeyPair = NewSignal()
peer.signal.handshakeBegin = NewSignal()
peer.signal.handshakeCompleted = NewSignal()
peer.signal.flushNonceQueue = NewSignal()
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
// reset signal and start (new) routines // reset signal and start (new) routines
peer.routines.stop = NewSignal() peer.routines.stop = NewSignal()
@ -192,6 +216,7 @@ func (peer *Peer) Start() {
go peer.RoutineSequentialReceiver() go peer.RoutineSequentialReceiver()
peer.routines.starting.Wait() peer.routines.starting.Wait()
peer.isRunning.Set(true)
} }
func (peer *Peer) Stop() { func (peer *Peer) Stop() {
@ -199,13 +224,22 @@ func (peer *Peer) Stop() {
peer.routines.mutex.Lock() peer.routines.mutex.Lock()
defer peer.routines.mutex.Lock() defer peer.routines.mutex.Lock()
peer.device.log.Debug.Println("Stopping:", peer.String())
// stop & wait for ungoing routines (if any) // stop & wait for ungoing routines (if any)
peer.routines.stop.Broadcast() peer.routines.stop.Broadcast()
peer.routines.starting.Wait() peer.routines.starting.Wait()
peer.routines.stopping.Wait() peer.routines.stopping.Wait()
// close queues
close(peer.queue.nonce)
close(peer.queue.outbound)
close(peer.queue.inbound)
// reset signal (to handle repeated stopping) // reset signal (to handle repeated stopping)
peer.routines.stop = NewSignal() peer.routines.stop = NewSignal()
peer.isRunning.Set(false)
} }

View File

@ -123,7 +123,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
case ipv6.Version: case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:]) size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default: default:
return panic("invalid IP version")
} }
if err != nil { if err != nil {
@ -184,9 +184,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
// add to decryption queues // add to decryption queues
device.addToDecryptionQueue(device.queue.decryption, elem) if peer.isRunning.Get() {
device.addToInboundQueue(peer.queue.inbound, elem) device.addToDecryptionQueue(device.queue.decryption, elem)
buffer = device.GetMessageBuffer() device.addToInboundQueue(peer.queue.inbound, elem)
buffer = device.GetMessageBuffer()
}
continue continue
@ -308,13 +310,20 @@ func (device *Device) RoutineHandshake() {
return return
} }
// lookup peer and consume response // lookup peer from index
entry := device.indices.Lookup(reply.Receiver) entry := device.indices.Lookup(reply.Receiver)
if entry.peer == nil { if entry.peer == nil {
continue continue
} }
entry.peer.mac.ConsumeReply(&reply)
// consume reply
if peer := entry.peer; peer.isRunning.Get() {
peer.mac.ConsumeReply(&reply)
}
continue continue
case MessageInitiationType, MessageResponseType: case MessageInitiationType, MessageResponseType:

View File

@ -170,9 +170,11 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue // insert into nonce/pre-handshake queue
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) if peer.isRunning.Get() {
addToOutboundQueue(peer.queue.nonce, elem) peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
elem = device.NewOutboundElement() addToOutboundQueue(peer.queue.nonce, elem)
elem = device.NewOutboundElement()
}
} }
} }

View File

@ -144,16 +144,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update port and rebind // update port and rebind
device.mutex.Lock()
device.net.mutex.Lock() device.net.mutex.Lock()
device.net.port = uint16(port) device.net.port = uint16(port)
err = unsafeUpdateBind(device)
device.net.mutex.Unlock() device.net.mutex.Unlock()
device.mutex.Unlock()
if err != nil { if err := device.BindUpdate(); err != nil {
logError.Println("Failed to set listen_port:", err) logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorPortInUse} return &IPCError{Code: ipcErrorPortInUse}
} }
@ -179,6 +174,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.net.fwmark = uint32(fwmark) device.net.fwmark = uint32(fwmark)
device.net.mutex.Unlock() device.net.mutex.Unlock()
if err := device.BindUpdate(); err != nil {
logError.Println("Failed to update fwmark:", err)
return &IPCError{Code: ipcErrorPortInUse}
}
case "public_key": case "public_key":
// switch to peer configuration // switch to peer configuration
deviceConfig = false deviceConfig = false