diff --git a/src/config.go b/src/config.go index 4edaa2e..d92e8d7 100644 --- a/src/config.go +++ b/src/config.go @@ -8,39 +8,36 @@ import ( "net" "strconv" "strings" + "sync/atomic" + "syscall" ) -// #include -import "C" - -/* TODO: More fine grained? - */ const ( - ipcErrorNoPeer = C.EPROTO - ipcErrorNoKeyValue = C.EPROTO - ipcErrorInvalidKey = C.EPROTO - ipcErrorInvalidValue = C.EPROTO + ipcErrorIO = syscall.EIO + ipcErrorNoPeer = syscall.EPROTO + ipcErrorNoKeyValue = syscall.EPROTO + ipcErrorInvalidKey = syscall.EPROTO + ipcErrorInvalidValue = syscall.EPROTO ) type IPCError struct { - Code int + Code syscall.Errno } func (s *IPCError) Error() string { return fmt.Sprintf("IPC error: %d", s.Code) } -func (s *IPCError) ErrorCode() int { - return s.Code +func (s *IPCError) ErrorCode() uintptr { + return uintptr(s.Code) } -func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { - - device.mutex.RLock() - defer device.mutex.RUnlock() +func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // create lines + device.mutex.RLock() + lines := make([]string, 0, 100) send := func(line string) { lines = append(lines, line) @@ -63,19 +60,25 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { } send(fmt.Sprintf("tx_bytes=%d", peer.txBytes)) send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes)) - send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) + send(fmt.Sprintf("persistent_keepalive_interval=%d", + atomic.LoadUint64(&peer.persistentKeepaliveInterval), + )) for _, ip := range device.routingTable.AllowedIPs(peer) { send("allowed_ip=" + ip.String()) } }() } + device.mutex.RUnlock() + // send lines for _, line := range lines { _, err := socket.WriteString(line + "\n") if err != nil { - return err + return &IPCError{ + Code: ipcErrorIO, + } } } @@ -83,13 +86,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { } func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { - logger := device.log.Debug scanner := bufio.NewScanner(socket) + logError := device.log.Error + logDebug := device.log.Debug var peer *Peer for scanner.Scan() { - // Parse line + // parse line line := scanner.Text() if line == "" { @@ -97,7 +101,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } parts := strings.Split(line, "=") if len(parts) != 2 { - device.log.Debug.Println(parts) return &IPCError{Code: ipcErrorNoKeyValue} } key := parts[0] @@ -105,7 +108,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { switch key { - /* Interface configuration */ + /* interface configuration */ case "private_key": if value == "" { @@ -116,7 +119,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { var sk NoisePrivateKey err := sk.FromHex(value) if err != nil { - logger.Println("Failed to set private_key:", err) + logError.Println("Failed to set private_key:", err) return &IPCError{Code: ipcErrorInvalidValue} } device.SetPrivateKey(sk) @@ -126,22 +129,26 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { var port int _, err := fmt.Sscanf(value, "%d", &port) if err != nil || port > (1<<16) || port < 0 { - logger.Println("Failed to set listen_port:", err) + logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorInvalidValue} } device.net.mutex.Lock() device.net.addr.Port = port device.net.conn, err = net.ListenUDP("udp", device.net.addr) device.net.mutex.Unlock() + if err != nil { + logError.Println("Failed to create UDP listener:", err) + return &IPCError{Code: ipcErrorInvalidValue} + } case "fwmark": - logger.Println("FWMark not handled yet") + logError.Println("FWMark not handled yet") case "public_key": var pubKey NoisePublicKey err := pubKey.FromHex(value) if err != nil { - logger.Println("Failed to get peer by public_key:", err) + logError.Println("Failed to get peer by public_key:", err) return &IPCError{Code: ipcErrorInvalidValue} } device.mutex.RLock() @@ -153,22 +160,23 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { peer = device.NewPeer(pubKey) } if peer == nil { - panic(errors.New("bug: failed to find peer")) + panic(errors.New("bug: failed to find / create peer")) } case "replace_peers": if value == "true" { device.RemoveAllPeers() } else { - logger.Println("Failed to set replace_peers, invalid value:", value) + logError.Println("Failed to set replace_peers, invalid value:", value) return &IPCError{Code: ipcErrorInvalidValue} } default: - /* Peer configuration */ + + /* peer configuration */ if peer == nil { - logger.Println("No peer referenced, before peer operation") + logError.Println("No peer referenced, before peer operation") return &IPCError{Code: ipcErrorNoPeer} } @@ -178,7 +186,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { peer.mutex.Lock() device.RemovePeer(peer.handshake.remoteStatic) peer.mutex.Unlock() - logger.Println("Remove peer") + logDebug.Println("Removing", peer.String()) peer = nil case "preshared_key": @@ -188,14 +196,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return peer.handshake.presharedKey.FromHex(value) }() if err != nil { - logger.Println("Failed to set preshared_key:", err) + logError.Println("Failed to set preshared_key:", err) return &IPCError{Code: ipcErrorInvalidValue} } case "endpoint": addr, err := net.ResolveUDPAddr("udp", value) if err != nil { - logger.Println("Failed to set endpoint:", value) + logError.Println("Failed to set endpoint:", value) return &IPCError{Code: ipcErrorInvalidValue} } peer.mutex.Lock() @@ -205,35 +213,34 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "persistent_keepalive_interval": secs, err := strconv.ParseInt(value, 10, 64) if secs < 0 || err != nil { - logger.Println("Failed to set persistent_keepalive_interval:", err) + logError.Println("Failed to set persistent_keepalive_interval:", err) return &IPCError{Code: ipcErrorInvalidValue} } - peer.mutex.Lock() - peer.persistentKeepaliveInterval = uint64(secs) - peer.mutex.Unlock() + atomic.StoreUint64( + &peer.persistentKeepaliveInterval, + uint64(secs), + ) case "replace_allowed_ips": if value == "true" { device.routingTable.RemovePeer(peer) } else { - logger.Println("Failed to set replace_allowed_ips, invalid value:", value) + logError.Println("Failed to set replace_allowed_ips, invalid value:", value) return &IPCError{Code: ipcErrorInvalidValue} } case "allowed_ip": _, network, err := net.ParseCIDR(value) if err != nil { - logger.Println("Failed to set allowed_ip:", err) + logError.Println("Failed to set allowed_ip:", err) return &IPCError{Code: ipcErrorInvalidValue} } ones, _ := network.Mask.Size() - logger.Println(network, ones, network.IP) + logError.Println(network, ones, network.IP) device.routingTable.Insert(network.IP, uint(ones), peer) - /* Invalid key */ - default: - logger.Println("Invalid key:", key) + logError.Println("Invalid UAPI key:", key) return &IPCError{Code: ipcErrorInvalidKey} } } @@ -244,46 +251,45 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcHandle(device *Device, socket net.Conn) { - func() { - buffered := func(s io.ReadWriter) *bufio.ReadWriter { - reader := bufio.NewReader(s) - writer := bufio.NewWriter(s) - return bufio.NewReadWriter(reader, writer) - }(socket) + defer socket.Close() - defer buffered.Flush() + buffered := func(s io.ReadWriter) *bufio.ReadWriter { + reader := bufio.NewReader(s) + writer := bufio.NewWriter(s) + return bufio.NewReadWriter(reader, writer) + }(socket) - op, err := buffered.ReadString('\n') + defer buffered.Flush() + + op, err := buffered.ReadString('\n') + if err != nil { + return + } + + switch op { + + case "set=1\n": + device.log.Debug.Println("Config, set operation") + err := ipcSetOperation(device, buffered) if err != nil { - return + fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") } + return - switch op { - - case "set=1\n": - device.log.Debug.Println("Config, set operation") - err := ipcSetOperation(device, buffered) - if err != nil { - fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") - } - break - - case "get=1\n": - device.log.Debug.Println("Config, get operation") - err := ipcGetOperation(device, buffered) - if err != nil { - fmt.Fprintf(buffered, "errno=1\n\n") // fix - } else { - fmt.Fprintf(buffered, "errno=0\n\n") - } - break - - default: - device.log.Info.Println("Invalid UAPI operation:", op) + case "get=1\n": + device.log.Debug.Println("Config, get operation") + err := ipcGetOperation(device, buffered) + if err != nil { + fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") } - }() + return - socket.Close() + default: + device.log.Error.Println("Invalid UAPI operation:", op) + + } } diff --git a/src/device.go b/src/device.go index 4981f51..d32d648 100644 --- a/src/device.go +++ b/src/device.go @@ -78,7 +78,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { defer device.mutex.Unlock() device.log = NewLogger(logLevel) - // device.mtu = tun.MTU() device.peers = make(map[NoisePublicKey]*Peer) device.indices.Init() device.ratelimiter.Init() @@ -131,12 +130,21 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { func (device *Device) RoutineMTUUpdater(tun TUNDevice) { logError := device.log.Error - for ; ; time.Sleep(time.Second) { + for ; ; time.Sleep(5 * time.Second) { + + // load updated MTU + mtu, err := tun.MTU() if err != nil { logError.Println("Failed to load updated MTU of device:", err) continue } + + // upper bound of mtu + + if mtu+MessageTransportSize > MaxMessageSize { + mtu = MaxMessageSize - MessageTransportSize + } atomic.StoreInt32(&device.mtu, int32(mtu)) } } diff --git a/src/index.go b/src/index.go index 59e2079..44b4974 100644 --- a/src/index.go +++ b/src/index.go @@ -6,8 +6,6 @@ import ( ) /* Index=0 is reserved for unset indecies - * - * TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake peer * */ @@ -72,12 +70,12 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { table.mutex.RLock() _, ok := table.table[index] + table.mutex.RUnlock() if ok { continue } - table.mutex.RUnlock() - // replace index + // map index to handshake table.mutex.Lock() _, found := table.table[index] diff --git a/src/main.go b/src/main.go index 74e7ec9..4bece16 100644 --- a/src/main.go +++ b/src/main.go @@ -17,12 +17,14 @@ func main() { } switch os.Args[1] { + case "-f", "--foreground": foreground = true if len(os.Args) != 3 { return } interfaceName = os.Args[2] + default: foreground = false if len(os.Args) != 2 { @@ -48,8 +50,8 @@ func main() { // open TUN device tun, err := CreateTUN(interfaceName) - log.Println(tun, err) if err != nil { + log.Println("Failed to create tun device:", err) return } @@ -69,11 +71,15 @@ func main() { } defer uapi.Close() - for { - conn, err := uapi.Accept() - if err != nil { - logError.Fatal("accept error:", err) + go func() { + for { + conn, err := uapi.Accept() + if err != nil { + logError.Fatal("UAPI accept error:", err) + } + go ipcHandle(device, conn) } - go ipcHandle(device, conn) - } + }() + + device.Wait() } diff --git a/src/noise_protocol.go b/src/noise_protocol.go index bfa3797..5fe6fb2 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -459,7 +459,8 @@ func (peer *Peer) NewKeyPair() *KeyPair { // remap index - peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{ + indices := &peer.device.indices + indices.Insert(handshake.localIndex, IndexTableEntry{ peer: peer, keyPair: keyPair, handshake: nil, @@ -476,7 +477,7 @@ func (peer *Peer) NewKeyPair() *KeyPair { if kp.previous != nil { kp.previous.send = nil kp.previous.receive = nil - peer.device.indices.Delete(kp.previous.localIndex) + indices.Delete(kp.previous.localIndex) } kp.previous = kp.current kp.current = keyPair diff --git a/src/receive.go b/src/receive.go index 31f74e2..e063c99 100644 --- a/src/receive.go +++ b/src/receive.go @@ -212,18 +212,18 @@ func (device *Device) RoutineReceiveIncomming() { // add to peer queue peer := value.peer - work := &QueueInboundElement{ + elem := &QueueInboundElement{ packet: packet, buffer: buffer, keyPair: keyPair, dropped: AtomicFalse, } - work.mutex.Lock() + elem.mutex.Lock() // add to decryption queues - device.addToInboundQueue(device.queue.decryption, work) - device.addToInboundQueue(peer.queue.inbound, work) + device.addToInboundQueue(device.queue.decryption, elem) + device.addToInboundQueue(peer.queue.inbound, elem) buffer = nil default: diff --git a/src/send.go b/src/send.go index 2db74ba..fdbc676 100644 --- a/src/send.go +++ b/src/send.go @@ -270,50 +270,65 @@ func (peer *Peer) RoutineNonce() { * Obs. One instance per core */ func (device *Device) RoutineEncryption() { + + var elem *QueueOutboundElement var nonce [chacha20poly1305.NonceSize]byte - for work := range device.queue.encryption { + + logDebug := device.log.Debug + logDebug.Println("Routine, encryption worker, started") + + for { + + // fetch next element + + select { + case elem = <-device.queue.encryption: + case <-device.signal.stop: + logDebug.Println("Routine, encryption worker, stopped") + return + } // check if dropped - if work.IsDropped() { + if elem.IsDropped() { continue } // populate header fields - header := work.buffer[:MessageTransportHeaderSize] + header := elem.buffer[:MessageTransportHeaderSize] fieldType := header[0:4] fieldReceiver := header[4:8] fieldNonce := header[8:16] binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, work.keyPair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, work.nonce) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) // pad content to MTU size mtu := int(atomic.LoadInt32(&device.mtu)) - for i := len(work.packet); i < mtu; i++ { - work.packet = append(work.packet, 0) + for i := len(elem.packet); i < mtu; i++ { + elem.packet = append(elem.packet, 0) } // encrypt content - binary.LittleEndian.PutUint64(nonce[4:], work.nonce) - work.packet = work.keyPair.send.Seal( - work.packet[:0], + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keyPair.send.Seal( + elem.packet[:0], nonce[:], - work.packet, + elem.packet, nil, ) - length := MessageTransportHeaderSize + len(work.packet) - work.packet = work.buffer[:length] - work.mutex.Unlock() + length := MessageTransportHeaderSize + len(elem.packet) + elem.packet = elem.buffer[:length] + elem.mutex.Unlock() // refresh key if necessary - work.peer.KeepKeyFreshSending() + elem.peer.KeepKeyFreshSending() } } @@ -334,49 +349,43 @@ func (peer *Peer) RoutineSequentialSender() { logDebug.Println("Routine, sequential sender, stopped for", peer.String()) return - case work := <-peer.queue.outbound: - work.mutex.Lock() + case elem := <-peer.queue.outbound: + elem.mutex.Lock() func() { - - // return buffer to pool after processing - - defer device.PutMessageBuffer(work.buffer) - if work.IsDropped() { + if elem.IsDropped() { return } - // send to endpoint + // get endpoint and connection peer.mutex.RLock() - defer peer.mutex.RUnlock() - - if peer.endpoint == nil { + endpoint := peer.endpoint + peer.mutex.RUnlock() + if endpoint == nil { logDebug.Println("No endpoint for", peer.String()) return } device.net.mutex.RLock() - defer device.net.mutex.RUnlock() - - if device.net.conn == nil { + conn := device.net.conn + device.net.mutex.RUnlock() + if conn == nil { logDebug.Println("No source for device") return } - // send message and return buffer to pool + // send message and refresh keys - _, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint) + _, err := conn.WriteToUDP(elem.packet, endpoint) if err != nil { return } - - atomic.AddUint64(&peer.txBytes, uint64(len(work.packet))) - - // reset keep-alive - + atomic.AddUint64(&peer.txBytes, uint64(len(elem.packet))) peer.TimerResetKeepalive() }() + + device.PutMessageBuffer(elem.buffer) } } } diff --git a/src/timers.go b/src/timers.go index 9140e41..fd2bdc3 100644 --- a/src/timers.go +++ b/src/timers.go @@ -138,6 +138,7 @@ func (peer *Peer) BeginHandshakeInitiation() (*QueueOutboundElement, error) { func (peer *Peer) RoutineTimerHandler() { device := peer.device + indices := &device.indices logDebug := device.log.Debug logDebug.Println("Routine, timer handler, started for peer", peer.String()) @@ -170,29 +171,42 @@ func (peer *Peer) RoutineTimerHandler() { logDebug.Println("Clearing all key material for", peer.String()) - // zero out key pairs + kp := &peer.keyPairs + kp.mutex.Lock() - func() { - kp := &peer.keyPairs - kp.mutex.Lock() - // best we can do is wait for GC :( ? - kp.current = nil - kp.previous = nil - kp.next = nil - kp.mutex.Unlock() - }() + hs := &peer.handshake + hs.mutex.Lock() + + // unmap local indecies + + indices.mutex.Lock() + if kp.previous != nil { + delete(indices.table, kp.previous.localIndex) + } + if kp.current != nil { + delete(indices.table, kp.current.localIndex) + } + if kp.next != nil { + delete(indices.table, kp.next.localIndex) + } + delete(indices.table, hs.localIndex) + indices.mutex.Unlock() + + // zero out key pairs (TODO: better than wait for GC) + + kp.current = nil + kp.previous = nil + kp.next = nil + kp.mutex.Unlock() // zero out handshake - func() { - hs := &peer.handshake - hs.mutex.Lock() - hs.localEphemeral = NoisePrivateKey{} - hs.remoteEphemeral = NoisePublicKey{} - hs.chainKey = [blake2s.Size]byte{} - hs.hash = [blake2s.Size]byte{} - hs.mutex.Unlock() - }() + hs.localIndex = 0 + hs.localEphemeral = NoisePrivateKey{} + hs.remoteEphemeral = NoisePublicKey{} + hs.chainKey = [blake2s.Size]byte{} + hs.hash = [blake2s.Size]byte{} + hs.mutex.Unlock() } } }