diff --git a/src/conn.go b/src/conn.go index ddb7ed1..6d292d3 100644 --- a/src/conn.go +++ b/src/conn.go @@ -82,7 +82,7 @@ func updateBind(device *Device) error { // open new sockets - if device.isUp.Get() { + if device.tun.isUp.Get() { device.log.Debug.Println("UDP bind updating") diff --git a/src/device.go b/src/device.go index f4a087c..a3461ad 100644 --- a/src/device.go +++ b/src/device.go @@ -8,13 +8,13 @@ import ( ) type Device struct { - isUp AtomicBool // device is up (TUN interface up)? - isClosed AtomicBool // device is closed? (acting as guard) + closed AtomicBool // device is closed? (acting as guard) log *Logger // collection of loggers for levels idCounter uint // for assigning debug ids to peers fwMark uint32 tun struct { device TUNDevice + isUp AtomicBool mtu int32 } pool struct { @@ -45,28 +45,6 @@ type Device struct { mac CookieChecker } -func (device *Device) Up() { - device.mutex.Lock() - defer device.mutex.Unlock() - - device.isUp.Set(true) - updateBind(device) - for _, peer := range device.peers { - peer.Start() - } -} - -func (device *Device) Down() { - device.mutex.Lock() - defer device.mutex.Unlock() - - device.isUp.Set(false) - closeBind(device) - for _, peer := range device.peers { - peer.Stop() - } -} - /* Warning: * The caller must hold the device mutex (write lock) */ @@ -76,9 +54,9 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) { return } peer.mutex.Lock() - peer.Stop() device.routingTable.RemovePeer(peer) delete(device.peers, key) + peer.Close() } func (device *Device) IsUnderLoad() bool { @@ -120,7 +98,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { device.publicKey = publicKey device.mac.Init(publicKey) - // do DH pre-computations + // do DH precomputations rmKey := device.privateKey.IsZero() @@ -154,12 +132,10 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { device.mutex.Lock() defer device.mutex.Unlock() - device.isUp.Set(false) - device.isClosed.Set(false) - device.log = logger device.peers = make(map[NoisePublicKey]*Peer) device.tun.device = tun + device.tun.isUp.Set(false) device.indices.Init() device.ratelimiter.Init() @@ -220,13 +196,17 @@ func (device *Device) RemovePeer(key NoisePublicKey) { func (device *Device) RemoveAllPeers() { device.mutex.Lock() defer device.mutex.Unlock() - for key := range device.peers { - removePeerUnsafe(device, key) + + for key, peer := range device.peers { + peer.mutex.Lock() + delete(device.peers, key) + peer.Close() + peer.mutex.Unlock() } } func (device *Device) Close() { - if device.isClosed.Swap(true) { + if device.closed.Swap(true) { return } device.log.Info.Println("Closing device") diff --git a/src/peer.go b/src/peer.go index 7c6ad47..f582556 100644 --- a/src/peer.go +++ b/src/peer.go @@ -34,15 +34,15 @@ type Peer struct { flushNonceQueue Signal // size 1, empty queued packets messageSend Signal // size 1, message was send to peer messageReceived Signal // size 1, authenticated message recv - stop Signal // size 0, stop all goroutines in peer + stop Signal // size 0, stop all goroutines } timer struct { // state related to WireGuard timers keepalivePersistent Timer // set for persistent keepalives keepalivePassive Timer // set upon recieving messages + newHandshake Timer // begin a new handshake (stale) zeroAllKeys Timer // zero all key material - handshakeNew Timer // begin a new handshake (stale) handshakeDeadline Timer // complete handshake timeout handshakeTimeout Timer // current handshake message timeout @@ -69,8 +69,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.timer.keepalivePersistent = NewTimer() peer.timer.keepalivePassive = NewTimer() + peer.timer.newHandshake = NewTimer() peer.timer.zeroAllKeys = NewTimer() - peer.timer.handshakeNew = NewTimer() peer.timer.handshakeDeadline = NewTimer() peer.timer.handshakeTimeout = NewTimer() @@ -116,29 +116,32 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // prepare signaling & routines + peer.signal.stop = NewSignal() peer.signal.newKeyPair = NewSignal() peer.signal.handshakeBegin = NewSignal() peer.signal.handshakeCompleted = NewSignal() peer.signal.flushNonceQueue = NewSignal() + go peer.RoutineNonce() + go peer.RoutineTimerHandler() + go peer.RoutineSequentialSender() + go peer.RoutineSequentialReceiver() + return peer, nil } func (peer *Peer) SendBuffer(buffer []byte) error { peer.device.net.mutex.RLock() defer peer.device.net.mutex.RUnlock() - peer.mutex.RLock() defer peer.mutex.RUnlock() - if peer.endpoint == nil { return errors.New("No known endpoint for peer") } - return peer.device.net.bind.Send(buffer, peer.endpoint) } -/* Returns a short string identifier for logging +/* Returns a short string identification for logging */ func (peer *Peer) String() string { if peer.endpoint == nil { @@ -156,32 +159,6 @@ func (peer *Peer) String() string { ) } -/* Starts all routines for a given peer - * - * Requires that the caller holds the exclusive peer lock! - */ -func unsafePeerStart(peer *Peer) { - peer.signal.stop.Broadcast() - peer.signal.stop = NewSignal() - - var wait sync.WaitGroup - - wait.Add(1) - - go peer.RoutineNonce() - go peer.RoutineTimerHandler(&wait) - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() - - wait.Wait() -} - -func (peer *Peer) Start() { - peer.mutex.Lock() - unsafePeerStart(peer) - peer.mutex.Unlock() -} - -func (peer *Peer) Stop() { +func (peer *Peer) Close() { peer.signal.stop.Broadcast() } diff --git a/src/timer.go b/src/timer.go index f00ca49..3def253 100644 --- a/src/timer.go +++ b/src/timer.go @@ -43,6 +43,12 @@ func (t *Timer) Reset(dur time.Duration) { t.Start(dur) } +func (t *Timer) Push(dur time.Duration) { + if t.pending.Get() { + t.Reset(dur) + } +} + func (t *Timer) Wait() <-chan time.Time { return t.timer.C } diff --git a/src/timers.go b/src/timers.go index f2fed30..ee47393 100644 --- a/src/timers.go +++ b/src/timers.go @@ -4,17 +4,10 @@ import ( "bytes" "encoding/binary" "math/rand" - "sync" "sync/atomic" "time" ) -/* NOTE: - * Notion of validity - * - * - */ - /* Called when a new authenticated message has been send * */ @@ -51,25 +44,25 @@ func (peer *Peer) KeepKeyFreshReceiving() { send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving if send { // do a last minute attempt at initiating a new handshake - peer.timer.sendLastMinuteHandshake = true peer.signal.handshakeBegin.Send() + peer.timer.sendLastMinuteHandshake = true } } /* Queues a keep-alive if no packets are queued for peer */ func (peer *Peer) SendKeepAlive() bool { - if len(peer.queue.nonce) != 0 { - return false - } elem := peer.device.NewOutboundElement() elem.packet = nil - select { - case peer.queue.nonce <- elem: - return true - default: - return false + if len(peer.queue.nonce) == 0 { + select { + case peer.queue.nonce <- elem: + return true + default: + return false + } } + return true } /* Event: @@ -77,7 +70,9 @@ func (peer *Peer) SendKeepAlive() bool { */ func (peer *Peer) TimerDataSent() { peer.timer.keepalivePassive.Stop() - peer.timer.handshakeNew.Start(NewHandshakeTime) + if peer.timer.newHandshake.Pending() { + peer.timer.newHandshake.Reset(NewHandshakeTime) + } } /* Event: @@ -96,7 +91,7 @@ func (peer *Peer) TimerDataReceived() { * Any (authenticated) packet received */ func (peer *Peer) TimerAnyAuthenticatedPacketReceived() { - peer.timer.handshakeNew.Stop() + peer.timer.newHandshake.Stop() } /* Event: @@ -120,6 +115,10 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { * - First transport message under the "next" key */ func (peer *Peer) TimerHandshakeComplete() { + atomic.StoreInt64( + &peer.stats.lastHandshakeNano, + time.Now().UnixNano(), + ) peer.signal.handshakeCompleted.Send() peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) } @@ -140,75 +139,13 @@ func (peer *Peer) TimerEphemeralKeyCreated() { peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) } -/* Sends a new handshake initiation message to the peer (endpoint) - */ -func (peer *Peer) sendNewHandshake() error { - - // temporarily disable the handshake complete signal - - peer.signal.handshakeCompleted.Disable() - - // create initiation message - - msg, err := peer.device.CreateMessageInitiation(peer) - if err != nil { - return err - } - - // marshal handshake message - - var buff [MessageInitiationSize]byte - writer := bytes.NewBuffer(buff[:0]) - binary.Write(writer, binary.LittleEndian, msg) - packet := writer.Bytes() - peer.mac.AddMacs(packet) - - // send to endpoint - - peer.TimerAnyAuthenticatedPacketTraversal() - - err = peer.SendBuffer(packet) - if err == nil { - peer.signal.handshakeCompleted.Enable() - } - - // set timeout - - jitter := time.Millisecond * time.Duration(rand.Uint32()%334) - - peer.timer.keepalivePassive.Stop() - peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter) - - return err -} - -func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { +func (peer *Peer) RoutineTimerHandler() { device := peer.device logInfo := device.log.Info logDebug := device.log.Debug logDebug.Println("Routine, timer handler, started for peer", peer.String()) - // reset all timers - - peer.timer.keepalivePassive.Stop() - peer.timer.handshakeDeadline.Stop() - peer.timer.handshakeTimeout.Stop() - peer.timer.handshakeNew.Stop() - peer.timer.zeroAllKeys.Stop() - - interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) - if interval > 0 { - duration := time.Duration(interval) * time.Second - peer.timer.keepalivePersistent.Reset(duration) - } - - // signal that timers are reset - - ready.Done() - - // handle timer events - for { select { @@ -221,7 +158,6 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) if interval > 0 { logDebug.Println("Sending keep-alive to", peer.String()) - peer.timer.keepalivePassive.Stop() peer.SendKeepAlive() } @@ -232,8 +168,8 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { peer.SendKeepAlive() if peer.timer.needAnotherKeepalive { - peer.timer.needAnotherKeepalive = false peer.timer.keepalivePassive.Reset(KeepaliveTimeout) + peer.timer.needAnotherKeepalive = false } // clear key material timer @@ -277,7 +213,7 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { // handshake timers - case <-peer.timer.handshakeNew.Wait(): + case <-peer.timer.newHandshake.Wait(): logInfo.Println("Retrying handshake with", peer.String()) peer.signal.handshakeBegin.Send() @@ -332,16 +268,48 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { logInfo.Println( "Handshake completed for:", peer.String()) - atomic.StoreInt64( - &peer.stats.lastHandshakeNano, - time.Now().UnixNano(), - ) - peer.timer.handshakeTimeout.Stop() peer.timer.handshakeDeadline.Stop() peer.signal.handshakeBegin.Enable() - - peer.timer.sendLastMinuteHandshake = false } } } + +/* Sends a new handshake initiation message to the peer (endpoint) + */ +func (peer *Peer) sendNewHandshake() error { + + // temporarily disable the handshake complete signal + + peer.signal.handshakeCompleted.Disable() + + // create initiation message + + msg, err := peer.device.CreateMessageInitiation(peer) + if err != nil { + return err + } + + // marshal handshake message + + var buff [MessageInitiationSize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, msg) + packet := writer.Bytes() + peer.mac.AddMacs(packet) + + // send to endpoint + + err = peer.SendBuffer(packet) + if err == nil { + peer.TimerAnyAuthenticatedPacketTraversal() + peer.signal.handshakeCompleted.Enable() + } + + // set timeout + + jitter := time.Millisecond * time.Duration(rand.Uint32()%334) + peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter) + + return err +} diff --git a/src/tun.go b/src/tun.go index 024f0f0..54253b4 100644 --- a/src/tun.go +++ b/src/tun.go @@ -46,13 +46,21 @@ func (device *Device) RoutineTUNEventReader() { } if event&TUNEventUp != 0 { - logInfo.Println("Interface set up") - device.Up() + if !device.tun.isUp.Get() { + // begin listening for incomming datagrams + logInfo.Println("Interface set up") + device.tun.isUp.Set(true) + updateBind(device) + } } if event&TUNEventDown != 0 { - logInfo.Println("Interface set down") - device.Up() + if device.tun.isUp.Get() { + // stop listening for incomming datagrams + logInfo.Println("Interface set down") + device.tun.isUp.Set(false) + closeBind(device) + } } } } diff --git a/src/uapi.go b/src/uapi.go index a67bff1..155f483 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -296,7 +296,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to get tun device status:", err) return &IPCError{Code: ipcErrorIO} } - if device.isUp.Get() && !dummy { + if device.tun.isUp.Get() && !dummy { peer.SendKeepAlive() } }