From 029410b118f079d77fa448cf56a97b949faee126 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 2 Feb 2018 16:40:14 +0100 Subject: [PATCH] Rework of entire locking system Locking on the Device instance is now much more fined-grained, seperating out the fields into "resources" st. most common interactions only require a small number. --- src/conn.go | 17 +-- src/device.go | 331 ++++++++++++++++++++++++++---------------- src/noise_helpers.go | 7 +- src/noise_protocol.go | 23 ++- src/peer.go | 63 ++++---- src/receive.go | 14 +- src/send.go | 8 +- src/timers.go | 4 +- src/tun_linux.go | 4 +- src/uapi.go | 154 +++++++++++++------- 10 files changed, 386 insertions(+), 239 deletions(-) diff --git a/src/conn.go b/src/conn.go index c2f5dee..fb30ec2 100644 --- a/src/conn.go +++ b/src/conn.go @@ -65,12 +65,12 @@ func unsafeCloseBind(device *Device) error { } func (device *Device) BindUpdate() error { - device.mutex.Lock() - defer device.mutex.Unlock() - netc := &device.net - netc.mutex.Lock() - defer netc.mutex.Unlock() + device.net.mutex.Lock() + defer device.net.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() // close existing sockets @@ -85,6 +85,7 @@ func (device *Device) BindUpdate() error { // bind to new port var err error + netc := &device.net netc.bind, netc.port, err = CreateBind(netc.port) if err != nil { netc.bind = nil @@ -100,12 +101,12 @@ func (device *Device) BindUpdate() error { // clear cached source addresses - for _, peer := range device.peers { + for _, peer := range device.peers.keyMap { peer.mutex.Lock() + defer peer.mutex.Unlock() if peer.endpoint != nil { peer.endpoint.ClearSrc() } - peer.mutex.Unlock() } // start receiving routines @@ -120,10 +121,8 @@ func (device *Device) BindUpdate() error { } func (device *Device) BindClose() error { - device.mutex.Lock() device.net.mutex.Lock() err := unsafeCloseBind(device) device.net.mutex.Unlock() - device.mutex.Unlock() return err } diff --git a/src/device.go b/src/device.go index f1c09c6..0317b60 100644 --- a/src/device.go +++ b/src/device.go @@ -9,46 +9,110 @@ import ( ) type Device struct { - isUp AtomicBool // device is (going) up - isClosed AtomicBool // device is closed? (acting as guard) - log *Logger // collection of loggers for levels - idCounter uint // for assigning debug ids to peers - fwMark uint32 - tun struct { - device TUNDevice - mtu int32 - } + isUp AtomicBool // device is (going) up + isClosed AtomicBool // device is closed? (acting as guard) + log *Logger + + // synchronized resources (locks acquired in order) + state struct { mutex deadlock.Mutex changing AtomicBool current bool } - pool struct { - messageBuffers sync.Pool - } + net struct { mutex deadlock.RWMutex bind Bind // bind interface port uint16 // listening port fwmark uint32 // mark value (0 = disabled) } - mutex deadlock.RWMutex - privateKey NoisePrivateKey - publicKey NoisePublicKey - routingTable RoutingTable - indices IndexTable - queue struct { + + noise struct { + mutex deadlock.RWMutex + privateKey NoisePrivateKey + publicKey NoisePublicKey + } + + routing struct { + mutex deadlock.RWMutex + table RoutingTable + } + + peers struct { + mutex deadlock.RWMutex + keyMap map[NoisePublicKey]*Peer + } + + // unprotected / "self-synchronising resources" + + indices IndexTable + mac CookieChecker + + rate struct { + underLoadUntil atomic.Value + limiter Ratelimiter + } + + pool struct { + messageBuffers sync.Pool + } + + queue struct { encryption chan *QueueOutboundElement decryption chan *QueueInboundElement handshake chan QueueHandshakeElement } + signal struct { stop Signal } - underLoadUntil atomic.Value - ratelimiter Ratelimiter - peers map[NoisePublicKey]*Peer - mac CookieChecker + + tun struct { + device TUNDevice + mtu int32 + } +} + +/* Converts the peer into a "zombie", which remains in the peer map, + * but processes no packets and does not exists in the routing table. + * + * Must hold: + * device.peers.mutex : exclusive lock + * device.routing : exclusive lock + */ +func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { + + // stop routing and processing of packets + + device.routing.table.RemovePeer(peer) + peer.Stop() + + // clean index table + + kp := &peer.keyPairs + kp.mutex.Lock() + + if kp.previous != nil { + device.indices.Delete(kp.previous.localIndex) + } + + if kp.current != nil { + device.indices.Delete(kp.current.localIndex) + } + + if kp.next != nil { + device.indices.Delete(kp.next.localIndex) + } + + kp.previous = nil + kp.current = nil + kp.next = nil + kp.mutex.Unlock() + + // remove from peer map + + delete(device.peers.keyMap, key) } func deviceUpdateState(device *Device) { @@ -59,56 +123,56 @@ func deviceUpdateState(device *Device) { return } - // compare to current state of device + func() { - device.state.mutex.Lock() + // compare to current state of device - newIsUp := device.isUp.Get() + device.state.mutex.Lock() + defer device.state.mutex.Unlock() - if newIsUp == device.state.current { - device.state.mutex.Unlock() + newIsUp := device.isUp.Get() + + if newIsUp == device.state.current { + device.state.changing.Set(false) + return + } + + // change state of device + + switch newIsUp { + case true: + if err := device.BindUpdate(); err != nil { + device.isUp.Set(false) + break + } + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for _, peer := range device.peers.keyMap { + peer.Start() + } + + case false: + device.BindClose() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for _, peer := range device.peers.keyMap { + println("stopping peer") + peer.Stop() + } + } + + // update state variables + + device.state.current = newIsUp device.state.changing.Set(false) - return - } + }() - device.state.mutex.Unlock() + // check for state change in the mean time - // change state of device - - switch newIsUp { - case true: - - // start listener - - if err := device.BindUpdate(); err != nil { - device.isUp.Set(false) - break - } - - // start every peer - - for _, peer := range device.peers { - peer.Start() - } - - case false: - - // stop listening - - device.BindClose() - - // stop every peer - - for _, peer := range device.peers { - peer.Stop() - } - } - - // update state variables - // and check for state change in the mean time - - device.state.current = newIsUp - device.state.changing.Set(false) deviceUpdateState(device) } @@ -133,18 +197,6 @@ func (device *Device) Down() { deviceUpdateState(device) } -/* Warning: - * The caller must hold the device mutex (write lock) - */ -func removePeerUnsafe(device *Device, key NoisePublicKey) { - peer, ok := device.peers[key] - if !ok { - return - } - device.routingTable.RemovePeer(peer) - delete(device.peers, key) -} - func (device *Device) IsUnderLoad() bool { // check if currently under load @@ -152,54 +204,66 @@ func (device *Device) IsUnderLoad() bool { now := time.Now() underLoad := len(device.queue.handshake) >= UnderLoadQueueSize if underLoad { - device.underLoadUntil.Store(now.Add(time.Second)) + device.rate.underLoadUntil.Store(now.Add(time.Second)) return true } // check if recently under load - until := device.underLoadUntil.Load().(time.Time) + until := device.rate.underLoadUntil.Load().(time.Time) return until.After(now) } func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { - device.mutex.Lock() - defer device.mutex.Unlock() + + // lock required resources + + device.noise.mutex.Lock() + defer device.noise.mutex.Unlock() + + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for _, peer := range device.peers.keyMap { + peer.handshake.mutex.RLock() + defer peer.handshake.mutex.RUnlock() + } // remove peers with matching public keys publicKey := sk.publicKey() - for key, peer := range device.peers { - h := &peer.handshake - h.mutex.RLock() - if h.remoteStatic.Equals(publicKey) { - removePeerUnsafe(device, key) + for key, peer := range device.peers.keyMap { + if peer.handshake.remoteStatic.Equals(publicKey) { + unsafeRemovePeer(device, peer, key) } - h.mutex.RUnlock() } // update key material - device.privateKey = sk - device.publicKey = publicKey + device.noise.privateKey = sk + device.noise.publicKey = publicKey device.mac.Init(publicKey) - // do DH pre-computations + // do static-static DH pre-computations - rmKey := device.privateKey.IsZero() + rmKey := device.noise.privateKey.IsZero() + + for key, peer := range device.peers.keyMap { + + hs := &peer.handshake - for key, peer := range device.peers { - h := &peer.handshake - h.mutex.Lock() if rmKey { - h.precomputedStaticStatic = [NoisePublicKeySize]byte{} + hs.precomputedStaticStatic = [NoisePublicKeySize]byte{} } else { - h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) - if isZero(h.precomputedStaticStatic[:]) { - removePeerUnsafe(device, key) - } + hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic) + } + + if isZero(hs.precomputedStaticStatic[:]) { + unsafeRemovePeer(device, peer, key) } - h.mutex.Unlock() } return nil @@ -215,21 +279,23 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { func NewDevice(tun TUNDevice, logger *Logger) *Device { device := new(Device) - device.mutex.Lock() - defer device.mutex.Unlock() device.isUp.Set(false) device.isClosed.Set(false) device.log = logger - device.peers = make(map[NoisePublicKey]*Peer) device.tun.device = tun + device.peers.keyMap = make(map[NoisePublicKey]*Peer) + + // initialize anti-DoS / anti-scanning features + + device.rate.limiter.Init() + device.rate.underLoadUntil.Store(time.Time{}) + + // initialize noise & crypt-key routine device.indices.Init() - device.ratelimiter.Init() - - device.routingTable.Reset() - device.underLoadUntil.Store(time.Time{}) + device.routing.table.Reset() // setup buffer pool @@ -264,36 +330,50 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() - go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) + go device.rate.limiter.RoutineGarbageCollector(device.signal.stop) return device } func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { - device.mutex.RLock() - defer device.mutex.RUnlock() - return device.peers[pk] + device.peers.mutex.RLock() + defer device.peers.mutex.RUnlock() + + return device.peers.keyMap[pk] } func (device *Device) RemovePeer(key NoisePublicKey) { - device.mutex.Lock() - defer device.mutex.Unlock() - removePeerUnsafe(device, key) + device.noise.mutex.Lock() + defer device.noise.mutex.Unlock() + + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + // stop peer and remove from routing + + peer, ok := device.peers.keyMap[key] + if ok { + unsafeRemovePeer(device, peer, key) + } } func (device *Device) RemoveAllPeers() { - device.mutex.Lock() - defer device.mutex.Unlock() - for key, peer := range device.peers { - peer.Stop() - peer, ok := device.peers[key] - if !ok { - return - } - device.routingTable.RemovePeer(peer) - delete(device.peers, key) + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for key, peer := range device.peers.keyMap { + println("rm", peer.String()) + unsafeRemovePeer(device, peer, key) } + + device.peers.keyMap = make(map[NoisePublicKey]*Peer) } func (device *Device) Close() { @@ -305,7 +385,6 @@ func (device *Device) Close() { device.tun.device.Close() device.BindClose() device.isUp.Set(false) - println("remove") device.RemoveAllPeers() device.log.Info.Println("Interface closed") } diff --git a/src/noise_helpers.go b/src/noise_helpers.go index 24302c0..1e2de5f 100644 --- a/src/noise_helpers.go +++ b/src/noise_helpers.go @@ -3,6 +3,7 @@ package main import ( "crypto/hmac" "crypto/rand" + "crypto/subtle" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/curve25519" "hash" @@ -58,11 +59,11 @@ func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { } func isZero(val []byte) bool { - var acc byte + acc := 1 for _, b := range val { - acc |= b + acc &= subtle.ConstantTimeByteEq(b, 0) } - return acc == 0 + return acc == 1 } func setZero(arr []byte) { diff --git a/src/noise_protocol.go b/src/noise_protocol.go index 2f9e1d5..d620a0d 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -137,6 +137,10 @@ func init() { } func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { + + device.noise.mutex.Lock() + defer device.noise.mutex.Unlock() + handshake := &peer.handshake handshake.mutex.Lock() defer handshake.mutex.Unlock() @@ -187,7 +191,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e ss[:], ) aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:]) }() handshake.mixHash(msg.Static[:]) @@ -212,16 +216,19 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e } func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { - if msg.Type != MessageInitiationType { - return nil - } - var ( hash [blake2s.Size]byte chainKey [blake2s.Size]byte ) - mixHash(&hash, &InitialHash, device.publicKey[:]) + if msg.Type != MessageInitiationType { + return nil + } + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + + mixHash(&hash, &InitialHash, device.noise.publicKey[:]) mixHash(&hash, &hash, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) @@ -231,7 +238,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { var peerPK NoisePublicKey func() { var key [chacha20poly1305.KeySize]byte - ss := device.privateKey.sharedSecret(msg.Ephemeral) + ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) KDF2(&chainKey, &key, chainKey[:], ss[:]) aead, _ := chacha20poly1305.New(key[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) @@ -407,7 +414,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { }() func() { - ss := device.privateKey.sharedSecret(msg.Ephemeral) + ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) mixKey(&chainKey, &chainKey, ss[:]) setZero(ss[:]) }() diff --git a/src/peer.go b/src/peer.go index 5ad4511..3b8f7cc 100644 --- a/src/peer.go +++ b/src/peer.go @@ -14,7 +14,6 @@ const ( ) type Peer struct { - id uint isRunning AtomicBool mutex deadlock.RWMutex persistentKeepaliveInterval uint64 @@ -22,17 +21,20 @@ type Peer struct { handshake Handshake device *Device endpoint Endpoint - stats struct { + + stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer lastHandshakeNano int64 // nano seconds since epoch } + time struct { mutex deadlock.RWMutex lastSend time.Time // last send message lastHandshake time.Time // last completed handshake nextKeepalive time.Time } + signal struct { newKeyPair Signal // size 1, new key pair was generated handshakeCompleted Signal // size 1, handshake completed @@ -41,7 +43,9 @@ type Peer struct { messageSend Signal // size 1, message was send to peer messageReceived Signal // size 1, authenticated message recv } + timer struct { + // state related to WireGuard timers keepalivePersistent Timer // set for persistent keepalives @@ -54,17 +58,20 @@ type Peer struct { sendLastMinuteHandshake bool needAnotherKeepalive bool } + queue struct { nonce chan *QueueOutboundElement // nonce / pre-handshake queue outbound chan *QueueOutboundElement // sequential ordering of work inbound chan *QueueInboundElement // sequential ordering of work } + routines struct { mutex deadlock.Mutex // held when stopping / starting routines starting sync.WaitGroup // routines pending start stopping sync.WaitGroup // routines pending stop stop Signal // size 0, stop all goroutines in peer } + mac CookieGenerator } @@ -74,8 +81,22 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return nil, errors.New("Device closed") } - device.mutex.Lock() - defer device.mutex.Unlock() + // lock resources + + device.state.mutex.Lock() + defer device.state.mutex.Unlock() + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + // check if over limit + + if len(device.peers.keyMap) >= MaxPeers { + return nil, errors.New("Too many peers") + } // create peer @@ -94,32 +115,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.timer.handshakeDeadline = NewTimer() peer.timer.handshakeTimeout = NewTimer() - // assign id for debugging - - peer.id = device.idCounter - device.idCounter += 1 - - // check if over limit - - if len(device.peers) >= MaxPeers { - return nil, errors.New("Too many peers") - } - // map public key - _, ok := device.peers[pk] + _, ok := device.peers.keyMap[pk] if ok { return nil, errors.New("Adding existing peer") } - device.peers[pk] = peer + device.peers.keyMap[pk] = peer // precompute DH handshake := &peer.handshake handshake.mutex.Lock() handshake.remoteStatic = pk - handshake.precomputedStaticStatic = - device.privateKey.sharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk) handshake.mutex.Unlock() // reset endpoint @@ -134,11 +143,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // start peer - peer.device.state.mutex.Lock() if peer.device.isUp.Get() { peer.Start() } - peer.device.state.mutex.Unlock() return peer, nil } @@ -166,14 +173,12 @@ func (peer *Peer) SendBuffer(buffer []byte) error { func (peer *Peer) String() string { if peer.endpoint == nil { return fmt.Sprintf( - "peer(%d unknown %s)", - peer.id, + "peer(unknown %s)", base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } return fmt.Sprintf( - "peer(%d %s %s)", - peer.id, + "peer(%s %s)", peer.endpoint.DstToString(), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) @@ -181,8 +186,12 @@ func (peer *Peer) String() string { func (peer *Peer) Start() { + if peer.device.isClosed.Get() { + return + } + peer.routines.mutex.Lock() - defer peer.routines.mutex.Lock() + defer peer.routines.mutex.Unlock() peer.device.log.Debug.Println("Starting:", peer.String()) @@ -222,7 +231,7 @@ func (peer *Peer) Start() { func (peer *Peer) Stop() { peer.routines.mutex.Lock() - defer peer.routines.mutex.Lock() + defer peer.routines.mutex.Unlock() peer.device.log.Debug.Println("Stopping:", peer.String()) diff --git a/src/receive.go b/src/receive.go index 5ad7c4b..1f44df2 100644 --- a/src/receive.go +++ b/src/receive.go @@ -372,7 +372,7 @@ func (device *Device) RoutineHandshake() { // check ratelimiter - if !device.ratelimiter.Allow(elem.endpoint.DstIP()) { + if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { continue } } @@ -495,19 +495,23 @@ func (device *Device) RoutineHandshake() { func (peer *Peer) RoutineSequentialReceiver() { + defer peer.routines.stopping.Done() + device := peer.device logInfo := device.log.Info logError := device.log.Error logDebug := device.log.Debug - logDebug.Println("Routine, sequential receiver, started for peer", peer.id) + logDebug.Println("Routine, sequential receiver, started for peer", peer.String()) + + peer.routines.starting.Done() for { select { case <-peer.routines.stop.Wait(): - logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id) + logDebug.Println("Routine, sequential receiver, stopped for peer", peer.String()) return case elem := <-peer.queue.inbound: @@ -581,7 +585,7 @@ func (peer *Peer) RoutineSequentialReceiver() { // verify IPv4 source src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.routingTable.LookupIPv4(src) != peer { + if device.routing.table.LookupIPv4(src) != peer { logInfo.Println( "IPv4 packet with disallowed source address from", peer.String(), @@ -609,7 +613,7 @@ func (peer *Peer) RoutineSequentialReceiver() { // verify IPv6 source src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.routingTable.LookupIPv6(src) != peer { + if device.routing.table.LookupIPv6(src) != peer { logInfo.Println( "IPv6 packet with disallowed source address from", peer.String(), diff --git a/src/send.go b/src/send.go index e0a546d..7488d3a 100644 --- a/src/send.go +++ b/src/send.go @@ -151,14 +151,14 @@ func (device *Device) RoutineReadFromTUN() { continue } dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.routingTable.LookupIPv4(dst) + peer = device.routing.table.LookupIPv4(dst) case ipv6.Version: if len(elem.packet) < ipv6.HeaderLen { continue } dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.routingTable.LookupIPv6(dst) + peer = device.routing.table.LookupIPv6(dst) default: logDebug.Println("Received packet with unknown IP version") @@ -187,10 +187,14 @@ func (device *Device) RoutineReadFromTUN() { func (peer *Peer) RoutineNonce() { var keyPair *KeyPair + defer peer.routines.stopping.Done() + device := peer.device logDebug := device.log.Debug logDebug.Println("Routine, nonce worker, started for peer", peer.String()) + peer.routines.starting.Done() + for { NextPacket: select { diff --git a/src/timers.go b/src/timers.go index f1ed9c5..2ef105e 100644 --- a/src/timers.go +++ b/src/timers.go @@ -303,7 +303,7 @@ func (peer *Peer) RoutineTimerHandler() { err := peer.sendNewHandshake() if err != nil { logInfo.Println( - "Failed to send handshake to peer:", peer.String()) + "Failed to send handshake to peer:", peer.String(), "(", err, ")") } case <-peer.timer.handshakeDeadline.Wait(): @@ -326,7 +326,7 @@ func (peer *Peer) RoutineTimerHandler() { err := peer.sendNewHandshake() if err != nil { logInfo.Println( - "Failed to send handshake to peer:", peer.String()) + "Failed to send handshake to peer:", peer.String(), "(", err, ")") } peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) diff --git a/src/tun_linux.go b/src/tun_linux.go index daa2462..9756169 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -313,7 +313,7 @@ func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) { } go device.RoutineNetlinkListener() - go device.RoutineHackListener() // cross namespace + // go device.RoutineHackListener() // cross namespace // set default MTU @@ -369,7 +369,7 @@ func CreateTUN(name string) (TUNDevice, error) { } go device.RoutineNetlinkListener() - go device.RoutineHackListener() // cross namespace + // go device.RoutineHackListener() // cross namespace // set default MTU diff --git a/src/uapi.go b/src/uapi.go index 68ebe43..caaa498 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -25,32 +25,51 @@ func (s *IPCError) ErrorCode() int64 { func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { - // create lines + device.log.Debug.Println("UAPI: Processing get operation") - device.mutex.RLock() - device.net.mutex.RLock() + // create lines lines := make([]string, 0, 100) send := func(line string) { lines = append(lines, line) } - if !device.privateKey.IsZero() { - send("private_key=" + device.privateKey.ToHex()) - } + func() { - if device.net.port != 0 { - send(fmt.Sprintf("listen_port=%d", device.net.port)) - } + // lock required resources - if device.net.fwmark != 0 { - send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) - } + device.net.mutex.RLock() + defer device.net.mutex.RUnlock() - for _, peer := range device.peers { - func() { + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + + device.routing.mutex.RLock() + defer device.routing.mutex.RUnlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + // serialize device related values + + if !device.noise.privateKey.IsZero() { + send("private_key=" + device.noise.privateKey.ToHex()) + } + + if device.net.port != 0 { + send(fmt.Sprintf("listen_port=%d", device.net.port)) + } + + if device.net.fwmark != 0 { + send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) + } + + // serialize each peer state + + for _, peer := range device.peers.keyMap { peer.mutex.RLock() defer peer.mutex.RUnlock() + send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex()) if peer.endpoint != nil { @@ -69,16 +88,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { atomic.LoadUint64(&peer.persistentKeepaliveInterval), )) - for _, ip := range device.routingTable.AllowedIPs(peer) { + for _, ip := range device.routing.table.AllowedIPs(peer) { send("allowed_ip=" + ip.String()) } - }() - } - device.net.mutex.RUnlock() - device.mutex.RUnlock() + } + }() - // send lines + // send lines (does not require resource locks) for _, line := range lines { _, err := socket.WriteString(line + "\n") @@ -94,7 +111,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { scanner := bufio.NewScanner(socket) - logInfo := device.log.Info logError := device.log.Error logDebug := device.log.Debug @@ -130,6 +146,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to set private_key:", err) return &IPCError{Code: ipcErrorInvalid} } + logDebug.Println("UAPI: Updating device private key") device.SetPrivateKey(sk) case "listen_port": @@ -144,6 +161,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // update port and rebind + logDebug.Println("UAPI: Updating listen port") + device.net.mutex.Lock() device.net.port = uint16(port) device.net.mutex.Unlock() @@ -170,6 +189,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorInvalid} } + logDebug.Println("UAPI: Updating fwmark") + device.net.mutex.Lock() device.net.fwmark = uint32(fwmark) device.net.mutex.Unlock() @@ -181,6 +202,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "public_key": // switch to peer configuration + logDebug.Println("UAPI: Transition to peer configuration") deviceConfig = false case "replace_peers": @@ -188,6 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to set replace_peers, invalid value:", value) return &IPCError{Code: ipcErrorInvalid} } + logDebug.Println("UAPI: Removing all peers") device.RemoveAllPeers() default: @@ -203,43 +226,41 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { switch key { case "public_key": - var pubKey NoisePublicKey - err := pubKey.FromHex(value) + var publicKey NoisePublicKey + err := publicKey.FromHex(value) if err != nil { logError.Println("Failed to get peer by public_key:", err) return &IPCError{Code: ipcErrorInvalid} } - // check if public key of peer equal to device + // ignore peer with public key of device - device.mutex.RLock() - if device.publicKey.Equals(pubKey) { - - // create dummy instance (not added to device) + device.noise.mutex.RLock() + equals := device.noise.publicKey.Equals(publicKey) + device.noise.mutex.RUnlock() + if equals { peer = &Peer{} dummy = true - device.mutex.RUnlock() - logInfo.Println("Ignoring peer with public key of device") - - } else { - - // find peer referenced - - peer, _ = device.peers[pubKey] - device.mutex.RUnlock() - if peer == nil { - peer, err = device.NewPeer(pubKey) - if err != nil { - logError.Println("Failed to create new peer:", err) - return &IPCError{Code: ipcErrorInvalid} - } - } - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - dummy = false - } + // find peer referenced + + peer = device.LookupPeer(publicKey) + + if peer == nil { + peer, err = device.NewPeer(publicKey) + if err != nil { + logError.Println("Failed to create new peer:", err) + return &IPCError{Code: ipcErrorInvalid} + } + logDebug.Println("UAPI: Created new peer:", peer.String()) + } + + peer.mutex.Lock() + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) + peer.mutex.Unlock() + case "remove": // remove currently selected peer from device @@ -249,7 +270,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorInvalid} } if !dummy { - logDebug.Println("Removing", peer.String()) + logDebug.Println("UAPI: Removing peer:", peer.String()) device.RemovePeer(peer.handshake.remoteStatic) } peer = &Peer{} @@ -259,9 +280,12 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // update PSK - peer.mutex.Lock() + logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String()) + + peer.handshake.mutex.Lock() err := peer.handshake.presharedKey.FromHex(value) - peer.mutex.Unlock() + peer.handshake.mutex.Unlock() + if err != nil { logError.Println("Failed to set preshared_key:", err) return &IPCError{Code: ipcErrorInvalid} @@ -271,6 +295,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // set endpoint destination + logDebug.Println("UAPI: Updating endpoint for peer:", peer.String()) + err := func() error { peer.mutex.Lock() defer peer.mutex.Unlock() @@ -292,6 +318,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // update keep-alive interval + logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String()) + secs, err := strconv.ParseUint(value, 10, 16) if err != nil { logError.Println("Failed to set persistent_keepalive_interval:", err) @@ -316,25 +344,41 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } case "replace_allowed_ips": + + logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String()) + if value != "true" { logError.Println("Failed to set replace_allowed_ips, invalid value:", value) return &IPCError{Code: ipcErrorInvalid} } - if !dummy { - device.routingTable.RemovePeer(peer) + + if dummy { + continue } + device.routing.mutex.Lock() + device.routing.table.RemovePeer(peer) + device.routing.mutex.Unlock() + case "allowed_ip": + + logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String()) + _, network, err := net.ParseCIDR(value) if err != nil { logError.Println("Failed to set allowed_ip:", err) return &IPCError{Code: ipcErrorInvalid} } - ones, _ := network.Mask.Size() - if !dummy { - device.routingTable.Insert(network.IP, uint(ones), peer) + + if dummy { + continue } + ones, _ := network.Mask.Size() + device.routing.mutex.Lock() + device.routing.table.Insert(network.IP, uint(ones), peer) + device.routing.mutex.Unlock() + default: logError.Println("Invalid UAPI key (peer configuration):", key) return &IPCError{Code: ipcErrorInvalid}