1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

First set of code review patches

This commit is contained in:
Mathias Hall-Andersen 2017-08-04 16:15:53 +02:00
parent 22c83f4b8d
commit 8c34c4cbb3
15 changed files with 315 additions and 182 deletions

View File

@ -61,6 +61,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send(fmt.Sprintf("persistent_keepalive_interval=%d", send(fmt.Sprintf("persistent_keepalive_interval=%d",
atomic.LoadUint64(&peer.persistentKeepaliveInterval), atomic.LoadUint64(&peer.persistentKeepaliveInterval),
)) ))
for _, ip := range device.routingTable.AllowedIPs(peer) { for _, ip := range device.routingTable.AllowedIPs(peer) {
send("allowed_ip=" + ip.String()) send("allowed_ip=" + ip.String())
} }
@ -89,6 +90,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logDebug := device.log.Debug logDebug := device.log.Debug
var peer *Peer var peer *Peer
deviceConfig := true
for scanner.Scan() { for scanner.Scan() {
// parse line // parse line
@ -99,86 +103,110 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
parts := strings.Split(line, "=") parts := strings.Split(line, "=")
if len(parts) != 2 { if len(parts) != 2 {
return &IPCError{Code: ipcErrorNoKeyValue} return &IPCError{Code: ipcErrorProtocol}
} }
key := parts[0] key := parts[0]
value := parts[1] value := parts[1]
switch key { /* device configuration */
/* interface configuration */ if deviceConfig {
case "private_key": switch key {
var sk NoisePrivateKey case "private_key":
if value == "" { var sk NoisePrivateKey
device.SetPrivateKey(sk) if value == "" {
} else { device.SetPrivateKey(sk)
err := sk.FromHex(value) } else {
err := sk.FromHex(value)
if err != nil {
logError.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalid}
}
device.SetPrivateKey(sk)
}
case "listen_port":
port, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
logError.Println("Failed to set private_key:", err) logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalid}
} }
device.SetPrivateKey(sk) netc := &device.net
} netc.mutex.Lock()
if netc.addr.Port != int(port) {
case "listen_port": if netc.conn != nil {
port, err := strconv.ParseUint(value, 10, 16) netc.conn.Close()
if err != nil { }
logError.Println("Failed to set listen_port:", err) netc.addr.Port = int(port)
return &IPCError{Code: ipcErrorInvalidValue} netc.conn, err = net.ListenUDP("udp", netc.addr)
}
netc := &device.net
netc.mutex.Lock()
if netc.addr.Port != int(port) {
if netc.conn != nil {
netc.conn.Close()
} }
netc.addr.Port = int(port) netc.mutex.Unlock()
netc.conn, err = net.ListenUDP("udp", netc.addr) if err != nil {
} logError.Println("Failed to create UDP listener:", err)
netc.mutex.Unlock() return &IPCError{Code: ipcErrorIO}
if err != nil { }
logError.Println("Failed to create UDP listener:", err) // TODO: Clear source address of all peers
return &IPCError{Code: ipcErrorInvalidValue}
}
case "fwmark": case "fwmark":
logError.Println("FWMark not handled yet") logError.Println("FWMark not handled yet")
// TODO: Clear source address of all peers
case "public_key": case "public_key":
var pubKey NoisePublicKey
err := pubKey.FromHex(value)
if err != nil {
logError.Println("Failed to get peer by public_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
device.mutex.RLock()
peer, _ = device.peers[pubKey]
device.mutex.RUnlock()
if peer == nil {
peer = device.NewPeer(pubKey)
}
case "replace_peers": // switch to peer configuration
if value == "true" {
deviceConfig = false
case "replace_peers":
if value != "true" {
logError.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
device.RemoveAllPeers() device.RemoveAllPeers()
} else {
logError.Println("Failed to set replace_peers, invalid value:", value) default:
return &IPCError{Code: ipcErrorInvalidValue} logError.Println("Invalid UAPI key (device configuration):", key)
return &IPCError{Code: ipcErrorInvalid}
} }
}
default: /* peer configuration */
/* peer configuration */ if !deviceConfig {
if peer == nil {
logError.Println("No peer referenced, before peer operation")
return &IPCError{Code: ipcErrorNoPeer}
}
switch key { switch key {
case "public_key":
var pubKey NoisePublicKey
err := pubKey.FromHex(value)
if err != nil {
logError.Println("Failed to get peer by public_key:", err)
return &IPCError{Code: ipcErrorInvalid}
}
// check if public key of peer equal to device
device.mutex.RLock()
if device.publicKey.Equals(pubKey) {
device.mutex.RUnlock()
logError.Println("Public key of peer matches private key of device")
return &IPCError{Code: ipcErrorInvalid}
}
// find peer referenced
peer, _ = device.peers[pubKey]
device.mutex.RUnlock()
if peer == nil {
peer = device.NewPeer(pubKey)
}
case "remove": case "remove":
if value != "true" {
logError.Println("Failed to set remove, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
device.RemovePeer(peer.handshake.remoteStatic) device.RemovePeer(peer.handshake.remoteStatic)
logDebug.Println("Removing", peer.String()) logDebug.Println("Removing", peer.String())
peer = nil peer = nil
@ -191,50 +219,67 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}() }()
if err != nil { if err != nil {
logError.Println("Failed to set preshared_key:", err) logError.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalid}
} }
case "endpoint": case "endpoint":
// TODO: Only IP and port
addr, err := net.ResolveUDPAddr("udp", value) addr, err := net.ResolveUDPAddr("udp", value)
if err != nil { if err != nil {
logError.Println("Failed to set endpoint:", value) logError.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalid}
} }
peer.mutex.Lock() peer.mutex.Lock()
peer.endpoint = addr peer.endpoint = addr
peer.mutex.Unlock() peer.mutex.Unlock()
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
secs, err := strconv.ParseInt(value, 10, 64)
if secs < 0 || err != nil { // update keep-alive interval
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
logError.Println("Failed to set persistent_keepalive_interval:", err) logError.Println("Failed to set persistent_keepalive_interval:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalid}
} }
atomic.StoreUint64(
old := atomic.SwapUint64(
&peer.persistentKeepaliveInterval, &peer.persistentKeepaliveInterval,
uint64(secs), secs,
) )
case "replace_allowed_ips": // send immediate keep-alive
if value == "true" {
device.routingTable.RemovePeer(peer) if old == 0 && secs != 0 {
} else { up, err := device.tun.IsUp()
logError.Println("Failed to set replace_allowed_ips, invalid value:", value) if err != nil {
return &IPCError{Code: ipcErrorInvalidValue} logError.Println("Failed to get tun device status:", err)
return &IPCError{Code: ipcErrorIO}
}
if up {
peer.SendKeepAlive()
}
} }
case "replace_allowed_ips":
if value != "true" {
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
device.routingTable.RemovePeer(peer)
case "allowed_ip": case "allowed_ip":
_, network, err := net.ParseCIDR(value) _, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {
logError.Println("Failed to set allowed_ip:", err) logError.Println("Failed to set allowed_ip:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalid}
} }
ones, _ := network.Mask.Size() ones, _ := network.Mask.Size()
device.routingTable.Insert(network.IP, uint(ones), peer) device.routingTable.Insert(network.IP, uint(ones), peer)
default: default:
logError.Println("Invalid UAPI key:", key) logError.Println("Invalid UAPI key (peer configuration):", key)
return &IPCError{Code: ipcErrorInvalidKey} return &IPCError{Code: ipcErrorInvalid}
} }
} }
} }
@ -244,6 +289,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
func ipcHandle(device *Device, socket net.Conn) { func ipcHandle(device *Device, socket net.Conn) {
// create buffered read/writer
defer socket.Close() defer socket.Close()
buffered := func(s io.ReadWriter) *bufio.ReadWriter { buffered := func(s io.ReadWriter) *bufio.ReadWriter {
@ -259,30 +306,30 @@ func ipcHandle(device *Device, socket net.Conn) {
return return
} }
switch op { // handle operation
var status *IPCError
switch op {
case "set=1\n": case "set=1\n":
device.log.Debug.Println("Config, set operation") device.log.Debug.Println("Config, set operation")
err := ipcSetOperation(device, buffered) status = ipcSetOperation(device, buffered)
if err != nil {
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
return
case "get=1\n": case "get=1\n":
device.log.Debug.Println("Config, get operation") device.log.Debug.Println("Config, get operation")
err := ipcGetOperation(device, buffered) status = ipcGetOperation(device, buffered)
if err != nil {
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
return
default: default:
device.log.Error.Println("Invalid UAPI operation:", op) device.log.Error.Println("Invalid UAPI operation:", op)
return
}
// write status
if status != nil {
device.log.Error.Println(status)
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
} }
} }

View File

@ -16,6 +16,7 @@ const (
KeepaliveTimeout = time.Second * 10 KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 120 CookieRefreshTime = time.Second * 120
MaxHandshakeAttemptTime = time.Second * 90 MaxHandshakeAttemptTime = time.Second * 90
PaddingMultiple = 16
) )
const ( const (
@ -31,5 +32,5 @@ const (
QueueHandshakeSize = 1024 QueueHandshakeSize = 1024
QueueHandshakeBusySize = QueueHandshakeSize / 8 QueueHandshakeBusySize = QueueHandshakeSize / 8
MinMessageSize = MessageTransportSize // size of keep-alive MinMessageSize = MessageTransportSize // size of keep-alive
MaxMessageSize = (1 << 16) - 1 MaxMessageSize = ((1 << 16) - 1) + MessageTransportHeaderSize
) )

View File

@ -1,6 +1,8 @@
package main package main
import ( import (
"errors"
"fmt"
"net" "net"
"runtime" "runtime"
"sync" "sync"
@ -10,6 +12,7 @@ import (
type Device struct { type Device struct {
mtu int32 mtu int32
tun TUNDevice
log *Logger // collection of loggers for levels log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers idCounter uint // for assigning debug ids to peers
fwMark uint32 fwMark uint32
@ -43,24 +46,46 @@ type Device struct {
mac MACStateDevice mac MACStateDevice
} }
func (device *Device) SetPrivateKey(sk NoisePrivateKey) { func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
device.mutex.Lock() device.mutex.Lock()
defer device.mutex.Unlock() defer device.mutex.Unlock()
// check if public key is matching any peer
publicKey := sk.publicKey()
for _, peer := range device.peers {
h := &peer.handshake
h.mutex.RLock()
if h.remoteStatic.Equals(publicKey) {
h.mutex.RUnlock()
return errors.New("Private key matches public key of peer")
}
h.mutex.RUnlock()
}
// update key material // update key material
device.privateKey = sk device.privateKey = sk
device.publicKey = sk.publicKey() device.publicKey = publicKey
device.mac.Init(device.publicKey) device.mac.Init(publicKey)
// do DH precomputations // do DH precomputations
isZero := device.privateKey.IsZero()
for _, peer := range device.peers { for _, peer := range device.peers {
h := &peer.handshake h := &peer.handshake
h.mutex.Lock() h.mutex.Lock()
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) if isZero {
h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
}
fmt.Println(h.precomputedStaticStatic)
h.mutex.Unlock() h.mutex.Unlock()
} }
return nil
} }
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
@ -77,6 +102,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.mutex.Lock() device.mutex.Lock()
defer device.mutex.Unlock() defer device.mutex.Unlock()
device.tun = tun
device.log = NewLogger(logLevel) device.log = NewLogger(logLevel)
device.peers = make(map[NoisePublicKey]*Peer) device.peers = make(map[NoisePublicKey]*Peer)
device.indices.Init() device.indices.Init()
@ -119,22 +145,22 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
} }
go device.RoutineBusyMonitor() go device.RoutineBusyMonitor()
go device.RoutineMTUUpdater(tun) go device.RoutineMTUUpdater()
go device.RoutineWriteToTUN(tun) go device.RoutineWriteToTUN()
go device.RoutineReadFromTUN(tun) go device.RoutineReadFromTUN()
go device.RoutineReceiveIncomming() go device.RoutineReceiveIncomming()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
return device return device
} }
func (device *Device) RoutineMTUUpdater(tun TUNDevice) { func (device *Device) RoutineMTUUpdater() {
logError := device.log.Error logError := device.log.Error
for ; ; time.Sleep(5 * time.Second) { for ; ; time.Sleep(5 * time.Second) {
// load updated MTU // load updated MTU
mtu, err := tun.MTU() mtu, err := device.tun.MTU()
if err != nil { if err != nil {
logError.Println("Failed to load updated MTU of device:", err) logError.Println("Failed to load updated MTU of device:", err)
continue continue

View File

@ -3,6 +3,7 @@ package main
import ( import (
"crypto/rand" "crypto/rand"
"sync" "sync"
"unsafe"
) )
/* Index=0 is reserved for unset indecies /* Index=0 is reserved for unset indecies
@ -23,14 +24,7 @@ type IndexTable struct {
func randUint32() (uint32, error) { func randUint32() (uint32, error) {
var buff [4]byte var buff [4]byte
_, err := rand.Read(buff[:]) _, err := rand.Read(buff[:])
id := uint32(buff[0]) return *((*uint32)(unsafe.Pointer(&buff))), err
id <<= 8
id |= uint32(buff[1])
id <<= 8
id |= uint32(buff[2])
id <<= 8
id |= uint32(buff[3])
return id, err
} }
func (table *IndexTable) Init() { func (table *IndexTable) Init() {

View File

@ -3,7 +3,6 @@ package main
import ( import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"errors"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"net" "net"
"sync" "sync"
@ -15,14 +14,14 @@ type MACStateDevice struct {
refreshed time.Time refreshed time.Time
secret [blake2s.Size]byte secret [blake2s.Size]byte
keyMAC1 [blake2s.Size]byte keyMAC1 [blake2s.Size]byte
keyMAC2 [blake2s.Size]byte keyMAC2 [blake2s.Size]byte // TODO: Change to more descriptive size constant, rename to something.
} }
type MACStatePeer struct { type MACStatePeer struct {
mutex sync.RWMutex mutex sync.RWMutex
cookieSet time.Time cookieSet time.Time
cookie [blake2s.Size128]byte cookie [blake2s.Size128]byte
lastMAC1 [blake2s.Size128]byte lastMAC1 [blake2s.Size128]byte // TODO: Check if set
keyMAC1 [blake2s.Size]byte keyMAC1 [blake2s.Size]byte
keyMAC2 [blake2s.Size]byte keyMAC2 [blake2s.Size]byte
} }
@ -83,7 +82,7 @@ func (state *MACStateDevice) CheckMAC2(msg []byte, addr *net.UDPAddr) bool {
port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)} port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
mac, _ := blake2s.New128(state.secret[:]) mac, _ := blake2s.New128(state.secret[:])
mac.Write(addr.IP) mac.Write(addr.IP)
mac.Write(port[:]) mac.Write(port[:]) // TODO: Be faster and more platform dependent?
mac.Sum(cookie[:0]) mac.Sum(cookie[:0])
}() }()
@ -130,7 +129,7 @@ func (device *Device) CreateMessageCookieReply(
port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)} port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
mac, _ := blake2s.New128(state.secret[:]) mac, _ := blake2s.New128(state.secret[:])
mac.Write(addr.IP) mac.Write(addr.IP)
mac.Write(port[:]) mac.Write(port[:]) // TODO: Do whatever we did above
mac.Sum(cookie[:0]) mac.Sum(cookie[:0])
}() }()
@ -196,6 +195,7 @@ func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool {
if err != nil { if err != nil {
return false return false
} }
state.cookieSet = time.Now() state.cookieSet = time.Now()
state.cookie = cookie state.cookie = cookie
return true return true
@ -229,10 +229,6 @@ func (state *MACStatePeer) Init(pk NoisePublicKey) {
func (state *MACStatePeer) AddMacs(msg []byte) { func (state *MACStatePeer) AddMacs(msg []byte) {
size := len(msg) size := len(msg)
if size < blake2s.Size128*2 {
panic(errors.New("bug: message too short"))
}
startMac1 := size - (blake2s.Size128 * 2) startMac1 := size - (blake2s.Size128 * 2)
startMac2 := size - blake2s.Size128 startMac2 := size - blake2s.Size128
@ -250,6 +246,7 @@ func (state *MACStatePeer) AddMacs(msg []byte) {
mac.Sum(mac1[:0]) mac.Sum(mac1[:0])
}() }()
copy(state.lastMAC1[:], mac1) copy(state.lastMAC1[:], mac1)
// TODO: Set lastMac flag
// set mac2 // set mac2

View File

@ -47,6 +47,14 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
return return
} }
func isZero(val []byte) bool {
var acc byte
for _, b := range val {
acc |= b
}
return acc == 0
}
/* curve25519 wrappers */ /* curve25519 wrappers */
func newPrivateKey() (sk NoisePrivateKey, err error) { func newPrivateKey() (sk NoisePrivateKey, err error) {

View File

@ -135,6 +135,10 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("Static shared secret is zero")
}
// create ephemeral key // create ephemeral key
var err error var err error
@ -226,7 +230,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if peer == nil { if peer == nil {
return nil return nil
} }
handshake := &peer.handshake handshake := &peer.handshake
if isZero(handshake.precomputedStaticStatic[:]) {
return nil
}
// verify identity // verify identity
@ -472,6 +480,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
func() { func() {
kp.mutex.Lock() kp.mutex.Lock()
defer kp.mutex.Unlock() defer kp.mutex.Unlock()
// TODO: Adapt kernel behavior noise.c:161
if isInitiator { if isInitiator {
if kp.previous != nil { if kp.previous != nil {
kp.previous.send = nil kp.previous.send = nil

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"crypto/subtle"
"encoding/hex" "encoding/hex"
"errors" "errors"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
@ -31,12 +32,12 @@ func loadExactHex(dst []byte, src string) error {
} }
func (key NoisePrivateKey) IsZero() bool { func (key NoisePrivateKey) IsZero() bool {
for _, b := range key[:] { var zero NoisePrivateKey
if b != 0 { return key.Equals(zero)
return false }
}
} func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
return true return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
} }
func (key *NoisePrivateKey) FromHex(src string) error { func (key *NoisePrivateKey) FromHex(src string) error {
@ -55,6 +56,15 @@ func (key NoisePublicKey) ToHex() string {
return hex.EncodeToString(key[:]) return hex.EncodeToString(key[:])
} }
func (key NoisePublicKey) IsZero() bool {
var zero NoisePublicKey
return key.Equals(zero)
}
func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
}
func (key *NoiseSymmetricKey) FromHex(src string) error { func (key *NoiseSymmetricKey) FromHex(src string) error {
return loadExactHex(key[:], src) return loadExactHex(key[:], src)
} }

View File

@ -73,6 +73,8 @@ func (device *Device) addToHandshakeQueue(
} }
/* Routine determining the busy state of the interface /* Routine determining the busy state of the interface
*
* TODO: Under load for some time
*/ */
func (device *Device) RoutineBusyMonitor() { func (device *Device) RoutineBusyMonitor() {
samples := 0 samples := 0
@ -131,6 +133,7 @@ func (device *Device) RoutineReceiveIncomming() {
buffer = device.GetMessageBuffer() buffer = device.GetMessageBuffer()
} }
// TODO: Take writelock to sleep
device.net.mutex.RLock() device.net.mutex.RLock()
conn := device.net.conn conn := device.net.conn
device.net.mutex.RUnlock() device.net.mutex.RUnlock()
@ -139,6 +142,7 @@ func (device *Device) RoutineReceiveIncomming() {
continue continue
} }
// TODO: Wait for new conn or message
conn.SetReadDeadline(time.Now().Add(time.Second)) conn.SetReadDeadline(time.Now().Add(time.Second))
size, raddr, err := conn.ReadFromUDP(buffer[:]) size, raddr, err := conn.ReadFromUDP(buffer[:])
@ -156,6 +160,8 @@ func (device *Device) RoutineReceiveIncomming() {
case MessageInitiationType, MessageResponseType: case MessageInitiationType, MessageResponseType:
// TODO: Check size early
// add to handshake queue // add to handshake queue
device.addToHandshakeQueue( device.addToHandshakeQueue(
@ -171,6 +177,8 @@ func (device *Device) RoutineReceiveIncomming() {
case MessageCookieReplyType: case MessageCookieReplyType:
// TODO: Queue all the things
// verify and update peer cookie state // verify and update peer cookie state
if len(packet) != MessageCookieReplySize { if len(packet) != MessageCookieReplySize {
@ -250,7 +258,7 @@ func (device *Device) RoutineDecryption() {
// check if dropped // check if dropped
if elem.IsDropped() { if elem.IsDropped() {
elem.mutex.Unlock() elem.mutex.Unlock() // TODO: Make consistent with send
continue continue
} }
@ -318,6 +326,7 @@ func (device *Device) RoutineHandshake() {
logError.Println("Failed to create cookie reply:", err) logError.Println("Failed to create cookie reply:", err)
return return
} }
// TODO: Use temp
writer := bytes.NewBuffer(elem.packet[:0]) writer := bytes.NewBuffer(elem.packet[:0])
binary.Write(writer, binary.LittleEndian, reply) binary.Write(writer, binary.LittleEndian, reply)
elem.packet = writer.Bytes() elem.packet = writer.Bytes()
@ -330,6 +339,8 @@ func (device *Device) RoutineHandshake() {
// ratelimit // ratelimit
// TODO: Only ratelimit when busy
if !device.ratelimiter.Allow(elem.source.IP) { if !device.ratelimiter.Allow(elem.source.IP) {
return return
} }
@ -364,9 +375,14 @@ func (device *Device) RoutineHandshake() {
) )
return return
} }
peer.TimerPacketReceived()
// update timers
peer.TimerAnyAuthenticatedPacketTraversal()
peer.TimerAnyAuthenticatedPacketReceived()
// update endpoint // update endpoint
// TODO: Add a race condition \s
peer.mutex.Lock() peer.mutex.Lock()
peer.endpoint = elem.source peer.endpoint = elem.source
@ -381,6 +397,7 @@ func (device *Device) RoutineHandshake() {
} }
peer.TimerEphemeralKeyCreated() peer.TimerEphemeralKeyCreated()
peer.NewKeyPair()
logDebug.Println("Creating response message for", peer.String()) logDebug.Println("Creating response message for", peer.String())
@ -392,8 +409,7 @@ func (device *Device) RoutineHandshake() {
// send response // send response
peer.SendBuffer(packet) peer.SendBuffer(packet)
peer.TimerPacketSent() peer.TimerAnyAuthenticatedPacketTraversal()
peer.NewKeyPair()
case MessageResponseType: case MessageResponseType:
@ -423,8 +439,14 @@ func (device *Device) RoutineHandshake() {
return return
} }
peer.TimerPacketReceived() // update timers
peer.TimerAnyAuthenticatedPacketTraversal()
peer.TimerAnyAuthenticatedPacketReceived()
peer.TimerHandshakeComplete() peer.TimerHandshakeComplete()
// derive key-pair
peer.NewKeyPair() peer.NewKeyPair()
peer.SendKeepAlive() peer.SendKeepAlive()
@ -467,8 +489,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
return return
} }
peer.TimerPacketReceived() peer.TimerAnyAuthenticatedPacketTraversal()
peer.TimerTransportReceived() peer.TimerAnyAuthenticatedPacketReceived()
peer.KeepKeyFreshReceiving() peer.KeepKeyFreshReceiving()
// check if using new key-pair // check if using new key-pair
@ -504,6 +526,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field) length := binary.BigEndian.Uint16(field)
// TODO: check length of packet & NOT TOO SMALL either
elem.packet = elem.packet[:length] elem.packet = elem.packet[:length]
// verify IPv4 source // verify IPv4 source
@ -525,6 +548,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field) length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen length += ipv6.HeaderLen
// TODO: check length of packet
elem.packet = elem.packet[:length] elem.packet = elem.packet[:length]
// verify IPv6 source // verify IPv6 source
@ -542,11 +566,13 @@ func (peer *Peer) RoutineSequentialReceiver() {
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
device.addToInboundQueue(device.queue.inbound, elem) device.addToInboundQueue(device.queue.inbound, elem)
// TODO: move TUN write into per peer routine
}() }()
} }
} }
func (device *Device) RoutineWriteToTUN(tun TUNDevice) { func (device *Device) RoutineWriteToTUN() {
logError := device.log.Error logError := device.log.Error
logDebug := device.log.Debug logDebug := device.log.Debug
@ -557,7 +583,7 @@ func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
case <-device.signal.stop: case <-device.signal.stop:
return return
case elem := <-device.queue.inbound: case elem := <-device.queue.inbound:
_, err := tun.Write(elem.packet) _, err := device.tun.Write(elem.packet)
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
if err != nil { if err != nil {
logError.Println("Failed to write packet to TUN device:", err) logError.Println("Failed to write packet to TUN device:", err)

View File

@ -110,17 +110,19 @@ func addToEncryptionQueue(
} }
func (peer *Peer) SendBuffer(buffer []byte) (int, error) { func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock()
peer.mutex.RLock() peer.mutex.RLock()
defer peer.mutex.RUnlock()
endpoint := peer.endpoint endpoint := peer.endpoint
peer.mutex.RUnlock() conn := peer.device.net.conn
if endpoint == nil { if endpoint == nil {
return 0, ErrorNoEndpoint return 0, ErrorNoEndpoint
} }
peer.device.net.mutex.RLock()
conn := peer.device.net.conn
peer.device.net.mutex.RUnlock()
if conn == nil { if conn == nil {
return 0, ErrorNoConnection return 0, ErrorNoConnection
} }
@ -133,13 +135,13 @@ func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
* *
* Obs. Single instance per TUN device * Obs. Single instance per TUN device
*/ */
func (device *Device) RoutineReadFromTUN(tun TUNDevice) { func (device *Device) RoutineReadFromTUN() {
if tun == nil { if device.tun == nil {
return return
} }
elem := device.NewOutboundElement() var elem *QueueOutboundElement
logDebug := device.log.Debug logDebug := device.log.Debug
logError := device.log.Error logError := device.log.Error
@ -153,32 +155,38 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
elem = device.NewOutboundElement() elem = device.NewOutboundElement()
} }
// TODO: THIS!
elem.packet = elem.buffer[MessageTransportHeaderSize:] elem.packet = elem.buffer[MessageTransportHeaderSize:]
size, err := tun.Read(elem.packet) size, err := device.tun.Read(elem.packet)
if err != nil { if err != nil {
// stop process
logError.Println("Failed to read packet from TUN device:", err) logError.Println("Failed to read packet from TUN device:", err)
device.Close() device.Close()
return return
} }
elem.packet = elem.packet[:size] if size == 0 {
if len(elem.packet) < ipv4.HeaderLen {
logError.Println("Packet too short, length:", size)
continue continue
} }
println(size, err)
elem.packet = elem.packet[:size]
// lookup peer // lookup peer
var peer *Peer var peer *Peer
switch elem.packet[0] >> 4 { switch elem.packet[0] >> 4 {
case ipv4.Version: case ipv4.Version:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst) peer = device.routingTable.LookupIPv4(dst)
case ipv6.Version: case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst) peer = device.routingTable.LookupIPv6(dst)
@ -190,10 +198,15 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
continue continue
} }
// check if known endpoint
peer.mutex.RLock()
if peer.endpoint == nil { if peer.endpoint == nil {
peer.mutex.RUnlock()
logDebug.Println("No known endpoint for peer", peer.String()) logDebug.Println("No known endpoint for peer", peer.String())
continue continue
} }
peer.mutex.RUnlock()
// insert into nonce/pre-handshake queue // insert into nonce/pre-handshake queue
@ -334,8 +347,12 @@ func (device *Device) RoutineEncryption() {
// pad content to MTU size // pad content to MTU size
mtu := int(atomic.LoadInt32(&device.mtu)) mtu := int(atomic.LoadInt32(&device.mtu))
for i := len(elem.packet); i < mtu; i++ { pad := len(elem.packet) % PaddingMultiple
elem.packet = append(elem.packet, 0) if pad > 0 {
for i := 0; i < PaddingMultiple-pad && len(elem.packet) < mtu; i++ {
elem.packet = append(elem.packet, 0)
}
// TODO: How good is this code
} }
// encrypt content (append to header) // encrypt content (append to header)
@ -390,7 +407,7 @@ func (peer *Peer) RoutineSequentialSender() {
// update timers // update timers
peer.TimerPacketSent() peer.TimerAnyAuthenticatedPacketTraversal()
if len(elem.packet) != MessageKeepaliveSize { if len(elem.packet) != MessageKeepaliveSize {
peer.TimerDataSent() peer.TimerDataSent()
} }

View File

@ -60,10 +60,8 @@ func (peer *Peer) SendKeepAlive() bool {
return true return true
} }
/* Authenticated data packet send /* Event:
* Always called together with peer.EventPacketSend * Sent non-empty (authenticated) transport message
*
* - Start new handshake timer
*/ */
func (peer *Peer) TimerDataSent() { func (peer *Peer) TimerDataSent() {
timerStop(peer.timer.keepalivePassive) timerStop(peer.timer.keepalivePassive)
@ -75,8 +73,6 @@ func (peer *Peer) TimerDataSent() {
/* Event: /* Event:
* Received non-empty (authenticated) transport message * Received non-empty (authenticated) transport message
*
* - Start passive keep-alive timer
*/ */
func (peer *Peer) TimerDataReceived() { func (peer *Peer) TimerDataReceived() {
if peer.timer.pendingKeepalivePassive { if peer.timer.pendingKeepalivePassive {
@ -88,17 +84,16 @@ func (peer *Peer) TimerDataReceived() {
} }
/* Event: /* Event:
* Any (authenticated) transport message received * Any (authenticated) packet received
* (keep-alive or data)
*/ */
func (peer *Peer) TimerTransportReceived() { func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
timerStop(peer.timer.newHandshake) timerStop(peer.timer.newHandshake)
} }
/* Event: /* Event:
* Any packet send to the peer. * Any authenticated packet send / received.
*/ */
func (peer *Peer) TimerPacketSent() { func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 { if interval > 0 {
duration := time.Duration(interval) * time.Second duration := time.Duration(interval) * time.Second
@ -106,13 +101,6 @@ func (peer *Peer) TimerPacketSent() {
} }
} }
/* Event:
* Any authenticated packet received from peer
*/
func (peer *Peer) TimerPacketReceived() {
peer.TimerPacketSent()
}
/* Called after succesfully completing a handshake. /* Called after succesfully completing a handshake.
* i.e. after: * i.e. after:
* *
@ -129,7 +117,9 @@ func (peer *Peer) TimerHandshakeComplete() {
peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
} }
/* Called whenever an ephemeral key is generated /* Event:
* An ephemeral key is generated
*
* i.e after: * i.e after:
* *
* CreateMessageInitiation * CreateMessageInitiation
@ -257,7 +247,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
select { select {
case <-peer.signal.handshakeBegin: case <-peer.signal.handshakeBegin:
signalSend(peer.signal.handshakeBegin)
case <-peer.signal.stop: case <-peer.signal.stop:
return return
} }
@ -303,7 +292,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
binary.Write(writer, binary.LittleEndian, msg) binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes() packet := writer.Bytes()
peer.mac.AddMacs(packet) peer.mac.AddMacs(packet)
peer.TimerPacketSent()
_, err = peer.SendBuffer(packet) _, err = peer.SendBuffer(packet)
if err != nil { if err != nil {
@ -314,6 +302,8 @@ func (peer *Peer) RoutineHandshakeInitiator() {
continue continue
} }
peer.TimerAnyAuthenticatedPacketTraversal()
// set timeout // set timeout
timeout := time.NewTimer(RekeyTimeout) timeout := time.NewTimer(RekeyTimeout)
@ -337,7 +327,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
continue continue
} }
} }
// allow new signal to be set // allow new signal to be set

View File

@ -32,11 +32,14 @@ type Trie struct {
/* Finds length of matching prefix /* Finds length of matching prefix
* TODO: Make faster * TODO: Make faster
* *
* Assumption: len(ip1) == len(ip2) * Assumption:
* len(ip1) == len(ip2)
* len(ip1) mod 4 = 0
*/ */
func commonBits(ip1 net.IP, ip2 net.IP) uint { func commonBits(ip1 []byte, ip2 []byte) uint {
var i uint var i uint
size := uint(len(ip1)) size := uint(len(ip1)) / 4
for i = 0; i < size; i++ { for i = 0; i < size; i++ {
v := ip1[i] ^ ip2[i] v := ip1[i] ^ ip2[i]
if v != 0 { if v != 0 {

View File

@ -9,6 +9,7 @@ const DefaultMTU = 1420
type TUNDevice interface { type TUNDevice interface {
Read([]byte) (int, error) // read a packet from the device (without any additional headers) Read([]byte) (int, error) // read a packet from the device (without any additional headers)
Write([]byte) (int, error) // writes a packet to the device (without any additional headers) Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
IsUp() (bool, error) // is the interface up?
MTU() (int, error) // returns the MTU of the device MTU() (int, error) // returns the MTU of the device
Name() string // returns the current name Name() string // returns the current name
} }

View File

@ -7,6 +7,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"net"
"os" "os"
"strings" "strings"
"unsafe" "unsafe"
@ -19,6 +20,11 @@ type NativeTun struct {
name string name string
} }
func (tun *NativeTun) IsUp() (bool, error) {
inter, err := net.InterfaceByName(tun.name)
return inter.Flags&net.FlagUp != 0, err
}
func (tun *NativeTun) Name() string { func (tun *NativeTun) Name() string {
return tun.name return tun.name
} }

View File

@ -11,13 +11,12 @@ import (
) )
const ( const (
ipcErrorIO = int64(unix.EIO) ipcErrorIO = -int64(unix.EIO)
ipcErrorNoPeer = int64(unix.EPROTO) ipcErrorNotDefined = -int64(unix.ENODEV)
ipcErrorNoKeyValue = int64(unix.EPROTO) ipcErrorProtocol = -int64(unix.EPROTO)
ipcErrorInvalidKey = int64(unix.EPROTO) ipcErrorInvalid = -int64(unix.EINVAL)
ipcErrorInvalidValue = int64(unix.EPROTO) socketDirectory = "/var/run/wireguard"
socketDirectory = "/var/run/wireguard" socketName = "%s.sock"
socketName = "%s.sock"
) )
/* TODO: /* TODO: