From 8c34c4cbb3780c433148966a004f5a51aace0f64 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 4 Aug 2017 16:15:53 +0200 Subject: [PATCH] First set of code review patches --- src/config.go | 229 +++++++++++++++++++++++++----------------- src/constants.go | 3 +- src/device.go | 44 ++++++-- src/index.go | 10 +- src/macs.go | 15 ++- src/noise_helpers.go | 8 ++ src/noise_protocol.go | 9 ++ src/noise_types.go | 22 ++-- src/receive.go | 44 ++++++-- src/send.go | 51 ++++++---- src/timers.go | 33 ++---- src/trie.go | 9 +- src/tun.go | 1 + src/tun_linux.go | 6 ++ src/uapi_linux.go | 13 ++- 15 files changed, 315 insertions(+), 182 deletions(-) diff --git a/src/config.go b/src/config.go index 72a604f..e2d7f20 100644 --- a/src/config.go +++ b/src/config.go @@ -61,6 +61,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { send(fmt.Sprintf("persistent_keepalive_interval=%d", atomic.LoadUint64(&peer.persistentKeepaliveInterval), )) + for _, ip := range device.routingTable.AllowedIPs(peer) { send("allowed_ip=" + ip.String()) } @@ -89,6 +90,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logDebug := device.log.Debug var peer *Peer + + deviceConfig := true + for scanner.Scan() { // parse line @@ -99,86 +103,110 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } parts := strings.Split(line, "=") if len(parts) != 2 { - return &IPCError{Code: ipcErrorNoKeyValue} + return &IPCError{Code: ipcErrorProtocol} } key := parts[0] value := parts[1] - switch key { + /* device configuration */ - /* interface configuration */ + if deviceConfig { - case "private_key": - var sk NoisePrivateKey - if value == "" { - device.SetPrivateKey(sk) - } else { - err := sk.FromHex(value) + switch key { + case "private_key": + var sk NoisePrivateKey + if value == "" { + device.SetPrivateKey(sk) + } else { + err := sk.FromHex(value) + if err != nil { + logError.Println("Failed to set private_key:", err) + return &IPCError{Code: ipcErrorInvalid} + } + device.SetPrivateKey(sk) + } + + case "listen_port": + port, err := strconv.ParseUint(value, 10, 16) if err != nil { - logError.Println("Failed to set private_key:", err) - return &IPCError{Code: ipcErrorInvalidValue} + logError.Println("Failed to set listen_port:", err) + return &IPCError{Code: ipcErrorInvalid} } - device.SetPrivateKey(sk) - } - - case "listen_port": - port, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{Code: ipcErrorInvalidValue} - } - netc := &device.net - netc.mutex.Lock() - if netc.addr.Port != int(port) { - if netc.conn != nil { - netc.conn.Close() + netc := &device.net + netc.mutex.Lock() + if netc.addr.Port != int(port) { + if netc.conn != nil { + netc.conn.Close() + } + netc.addr.Port = int(port) + netc.conn, err = net.ListenUDP("udp", netc.addr) } - netc.addr.Port = int(port) - netc.conn, err = net.ListenUDP("udp", netc.addr) - } - netc.mutex.Unlock() - if err != nil { - logError.Println("Failed to create UDP listener:", err) - return &IPCError{Code: ipcErrorInvalidValue} - } + netc.mutex.Unlock() + if err != nil { + logError.Println("Failed to create UDP listener:", err) + return &IPCError{Code: ipcErrorIO} + } + // TODO: Clear source address of all peers - case "fwmark": - logError.Println("FWMark not handled yet") + case "fwmark": + logError.Println("FWMark not handled yet") + // TODO: Clear source address of all peers - case "public_key": - var pubKey NoisePublicKey - err := pubKey.FromHex(value) - if err != nil { - logError.Println("Failed to get peer by public_key:", err) - return &IPCError{Code: ipcErrorInvalidValue} - } - device.mutex.RLock() - peer, _ = device.peers[pubKey] - device.mutex.RUnlock() - if peer == nil { - peer = device.NewPeer(pubKey) - } + case "public_key": - case "replace_peers": - if value == "true" { + // switch to peer configuration + + deviceConfig = false + + case "replace_peers": + if value != "true" { + logError.Println("Failed to set replace_peers, invalid value:", value) + return &IPCError{Code: ipcErrorInvalid} + } device.RemoveAllPeers() - } else { - logError.Println("Failed to set replace_peers, invalid value:", value) - return &IPCError{Code: ipcErrorInvalidValue} + + default: + logError.Println("Invalid UAPI key (device configuration):", key) + return &IPCError{Code: ipcErrorInvalid} } + } - default: + /* peer configuration */ - /* peer configuration */ - - if peer == nil { - logError.Println("No peer referenced, before peer operation") - return &IPCError{Code: ipcErrorNoPeer} - } + if !deviceConfig { switch key { + case "public_key": + var pubKey NoisePublicKey + err := pubKey.FromHex(value) + if err != nil { + logError.Println("Failed to get peer by public_key:", err) + return &IPCError{Code: ipcErrorInvalid} + } + + // check if public key of peer equal to device + + device.mutex.RLock() + if device.publicKey.Equals(pubKey) { + device.mutex.RUnlock() + logError.Println("Public key of peer matches private key of device") + return &IPCError{Code: ipcErrorInvalid} + } + + // find peer referenced + + peer, _ = device.peers[pubKey] + device.mutex.RUnlock() + if peer == nil { + peer = device.NewPeer(pubKey) + } + case "remove": + if value != "true" { + logError.Println("Failed to set remove, invalid value:", value) + return &IPCError{Code: ipcErrorInvalid} + } device.RemovePeer(peer.handshake.remoteStatic) logDebug.Println("Removing", peer.String()) peer = nil @@ -191,50 +219,67 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { }() if err != nil { logError.Println("Failed to set preshared_key:", err) - return &IPCError{Code: ipcErrorInvalidValue} + return &IPCError{Code: ipcErrorInvalid} } case "endpoint": + // TODO: Only IP and port addr, err := net.ResolveUDPAddr("udp", value) if err != nil { logError.Println("Failed to set endpoint:", value) - return &IPCError{Code: ipcErrorInvalidValue} + return &IPCError{Code: ipcErrorInvalid} } peer.mutex.Lock() peer.endpoint = addr peer.mutex.Unlock() case "persistent_keepalive_interval": - secs, err := strconv.ParseInt(value, 10, 64) - if secs < 0 || err != nil { + + // update keep-alive interval + + secs, err := strconv.ParseUint(value, 10, 16) + if err != nil { logError.Println("Failed to set persistent_keepalive_interval:", err) - return &IPCError{Code: ipcErrorInvalidValue} + return &IPCError{Code: ipcErrorInvalid} } - atomic.StoreUint64( + + old := atomic.SwapUint64( &peer.persistentKeepaliveInterval, - uint64(secs), + secs, ) - case "replace_allowed_ips": - if value == "true" { - device.routingTable.RemovePeer(peer) - } else { - logError.Println("Failed to set replace_allowed_ips, invalid value:", value) - return &IPCError{Code: ipcErrorInvalidValue} + // send immediate keep-alive + + if old == 0 && secs != 0 { + up, err := device.tun.IsUp() + if err != nil { + logError.Println("Failed to get tun device status:", err) + return &IPCError{Code: ipcErrorIO} + } + if up { + peer.SendKeepAlive() + } } + case "replace_allowed_ips": + if value != "true" { + logError.Println("Failed to set replace_allowed_ips, invalid value:", value) + return &IPCError{Code: ipcErrorInvalid} + } + device.routingTable.RemovePeer(peer) + case "allowed_ip": _, network, err := net.ParseCIDR(value) if err != nil { logError.Println("Failed to set allowed_ip:", err) - return &IPCError{Code: ipcErrorInvalidValue} + return &IPCError{Code: ipcErrorInvalid} } ones, _ := network.Mask.Size() device.routingTable.Insert(network.IP, uint(ones), peer) default: - logError.Println("Invalid UAPI key:", key) - return &IPCError{Code: ipcErrorInvalidKey} + logError.Println("Invalid UAPI key (peer configuration):", key) + return &IPCError{Code: ipcErrorInvalid} } } } @@ -244,6 +289,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcHandle(device *Device, socket net.Conn) { + // create buffered read/writer + defer socket.Close() buffered := func(s io.ReadWriter) *bufio.ReadWriter { @@ -259,30 +306,30 @@ func ipcHandle(device *Device, socket net.Conn) { return } - switch op { + // handle operation + var status *IPCError + + switch op { case "set=1\n": device.log.Debug.Println("Config, set operation") - err := ipcSetOperation(device, buffered) - if err != nil { - fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") - } - return + status = ipcSetOperation(device, buffered) case "get=1\n": device.log.Debug.Println("Config, get operation") - err := ipcGetOperation(device, buffered) - if err != nil { - fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") - } - return + status = ipcGetOperation(device, buffered) default: device.log.Error.Println("Invalid UAPI operation:", op) + return + } + // write status + + if status != nil { + device.log.Error.Println(status) + fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") } } diff --git a/src/constants.go b/src/constants.go index 09d33d8..f09ded6 100644 --- a/src/constants.go +++ b/src/constants.go @@ -16,6 +16,7 @@ const ( KeepaliveTimeout = time.Second * 10 CookieRefreshTime = time.Second * 120 MaxHandshakeAttemptTime = time.Second * 90 + PaddingMultiple = 16 ) const ( @@ -31,5 +32,5 @@ const ( QueueHandshakeSize = 1024 QueueHandshakeBusySize = QueueHandshakeSize / 8 MinMessageSize = MessageTransportSize // size of keep-alive - MaxMessageSize = (1 << 16) - 1 + MaxMessageSize = ((1 << 16) - 1) + MessageTransportHeaderSize ) diff --git a/src/device.go b/src/device.go index 1185d60..de96f0b 100644 --- a/src/device.go +++ b/src/device.go @@ -1,6 +1,8 @@ package main import ( + "errors" + "fmt" "net" "runtime" "sync" @@ -10,6 +12,7 @@ import ( type Device struct { mtu int32 + tun TUNDevice log *Logger // collection of loggers for levels idCounter uint // for assigning debug ids to peers fwMark uint32 @@ -43,24 +46,46 @@ type Device struct { mac MACStateDevice } -func (device *Device) SetPrivateKey(sk NoisePrivateKey) { +func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { device.mutex.Lock() defer device.mutex.Unlock() + // check if public key is matching any peer + + publicKey := sk.publicKey() + for _, peer := range device.peers { + h := &peer.handshake + h.mutex.RLock() + if h.remoteStatic.Equals(publicKey) { + h.mutex.RUnlock() + return errors.New("Private key matches public key of peer") + } + h.mutex.RUnlock() + } + // update key material device.privateKey = sk - device.publicKey = sk.publicKey() - device.mac.Init(device.publicKey) + device.publicKey = publicKey + device.mac.Init(publicKey) // do DH precomputations + isZero := device.privateKey.IsZero() + for _, peer := range device.peers { h := &peer.handshake h.mutex.Lock() - h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) + if isZero { + h.precomputedStaticStatic = [NoisePublicKeySize]byte{} + } else { + h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) + } + fmt.Println(h.precomputedStaticStatic) h.mutex.Unlock() } + + return nil } func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { @@ -77,6 +102,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.mutex.Lock() defer device.mutex.Unlock() + device.tun = tun device.log = NewLogger(logLevel) device.peers = make(map[NoisePublicKey]*Peer) device.indices.Init() @@ -119,22 +145,22 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { } go device.RoutineBusyMonitor() - go device.RoutineMTUUpdater(tun) - go device.RoutineWriteToTUN(tun) - go device.RoutineReadFromTUN(tun) + go device.RoutineMTUUpdater() + go device.RoutineWriteToTUN() + go device.RoutineReadFromTUN() go device.RoutineReceiveIncomming() go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) return device } -func (device *Device) RoutineMTUUpdater(tun TUNDevice) { +func (device *Device) RoutineMTUUpdater() { logError := device.log.Error for ; ; time.Sleep(5 * time.Second) { // load updated MTU - mtu, err := tun.MTU() + mtu, err := device.tun.MTU() if err != nil { logError.Println("Failed to load updated MTU of device:", err) continue diff --git a/src/index.go b/src/index.go index 44b4974..e518b0f 100644 --- a/src/index.go +++ b/src/index.go @@ -3,6 +3,7 @@ package main import ( "crypto/rand" "sync" + "unsafe" ) /* Index=0 is reserved for unset indecies @@ -23,14 +24,7 @@ type IndexTable struct { func randUint32() (uint32, error) { var buff [4]byte _, err := rand.Read(buff[:]) - id := uint32(buff[0]) - id <<= 8 - id |= uint32(buff[1]) - id <<= 8 - id |= uint32(buff[2]) - id <<= 8 - id |= uint32(buff[3]) - return id, err + return *((*uint32)(unsafe.Pointer(&buff))), err } func (table *IndexTable) Init() { diff --git a/src/macs.go b/src/macs.go index 841ef31..beb5f76 100644 --- a/src/macs.go +++ b/src/macs.go @@ -3,7 +3,6 @@ package main import ( "crypto/hmac" "crypto/rand" - "errors" "golang.org/x/crypto/blake2s" "net" "sync" @@ -15,14 +14,14 @@ type MACStateDevice struct { refreshed time.Time secret [blake2s.Size]byte keyMAC1 [blake2s.Size]byte - keyMAC2 [blake2s.Size]byte + keyMAC2 [blake2s.Size]byte // TODO: Change to more descriptive size constant, rename to something. } type MACStatePeer struct { mutex sync.RWMutex cookieSet time.Time cookie [blake2s.Size128]byte - lastMAC1 [blake2s.Size128]byte + lastMAC1 [blake2s.Size128]byte // TODO: Check if set keyMAC1 [blake2s.Size]byte keyMAC2 [blake2s.Size]byte } @@ -83,7 +82,7 @@ func (state *MACStateDevice) CheckMAC2(msg []byte, addr *net.UDPAddr) bool { port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)} mac, _ := blake2s.New128(state.secret[:]) mac.Write(addr.IP) - mac.Write(port[:]) + mac.Write(port[:]) // TODO: Be faster and more platform dependent? mac.Sum(cookie[:0]) }() @@ -130,7 +129,7 @@ func (device *Device) CreateMessageCookieReply( port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)} mac, _ := blake2s.New128(state.secret[:]) mac.Write(addr.IP) - mac.Write(port[:]) + mac.Write(port[:]) // TODO: Do whatever we did above mac.Sum(cookie[:0]) }() @@ -196,6 +195,7 @@ func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool { if err != nil { return false } + state.cookieSet = time.Now() state.cookie = cookie return true @@ -229,10 +229,6 @@ func (state *MACStatePeer) Init(pk NoisePublicKey) { func (state *MACStatePeer) AddMacs(msg []byte) { size := len(msg) - if size < blake2s.Size128*2 { - panic(errors.New("bug: message too short")) - } - startMac1 := size - (blake2s.Size128 * 2) startMac2 := size - blake2s.Size128 @@ -250,6 +246,7 @@ func (state *MACStatePeer) AddMacs(msg []byte) { mac.Sum(mac1[:0]) }() copy(state.lastMAC1[:], mac1) + // TODO: Set lastMac flag // set mac2 diff --git a/src/noise_helpers.go b/src/noise_helpers.go index 1e622a5..105f78f 100644 --- a/src/noise_helpers.go +++ b/src/noise_helpers.go @@ -47,6 +47,14 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt return } +func isZero(val []byte) bool { + var acc byte + for _, b := range val { + acc |= b + } + return acc == 0 +} + /* curve25519 wrappers */ func newPrivateKey() (sk NoisePrivateKey, err error) { diff --git a/src/noise_protocol.go b/src/noise_protocol.go index e2ff573..5c776a8 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -135,6 +135,10 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mutex.Lock() defer handshake.mutex.Unlock() + if isZero(handshake.precomputedStaticStatic[:]) { + return nil, errors.New("Static shared secret is zero") + } + // create ephemeral key var err error @@ -226,7 +230,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { if peer == nil { return nil } + handshake := &peer.handshake + if isZero(handshake.precomputedStaticStatic[:]) { + return nil + } // verify identity @@ -472,6 +480,7 @@ func (peer *Peer) NewKeyPair() *KeyPair { func() { kp.mutex.Lock() defer kp.mutex.Unlock() + // TODO: Adapt kernel behavior noise.c:161 if isInitiator { if kp.previous != nil { kp.previous.send = nil diff --git a/src/noise_types.go b/src/noise_types.go index 5ebc130..1a944df 100644 --- a/src/noise_types.go +++ b/src/noise_types.go @@ -1,6 +1,7 @@ package main import ( + "crypto/subtle" "encoding/hex" "errors" "golang.org/x/crypto/chacha20poly1305" @@ -31,12 +32,12 @@ func loadExactHex(dst []byte, src string) error { } func (key NoisePrivateKey) IsZero() bool { - for _, b := range key[:] { - if b != 0 { - return false - } - } - return true + var zero NoisePrivateKey + return key.Equals(zero) +} + +func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 } func (key *NoisePrivateKey) FromHex(src string) error { @@ -55,6 +56,15 @@ func (key NoisePublicKey) ToHex() string { return hex.EncodeToString(key[:]) } +func (key NoisePublicKey) IsZero() bool { + var zero NoisePublicKey + return key.Equals(zero) +} + +func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 +} + func (key *NoiseSymmetricKey) FromHex(src string) error { return loadExactHex(key[:], src) } diff --git a/src/receive.go b/src/receive.go index 700b894..fb5c51f 100644 --- a/src/receive.go +++ b/src/receive.go @@ -73,6 +73,8 @@ func (device *Device) addToHandshakeQueue( } /* Routine determining the busy state of the interface + * + * TODO: Under load for some time */ func (device *Device) RoutineBusyMonitor() { samples := 0 @@ -131,6 +133,7 @@ func (device *Device) RoutineReceiveIncomming() { buffer = device.GetMessageBuffer() } + // TODO: Take writelock to sleep device.net.mutex.RLock() conn := device.net.conn device.net.mutex.RUnlock() @@ -139,6 +142,7 @@ func (device *Device) RoutineReceiveIncomming() { continue } + // TODO: Wait for new conn or message conn.SetReadDeadline(time.Now().Add(time.Second)) size, raddr, err := conn.ReadFromUDP(buffer[:]) @@ -156,6 +160,8 @@ func (device *Device) RoutineReceiveIncomming() { case MessageInitiationType, MessageResponseType: + // TODO: Check size early + // add to handshake queue device.addToHandshakeQueue( @@ -171,6 +177,8 @@ func (device *Device) RoutineReceiveIncomming() { case MessageCookieReplyType: + // TODO: Queue all the things + // verify and update peer cookie state if len(packet) != MessageCookieReplySize { @@ -250,7 +258,7 @@ func (device *Device) RoutineDecryption() { // check if dropped if elem.IsDropped() { - elem.mutex.Unlock() + elem.mutex.Unlock() // TODO: Make consistent with send continue } @@ -318,6 +326,7 @@ func (device *Device) RoutineHandshake() { logError.Println("Failed to create cookie reply:", err) return } + // TODO: Use temp writer := bytes.NewBuffer(elem.packet[:0]) binary.Write(writer, binary.LittleEndian, reply) elem.packet = writer.Bytes() @@ -330,6 +339,8 @@ func (device *Device) RoutineHandshake() { // ratelimit + // TODO: Only ratelimit when busy + if !device.ratelimiter.Allow(elem.source.IP) { return } @@ -364,9 +375,14 @@ func (device *Device) RoutineHandshake() { ) return } - peer.TimerPacketReceived() + + // update timers + + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() // update endpoint + // TODO: Add a race condition \s peer.mutex.Lock() peer.endpoint = elem.source @@ -381,6 +397,7 @@ func (device *Device) RoutineHandshake() { } peer.TimerEphemeralKeyCreated() + peer.NewKeyPair() logDebug.Println("Creating response message for", peer.String()) @@ -392,8 +409,7 @@ func (device *Device) RoutineHandshake() { // send response peer.SendBuffer(packet) - peer.TimerPacketSent() - peer.NewKeyPair() + peer.TimerAnyAuthenticatedPacketTraversal() case MessageResponseType: @@ -423,8 +439,14 @@ func (device *Device) RoutineHandshake() { return } - peer.TimerPacketReceived() + // update timers + + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() peer.TimerHandshakeComplete() + + // derive key-pair + peer.NewKeyPair() peer.SendKeepAlive() @@ -467,8 +489,8 @@ func (peer *Peer) RoutineSequentialReceiver() { return } - peer.TimerPacketReceived() - peer.TimerTransportReceived() + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() peer.KeepKeyFreshReceiving() // check if using new key-pair @@ -504,6 +526,7 @@ func (peer *Peer) RoutineSequentialReceiver() { field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] length := binary.BigEndian.Uint16(field) + // TODO: check length of packet & NOT TOO SMALL either elem.packet = elem.packet[:length] // verify IPv4 source @@ -525,6 +548,7 @@ func (peer *Peer) RoutineSequentialReceiver() { field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] length := binary.BigEndian.Uint16(field) length += ipv6.HeaderLen + // TODO: check length of packet elem.packet = elem.packet[:length] // verify IPv6 source @@ -542,11 +566,13 @@ func (peer *Peer) RoutineSequentialReceiver() { atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) device.addToInboundQueue(device.queue.inbound, elem) + + // TODO: move TUN write into per peer routine }() } } -func (device *Device) RoutineWriteToTUN(tun TUNDevice) { +func (device *Device) RoutineWriteToTUN() { logError := device.log.Error logDebug := device.log.Debug @@ -557,7 +583,7 @@ func (device *Device) RoutineWriteToTUN(tun TUNDevice) { case <-device.signal.stop: return case elem := <-device.queue.inbound: - _, err := tun.Write(elem.packet) + _, err := device.tun.Write(elem.packet) device.PutMessageBuffer(elem.buffer) if err != nil { logError.Println("Failed to write packet to TUN device:", err) diff --git a/src/send.go b/src/send.go index 37078b9..fc35732 100644 --- a/src/send.go +++ b/src/send.go @@ -110,17 +110,19 @@ func addToEncryptionQueue( } func (peer *Peer) SendBuffer(buffer []byte) (int, error) { + peer.device.net.mutex.RLock() + defer peer.device.net.mutex.RUnlock() peer.mutex.RLock() + defer peer.mutex.RUnlock() + endpoint := peer.endpoint - peer.mutex.RUnlock() + conn := peer.device.net.conn + if endpoint == nil { return 0, ErrorNoEndpoint } - peer.device.net.mutex.RLock() - conn := peer.device.net.conn - peer.device.net.mutex.RUnlock() if conn == nil { return 0, ErrorNoConnection } @@ -133,13 +135,13 @@ func (peer *Peer) SendBuffer(buffer []byte) (int, error) { * * Obs. Single instance per TUN device */ -func (device *Device) RoutineReadFromTUN(tun TUNDevice) { +func (device *Device) RoutineReadFromTUN() { - if tun == nil { + if device.tun == nil { return } - elem := device.NewOutboundElement() + var elem *QueueOutboundElement logDebug := device.log.Debug logError := device.log.Error @@ -153,32 +155,38 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { elem = device.NewOutboundElement() } + // TODO: THIS! elem.packet = elem.buffer[MessageTransportHeaderSize:] - size, err := tun.Read(elem.packet) + size, err := device.tun.Read(elem.packet) if err != nil { - - // stop process - logError.Println("Failed to read packet from TUN device:", err) device.Close() return } - elem.packet = elem.packet[:size] - if len(elem.packet) < ipv4.HeaderLen { - logError.Println("Packet too short, length:", size) + if size == 0 { continue } + println(size, err) + + elem.packet = elem.packet[:size] + // lookup peer var peer *Peer switch elem.packet[0] >> 4 { case ipv4.Version: + if len(elem.packet) < ipv4.HeaderLen { + continue + } dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] peer = device.routingTable.LookupIPv4(dst) case ipv6.Version: + if len(elem.packet) < ipv6.HeaderLen { + continue + } dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] peer = device.routingTable.LookupIPv6(dst) @@ -190,10 +198,15 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { continue } + // check if known endpoint + + peer.mutex.RLock() if peer.endpoint == nil { + peer.mutex.RUnlock() logDebug.Println("No known endpoint for peer", peer.String()) continue } + peer.mutex.RUnlock() // insert into nonce/pre-handshake queue @@ -334,8 +347,12 @@ func (device *Device) RoutineEncryption() { // pad content to MTU size mtu := int(atomic.LoadInt32(&device.mtu)) - for i := len(elem.packet); i < mtu; i++ { - elem.packet = append(elem.packet, 0) + pad := len(elem.packet) % PaddingMultiple + if pad > 0 { + for i := 0; i < PaddingMultiple-pad && len(elem.packet) < mtu; i++ { + elem.packet = append(elem.packet, 0) + } + // TODO: How good is this code } // encrypt content (append to header) @@ -390,7 +407,7 @@ func (peer *Peer) RoutineSequentialSender() { // update timers - peer.TimerPacketSent() + peer.TimerAnyAuthenticatedPacketTraversal() if len(elem.packet) != MessageKeepaliveSize { peer.TimerDataSent() } diff --git a/src/timers.go b/src/timers.go index 5a16e9b..1be85f0 100644 --- a/src/timers.go +++ b/src/timers.go @@ -60,10 +60,8 @@ func (peer *Peer) SendKeepAlive() bool { return true } -/* Authenticated data packet send - * Always called together with peer.EventPacketSend - * - * - Start new handshake timer +/* Event: + * Sent non-empty (authenticated) transport message */ func (peer *Peer) TimerDataSent() { timerStop(peer.timer.keepalivePassive) @@ -75,8 +73,6 @@ func (peer *Peer) TimerDataSent() { /* Event: * Received non-empty (authenticated) transport message - * - * - Start passive keep-alive timer */ func (peer *Peer) TimerDataReceived() { if peer.timer.pendingKeepalivePassive { @@ -88,17 +84,16 @@ func (peer *Peer) TimerDataReceived() { } /* Event: - * Any (authenticated) transport message received - * (keep-alive or data) + * Any (authenticated) packet received */ -func (peer *Peer) TimerTransportReceived() { +func (peer *Peer) TimerAnyAuthenticatedPacketReceived() { timerStop(peer.timer.newHandshake) } /* Event: - * Any packet send to the peer. + * Any authenticated packet send / received. */ -func (peer *Peer) TimerPacketSent() { +func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) if interval > 0 { duration := time.Duration(interval) * time.Second @@ -106,13 +101,6 @@ func (peer *Peer) TimerPacketSent() { } } -/* Event: - * Any authenticated packet received from peer - */ -func (peer *Peer) TimerPacketReceived() { - peer.TimerPacketSent() -} - /* Called after succesfully completing a handshake. * i.e. after: * @@ -129,7 +117,9 @@ func (peer *Peer) TimerHandshakeComplete() { peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) } -/* Called whenever an ephemeral key is generated +/* Event: + * An ephemeral key is generated + * * i.e after: * * CreateMessageInitiation @@ -257,7 +247,6 @@ func (peer *Peer) RoutineHandshakeInitiator() { select { case <-peer.signal.handshakeBegin: - signalSend(peer.signal.handshakeBegin) case <-peer.signal.stop: return } @@ -303,7 +292,6 @@ func (peer *Peer) RoutineHandshakeInitiator() { binary.Write(writer, binary.LittleEndian, msg) packet := writer.Bytes() peer.mac.AddMacs(packet) - peer.TimerPacketSent() _, err = peer.SendBuffer(packet) if err != nil { @@ -314,6 +302,8 @@ func (peer *Peer) RoutineHandshakeInitiator() { continue } + peer.TimerAnyAuthenticatedPacketTraversal() + // set timeout timeout := time.NewTimer(RekeyTimeout) @@ -337,7 +327,6 @@ func (peer *Peer) RoutineHandshakeInitiator() { continue } - } // allow new signal to be set diff --git a/src/trie.go b/src/trie.go index e81b5b6..aa96a8a 100644 --- a/src/trie.go +++ b/src/trie.go @@ -32,11 +32,14 @@ type Trie struct { /* Finds length of matching prefix * TODO: Make faster * - * Assumption: len(ip1) == len(ip2) + * Assumption: + * len(ip1) == len(ip2) + * len(ip1) mod 4 = 0 */ -func commonBits(ip1 net.IP, ip2 net.IP) uint { +func commonBits(ip1 []byte, ip2 []byte) uint { var i uint - size := uint(len(ip1)) + size := uint(len(ip1)) / 4 + for i = 0; i < size; i++ { v := ip1[i] ^ ip2[i] if v != 0 { diff --git a/src/tun.go b/src/tun.go index f529c54..d782bd5 100644 --- a/src/tun.go +++ b/src/tun.go @@ -9,6 +9,7 @@ const DefaultMTU = 1420 type TUNDevice interface { Read([]byte) (int, error) // read a packet from the device (without any additional headers) Write([]byte) (int, error) // writes a packet to the device (without any additional headers) + IsUp() (bool, error) // is the interface up? MTU() (int, error) // returns the MTU of the device Name() string // returns the current name } diff --git a/src/tun_linux.go b/src/tun_linux.go index 261d142..d0e2f47 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "errors" "golang.org/x/sys/unix" + "net" "os" "strings" "unsafe" @@ -19,6 +20,11 @@ type NativeTun struct { name string } +func (tun *NativeTun) IsUp() (bool, error) { + inter, err := net.InterfaceByName(tun.name) + return inter.Flags&net.FlagUp != 0, err +} + func (tun *NativeTun) Name() string { return tun.name } diff --git a/src/uapi_linux.go b/src/uapi_linux.go index fd83918..d6d78e7 100644 --- a/src/uapi_linux.go +++ b/src/uapi_linux.go @@ -11,13 +11,12 @@ import ( ) const ( - ipcErrorIO = int64(unix.EIO) - ipcErrorNoPeer = int64(unix.EPROTO) - ipcErrorNoKeyValue = int64(unix.EPROTO) - ipcErrorInvalidKey = int64(unix.EPROTO) - ipcErrorInvalidValue = int64(unix.EPROTO) - socketDirectory = "/var/run/wireguard" - socketName = "%s.sock" + ipcErrorIO = -int64(unix.EIO) + ipcErrorNotDefined = -int64(unix.ENODEV) + ipcErrorProtocol = -int64(unix.EPROTO) + ipcErrorInvalid = -int64(unix.EINVAL) + socketDirectory = "/var/run/wireguard" + socketName = "%s.sock" ) /* TODO: