diff --git a/src/receive.go b/src/receive.go index 97646d8..09fca77 100644 --- a/src/receive.go +++ b/src/receive.go @@ -128,7 +128,7 @@ func (device *Device) RoutineReceiveIncomming() { // read next datagram - size, raddr, err := conn.ReadFromUDP(buffer[:]) // Blocks sometimes + size, raddr, err := conn.ReadFromUDP(buffer[:]) if err != nil { break @@ -222,7 +222,7 @@ func (device *Device) RoutineReceiveIncomming() { } func (device *Device) RoutineDecryption() { - var elem *QueueInboundElement + var nonce [chacha20poly1305.NonceSize]byte logDebug := device.log.Debug @@ -230,50 +230,51 @@ func (device *Device) RoutineDecryption() { for { select { - case elem = <-device.queue.decryption: case <-device.signal.stop: + logDebug.Println("Routine, decryption worker, stopped") return - } - // check if dropped + case elem := <-device.queue.decryption: - if elem.IsDropped() { - continue - } + // check if dropped - // split message into fields - - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] - content := elem.packet[MessageTransportOffsetContent:] - - // decrypt with key-pair - - var err error - copy(nonce[4:], counter) - elem.counter = binary.LittleEndian.Uint64(counter) - elem.keyPair.receive.mutex.RLock() - if elem.keyPair.receive.aead == nil { - // very unlikely (the key was deleted during queuing) - elem.Drop() - } else { - elem.packet, err = elem.keyPair.receive.aead.Open( - elem.buffer[:0], - nonce[:], - content, - nil, - ) - if err != nil { - elem.Drop() + if elem.IsDropped() { + continue } + + // split message into fields + + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] + content := elem.packet[MessageTransportOffsetContent:] + + // decrypt with key-pair + + copy(nonce[4:], counter) + elem.counter = binary.LittleEndian.Uint64(counter) + elem.keyPair.receive.mutex.RLock() + if elem.keyPair.receive.aead == nil { + // very unlikely (the key was deleted during queuing) + elem.Drop() + } else { + var err error + elem.packet, err = elem.keyPair.receive.aead.Open( + elem.buffer[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.Drop() + } + } + + elem.keyPair.receive.mutex.RUnlock() + elem.mutex.Unlock() } - elem.keyPair.receive.mutex.RUnlock() - elem.mutex.Unlock() } } /* Handles incomming packets related to handshake - * - * */ func (device *Device) RoutineHandshake() { @@ -473,7 +474,6 @@ func (device *Device) RoutineHandshake() { } func (peer *Peer) RoutineSequentialReceiver() { - var elem *QueueInboundElement device := peer.device @@ -483,118 +483,119 @@ func (peer *Peer) RoutineSequentialReceiver() { logDebug.Println("Routine, sequential receiver, started for peer", peer.id) for { - // wait for decryption select { case <-peer.signal.stop: + logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id) return - case elem = <-peer.queue.inbound: - } - elem.mutex.Lock() - // process packet + case elem := <-peer.queue.inbound: - if elem.IsDropped() { - continue - } + // wait for decryption - // check for replay - - if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { - continue - } - - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - peer.KeepKeyFreshReceiving() - - // check if using new key-pair - - kp := &peer.keyPairs - kp.mutex.Lock() - if kp.next == elem.keyPair { - peer.TimerHandshakeComplete() - if kp.previous != nil { - device.DeleteKeyPair(kp.previous) - } - kp.previous = kp.current - kp.current = kp.next - kp.next = nil - } - kp.mutex.Unlock() - - // check for keep-alive - - if len(elem.packet) == 0 { - logDebug.Println("Received keep-alive from", peer.String()) - continue - } - peer.TimerDataReceived() - - // verify source and strip padding - - switch elem.packet[0] >> 4 { - case ipv4.Version: - - // strip padding - - if len(elem.packet) < ipv4.HeaderLen { + elem.mutex.Lock() + if elem.IsDropped() { continue } - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + // check for replay + + if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { continue } - elem.packet = elem.packet[:length] + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() + peer.KeepKeyFreshReceiving() - // verify IPv4 source + // check if using new key-pair - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.routingTable.LookupIPv4(src) != peer { - logInfo.Println("Packet with unallowed source IP from", peer.String()) + kp := &peer.keyPairs + kp.mutex.Lock() + if kp.next == elem.keyPair { + peer.TimerHandshakeComplete() + if kp.previous != nil { + device.DeleteKeyPair(kp.previous) + } + kp.previous = kp.current + kp.current = kp.next + kp.next = nil + } + kp.mutex.Unlock() + + // check for keep-alive + + if len(elem.packet) == 0 { + logDebug.Println("Received keep-alive from", peer.String()) + continue + } + peer.TimerDataReceived() + + // verify source and strip padding + + switch elem.packet[0] >> 4 { + case ipv4.Version: + + // strip padding + + if len(elem.packet) < ipv4.HeaderLen { + continue + } + + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + + elem.packet = elem.packet[:length] + + // verify IPv4 source + + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.routingTable.LookupIPv4(src) != peer { + logInfo.Println("Packet with unallowed source IP from", peer.String()) + continue + } + + case ipv6.Version: + + // strip padding + + if len(elem.packet) < ipv6.HeaderLen { + continue + } + + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + + elem.packet = elem.packet[:length] + + // verify IPv6 source + + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.routingTable.LookupIPv6(src) != peer { + logInfo.Println("Packet with unallowed source IP from", peer.String()) + continue + } + + default: + logInfo.Println("Packet with invalid IP version from", peer.String()) continue } - case ipv6.Version: + // write to tun - // strip padding - - if len(elem.packet) < ipv6.HeaderLen { - continue + atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + _, err := device.tun.device.Write(elem.packet) + device.PutMessageBuffer(elem.buffer) + if err != nil { + logError.Println("Failed to write packet to TUN device:", err) } - - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - continue - } - - elem.packet = elem.packet[:length] - - // verify IPv6 source - - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.routingTable.LookupIPv6(src) != peer { - logInfo.Println("Packet with unallowed source IP from", peer.String()) - continue - } - - default: - logInfo.Println("Packet with invalid IP version from", peer.String()) - continue - } - - // write to tun - - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) - _, err := device.tun.device.Write(elem.packet) - device.PutMessageBuffer(elem.buffer) - if err != nil { - logError.Println("Failed to write packet to TUN device:", err) } } } diff --git a/src/send.go b/src/send.go index c598ad4..e9dfb54 100644 --- a/src/send.go +++ b/src/send.go @@ -35,7 +35,7 @@ type QueueOutboundElement struct { dropped int32 mutex sync.Mutex buffer *[MaxMessageSize]byte // slice holding the packet data - packet []byte // slice of "data" (always!) + packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption keyPair *KeyPair // key-pair for encryption peer *Peer // related peer @@ -52,11 +52,6 @@ func (peer *Peer) FlushNonceQueue() { } } -var ( - ErrorNoEndpoint = errors.New("No known endpoint for peer") - ErrorNoConnection = errors.New("No UDP socket for device") -) - func (device *Device) NewOutboundElement() *QueueOutboundElement { return &QueueOutboundElement{ dropped: AtomicFalse, @@ -118,14 +113,13 @@ func (peer *Peer) SendBuffer(buffer []byte) (int, error) { defer peer.mutex.RUnlock() endpoint := peer.endpoint - conn := peer.device.net.conn - if endpoint == nil { - return 0, ErrorNoEndpoint + return 0, errors.New("No known endpoint for peer") } + conn := peer.device.net.conn if conn == nil { - return 0, ErrorNoConnection + return 0, errors.New("No UDP socket for device") } return conn.WriteToUDP(buffer, endpoint) @@ -189,16 +183,6 @@ func (device *Device) RoutineReadFromTUN() { continue } - // check if known endpoint (drop early) - - peer.mutex.RLock() - if peer.endpoint == nil { - peer.mutex.RUnlock() - logDebug.Println("No known endpoint for peer", peer.String()) - continue - } - peer.mutex.RUnlock() - // insert into nonce/pre-handshake queue signalSend(peer.signal.handshakeReset) @@ -211,86 +195,61 @@ func (device *Device) RoutineReadFromTUN() { * Then assigns nonces to packets sequentially * and creates "work" structs for workers * - * TODO: Avoid dynamic allocation of work queue elements - * * Obs. A single instance per peer */ func (peer *Peer) RoutineNonce() { var keyPair *KeyPair - var elem *QueueOutboundElement device := peer.device logDebug := device.log.Debug logDebug.Println("Routine, nonce worker, started for peer", peer.String()) - func() { + for { + NextPacket: + select { + case <-peer.signal.stop: + return - for { - NextPacket: - - // wait for packet - - if elem == nil { - select { - case elem = <-peer.queue.nonce: - case <-peer.signal.stop: - return - } - } + case elem := <-peer.queue.nonce: // wait for key pair for { - select { - case <-peer.signal.newKeyPair: - default: - } - keyPair = peer.keyPairs.Current() if keyPair != nil && keyPair.sendNonce < RejectAfterMessages { if time.Now().Sub(keyPair.created) < RejectAfterTime { break } } + signalSend(peer.signal.handshakeBegin) logDebug.Println("Awaiting key-pair for", peer.String()) select { case <-peer.signal.newKeyPair: - logDebug.Println("Key-pair negotiated for", peer.String()) - goto NextPacket - case <-peer.signal.flushNonceQueue: logDebug.Println("Clearing queue for", peer.String()) peer.FlushNonceQueue() - elem = nil goto NextPacket - case <-peer.signal.stop: return } } - // process current packet + // populate work element - if elem != nil { + elem.peer = peer + elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 + elem.keyPair = keyPair + elem.dropped = AtomicFalse + elem.mutex.Lock() - // create work element + // add to parallel and sequential queue - elem.keyPair = keyPair - elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 - elem.dropped = AtomicFalse - elem.peer = peer - elem.mutex.Lock() - - // add to parallel and sequential queue - - addToEncryptionQueue(device.queue.encryption, elem) - addToOutboundQueue(peer.queue.outbound, elem) - elem = nil - } + addToEncryptionQueue(device.queue.encryption, elem) + addToOutboundQueue(peer.queue.outbound, elem) } - }() + } } /* Encrypts the elements in the queue @@ -300,7 +259,6 @@ func (peer *Peer) RoutineNonce() { */ func (device *Device) RoutineEncryption() { - var elem *QueueOutboundElement var nonce [chacha20poly1305.NonceSize]byte logDebug := device.log.Debug @@ -311,62 +269,62 @@ func (device *Device) RoutineEncryption() { // fetch next element select { - case elem = <-device.queue.encryption: case <-device.signal.stop: logDebug.Println("Routine, encryption worker, stopped") return - } - // check if dropped + case elem := <-device.queue.encryption: - if elem.IsDropped() { - continue - } + // check if dropped - // populate header fields - - header := elem.buffer[:MessageTransportHeaderSize] - - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] - - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - - // pad content to MTU size - - mtu := int(atomic.LoadInt32(&device.tun.mtu)) - pad := len(elem.packet) % PaddingMultiple - if pad > 0 { - for i := 0; i < PaddingMultiple-pad && len(elem.packet) < mtu; i++ { - elem.packet = append(elem.packet, 0) + if elem.IsDropped() { + continue } - // TODO: How good is this code + + // populate header fields + + header := elem.buffer[:MessageTransportHeaderSize] + + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] + + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + + // pad content to multiple of 16 + + mtu := int(atomic.LoadInt32(&device.tun.mtu)) + rem := len(elem.packet) % PaddingMultiple + if rem > 0 { + for i := 0; i < PaddingMultiple-rem && len(elem.packet) < mtu; i++ { + elem.packet = append(elem.packet, 0) + } + } + + // encrypt content (append to header) + + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.keyPair.send.mutex.RLock() + if elem.keyPair.send.aead == nil { + // very unlikely (the key was deleted during queuing) + elem.Drop() + } else { + elem.packet = elem.keyPair.send.aead.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + } + elem.mutex.Unlock() + elem.keyPair.send.mutex.RUnlock() + + // refresh key if necessary + + elem.peer.KeepKeyFreshSending() } - - // encrypt content (append to header) - - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.keyPair.send.mutex.RLock() - if elem.keyPair.send.aead == nil { - // very unlikely (the key was deleted during queuing) - elem.Drop() - } else { - elem.packet = elem.keyPair.send.aead.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - } - elem.keyPair.send.mutex.RUnlock() - elem.mutex.Unlock() - - // refresh key if necessary - - elem.peer.KeepKeyFreshSending() } } @@ -399,6 +357,7 @@ func (peer *Peer) RoutineSequentialSender() { _, err := peer.SendBuffer(elem.packet) device.PutMessageBuffer(elem.buffer) if err != nil { + logDebug.Println("Failed to send authenticated packet to peer", peer.String()) continue } atomic.AddUint64(&peer.stats.txBytes, length)