diff --git a/src/conn.go b/src/conn.go index 1d033ff..c2f5dee 100644 --- a/src/conn.go +++ b/src/conn.go @@ -64,9 +64,13 @@ func unsafeCloseBind(device *Device) error { return err } -/* Must hold device and net lock - */ -func unsafeUpdateBind(device *Device) error { +func (device *Device) BindUpdate() error { + device.mutex.Lock() + defer device.mutex.Unlock() + + netc := &device.net + netc.mutex.Lock() + defer netc.mutex.Unlock() // close existing sockets @@ -74,18 +78,13 @@ func unsafeUpdateBind(device *Device) error { return err } - // assumption: netc.update WaitGroup should be exactly 1 - // open new sockets if device.isUp.Get() { - device.log.Debug.Println("UDP bind updating") - // bind to new port var err error - netc := &device.net netc.bind, netc.port, err = CreateBind(netc.port) if err != nil { netc.bind = nil @@ -109,7 +108,7 @@ func unsafeUpdateBind(device *Device) error { peer.mutex.Unlock() } - // decrease waitgroup to 0 + // start receiving routines go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) @@ -120,7 +119,7 @@ func unsafeUpdateBind(device *Device) error { return nil } -func closeBind(device *Device) error { +func (device *Device) BindClose() error { device.mutex.Lock() device.net.mutex.Lock() err := unsafeCloseBind(device) diff --git a/src/device.go b/src/device.go index 5f8e91b..f1c09c6 100644 --- a/src/device.go +++ b/src/device.go @@ -9,7 +9,7 @@ import ( ) type Device struct { - isUp AtomicBool // device is up (TUN interface up)? + isUp AtomicBool // device is (going) up isClosed AtomicBool // device is closed? (acting as guard) log *Logger // collection of loggers for levels idCounter uint // for assigning debug ids to peers @@ -18,6 +18,11 @@ type Device struct { device TUNDevice mtu int32 } + state struct { + mutex deadlock.Mutex + changing AtomicBool + current bool + } pool struct { messageBuffers sync.Pool } @@ -46,37 +51,86 @@ type Device struct { mac CookieChecker } -func (device *Device) Up() { - device.mutex.Lock() - defer device.mutex.Unlock() +func deviceUpdateState(device *Device) { - device.net.mutex.Lock() - defer device.net.mutex.Unlock() + // check if state already being updated (guard) - if device.isUp.Swap(true) { + if device.state.changing.Swap(true) { return } - unsafeUpdateBind(device) + // compare to current state of device - for _, peer := range device.peers { - peer.Start() + device.state.mutex.Lock() + + newIsUp := device.isUp.Get() + + if newIsUp == device.state.current { + device.state.mutex.Unlock() + device.state.changing.Set(false) + return } + + device.state.mutex.Unlock() + + // change state of device + + switch newIsUp { + case true: + + // start listener + + if err := device.BindUpdate(); err != nil { + device.isUp.Set(false) + break + } + + // start every peer + + for _, peer := range device.peers { + peer.Start() + } + + case false: + + // stop listening + + device.BindClose() + + // stop every peer + + for _, peer := range device.peers { + peer.Stop() + } + } + + // update state variables + // and check for state change in the mean time + + device.state.current = newIsUp + device.state.changing.Set(false) + deviceUpdateState(device) +} + +func (device *Device) Up() { + + // closed device cannot be brought up + + if device.isClosed.Get() { + return + } + + device.state.mutex.Lock() + device.isUp.Set(true) + device.state.mutex.Unlock() + deviceUpdateState(device) } func (device *Device) Down() { - device.mutex.Lock() - defer device.mutex.Unlock() - - if !device.isUp.Swap(false) { - return - } - - closeBind(device) - - for _, peer := range device.peers { - peer.Stop() - } + device.state.mutex.Lock() + device.isUp.Set(false) + device.state.mutex.Unlock() + deviceUpdateState(device) } /* Warning: @@ -87,7 +141,6 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) { if !ok { return } - peer.Stop() device.routingTable.RemovePeer(peer) delete(device.peers, key) } @@ -231,20 +284,30 @@ func (device *Device) RemovePeer(key NoisePublicKey) { func (device *Device) RemoveAllPeers() { device.mutex.Lock() defer device.mutex.Unlock() - for key := range device.peers { - removePeerUnsafe(device, key) + + for key, peer := range device.peers { + peer.Stop() + peer, ok := device.peers[key] + if !ok { + return + } + device.routingTable.RemovePeer(peer) + delete(device.peers, key) } } func (device *Device) Close() { + device.log.Info.Println("Device closing") if device.isClosed.Swap(true) { return } - device.log.Info.Println("Closing device") - device.RemoveAllPeers() device.signal.stop.Broadcast() device.tun.device.Close() - closeBind(device) + device.BindClose() + device.isUp.Set(false) + println("remove") + device.RemoveAllPeers() + device.log.Info.Println("Interface closed") } func (device *Device) Wait() chan struct{} { diff --git a/src/peer.go b/src/peer.go index 3d82989..5ad4511 100644 --- a/src/peer.go +++ b/src/peer.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "errors" "fmt" + "github.com/sasha-s/go-deadlock" "sync" "time" ) @@ -14,7 +15,8 @@ const ( type Peer struct { id uint - mutex sync.RWMutex + isRunning AtomicBool + mutex deadlock.RWMutex persistentKeepaliveInterval uint64 keyPairs KeyPairs handshake Handshake @@ -26,7 +28,7 @@ type Peer struct { lastHandshakeNano int64 // nano seconds since epoch } time struct { - mutex sync.RWMutex + mutex deadlock.RWMutex lastSend time.Time // last send message lastHandshake time.Time // last completed handshake nextKeepalive time.Time @@ -58,7 +60,7 @@ type Peer struct { inbound chan *QueueInboundElement // sequential ordering of work } routines struct { - mutex sync.Mutex // held when stopping / starting routines + mutex deadlock.Mutex // held when stopping / starting routines starting sync.WaitGroup // routines pending start stopping sync.WaitGroup // routines pending stop stop Signal // size 0, stop all goroutines in peer @@ -67,6 +69,14 @@ type Peer struct { } func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { + + if device.isClosed.Get() { + return nil, errors.New("Device closed") + } + + device.mutex.Lock() + defer device.mutex.Unlock() + // create peer peer := new(Peer) @@ -75,17 +85,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.mac.Init(pk) peer.device = device + peer.isRunning.Set(false) + peer.timer.zeroAllKeys = NewTimer() peer.timer.keepalivePersistent = NewTimer() peer.timer.keepalivePassive = NewTimer() - peer.timer.zeroAllKeys = NewTimer() peer.timer.handshakeNew = NewTimer() peer.timer.handshakeDeadline = NewTimer() peer.timer.handshakeTimeout = NewTimer() // assign id for debugging - device.mutex.Lock() peer.id = device.idCounter device.idCounter += 1 @@ -102,7 +112,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return nil, errors.New("Adding existing peer") } device.peers[pk] = peer - device.mutex.Unlock() // precompute DH @@ -117,23 +126,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.endpoint = nil - // prepare queuing - - peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) - // prepare signaling & routines - peer.signal.newKeyPair = NewSignal() - peer.signal.handshakeBegin = NewSignal() - peer.signal.handshakeCompleted = NewSignal() - peer.signal.flushNonceQueue = NewSignal() - peer.routines.mutex.Lock() peer.routines.stop = NewSignal() peer.routines.mutex.Unlock() + // start peer + + peer.device.state.mutex.Lock() + if peer.device.isUp.Get() { + peer.Start() + } + peer.device.state.mutex.Unlock() + return peer, nil } @@ -148,6 +154,10 @@ func (peer *Peer) SendBuffer(buffer []byte) error { return errors.New("No known endpoint for peer") } + if peer.device.net.bind == nil { + return errors.New("No bind") + } + return peer.device.net.bind.Send(buffer, peer.endpoint) } @@ -174,12 +184,26 @@ func (peer *Peer) Start() { peer.routines.mutex.Lock() defer peer.routines.mutex.Lock() + peer.device.log.Debug.Println("Starting:", peer.String()) + // stop & wait for ungoing routines (if any) + peer.isRunning.Set(false) peer.routines.stop.Broadcast() peer.routines.starting.Wait() peer.routines.stopping.Wait() + // prepare queues + + peer.signal.newKeyPair = NewSignal() + peer.signal.handshakeBegin = NewSignal() + peer.signal.handshakeCompleted = NewSignal() + peer.signal.flushNonceQueue = NewSignal() + + peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + // reset signal and start (new) routines peer.routines.stop = NewSignal() @@ -192,6 +216,7 @@ func (peer *Peer) Start() { go peer.RoutineSequentialReceiver() peer.routines.starting.Wait() + peer.isRunning.Set(true) } func (peer *Peer) Stop() { @@ -199,13 +224,22 @@ func (peer *Peer) Stop() { peer.routines.mutex.Lock() defer peer.routines.mutex.Lock() + peer.device.log.Debug.Println("Stopping:", peer.String()) + // stop & wait for ungoing routines (if any) peer.routines.stop.Broadcast() peer.routines.starting.Wait() peer.routines.stopping.Wait() + // close queues + + close(peer.queue.nonce) + close(peer.queue.outbound) + close(peer.queue.inbound) + // reset signal (to handle repeated stopping) peer.routines.stop = NewSignal() + peer.isRunning.Set(false) } diff --git a/src/receive.go b/src/receive.go index 0b87a3c..5ad7c4b 100644 --- a/src/receive.go +++ b/src/receive.go @@ -123,7 +123,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { case ipv6.Version: size, endpoint, err = bind.ReceiveIPv6(buffer[:]) default: - return + panic("invalid IP version") } if err != nil { @@ -184,9 +184,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { // add to decryption queues - device.addToDecryptionQueue(device.queue.decryption, elem) - device.addToInboundQueue(peer.queue.inbound, elem) - buffer = device.GetMessageBuffer() + if peer.isRunning.Get() { + device.addToDecryptionQueue(device.queue.decryption, elem) + device.addToInboundQueue(peer.queue.inbound, elem) + buffer = device.GetMessageBuffer() + } continue @@ -308,13 +310,20 @@ func (device *Device) RoutineHandshake() { return } - // lookup peer and consume response + // lookup peer from index entry := device.indices.Lookup(reply.Receiver) + if entry.peer == nil { continue } - entry.peer.mac.ConsumeReply(&reply) + + // consume reply + + if peer := entry.peer; peer.isRunning.Get() { + peer.mac.ConsumeReply(&reply) + } + continue case MessageInitiationType, MessageResponseType: diff --git a/src/send.go b/src/send.go index fa13c91..e0a546d 100644 --- a/src/send.go +++ b/src/send.go @@ -170,9 +170,11 @@ func (device *Device) RoutineReadFromTUN() { // insert into nonce/pre-handshake queue - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - addToOutboundQueue(peer.queue.nonce, elem) - elem = device.NewOutboundElement() + if peer.isRunning.Get() { + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) + addToOutboundQueue(peer.queue.nonce, elem) + elem = device.NewOutboundElement() + } } } diff --git a/src/uapi.go b/src/uapi.go index f66528c..68ebe43 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -144,16 +144,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // update port and rebind - device.mutex.Lock() device.net.mutex.Lock() - device.net.port = uint16(port) - err = unsafeUpdateBind(device) - device.net.mutex.Unlock() - device.mutex.Unlock() - if err != nil { + if err := device.BindUpdate(); err != nil { logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorPortInUse} } @@ -179,6 +174,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { device.net.fwmark = uint32(fwmark) device.net.mutex.Unlock() + if err := device.BindUpdate(); err != nil { + logError.Println("Failed to update fwmark:", err) + return &IPCError{Code: ipcErrorPortInUse} + } + case "public_key": // switch to peer configuration deviceConfig = false