diff --git a/src/config.go b/src/config.go index e2d7f20..d952a3a 100644 --- a/src/config.go +++ b/src/config.go @@ -84,13 +84,47 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return nil } +func updateUDPConn(device *Device) error { + var err error + netc := &device.net + netc.mutex.Lock() + + // close existing connection + + if netc.conn != nil { + netc.conn.Close() + netc.conn = nil + } + + // open new existing connection + + conn, err := net.ListenUDP("udp", netc.addr) + if err == nil { + netc.conn = conn + signalSend(device.signal.newUDPConn) + } + + netc.mutex.Unlock() + return err +} + +func closeUDPConn(device *Device) { + device.net.mutex.Lock() + device.net.conn = nil + device.net.mutex.Unlock() + println("send signal") + signalSend(device.signal.newUDPConn) +} + func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { scanner := bufio.NewScanner(socket) + logInfo := device.log.Info logError := device.log.Error logDebug := device.log.Debug var peer *Peer + dummy := false deviceConfig := true for scanner.Scan() { @@ -135,17 +169,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { netc := &device.net netc.mutex.Lock() if netc.addr.Port != int(port) { - if netc.conn != nil { - netc.conn.Close() - } netc.addr.Port = int(port) - netc.conn, err = net.ListenUDP("udp", netc.addr) } netc.mutex.Unlock() - if err != nil { - logError.Println("Failed to create UDP listener:", err) - return &IPCError{Code: ipcErrorIO} - } + updateUDPConn(device) + // TODO: Clear source address of all peers case "fwmark": @@ -189,17 +217,30 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { device.mutex.RLock() if device.publicKey.Equals(pubKey) { + + // create dummy instance + + peer = &Peer{} + dummy = true device.mutex.RUnlock() - logError.Println("Public key of peer matches private key of device") - return &IPCError{Code: ipcErrorInvalid} - } + logInfo.Println("Ignoring peer with public key of device") - // find peer referenced + } else { + + // find peer referenced + + peer, _ = device.peers[pubKey] + device.mutex.RUnlock() + if peer == nil { + peer, err = device.NewPeer(pubKey) + if err != nil { + logError.Println("Failed to create new peer:", err) + return &IPCError{Code: ipcErrorInvalid} + } + } + signalSend(peer.signal.handshakeReset) + dummy = false - peer, _ = device.peers[pubKey] - device.mutex.RUnlock() - if peer == nil { - peer = device.NewPeer(pubKey) } case "remove": @@ -207,16 +248,17 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to set remove, invalid value:", value) return &IPCError{Code: ipcErrorInvalid} } - device.RemovePeer(peer.handshake.remoteStatic) - logDebug.Println("Removing", peer.String()) - peer = nil + if !dummy { + logDebug.Println("Removing", peer.String()) + device.RemovePeer(peer.handshake.remoteStatic) + } + peer = &Peer{} + dummy = true case "preshared_key": - err := func() error { - peer.mutex.Lock() - defer peer.mutex.Unlock() - return peer.handshake.presharedKey.FromHex(value) - }() + peer.mutex.Lock() + err := peer.handshake.presharedKey.FromHex(value) + peer.mutex.Unlock() if err != nil { logError.Println("Failed to set preshared_key:", err) return &IPCError{Code: ipcErrorInvalid} @@ -232,6 +274,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { peer.mutex.Lock() peer.endpoint = addr peer.mutex.Unlock() + signalSend(peer.signal.handshakeReset) case "persistent_keepalive_interval": @@ -251,12 +294,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // send immediate keep-alive if old == 0 && secs != 0 { - up, err := device.tun.IsUp() if err != nil { logError.Println("Failed to get tun device status:", err) return &IPCError{Code: ipcErrorIO} } - if up { + if atomic.LoadInt32(&device.isUp) == AtomicTrue && !dummy { peer.SendKeepAlive() } } @@ -266,7 +308,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to set replace_allowed_ips, invalid value:", value) return &IPCError{Code: ipcErrorInvalid} } - device.routingTable.RemovePeer(peer) + if !dummy { + device.routingTable.RemovePeer(peer) + } case "allowed_ip": _, network, err := net.ParseCIDR(value) @@ -275,7 +319,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorInvalid} } ones, _ := network.Mask.Size() - device.routingTable.Insert(network.IP, uint(ones), peer) + if !dummy { + device.routingTable.Insert(network.IP, uint(ones), peer) + } default: logError.Println("Invalid UAPI key (peer configuration):", key) diff --git a/src/constants.go b/src/constants.go index f09ded6..37603e8 100644 --- a/src/constants.go +++ b/src/constants.go @@ -7,16 +7,15 @@ import ( /* Specification constants */ const ( - RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 - RejectAfterMessages = (1 << 64) - (1 << 4) - 1 - RekeyAfterTime = time.Second * 120 - RekeyAttemptTime = time.Second * 90 - RekeyTimeout = time.Second * 5 - RejectAfterTime = time.Second * 180 - KeepaliveTimeout = time.Second * 10 - CookieRefreshTime = time.Second * 120 - MaxHandshakeAttemptTime = time.Second * 90 - PaddingMultiple = 16 + RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 + RejectAfterMessages = (1 << 64) - (1 << 4) - 1 + RekeyAfterTime = time.Second * 120 + RekeyAttemptTime = time.Second * 90 + RekeyTimeout = time.Second * 5 + RejectAfterTime = time.Second * 180 + KeepaliveTimeout = time.Second * 10 + CookieRefreshTime = time.Second * 120 + PaddingMultiple = 16 ) const ( @@ -33,4 +32,5 @@ const ( QueueHandshakeBusySize = QueueHandshakeSize / 8 MinMessageSize = MessageTransportSize // size of keep-alive MaxMessageSize = ((1 << 16) - 1) + MessageTransportHeaderSize + MaxPeers = 1 << 16 ) diff --git a/src/daemon_linux.go b/src/daemon_linux.go index 809c176..730f89e 100644 --- a/src/daemon_linux.go +++ b/src/daemon_linux.go @@ -7,6 +7,8 @@ import ( /* Daemonizes the process on linux * * This is done by spawning and releasing a copy with the --foreground flag + * + * TODO: Use env variable to spawn in background */ func Daemonize() error { diff --git a/src/device.go b/src/device.go index de96f0b..4aa90e3 100644 --- a/src/device.go +++ b/src/device.go @@ -1,13 +1,10 @@ package main import ( - "errors" - "fmt" "net" "runtime" "sync" "sync/atomic" - "time" ) type Device struct { @@ -34,31 +31,45 @@ type Device struct { queue struct { encryption chan *QueueOutboundElement decryption chan *QueueInboundElement - inbound chan *QueueInboundElement handshake chan QueueHandshakeElement } signal struct { - stop chan struct{} + stop chan struct{} // halts all go routines + newUDPConn chan struct{} // a net.conn was set } - underLoad int32 // used as an atomic bool + isUp int32 // atomic bool: interface is up + underLoad int32 // atomic bool: device is under load ratelimiter Ratelimiter peers map[NoisePublicKey]*Peer mac MACStateDevice } +/* Warning: + * The caller must hold the device mutex (write lock) + */ +func removePeerUnsafe(device *Device, key NoisePublicKey) { + peer, ok := device.peers[key] + if !ok { + return + } + peer.mutex.Lock() + device.routingTable.RemovePeer(peer) + delete(device.peers, key) + peer.Close() +} + func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { device.mutex.Lock() defer device.mutex.Unlock() - // check if public key is matching any peer + // remove peers with matching public keys publicKey := sk.publicKey() - for _, peer := range device.peers { + for key, peer := range device.peers { h := &peer.handshake h.mutex.RLock() if h.remoteStatic.Equals(publicKey) { - h.mutex.RUnlock() - return errors.New("Private key matches public key of peer") + removePeerUnsafe(device, key) } h.mutex.RUnlock() } @@ -71,17 +82,19 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { // do DH precomputations - isZero := device.privateKey.IsZero() + rmKey := device.privateKey.IsZero() - for _, peer := range device.peers { + for key, peer := range device.peers { h := &peer.handshake h.mutex.Lock() - if isZero { + if rmKey { h.precomputedStaticStatic = [NoisePublicKeySize]byte{} } else { h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) + if isZero(h.precomputedStaticStatic[:]) { + removePeerUnsafe(device, key) + } } - fmt.Println(h.precomputedStaticStatic) h.mutex.Unlock() } @@ -130,11 +143,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) - device.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) // prepare signals device.signal.stop = make(chan struct{}) + device.signal.newUDPConn = make(chan struct{}, 1) // start workers @@ -145,33 +158,42 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { } go device.RoutineBusyMonitor() - go device.RoutineMTUUpdater() - go device.RoutineWriteToTUN() go device.RoutineReadFromTUN() + go device.RoutineTUNEventReader() go device.RoutineReceiveIncomming() go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) return device } -func (device *Device) RoutineMTUUpdater() { +func (device *Device) RoutineTUNEventReader() { + events := device.tun.Events() logError := device.log.Error - for ; ; time.Sleep(5 * time.Second) { - // load updated MTU - - mtu, err := device.tun.MTU() - if err != nil { - logError.Println("Failed to load updated MTU of device:", err) - continue + for event := range events { + if event&TUNEventMTUUpdate != 0 { + mtu, err := device.tun.MTU() + if err != nil { + logError.Println("Failed to load updated MTU of device:", err) + } else { + if mtu+MessageTransportSize > MaxMessageSize { + mtu = MaxMessageSize - MessageTransportSize + } + atomic.StoreInt32(&device.mtu, int32(mtu)) + } } - // upper bound of mtu - - if mtu+MessageTransportSize > MaxMessageSize { - mtu = MaxMessageSize - MessageTransportSize + if event&TUNEventUp != 0 { + println("handle 1") + atomic.StoreInt32(&device.isUp, AtomicTrue) + updateUDPConn(device) + println("handle 2", device.net.conn) + } + + if event&TUNEventDown != 0 { + atomic.StoreInt32(&device.isUp, AtomicFalse) + closeUDPConn(device) } - atomic.StoreInt32(&device.mtu, int32(mtu)) } } @@ -184,15 +206,7 @@ func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { func (device *Device) RemovePeer(key NoisePublicKey) { device.mutex.Lock() defer device.mutex.Unlock() - - peer, ok := device.peers[key] - if !ok { - return - } - peer.mutex.Lock() - device.routingTable.RemovePeer(peer) - delete(device.peers, key) - peer.Close() + removePeerUnsafe(device, key) } func (device *Device) RemoveAllPeers() { diff --git a/src/macs.go b/src/macs.go index beb5f76..d55e18f 100644 --- a/src/macs.go +++ b/src/macs.go @@ -18,12 +18,13 @@ type MACStateDevice struct { } type MACStatePeer struct { - mutex sync.RWMutex - cookieSet time.Time - cookie [blake2s.Size128]byte - lastMAC1 [blake2s.Size128]byte // TODO: Check if set - keyMAC1 [blake2s.Size]byte - keyMAC2 [blake2s.Size]byte + mutex sync.RWMutex + cookieSet time.Time + cookie [blake2s.Size128]byte + lastMAC1Set bool + lastMAC1 [blake2s.Size128]byte + keyMAC1 [blake2s.Size]byte + keyMAC2 [blake2s.Size]byte } /* Methods for verifing MAC fields @@ -184,6 +185,10 @@ func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool { state.mutex.Lock() defer state.mutex.Unlock() + if !state.lastMAC1Set { + return false + } + _, err := XChaCha20Poly1305Decrypt( cookie[:0], &msg.Nonce, @@ -246,7 +251,7 @@ func (state *MACStatePeer) AddMacs(msg []byte) { mac.Sum(mac1[:0]) }() copy(state.lastMAC1[:], mac1) - // TODO: Set lastMac flag + state.lastMAC1Set = true // set mac2 diff --git a/src/peer.go b/src/peer.go index 9136959..02aac3b 100644 --- a/src/peer.go +++ b/src/peer.go @@ -9,16 +9,14 @@ import ( "time" ) -const () - type Peer struct { id uint mutex sync.RWMutex - endpoint *net.UDPAddr persistentKeepaliveInterval uint64 keyPairs KeyPairs handshake Handshake device *Device + endpoint *net.UDPAddr stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer @@ -34,6 +32,7 @@ type Peer struct { newKeyPair chan struct{} // (size 1) : a new key pair was generated handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake") handshakeCompleted chan struct{} // (size 1) : handshake completed + handshakeReset chan struct{} // (size 1) : reset handshake negotiation state flushNonceQueue chan struct{} // (size 1) : empty queued packets messageSend chan struct{} // (size 1) : a message was send to the peer messageReceived chan struct{} // (size 1) : an authenticated message was received @@ -44,6 +43,7 @@ type Peer struct { keepalivePassive *time.Timer // set upon recieving messages newHandshake *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout) zeroAllKeys *time.Timer // zero all key material (after RejectAfterTime*3) + handshakeDeadline *time.Timer // Current handshake must be completed pendingKeepalivePassive bool pendingNewHandshake bool @@ -59,7 +59,7 @@ type Peer struct { mac MACStatePeer } -func (device *Device) NewPeer(pk NoisePublicKey) *Peer { +func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // create peer peer := new(Peer) @@ -80,11 +80,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { peer.id = device.idCounter device.idCounter += 1 + // check if over limit + + if len(device.peers) >= MaxPeers { + return nil, errors.New("Too many peers") + } + // map public key _, ok := device.peers[pk] if ok { - panic(errors.New("bug: adding existing peer")) + return nil, errors.New("Adding existing peer") } device.peers[pk] = peer device.mutex.Unlock() @@ -108,6 +114,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { peer.signal.stop = make(chan struct{}) peer.signal.newKeyPair = make(chan struct{}, 1) peer.signal.handshakeBegin = make(chan struct{}, 1) + peer.signal.handshakeReset = make(chan struct{}, 1) peer.signal.handshakeCompleted = make(chan struct{}, 1) peer.signal.flushNonceQueue = make(chan struct{}, 1) @@ -117,7 +124,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { go peer.RoutineSequentialSender() go peer.RoutineSequentialReceiver() - return peer + return peer, nil } func (peer *Peer) String() string { diff --git a/src/receive.go b/src/receive.go index fb5c51f..5f46925 100644 --- a/src/receive.go +++ b/src/receive.go @@ -111,113 +111,84 @@ func (device *Device) RoutineBusyMonitor() { func (device *Device) RoutineReceiveIncomming() { - logInfo := device.log.Info logDebug := device.log.Debug logDebug.Println("Routine, receive incomming, started") - var buffer *[MaxMessageSize]byte - for { - // check if stopped + // wait for new conn + + var conn *net.UDPConn select { + case <-device.signal.newUDPConn: + device.net.mutex.RLock() + conn = device.net.conn + device.net.mutex.RUnlock() + case <-device.signal.stop: return - default: } - // read next datagram - - if buffer == nil { - buffer = device.GetMessageBuffer() - } - - // TODO: Take writelock to sleep - device.net.mutex.RLock() - conn := device.net.conn - device.net.mutex.RUnlock() if conn == nil { - time.Sleep(time.Second) continue } - // TODO: Wait for new conn or message - conn.SetReadDeadline(time.Now().Add(time.Second)) + // receive datagrams until closed - size, raddr, err := conn.ReadFromUDP(buffer[:]) - if err != nil || size < MinMessageSize { - continue - } + buffer := device.GetMessageBuffer() - // handle packet + for { - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) + // read next datagram + + size, raddr, err := conn.ReadFromUDP(buffer[:]) // TODO: This is broken + + if err != nil { + break + } + + if size < MinMessageSize { + continue + } + + // check size of packet + + packet := buffer[:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) + + var okay bool - func() { switch msgType { - case MessageInitiationType, MessageResponseType: - - // TODO: Check size early - - // add to handshake queue - - device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - source: raddr, - }, - ) - buffer = nil - - case MessageCookieReplyType: - - // TODO: Queue all the things - - // verify and update peer cookie state - - if len(packet) != MessageCookieReplySize { - return - } - - var reply MessageCookieReply - reader := bytes.NewReader(packet) - err := binary.Read(reader, binary.LittleEndian, &reply) - if err != nil { - logDebug.Println("Failed to decode cookie reply") - return - } - device.ConsumeMessageCookieReply(&reply) + // check if transport case MessageTransportType: - // lookup key pair + // check size - if len(packet) < MessageTransportSize { - return + if len(packet) < MessageTransportType { + continue } + // lookup key pair + receiver := binary.LittleEndian.Uint32( packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], ) value := device.indices.Lookup(receiver) keyPair := value.keyPair if keyPair == nil { - return + continue } // check key-pair expiry if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { - return + continue } - // add to peer queue + // create work element peer := value.peer elem := &QueueInboundElement{ @@ -233,11 +204,33 @@ func (device *Device) RoutineReceiveIncomming() { device.addToInboundQueue(device.queue.decryption, elem) device.addToInboundQueue(peer.queue.inbound, elem) buffer = nil + continue - default: - logInfo.Println("Got unknown message from:", raddr) + // otherwise it is a handshake related packet + + case MessageInitiationType: + okay = len(packet) == MessageInitiationSize + + case MessageResponseType: + okay = len(packet) == MessageResponseSize + + case MessageCookieReplyType: + okay = len(packet) == MessageCookieReplySize } - }() + + if okay { + device.addToHandshakeQueue( + device.queue.handshake, + QueueHandshakeElement{ + msgType: msgType, + buffer: buffer, + packet: packet, + source: raddr, + }, + ) + buffer = device.GetMessageBuffer() + } + } } } @@ -306,154 +299,165 @@ func (device *Device) RoutineHandshake() { return } - func() { + // handle cookie fields and ratelimiting - // verify mac1 + switch elem.msgType { + + case MessageCookieReplyType: + + // verify and update peer cookie state + + var reply MessageCookieReply + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &reply) + if err != nil { + logDebug.Println("Failed to decode cookie reply") + return + } + device.ConsumeMessageCookieReply(&reply) + continue + + case MessageInitiationType, MessageResponseType: + + // check mac fields and ratelimit if !device.mac.CheckMAC1(elem.packet) { logDebug.Println("Received packet with invalid mac1") return } - // verify mac2 - busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue - if busy && !device.mac.CheckMAC2(elem.packet, elem.source) { - sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" - reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source) - if err != nil { - logError.Println("Failed to create cookie reply:", err) - return - } - // TODO: Use temp - writer := bytes.NewBuffer(elem.packet[:0]) - binary.Write(writer, binary.LittleEndian, reply) - elem.packet = writer.Bytes() - _, err = device.net.conn.WriteToUDP(elem.packet, elem.source) - if err != nil { - logDebug.Println("Failed to send cookie reply:", err) - } - return - } - - // ratelimit - - // TODO: Only ratelimit when busy - - if !device.ratelimiter.Allow(elem.source.IP) { - return - } - - // handle messages - - switch elem.msgType { - case MessageInitiationType: - - // unmarshal - - if len(elem.packet) != MessageInitiationSize { - return - } - - var msg MessageInitiation - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) - if err != nil { - logError.Println("Failed to decode initiation message") - return - } - - // consume initiation - - peer := device.ConsumeMessageInitiation(&msg) - if peer == nil { - logInfo.Println( - "Recieved invalid initiation message from", - elem.source.IP.String(), - elem.source.Port, + if busy { + if !device.mac.CheckMAC2(elem.packet, elem.source) { + sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" + reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source) + if err != nil { + logError.Println("Failed to create cookie reply:", err) + return + } + writer := bytes.NewBuffer(temp[:0]) + binary.Write(writer, binary.LittleEndian, reply) + _, err = device.net.conn.WriteToUDP( + writer.Bytes(), + elem.source, ) - return + if err != nil { + logDebug.Println("Failed to send cookie reply:", err) + } + continue } - - // update timers - - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - - // update endpoint - // TODO: Add a race condition \s - - peer.mutex.Lock() - peer.endpoint = elem.source - peer.mutex.Unlock() - - // create response - - response, err := device.CreateMessageResponse(peer) - if err != nil { - logError.Println("Failed to create response message:", err) - return + if !device.ratelimiter.Allow(elem.source.IP) { + continue } - - peer.TimerEphemeralKeyCreated() - peer.NewKeyPair() - - logDebug.Println("Creating response message for", peer.String()) - - writer := bytes.NewBuffer(temp[:0]) - binary.Write(writer, binary.LittleEndian, response) - packet := writer.Bytes() - peer.mac.AddMacs(packet) - - // send response - - peer.SendBuffer(packet) - peer.TimerAnyAuthenticatedPacketTraversal() - - case MessageResponseType: - - // unmarshal - - if len(elem.packet) != MessageResponseSize { - return - } - - var msg MessageResponse - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) - if err != nil { - logError.Println("Failed to decode response message") - return - } - - // consume response - - peer := device.ConsumeMessageResponse(&msg) - if peer == nil { - logInfo.Println( - "Recieved invalid response message from", - elem.source.IP.String(), - elem.source.Port, - ) - return - } - - // update timers - - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - peer.TimerHandshakeComplete() - - // derive key-pair - - peer.NewKeyPair() - peer.SendKeepAlive() - - default: - logError.Println("Invalid message type in handshake queue") } - }() + + default: + logError.Println("Invalid packet ended up in the handshake queue") + continue + } + + // handle handshake initation/response content + + switch elem.msgType { + case MessageInitiationType: + + // unmarshal + + var msg MessageInitiation + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode initiation message") + continue + } + + // consume initiation + + peer := device.ConsumeMessageInitiation(&msg) + if peer == nil { + logInfo.Println( + "Recieved invalid initiation message from", + elem.source.IP.String(), + elem.source.Port, + ) + continue + } + + // update timers + + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() + + // update endpoint + // TODO: Discover destination address also, only update on change + + peer.mutex.Lock() + peer.endpoint = elem.source + peer.mutex.Unlock() + + // create response + + response, err := device.CreateMessageResponse(peer) + if err != nil { + logError.Println("Failed to create response message:", err) + continue + } + + peer.TimerEphemeralKeyCreated() + peer.NewKeyPair() + + logDebug.Println("Creating response message for", peer.String()) + + writer := bytes.NewBuffer(temp[:0]) + binary.Write(writer, binary.LittleEndian, response) + packet := writer.Bytes() + peer.mac.AddMacs(packet) + + // send response + + _, err = peer.SendBuffer(packet) + if err == nil { + peer.TimerAnyAuthenticatedPacketTraversal() + } + + case MessageResponseType: + + // unmarshal + + var msg MessageResponse + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode response message") + continue + } + + // consume response + + peer := device.ConsumeMessageResponse(&msg) + if peer == nil { + logInfo.Println( + "Recieved invalid response message from", + elem.source.IP.String(), + elem.source.Port, + ) + continue + } + + peer.TimerEphemeralKeyCreated() + + // update timers + + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() + peer.TimerHandshakeComplete() + + // derive key-pair + + peer.NewKeyPair() + peer.SendKeepAlive() + } } } @@ -463,6 +467,7 @@ func (peer *Peer) RoutineSequentialReceiver() { device := peer.device logInfo := device.log.Info + logError := device.log.Error logDebug := device.log.Debug logDebug.Println("Routine, sequential receiver, started for peer", peer.id) @@ -478,116 +483,104 @@ func (peer *Peer) RoutineSequentialReceiver() { // process packet - func() { - if elem.IsDropped() { - return + if elem.IsDropped() { + continue + } + + // check for replay + + if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { + continue + } + + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() + peer.KeepKeyFreshReceiving() + + // check if using new key-pair + + kp := &peer.keyPairs + kp.mutex.Lock() + if kp.next == elem.keyPair { + peer.TimerHandshakeComplete() + kp.previous = kp.current + kp.current = kp.next + kp.next = nil + } + kp.mutex.Unlock() + + // check for keep-alive + + if len(elem.packet) == 0 { + logDebug.Println("Received keep-alive from", peer.String()) + continue + } + peer.TimerDataReceived() + + // verify source and strip padding + + switch elem.packet[0] >> 4 { + case ipv4.Version: + + // strip padding + + if len(elem.packet) < ipv4.HeaderLen { + continue } - // check for replay - - if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { - return + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue } - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - peer.KeepKeyFreshReceiving() + elem.packet = elem.packet[:length] - // check if using new key-pair + // verify IPv4 source - kp := &peer.keyPairs - kp.mutex.Lock() - if kp.next == elem.keyPair { - peer.TimerHandshakeComplete() - kp.previous = kp.current - kp.current = kp.next - kp.next = nil - } - kp.mutex.Unlock() - - // check for keep-alive - - if len(elem.packet) == 0 { - logDebug.Println("Received keep-alive from", peer.String()) - return - } - peer.TimerDataReceived() - - // verify source and strip padding - - switch elem.packet[0] >> 4 { - case ipv4.Version: - - // strip padding - - if len(elem.packet) < ipv4.HeaderLen { - return - } - - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - // TODO: check length of packet & NOT TOO SMALL either - elem.packet = elem.packet[:length] - - // verify IPv4 source - - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.routingTable.LookupIPv4(src) != peer { - logInfo.Println("Packet with unallowed source IP from", peer.String()) - return - } - - case ipv6.Version: - - // strip padding - - if len(elem.packet) < ipv6.HeaderLen { - return - } - - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - // TODO: check length of packet - elem.packet = elem.packet[:length] - - // verify IPv6 source - - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.routingTable.LookupIPv6(src) != peer { - logInfo.Println("Packet with unallowed source IP from", peer.String()) - return - } - - default: - logInfo.Println("Packet with invalid IP version from", peer.String()) - return + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.routingTable.LookupIPv4(src) != peer { + logInfo.Println("Packet with unallowed source IP from", peer.String()) + continue } - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) - device.addToInboundQueue(device.queue.inbound, elem) + case ipv6.Version: - // TODO: move TUN write into per peer routine - }() - } -} + // strip padding -func (device *Device) RoutineWriteToTUN() { - - logError := device.log.Error - logDebug := device.log.Debug - logDebug.Println("Routine, sequential tun writer, started") - - for { - select { - case <-device.signal.stop: - return - case elem := <-device.queue.inbound: - _, err := device.tun.Write(elem.packet) - device.PutMessageBuffer(elem.buffer) - if err != nil { - logError.Println("Failed to write packet to TUN device:", err) + if len(elem.packet) < ipv6.HeaderLen { + continue } + + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + + elem.packet = elem.packet[:length] + + // verify IPv6 source + + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.routingTable.LookupIPv6(src) != peer { + logInfo.Println("Packet with unallowed source IP from", peer.String()) + continue + } + + default: + logInfo.Println("Packet with invalid IP version from", peer.String()) + continue + } + + // write to tun + + atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + _, err := device.tun.Write(elem.packet) + device.PutMessageBuffer(elem.buffer) + if err != nil { + logError.Println("Failed to write packet to TUN device:", err) } } } diff --git a/src/send.go b/src/send.go index fc35732..cf1f018 100644 --- a/src/send.go +++ b/src/send.go @@ -168,8 +168,6 @@ func (device *Device) RoutineReadFromTUN() { continue } - println(size, err) - elem.packet = elem.packet[:size] // lookup peer @@ -210,6 +208,7 @@ func (device *Device) RoutineReadFromTUN() { // insert into nonce/pre-handshake queue + signalSend(peer.signal.handshakeReset) addToOutboundQueue(peer.queue.nonce, elem) elem = nil diff --git a/src/timers.go b/src/timers.go index 1be85f0..ab2e7ad 100644 --- a/src/timers.go +++ b/src/timers.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "golang.org/x/crypto/blake2s" + "math/rand" "sync/atomic" "time" ) @@ -16,12 +17,11 @@ func (peer *Peer) KeepKeyFreshSending() { if kp == nil { return } - if !kp.isInitiator { - return - } nonce := atomic.LoadUint64(&kp.sendNonce) - send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime - if send { + if nonce > RekeyAfterMessages { + signalSend(peer.signal.handshakeBegin) + } + if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime { signalSend(peer.signal.handshakeBegin) } } @@ -30,6 +30,7 @@ func (peer *Peer) KeepKeyFreshSending() { * */ func (peer *Peer) KeepKeyFreshReceiving() { + // TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete) kp := peer.keyPairs.Current() if kp == nil { return @@ -108,7 +109,6 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { * - First transport message under the "next" key */ func (peer *Peer) TimerHandshakeComplete() { - timerStop(peer.timer.zeroAllKeys) atomic.StoreInt64( &peer.stats.lastHandshakeNano, time.Now().UnixNano(), @@ -129,10 +129,7 @@ func (peer *Peer) TimerHandshakeComplete() { * upon failure to complete a handshake */ func (peer *Peer) TimerEphemeralKeyCreated() { - if !peer.timer.pendingZeroAllKeys { - peer.timer.pendingZeroAllKeys = true - peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) - } + peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) } func (peer *Peer) RoutineTimerHandler() { @@ -154,19 +151,19 @@ func (peer *Peer) RoutineTimerHandler() { interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) if interval > 0 { - logDebug.Println("Sending persistent keep-alive to", peer.String()) + logDebug.Println("Sending keep-alive to", peer.String()) peer.SendKeepAlive() } case <-peer.timer.keepalivePassive.C: - logDebug.Println("Sending passive keep-alive to", peer.String()) + logDebug.Println("Sending keep-alive to", peer.String()) peer.SendKeepAlive() if peer.timer.needAnotherKeepalive { peer.timer.keepalivePassive.Reset(KeepaliveTimeout) - peer.timer.needAnotherKeepalive = true + peer.timer.needAnotherKeepalive = false } // unresponsive session @@ -189,8 +186,6 @@ func (peer *Peer) RoutineTimerHandler() { kp := &peer.keyPairs kp.mutex.Lock() - peer.timer.pendingZeroAllKeys = false - // unmap indecies indices.mutex.Lock() @@ -251,40 +246,41 @@ func (peer *Peer) RoutineHandshakeInitiator() { return } - // wait for handshake + // set deadline - deadline := time.Now().Add(MaxHandshakeAttemptTime) + BeginHandshakes: + + signalClear(peer.signal.handshakeReset) + deadline := time.NewTimer(RekeyAttemptTime) + + AttemptHandshakes: - Loop: for attempts := uint(1); ; attempts++ { - // clear completed signal + // check if deadline reached select { - case <-peer.signal.handshakeCompleted: + case <-deadline.C: + logInfo.Println("Handshake negotiation timed out for:", peer.String()) + signalSend(peer.signal.flushNonceQueue) + timerStop(peer.timer.keepalivePersistent) + break case <-peer.signal.stop: return default: } - // check if sufficient time for retry - - if deadline.Before(time.Now().Add(RekeyTimeout)) { - logInfo.Println("Handshake negotiation timed out for", peer.String()) - signalSend(peer.signal.flushNonceQueue) - timerStop(peer.timer.keepalivePersistent) - timerStop(peer.timer.keepalivePassive) - break Loop - } + signalClear(peer.signal.handshakeCompleted) // create initiation message msg, err := peer.device.CreateMessageInitiation(peer) if err != nil { logError.Println("Failed to create handshake initiation message:", err) - break Loop + break AttemptHandshakes } - peer.TimerEphemeralKeyCreated() + + jitter := time.Millisecond * time.Duration(rand.Uint32()%334) // marshal and send @@ -299,14 +295,14 @@ func (peer *Peer) RoutineHandshakeInitiator() { "Failed to send handshake initiation message to", peer.String(), ":", err, ) - continue + break } peer.TimerAnyAuthenticatedPacketTraversal() - // set timeout + // set handshake timeout - timeout := time.NewTimer(RekeyTimeout) + timeout := time.NewTimer(RekeyTimeout + jitter) logDebug.Println( "Handshake initiation attempt", attempts, "sent to", peer.String(), @@ -321,15 +317,19 @@ func (peer *Peer) RoutineHandshakeInitiator() { case <-peer.signal.handshakeCompleted: <-timeout.C - break Loop + break AttemptHandshakes + + case <-peer.signal.handshakeReset: + <-timeout.C + goto BeginHandshakes case <-timeout.C: + // TODO: Clear source address for peer continue - } } - // allow new signal to be set + // clear signal set in the meantime signalClear(peer.signal.handshakeBegin) } diff --git a/src/tun.go b/src/tun.go index d782bd5..1c4c281 100644 --- a/src/tun.go +++ b/src/tun.go @@ -6,10 +6,19 @@ package main const DefaultMTU = 1420 +type TUNEvent int + +const ( + TUNEventUp = 1 << iota + TUNEventDown + TUNEventMTUUpdate +) + type TUNDevice interface { Read([]byte) (int, error) // read a packet from the device (without any additional headers) Write([]byte) (int, error) // writes a packet to the device (without any additional headers) - IsUp() (bool, error) // is the interface up? MTU() (int, error) // returns the MTU of the device Name() string // returns the current name + Events() chan TUNEvent // returns a constant channel of events related to the device + Close() error // stops the device and closes the event channel } diff --git a/src/tun_linux.go b/src/tun_linux.go index d0e2f47..34f746a 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -16,11 +16,12 @@ import ( const CloneDevicePath = "/dev/net/tun" type NativeTun struct { - fd *os.File - name string + fd *os.File + name string + events chan TUNEvent } -func (tun *NativeTun) IsUp() (bool, error) { +func (tun *NativeTun) isUp() (bool, error) { inter, err := net.InterfaceByName(tun.name) return inter.Flags&net.FlagUp != 0, err } @@ -111,6 +112,14 @@ func (tun *NativeTun) Read(d []byte) (int, error) { return tun.fd.Read(d) } +func (tun *NativeTun) Events() chan TUNEvent { + return tun.events +} + +func (tun *NativeTun) Close() error { + return nil +} + func CreateTUN(name string) (TUNDevice, error) { // open clone device @@ -146,10 +155,14 @@ func CreateTUN(name string) (TUNDevice, error) { newName := string(ifr[:]) newName = newName[:strings.Index(newName, "\000")] device := &NativeTun{ - fd: fd, - name: newName, + fd: fd, + name: newName, + events: make(chan TUNEvent, 5), } + // TODO: Wait for device to be upped + device.events <- TUNEventUp + // set default MTU err = device.setMTU(DefaultMTU) diff --git a/src/uapi_linux.go b/src/uapi_linux.go index d6d78e7..fd56b5a 100644 --- a/src/uapi_linux.go +++ b/src/uapi_linux.go @@ -7,7 +7,6 @@ import ( "net" "os" "path" - "time" ) const ( @@ -26,9 +25,10 @@ const ( */ type UAPIListener struct { - listener net.Listener // unix socket listener - connNew chan net.Conn - connErr chan error + listener net.Listener // unix socket listener + connNew chan net.Conn + connErr chan error + inotifyFd int } func (l *UAPIListener) Accept() (net.Conn, error) { @@ -106,9 +106,28 @@ func NewUAPIListener(name string) (net.Listener, error) { // watch for deletion of socket + uapi.inotifyFd, err = unix.InotifyInit() + if err != nil { + return nil, err + } + + _, err = unix.InotifyAddWatch( + uapi.inotifyFd, + socketPath, + unix.IN_ATTRIB| + unix.IN_DELETE| + unix.IN_DELETE_SELF, + ) + + if err != nil { + return nil, err + } + go func(l *UAPIListener) { - for ; ; time.Sleep(time.Second) { - if _, err := os.Stat(socketPath); os.IsNotExist(err) { + var buff [4096]byte + for { + unix.Read(uapi.inotifyFd, buff[:]) + if _, err := os.Lstat(socketPath); os.IsNotExist(err) { l.connErr <- err return }