From b56af1829d0368c893f8e9e14894f9563afb60ef Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sun, 13 May 2018 23:14:43 +0200 Subject: [PATCH] More refactoring --- conn.go | 20 ++++++--- device.go | 59 ++++++++----------------- noise-protocol.go | 67 ++++++++++++++-------------- noise_test.go | 8 ++-- peer.go | 70 +++++++++++++++--------------- receive.go | 78 +++++++-------------------------- send.go | 108 ++++++++++++++++++++++++++++++++++++---------- timers.go | 26 +---------- uapi.go | 34 ++++++--------- 9 files changed, 219 insertions(+), 251 deletions(-) diff --git a/conn.go b/conn.go index 082bbca..4b347ec 100644 --- a/conn.go +++ b/conn.go @@ -74,9 +74,6 @@ func (device *Device) BindSetMark(mark uint32) error { device.net.mutex.Lock() defer device.net.mutex.Unlock() - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - // check if modified if device.net.fwmark == mark { @@ -92,6 +89,18 @@ func (device *Device) BindSetMark(mark uint32) error { } } + // clear cached source addresses + + device.peers.mutex.RLock() + for _, peer := range device.peers.keyMap { + peer.mutex.Lock() + defer peer.mutex.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.mutex.RUnlock() + return nil } @@ -100,9 +109,6 @@ func (device *Device) BindUpdate() error { device.net.mutex.Lock() defer device.net.mutex.Unlock() - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - // close existing sockets if err := unsafeCloseBind(device); err != nil { @@ -135,6 +141,7 @@ func (device *Device) BindUpdate() error { // clear cached source addresses + device.peers.mutex.RLock() for _, peer := range device.peers.keyMap { peer.mutex.Lock() defer peer.mutex.Unlock() @@ -142,6 +149,7 @@ func (device *Device) BindUpdate() error { peer.endpoint.ClearSrc() } } + device.peers.mutex.RUnlock() // start receiving routines diff --git a/device.go b/device.go index 34af419..cc12ac9 100644 --- a/device.go +++ b/device.go @@ -38,17 +38,12 @@ type Device struct { fwmark uint32 // mark value (0 = disabled) } - noise struct { + staticIdentity struct { mutex sync.RWMutex privateKey NoisePrivateKey publicKey NoisePublicKey } - routing struct { - mutex sync.RWMutex - table AllowedIPs - } - peers struct { mutex sync.RWMutex keyMap map[NoisePublicKey]*Peer @@ -56,8 +51,9 @@ type Device struct { // unprotected / "self-synchronising resources" - indexTable IndexTable - mac CookieChecker + allowedips AllowedIPs + indexTable IndexTable + cookieChecker CookieChecker rate struct { underLoadUntil atomic.Value @@ -87,15 +83,13 @@ type Device struct { /* Converts the peer into a "zombie", which remains in the peer map, * but processes no packets and does not exists in the routing table. * - * Must hold: - * device.peers.mutex : exclusive lock - * device.routing : exclusive lock + * Must hold device.peers.mutex. */ func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { // stop routing and processing of packets - device.routing.table.RemoveByPeer(peer) + device.allowedips.RemoveByPeer(peer) peer.Stop() // remove from peer map @@ -131,19 +125,19 @@ func deviceUpdateState(device *Device) { device.isUp.Set(false) break } - device.peers.mutex.Lock() + device.peers.mutex.RLock() for _, peer := range device.peers.keyMap { peer.Start() } - device.peers.mutex.Unlock() + device.peers.mutex.RUnlock() case false: device.BindClose() - device.peers.mutex.Lock() + device.peers.mutex.RLock() for _, peer := range device.peers.keyMap { peer.Stop() } - device.peers.mutex.Unlock() + device.peers.mutex.RUnlock() } // update state variables @@ -199,11 +193,8 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { // lock required resources - device.noise.mutex.Lock() - defer device.noise.mutex.Unlock() - - device.routing.mutex.Lock() - defer device.routing.mutex.Unlock() + device.staticIdentity.mutex.Lock() + defer device.staticIdentity.mutex.Unlock() device.peers.mutex.Lock() defer device.peers.mutex.Unlock() @@ -224,13 +215,13 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { // update key material - device.noise.privateKey = sk - device.noise.publicKey = publicKey - device.mac.Init(publicKey) + device.staticIdentity.privateKey = sk + device.staticIdentity.publicKey = publicKey + device.cookieChecker.Init(publicKey) // do static-static DH pre-computations - rmKey := device.noise.privateKey.IsZero() + rmKey := device.staticIdentity.privateKey.IsZero() for key, peer := range device.peers.keyMap { @@ -239,7 +230,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { if rmKey { hs.precomputedStaticStatic = [NoisePublicKeySize]byte{} } else { - hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic) + hs.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(hs.remoteStatic) } if isZero(hs.precomputedStaticStatic[:]) { @@ -281,10 +272,10 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { device.rate.limiter.Init() device.rate.underLoadUntil.Store(time.Time{}) - // initialize noise & crypt-key routine + // initialize staticIdentity & crypt-key routine device.indexTable.Init() - device.routing.table.Reset() + device.allowedips.Reset() // setup buffer pool @@ -333,12 +324,6 @@ func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { } func (device *Device) RemovePeer(key NoisePublicKey) { - device.noise.mutex.Lock() - defer device.noise.mutex.Unlock() - - device.routing.mutex.Lock() - defer device.routing.mutex.Unlock() - device.peers.mutex.Lock() defer device.peers.mutex.Unlock() @@ -351,12 +336,6 @@ func (device *Device) RemovePeer(key NoisePublicKey) { } func (device *Device) RemoveAllPeers() { - device.noise.mutex.Lock() - defer device.noise.mutex.Unlock() - - device.routing.mutex.Lock() - defer device.routing.mutex.Unlock() - device.peers.mutex.Lock() defer device.peers.mutex.Unlock() diff --git a/noise-protocol.go b/noise-protocol.go index f72dcc4..ffc2b50 100644 --- a/noise-protocol.go +++ b/noise-protocol.go @@ -107,6 +107,7 @@ type Handshake struct { precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret lastTimestamp tai64n.Timestamp lastInitiationConsumption time.Time + lastSentHandshake time.Time } var ( @@ -153,8 +154,8 @@ func init() { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() + device.staticIdentity.mutex.RLock() + defer device.staticIdentity.mutex.RUnlock() handshake := &peer.handshake handshake.mutex.Lock() @@ -206,7 +207,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e ss[:], ) aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) }() handshake.mixHash(msg.Static[:]) @@ -240,10 +241,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { return nil } - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() + device.staticIdentity.mutex.RLock() + defer device.staticIdentity.mutex.RUnlock() - mixHash(&hash, &InitialHash, device.noise.publicKey[:]) + mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) mixHash(&hash, &hash, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) @@ -253,7 +254,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { var peerPK NoisePublicKey func() { var key [chacha20poly1305.KeySize]byte - ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) KDF2(&chainKey, &key, chainKey[:], ss[:]) aead, _ := chacha20poly1305.New(key[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) @@ -422,8 +423,8 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // lock private key for reading - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() + device.staticIdentity.mutex.RLock() + defer device.staticIdentity.mutex.RUnlock() // finish 3-way DH @@ -437,7 +438,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { }() func() { - ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) mixKey(&chainKey, &chainKey, ss[:]) setZero(ss[:]) }() @@ -490,7 +491,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { /* Derives a new keypair from the current handshake state * */ -func (peer *Peer) DeriveNewKeypair() error { +func (peer *Peer) BeginSymmetricSession() error { device := peer.device handshake := &peer.handshake handshake.mutex.Lock() @@ -552,50 +553,48 @@ func (peer *Peer) DeriveNewKeypair() error { // rotate key pairs - kp := &peer.keypairs - kp.mutex.Lock() + keypairs := &peer.keypairs + keypairs.mutex.Lock() + defer keypairs.mutex.Unlock() - peer.timersSessionDerived() - - previous := kp.previous - next := kp.next - current := kp.current + previous := keypairs.previous + next := keypairs.next + current := keypairs.current if isInitiator { if next != nil { - kp.next = nil - kp.previous = next + keypairs.next = nil + keypairs.previous = next device.DeleteKeypair(current) } else { - kp.previous = current + keypairs.previous = current } device.DeleteKeypair(previous) - kp.current = keypair + keypairs.current = keypair } else { - kp.next = keypair + keypairs.next = keypair device.DeleteKeypair(next) - kp.previous = nil + keypairs.previous = nil device.DeleteKeypair(previous) } - kp.mutex.Unlock() return nil } func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { - kp := &peer.keypairs - if kp.next != receivedKeypair { + keypairs := &peer.keypairs + if keypairs.next != receivedKeypair { return false } - kp.mutex.Lock() - defer kp.mutex.Unlock() - if kp.next != receivedKeypair { + keypairs.mutex.Lock() + defer keypairs.mutex.Unlock() + if keypairs.next != receivedKeypair { return false } - old := kp.previous - kp.previous = kp.current + old := keypairs.previous + keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) - kp.current = kp.next - kp.next = nil + keypairs.current = keypairs.next + keypairs.next = nil return true } diff --git a/noise_test.go b/noise_test.go index ce32097..8e1bd89 100644 --- a/noise_test.go +++ b/noise_test.go @@ -36,8 +36,8 @@ func TestNoiseHandshake(t *testing.T) { defer dev1.Close() defer dev2.Close() - peer1, _ := dev2.NewPeer(dev1.noise.privateKey.publicKey()) - peer2, _ := dev1.NewPeer(dev2.noise.privateKey.publicKey()) + peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) + peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) assertEqual( t, @@ -102,8 +102,8 @@ func TestNoiseHandshake(t *testing.T) { t.Log("deriving keys") - key1 := peer1.DeriveNewKeypair() - key2 := peer2.DeriveNewKeypair() + key1 := peer1.BeginSymmetricSession() + key2 := peer2.BeginSymmetricSession() if key1 == nil { t.Fatal("failed to dervice keypair for peer 1") diff --git a/peer.go b/peer.go index d574c71..1151341 100644 --- a/peer.go +++ b/peer.go @@ -19,7 +19,7 @@ const ( type Peer struct { isRunning AtomicBool - mutex sync.RWMutex + mutex sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer keypairs Keypairs handshake Handshake device *Device @@ -42,7 +42,6 @@ type Peer struct { handshakeAttempts uint needAnotherKeepalive bool sentLastMinuteHandshake bool - lastSentHandshake time.Time } signals struct { @@ -64,7 +63,7 @@ type Peer struct { stop chan struct{} // size 0, stop all go routines in peer } - mac CookieGenerator + cookieGenerator CookieGenerator } func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { @@ -75,11 +74,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // lock resources - device.state.mutex.Lock() - defer device.state.mutex.Unlock() - - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() + device.staticIdentity.mutex.RLock() + defer device.staticIdentity.mutex.RUnlock() device.peers.mutex.Lock() defer device.peers.mutex.Unlock() @@ -96,7 +92,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.mutex.Lock() defer peer.mutex.Unlock() - peer.mac.Init(pk) + peer.cookieGenerator.Init(pk) peer.device = device peer.isRunning.Set(false) @@ -113,7 +109,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake := &peer.handshake handshake.mutex.Lock() handshake.remoteStatic = pk - handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk) + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) handshake.mutex.Unlock() // reset endpoint @@ -191,6 +187,7 @@ func (peer *Peer) Start() { peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) peer.timersInit() + peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) peer.signals.newKeypairArrived = make(chan struct{}, 1) peer.signals.flushNonceQueue = make(chan struct{}, 1) @@ -204,6 +201,32 @@ func (peer *Peer) Start() { peer.isRunning.Set(true) } +func (peer *Peer) ZeroAndFlushAll() { + device := peer.device + + // clear key pairs + + keypairs := &peer.keypairs + keypairs.mutex.Lock() + device.DeleteKeypair(keypairs.previous) + device.DeleteKeypair(keypairs.current) + device.DeleteKeypair(keypairs.next) + keypairs.previous = nil + keypairs.current = nil + keypairs.next = nil + keypairs.mutex.Unlock() + + // clear handshake state + + handshake := &peer.handshake + handshake.mutex.Lock() + device.indexTable.Delete(handshake.localIndex) + handshake.Clear() + handshake.mutex.Unlock() + + peer.FlushNonceQueue() +} + func (peer *Peer) Stop() { // prevent simultaneous start/stop operations @@ -215,8 +238,7 @@ func (peer *Peer) Stop() { return } - device := peer.device - device.log.Debug.Println(peer, ": Stopping...") + peer.device.log.Debug.Println(peer, ": Stopping...") peer.timersStop() @@ -232,27 +254,5 @@ func (peer *Peer) Stop() { close(peer.queue.outbound) close(peer.queue.inbound) - // clear key pairs - - kp := &peer.keypairs - kp.mutex.Lock() - - device.DeleteKeypair(kp.previous) - device.DeleteKeypair(kp.current) - device.DeleteKeypair(kp.next) - - kp.previous = nil - kp.current = nil - kp.next = nil - kp.mutex.Unlock() - - // clear handshake state - - hs := &peer.handshake - hs.mutex.Lock() - device.indexTable.Delete(hs.localIndex) - hs.Clear() - hs.mutex.Unlock() - - peer.FlushNonceQueue() + peer.ZeroAndFlushAll() } diff --git a/receive.go b/receive.go index 64253e6..77062fa 100644 --- a/receive.go +++ b/receive.go @@ -107,8 +107,8 @@ func (peer *Peer) keepKeyFreshReceiving() { if peer.timers.sentLastMinuteHandshake { return } - kp := peer.keypairs.Current() - if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { + keypair := peer.keypairs.Current() + if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { peer.timers.sentLastMinuteHandshake = true peer.SendHandshakeInitiation(false) } @@ -325,7 +325,6 @@ func (device *Device) RoutineHandshake() { logDebug.Println("Routine: handshake worker - started") - var temp [MessageHandshakeSize]byte var elem QueueHandshakeElement var ok bool @@ -367,52 +366,28 @@ func (device *Device) RoutineHandshake() { // consume reply if peer := entry.peer; peer.isRunning.Get() { - peer.mac.ConsumeReply(&reply) + peer.cookieGenerator.ConsumeReply(&reply) } continue case MessageInitiationType, MessageResponseType: - // check mac fields and ratelimit + // check mac fields and maybe ratelimit - if !device.mac.CheckMAC1(elem.packet) { + if !device.cookieChecker.CheckMAC1(elem.packet) { logDebug.Println("Received packet with invalid mac1") continue } // endpoints destination address is the source of the datagram - srcBytes := elem.endpoint.DstToBytes() - if device.IsUnderLoad() { // verify MAC2 field - if !device.mac.CheckMAC2(elem.packet, srcBytes) { - - // construct cookie reply - - logDebug.Println( - "Sending cookie reply to:", - elem.endpoint.DstToString(), - ) - - sender := binary.LittleEndian.Uint32(elem.packet[4:8]) - reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) - if err != nil { - logError.Println("Failed to create cookie reply:", err) - continue - } - - // marshal and send reply - - writer := bytes.NewBuffer(temp[:0]) - binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), elem.endpoint) - if err != nil { - logDebug.Println("Failed to send cookie reply:", err) - } + if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { + device.SendHandshakeCookie(&elem) continue } @@ -467,34 +442,7 @@ func (device *Device) RoutineHandshake() { logDebug.Println(peer, ": Received handshake initiation") - // create response - - response, err := device.CreateMessageResponse(peer) - if err != nil { - logError.Println("Failed to create response message:", err) - continue - } - - if peer.DeriveNewKeypair() != nil { - continue - } - - logDebug.Println(peer, ": Sending handshake response") - - writer := bytes.NewBuffer(temp[:0]) - binary.Write(writer, binary.LittleEndian, response) - packet := writer.Bytes() - peer.mac.AddMacs(packet) - - // send response - - peer.timers.lastSentHandshake = time.Now() - err = peer.SendBuffer(packet) - if err == nil { - peer.timersAnyAuthenticatedPacketTraversal() - } else { - logError.Println(peer, ": Failed to send handshake response", err) - } + peer.SendHandshakeResponse() case MessageResponseType: @@ -534,10 +482,14 @@ func (device *Device) RoutineHandshake() { // derive keypair - if peer.DeriveNewKeypair() != nil { + err = peer.BeginSymmetricSession() + + if err != nil { + logError.Println(peer, ": Failed to derive keypair:", err) continue } + peer.timersSessionDerived() peer.timersHandshakeComplete() peer.SendKeepalive() select { @@ -640,7 +592,7 @@ func (peer *Peer) RoutineSequentialReceiver() { // verify IPv4 source src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.routing.table.LookupIPv4(src) != peer { + if device.allowedips.LookupIPv4(src) != peer { logInfo.Println( "IPv4 packet with disallowed source address from", peer, @@ -668,7 +620,7 @@ func (peer *Peer) RoutineSequentialReceiver() { // verify IPv6 source src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.routing.table.LookupIPv6(src) != peer { + if device.allowedips.LookupIPv6(src) != peer { logInfo.Println( peer, "sent packet with disallowed IPv6 source", diff --git a/send.go b/send.go index a8ec28c..a670c4d 100644 --- a/send.go +++ b/send.go @@ -121,52 +121,114 @@ func (peer *Peer) SendKeepalive() bool { } } -/* Sends a new handshake initiation message to the peer (endpoint) - */ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { if !isRetry { peer.timers.handshakeAttempts = 0 } - if time.Now().Sub(peer.timers.lastSentHandshake) < RekeyTimeout { + peer.handshake.mutex.RLock() + if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout { + peer.handshake.mutex.RUnlock() return nil } - peer.timers.lastSentHandshake = time.Now() //TODO: locking for this variable? + peer.handshake.mutex.RUnlock() - // create initiation message - - msg, err := peer.device.CreateMessageInitiation(peer) - if err != nil { - return err + peer.handshake.mutex.Lock() + if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout { + peer.handshake.mutex.Unlock() + return nil } + peer.handshake.lastSentHandshake = time.Now() + peer.handshake.mutex.Unlock() peer.device.log.Debug.Println(peer, ": Sending handshake initiation") - // marshal handshake message + msg, err := peer.device.CreateMessageInitiation(peer) + if err != nil { + peer.device.log.Error.Println(peer, ": Failed to create initiation message:", err) + return err + } var buff [MessageInitiationSize]byte writer := bytes.NewBuffer(buff[:0]) binary.Write(writer, binary.LittleEndian, msg) packet := writer.Bytes() - peer.mac.AddMacs(packet) - - // send to endpoint + peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() + + err = peer.SendBuffer(packet) + if err != nil { + peer.device.log.Error.Println(peer, ": Failed to send handshake initiation", err) + } peer.timersHandshakeInitiated() - return peer.SendBuffer(packet) + + return err +} + +func (peer *Peer) SendHandshakeResponse() error { + peer.handshake.mutex.Lock() + peer.handshake.lastSentHandshake = time.Now() + peer.handshake.mutex.Unlock() + + peer.device.log.Debug.Println(peer, ": Sending handshake response") + + response, err := peer.device.CreateMessageResponse(peer) + if err != nil { + peer.device.log.Error.Println(peer, ": Failed to create response message:", err) + return err + } + + var buff [MessageResponseSize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, response) + packet := writer.Bytes() + peer.cookieGenerator.AddMacs(packet) + + err = peer.BeginSymmetricSession() + if err != nil { + peer.device.log.Error.Println(peer, ": Failed to derive keypair:", err) + return err + } + + peer.timersSessionDerived() + peer.timersAnyAuthenticatedPacketTraversal() + + err = peer.SendBuffer(packet) + if err != nil { + peer.device.log.Error.Println(peer, ": Failed to send handshake response", err) + } + return err +} + +func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { + + device.log.Debug.Println("Sending cookie reply to:", initiatingElem.endpoint.DstToString()) + + sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) + reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) + if err != nil { + device.log.Error.Println("Failed to create cookie reply:", err) + return err + } + + var buff [MessageCookieReplySize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, reply) + device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) + if err != nil { + device.log.Error.Println("Failed to send cookie reply:", err) + } + return err } -/* Called when a new authenticated message has been send - * - */ func (peer *Peer) keepKeyFreshSending() { - kp := peer.keypairs.Current() - if kp == nil { + keypair := peer.keypairs.Current() + if keypair == nil { return } - nonce := atomic.LoadUint64(&kp.sendNonce) - if nonce > RekeyAfterMessages || (kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime) { + nonce := atomic.LoadUint64(&keypair.sendNonce) + if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Now().Sub(keypair.created) > RekeyAfterTime) { peer.SendHandshakeInitiation(false) } } @@ -217,14 +279,14 @@ func (device *Device) RoutineReadFromTUN() { continue } dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.routing.table.LookupIPv4(dst) + peer = device.allowedips.LookupIPv4(dst) case ipv6.Version: if len(elem.packet) < ipv6.HeaderLen { continue } dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.routing.table.LookupIPv6(dst) + peer = device.allowedips.LookupIPv6(dst) default: logDebug.Println("Received packet with unknown IP version") diff --git a/timers.go b/timers.go index 9e633ee..e132376 100644 --- a/timers.go +++ b/timers.go @@ -104,30 +104,7 @@ func expiredNewHandshake(peer *Peer) { func expiredZeroKeyMaterial(peer *Peer) { peer.device.log.Debug.Printf(":%s Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) - - hs := &peer.handshake - hs.mutex.Lock() - - kp := &peer.keypairs - kp.mutex.Lock() - - if kp.previous != nil { - peer.device.DeleteKeypair(kp.previous) - kp.previous = nil - } - if kp.current != nil { - peer.device.DeleteKeypair(kp.current) - kp.current = nil - } - if kp.next != nil { - peer.device.DeleteKeypair(kp.next) - kp.next = nil - } - kp.mutex.Unlock() - - peer.device.indexTable.Delete(hs.localIndex) - hs.Clear() - hs.mutex.Unlock() + peer.ZeroAndFlushAll() } func expiredPersistentKeepalive(peer *Peer) { @@ -209,7 +186,6 @@ func (peer *Peer) timersInit() { peer.timers.handshakeAttempts = 0 peer.timers.sentLastMinuteHandshake = false peer.timers.needAnotherKeepalive = false - peer.timers.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) } func (peer *Peer) timersStop() { diff --git a/uapi.go b/uapi.go index 90c400a..53a598e 100644 --- a/uapi.go +++ b/uapi.go @@ -46,19 +46,16 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { device.net.mutex.RLock() defer device.net.mutex.RUnlock() - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() + device.staticIdentity.mutex.RLock() + defer device.staticIdentity.mutex.RUnlock() - device.routing.mutex.RLock() - defer device.routing.mutex.RUnlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() + device.peers.mutex.RLock() + defer device.peers.mutex.RUnlock() // serialize device related values - if !device.noise.privateKey.IsZero() { - send("private_key=" + device.noise.privateKey.ToHex()) + if !device.staticIdentity.privateKey.IsZero() { + send("private_key=" + device.staticIdentity.privateKey.ToHex()) } if device.net.port != 0 { @@ -91,7 +88,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes)) send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) - for _, ip := range device.routing.table.EntriesForPeer(peer) { + for _, ip := range device.allowedips.EntriesForPeer(peer) { send("allowed_ip=" + ip.String()) } @@ -234,13 +231,12 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // ignore peer with public key of device - device.noise.mutex.RLock() - equals := device.noise.publicKey.Equals(publicKey) - device.noise.mutex.RUnlock() + device.staticIdentity.mutex.RLock() + dummy = device.staticIdentity.publicKey.Equals(publicKey) + device.staticIdentity.mutex.RUnlock() - if equals { + if dummy { peer = &Peer{} - dummy = true } // find peer referenced @@ -348,9 +344,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { continue } - device.routing.mutex.Lock() - device.routing.table.RemoveByPeer(peer) - device.routing.mutex.Unlock() + device.allowedips.RemoveByPeer(peer) case "allowed_ip": @@ -367,9 +361,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } ones, _ := network.Mask.Size() - device.routing.mutex.Lock() - device.routing.table.Insert(network.IP, uint(ones), peer) - device.routing.mutex.Unlock() + device.allowedips.Insert(network.IP, uint(ones), peer) default: logError.Println("Invalid UAPI key (peer configuration):", key)