1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Rework of entire locking system

Locking on the Device instance is now much more fined-grained,
seperating out the fields into "resources" st. most common interactions
only require a small number.
This commit is contained in:
Mathias Hall-Andersen 2018-02-02 16:40:14 +01:00
parent 1e42b14022
commit 029410b118
10 changed files with 386 additions and 239 deletions

View File

@ -65,12 +65,12 @@ func unsafeCloseBind(device *Device) error {
}
func (device *Device) BindUpdate() error {
device.mutex.Lock()
defer device.mutex.Unlock()
netc := &device.net
netc.mutex.Lock()
defer netc.mutex.Unlock()
device.net.mutex.Lock()
defer device.net.mutex.Unlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// close existing sockets
@ -85,6 +85,7 @@ func (device *Device) BindUpdate() error {
// bind to new port
var err error
netc := &device.net
netc.bind, netc.port, err = CreateBind(netc.port)
if err != nil {
netc.bind = nil
@ -100,12 +101,12 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses
for _, peer := range device.peers {
for _, peer := range device.peers.keyMap {
peer.mutex.Lock()
defer peer.mutex.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.mutex.Unlock()
}
// start receiving routines
@ -120,10 +121,8 @@ func (device *Device) BindUpdate() error {
}
func (device *Device) BindClose() error {
device.mutex.Lock()
device.net.mutex.Lock()
err := unsafeCloseBind(device)
device.net.mutex.Unlock()
device.mutex.Unlock()
return err
}

View File

@ -9,46 +9,110 @@ import (
)
type Device struct {
isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard)
log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
fwMark uint32
tun struct {
device TUNDevice
mtu int32
}
isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard)
log *Logger
// synchronized resources (locks acquired in order)
state struct {
mutex deadlock.Mutex
changing AtomicBool
current bool
}
pool struct {
messageBuffers sync.Pool
}
net struct {
mutex deadlock.RWMutex
bind Bind // bind interface
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
}
mutex deadlock.RWMutex
privateKey NoisePrivateKey
publicKey NoisePublicKey
routingTable RoutingTable
indices IndexTable
queue struct {
noise struct {
mutex deadlock.RWMutex
privateKey NoisePrivateKey
publicKey NoisePublicKey
}
routing struct {
mutex deadlock.RWMutex
table RoutingTable
}
peers struct {
mutex deadlock.RWMutex
keyMap map[NoisePublicKey]*Peer
}
// unprotected / "self-synchronising resources"
indices IndexTable
mac CookieChecker
rate struct {
underLoadUntil atomic.Value
limiter Ratelimiter
}
pool struct {
messageBuffers sync.Pool
}
queue struct {
encryption chan *QueueOutboundElement
decryption chan *QueueInboundElement
handshake chan QueueHandshakeElement
}
signal struct {
stop Signal
}
underLoadUntil atomic.Value
ratelimiter Ratelimiter
peers map[NoisePublicKey]*Peer
mac CookieChecker
tun struct {
device TUNDevice
mtu int32
}
}
/* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table.
*
* Must hold:
* device.peers.mutex : exclusive lock
* device.routing : exclusive lock
*/
func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
// stop routing and processing of packets
device.routing.table.RemovePeer(peer)
peer.Stop()
// clean index table
kp := &peer.keyPairs
kp.mutex.Lock()
if kp.previous != nil {
device.indices.Delete(kp.previous.localIndex)
}
if kp.current != nil {
device.indices.Delete(kp.current.localIndex)
}
if kp.next != nil {
device.indices.Delete(kp.next.localIndex)
}
kp.previous = nil
kp.current = nil
kp.next = nil
kp.mutex.Unlock()
// remove from peer map
delete(device.peers.keyMap, key)
}
func deviceUpdateState(device *Device) {
@ -59,56 +123,56 @@ func deviceUpdateState(device *Device) {
return
}
// compare to current state of device
func() {
device.state.mutex.Lock()
// compare to current state of device
newIsUp := device.isUp.Get()
device.state.mutex.Lock()
defer device.state.mutex.Unlock()
if newIsUp == device.state.current {
device.state.mutex.Unlock()
newIsUp := device.isUp.Get()
if newIsUp == device.state.current {
device.state.changing.Set(false)
return
}
// change state of device
switch newIsUp {
case true:
if err := device.BindUpdate(); err != nil {
device.isUp.Set(false)
break
}
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
for _, peer := range device.peers.keyMap {
peer.Start()
}
case false:
device.BindClose()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
for _, peer := range device.peers.keyMap {
println("stopping peer")
peer.Stop()
}
}
// update state variables
device.state.current = newIsUp
device.state.changing.Set(false)
return
}
}()
device.state.mutex.Unlock()
// check for state change in the mean time
// change state of device
switch newIsUp {
case true:
// start listener
if err := device.BindUpdate(); err != nil {
device.isUp.Set(false)
break
}
// start every peer
for _, peer := range device.peers {
peer.Start()
}
case false:
// stop listening
device.BindClose()
// stop every peer
for _, peer := range device.peers {
peer.Stop()
}
}
// update state variables
// and check for state change in the mean time
device.state.current = newIsUp
device.state.changing.Set(false)
deviceUpdateState(device)
}
@ -133,18 +197,6 @@ func (device *Device) Down() {
deviceUpdateState(device)
}
/* Warning:
* The caller must hold the device mutex (write lock)
*/
func removePeerUnsafe(device *Device, key NoisePublicKey) {
peer, ok := device.peers[key]
if !ok {
return
}
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
}
func (device *Device) IsUnderLoad() bool {
// check if currently under load
@ -152,54 +204,66 @@ func (device *Device) IsUnderLoad() bool {
now := time.Now()
underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
if underLoad {
device.underLoadUntil.Store(now.Add(time.Second))
device.rate.underLoadUntil.Store(now.Add(time.Second))
return true
}
// check if recently under load
until := device.underLoadUntil.Load().(time.Time)
until := device.rate.underLoadUntil.Load().(time.Time)
return until.After(now)
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
device.mutex.Lock()
defer device.mutex.Unlock()
// lock required resources
device.noise.mutex.Lock()
defer device.noise.mutex.Unlock()
device.routing.mutex.Lock()
defer device.routing.mutex.Unlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
for _, peer := range device.peers.keyMap {
peer.handshake.mutex.RLock()
defer peer.handshake.mutex.RUnlock()
}
// remove peers with matching public keys
publicKey := sk.publicKey()
for key, peer := range device.peers {
h := &peer.handshake
h.mutex.RLock()
if h.remoteStatic.Equals(publicKey) {
removePeerUnsafe(device, key)
for key, peer := range device.peers.keyMap {
if peer.handshake.remoteStatic.Equals(publicKey) {
unsafeRemovePeer(device, peer, key)
}
h.mutex.RUnlock()
}
// update key material
device.privateKey = sk
device.publicKey = publicKey
device.noise.privateKey = sk
device.noise.publicKey = publicKey
device.mac.Init(publicKey)
// do DH pre-computations
// do static-static DH pre-computations
rmKey := device.privateKey.IsZero()
rmKey := device.noise.privateKey.IsZero()
for key, peer := range device.peers.keyMap {
hs := &peer.handshake
for key, peer := range device.peers {
h := &peer.handshake
h.mutex.Lock()
if rmKey {
h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
hs.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
if isZero(h.precomputedStaticStatic[:]) {
removePeerUnsafe(device, key)
}
hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic)
}
if isZero(hs.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key)
}
h.mutex.Unlock()
}
return nil
@ -215,21 +279,23 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
func NewDevice(tun TUNDevice, logger *Logger) *Device {
device := new(Device)
device.mutex.Lock()
defer device.mutex.Unlock()
device.isUp.Set(false)
device.isClosed.Set(false)
device.log = logger
device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
// initialize anti-DoS / anti-scanning features
device.rate.limiter.Init()
device.rate.underLoadUntil.Store(time.Time{})
// initialize noise & crypt-key routine
device.indices.Init()
device.ratelimiter.Init()
device.routingTable.Reset()
device.underLoadUntil.Store(time.Time{})
device.routing.table.Reset()
// setup buffer pool
@ -264,36 +330,50 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
go device.rate.limiter.RoutineGarbageCollector(device.signal.stop)
return device
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.mutex.RLock()
defer device.mutex.RUnlock()
return device.peers[pk]
device.peers.mutex.RLock()
defer device.peers.mutex.RUnlock()
return device.peers.keyMap[pk]
}
func (device *Device) RemovePeer(key NoisePublicKey) {
device.mutex.Lock()
defer device.mutex.Unlock()
removePeerUnsafe(device, key)
device.noise.mutex.Lock()
defer device.noise.mutex.Unlock()
device.routing.mutex.Lock()
defer device.routing.mutex.Unlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// stop peer and remove from routing
peer, ok := device.peers.keyMap[key]
if ok {
unsafeRemovePeer(device, peer, key)
}
}
func (device *Device) RemoveAllPeers() {
device.mutex.Lock()
defer device.mutex.Unlock()
for key, peer := range device.peers {
peer.Stop()
peer, ok := device.peers[key]
if !ok {
return
}
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
device.routing.mutex.Lock()
defer device.routing.mutex.Unlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
for key, peer := range device.peers.keyMap {
println("rm", peer.String())
unsafeRemovePeer(device, peer, key)
}
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
}
func (device *Device) Close() {
@ -305,7 +385,6 @@ func (device *Device) Close() {
device.tun.device.Close()
device.BindClose()
device.isUp.Set(false)
println("remove")
device.RemoveAllPeers()
device.log.Info.Println("Interface closed")
}

View File

@ -3,6 +3,7 @@ package main
import (
"crypto/hmac"
"crypto/rand"
"crypto/subtle"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/curve25519"
"hash"
@ -58,11 +59,11 @@ func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
}
func isZero(val []byte) bool {
var acc byte
acc := 1
for _, b := range val {
acc |= b
acc &= subtle.ConstantTimeByteEq(b, 0)
}
return acc == 0
return acc == 1
}
func setZero(arr []byte) {

View File

@ -137,6 +137,10 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
device.noise.mutex.Lock()
defer device.noise.mutex.Unlock()
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
@ -187,7 +191,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
ss[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:])
}()
handshake.mixHash(msg.Static[:])
@ -212,16 +216,19 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
}
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if msg.Type != MessageInitiationType {
return nil
}
var (
hash [blake2s.Size]byte
chainKey [blake2s.Size]byte
)
mixHash(&hash, &InitialHash, device.publicKey[:])
if msg.Type != MessageInitiationType {
return nil
}
device.noise.mutex.RLock()
defer device.noise.mutex.RUnlock()
mixHash(&hash, &InitialHash, device.noise.publicKey[:])
mixHash(&hash, &hash, msg.Ephemeral[:])
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
@ -231,7 +238,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
var peerPK NoisePublicKey
func() {
var key [chacha20poly1305.KeySize]byte
ss := device.privateKey.sharedSecret(msg.Ephemeral)
ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
@ -407,7 +414,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
}()
func() {
ss := device.privateKey.sharedSecret(msg.Ephemeral)
ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()

View File

@ -14,7 +14,6 @@ const (
)
type Peer struct {
id uint
isRunning AtomicBool
mutex deadlock.RWMutex
persistentKeepaliveInterval uint64
@ -22,17 +21,20 @@ type Peer struct {
handshake Handshake
device *Device
endpoint Endpoint
stats struct {
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch
}
time struct {
mutex deadlock.RWMutex
lastSend time.Time // last send message
lastHandshake time.Time // last completed handshake
nextKeepalive time.Time
}
signal struct {
newKeyPair Signal // size 1, new key pair was generated
handshakeCompleted Signal // size 1, handshake completed
@ -41,7 +43,9 @@ type Peer struct {
messageSend Signal // size 1, message was send to peer
messageReceived Signal // size 1, authenticated message recv
}
timer struct {
// state related to WireGuard timers
keepalivePersistent Timer // set for persistent keepalives
@ -54,17 +58,20 @@ type Peer struct {
sendLastMinuteHandshake bool
needAnotherKeepalive bool
}
queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work
inbound chan *QueueInboundElement // sequential ordering of work
}
routines struct {
mutex deadlock.Mutex // held when stopping / starting routines
starting sync.WaitGroup // routines pending start
stopping sync.WaitGroup // routines pending stop
stop Signal // size 0, stop all goroutines in peer
}
mac CookieGenerator
}
@ -74,8 +81,22 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return nil, errors.New("Device closed")
}
device.mutex.Lock()
defer device.mutex.Unlock()
// lock resources
device.state.mutex.Lock()
defer device.state.mutex.Unlock()
device.noise.mutex.RLock()
defer device.noise.mutex.RUnlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// check if over limit
if len(device.peers.keyMap) >= MaxPeers {
return nil, errors.New("Too many peers")
}
// create peer
@ -94,32 +115,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.timer.handshakeDeadline = NewTimer()
peer.timer.handshakeTimeout = NewTimer()
// assign id for debugging
peer.id = device.idCounter
device.idCounter += 1
// check if over limit
if len(device.peers) >= MaxPeers {
return nil, errors.New("Too many peers")
}
// map public key
_, ok := device.peers[pk]
_, ok := device.peers.keyMap[pk]
if ok {
return nil, errors.New("Adding existing peer")
}
device.peers[pk] = peer
device.peers.keyMap[pk] = peer
// precompute DH
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.remoteStatic = pk
handshake.precomputedStaticStatic =
device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk)
handshake.mutex.Unlock()
// reset endpoint
@ -134,11 +143,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// start peer
peer.device.state.mutex.Lock()
if peer.device.isUp.Get() {
peer.Start()
}
peer.device.state.mutex.Unlock()
return peer, nil
}
@ -166,14 +173,12 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
func (peer *Peer) String() string {
if peer.endpoint == nil {
return fmt.Sprintf(
"peer(%d unknown %s)",
peer.id,
"peer(unknown %s)",
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}
return fmt.Sprintf(
"peer(%d %s %s)",
peer.id,
"peer(%s %s)",
peer.endpoint.DstToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
@ -181,8 +186,12 @@ func (peer *Peer) String() string {
func (peer *Peer) Start() {
if peer.device.isClosed.Get() {
return
}
peer.routines.mutex.Lock()
defer peer.routines.mutex.Lock()
defer peer.routines.mutex.Unlock()
peer.device.log.Debug.Println("Starting:", peer.String())
@ -222,7 +231,7 @@ func (peer *Peer) Start() {
func (peer *Peer) Stop() {
peer.routines.mutex.Lock()
defer peer.routines.mutex.Lock()
defer peer.routines.mutex.Unlock()
peer.device.log.Debug.Println("Stopping:", peer.String())

View File

@ -372,7 +372,7 @@ func (device *Device) RoutineHandshake() {
// check ratelimiter
if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
continue
}
}
@ -495,19 +495,23 @@ func (device *Device) RoutineHandshake() {
func (peer *Peer) RoutineSequentialReceiver() {
defer peer.routines.stopping.Done()
device := peer.device
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
logDebug.Println("Routine, sequential receiver, started for peer", peer.String())
peer.routines.starting.Done()
for {
select {
case <-peer.routines.stop.Wait():
logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
logDebug.Println("Routine, sequential receiver, stopped for peer", peer.String())
return
case elem := <-peer.queue.inbound:
@ -581,7 +585,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer {
if device.routing.table.LookupIPv4(src) != peer {
logInfo.Println(
"IPv4 packet with disallowed source address from",
peer.String(),
@ -609,7 +613,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer {
if device.routing.table.LookupIPv6(src) != peer {
logInfo.Println(
"IPv6 packet with disallowed source address from",
peer.String(),

View File

@ -151,14 +151,14 @@ func (device *Device) RoutineReadFromTUN() {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst)
peer = device.routing.table.LookupIPv4(dst)
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst)
peer = device.routing.table.LookupIPv6(dst)
default:
logDebug.Println("Received packet with unknown IP version")
@ -187,10 +187,14 @@ func (device *Device) RoutineReadFromTUN() {
func (peer *Peer) RoutineNonce() {
var keyPair *KeyPair
defer peer.routines.stopping.Done()
device := peer.device
logDebug := device.log.Debug
logDebug.Println("Routine, nonce worker, started for peer", peer.String())
peer.routines.starting.Done()
for {
NextPacket:
select {

View File

@ -303,7 +303,7 @@ func (peer *Peer) RoutineTimerHandler() {
err := peer.sendNewHandshake()
if err != nil {
logInfo.Println(
"Failed to send handshake to peer:", peer.String())
"Failed to send handshake to peer:", peer.String(), "(", err, ")")
}
case <-peer.timer.handshakeDeadline.Wait():
@ -326,7 +326,7 @@ func (peer *Peer) RoutineTimerHandler() {
err := peer.sendNewHandshake()
if err != nil {
logInfo.Println(
"Failed to send handshake to peer:", peer.String())
"Failed to send handshake to peer:", peer.String(), "(", err, ")")
}
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)

View File

@ -313,7 +313,7 @@ func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) {
}
go device.RoutineNetlinkListener()
go device.RoutineHackListener() // cross namespace
// go device.RoutineHackListener() // cross namespace
// set default MTU
@ -369,7 +369,7 @@ func CreateTUN(name string) (TUNDevice, error) {
}
go device.RoutineNetlinkListener()
go device.RoutineHackListener() // cross namespace
// go device.RoutineHackListener() // cross namespace
// set default MTU

View File

@ -25,32 +25,51 @@ func (s *IPCError) ErrorCode() int64 {
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// create lines
device.log.Debug.Println("UAPI: Processing get operation")
device.mutex.RLock()
device.net.mutex.RLock()
// create lines
lines := make([]string, 0, 100)
send := func(line string) {
lines = append(lines, line)
}
if !device.privateKey.IsZero() {
send("private_key=" + device.privateKey.ToHex())
}
func() {
if device.net.port != 0 {
send(fmt.Sprintf("listen_port=%d", device.net.port))
}
// lock required resources
if device.net.fwmark != 0 {
send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
}
device.net.mutex.RLock()
defer device.net.mutex.RUnlock()
for _, peer := range device.peers {
func() {
device.noise.mutex.RLock()
defer device.noise.mutex.RUnlock()
device.routing.mutex.RLock()
defer device.routing.mutex.RUnlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// serialize device related values
if !device.noise.privateKey.IsZero() {
send("private_key=" + device.noise.privateKey.ToHex())
}
if device.net.port != 0 {
send(fmt.Sprintf("listen_port=%d", device.net.port))
}
if device.net.fwmark != 0 {
send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
}
// serialize each peer state
for _, peer := range device.peers.keyMap {
peer.mutex.RLock()
defer peer.mutex.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint != nil {
@ -69,16 +88,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
atomic.LoadUint64(&peer.persistentKeepaliveInterval),
))
for _, ip := range device.routingTable.AllowedIPs(peer) {
for _, ip := range device.routing.table.AllowedIPs(peer) {
send("allowed_ip=" + ip.String())
}
}()
}
device.net.mutex.RUnlock()
device.mutex.RUnlock()
}
}()
// send lines
// send lines (does not require resource locks)
for _, line := range lines {
_, err := socket.WriteString(line + "\n")
@ -94,7 +111,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
scanner := bufio.NewScanner(socket)
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
@ -130,6 +146,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalid}
}
logDebug.Println("UAPI: Updating device private key")
device.SetPrivateKey(sk)
case "listen_port":
@ -144,6 +161,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update port and rebind
logDebug.Println("UAPI: Updating listen port")
device.net.mutex.Lock()
device.net.port = uint16(port)
device.net.mutex.Unlock()
@ -170,6 +189,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid}
}
logDebug.Println("UAPI: Updating fwmark")
device.net.mutex.Lock()
device.net.fwmark = uint32(fwmark)
device.net.mutex.Unlock()
@ -181,6 +202,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "public_key":
// switch to peer configuration
logDebug.Println("UAPI: Transition to peer configuration")
deviceConfig = false
case "replace_peers":
@ -188,6 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
logDebug.Println("UAPI: Removing all peers")
device.RemoveAllPeers()
default:
@ -203,43 +226,41 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
switch key {
case "public_key":
var pubKey NoisePublicKey
err := pubKey.FromHex(value)
var publicKey NoisePublicKey
err := publicKey.FromHex(value)
if err != nil {
logError.Println("Failed to get peer by public_key:", err)
return &IPCError{Code: ipcErrorInvalid}
}
// check if public key of peer equal to device
// ignore peer with public key of device
device.mutex.RLock()
if device.publicKey.Equals(pubKey) {
// create dummy instance (not added to device)
device.noise.mutex.RLock()
equals := device.noise.publicKey.Equals(publicKey)
device.noise.mutex.RUnlock()
if equals {
peer = &Peer{}
dummy = true
device.mutex.RUnlock()
logInfo.Println("Ignoring peer with public key of device")
} else {
// find peer referenced
peer, _ = device.peers[pubKey]
device.mutex.RUnlock()
if peer == nil {
peer, err = device.NewPeer(pubKey)
if err != nil {
logError.Println("Failed to create new peer:", err)
return &IPCError{Code: ipcErrorInvalid}
}
}
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
dummy = false
}
// find peer referenced
peer = device.LookupPeer(publicKey)
if peer == nil {
peer, err = device.NewPeer(publicKey)
if err != nil {
logError.Println("Failed to create new peer:", err)
return &IPCError{Code: ipcErrorInvalid}
}
logDebug.Println("UAPI: Created new peer:", peer.String())
}
peer.mutex.Lock()
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
peer.mutex.Unlock()
case "remove":
// remove currently selected peer from device
@ -249,7 +270,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid}
}
if !dummy {
logDebug.Println("Removing", peer.String())
logDebug.Println("UAPI: Removing peer:", peer.String())
device.RemovePeer(peer.handshake.remoteStatic)
}
peer = &Peer{}
@ -259,9 +280,12 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update PSK
peer.mutex.Lock()
logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String())
peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
peer.mutex.Unlock()
peer.handshake.mutex.Unlock()
if err != nil {
logError.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalid}
@ -271,6 +295,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// set endpoint destination
logDebug.Println("UAPI: Updating endpoint for peer:", peer.String())
err := func() error {
peer.mutex.Lock()
defer peer.mutex.Unlock()
@ -292,6 +318,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update keep-alive interval
logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String())
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
logError.Println("Failed to set persistent_keepalive_interval:", err)
@ -316,25 +344,41 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
case "replace_allowed_ips":
logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String())
if value != "true" {
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
if !dummy {
device.routingTable.RemovePeer(peer)
if dummy {
continue
}
device.routing.mutex.Lock()
device.routing.table.RemovePeer(peer)
device.routing.mutex.Unlock()
case "allowed_ip":
logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String())
_, network, err := net.ParseCIDR(value)
if err != nil {
logError.Println("Failed to set allowed_ip:", err)
return &IPCError{Code: ipcErrorInvalid}
}
ones, _ := network.Mask.Size()
if !dummy {
device.routingTable.Insert(network.IP, uint(ones), peer)
if dummy {
continue
}
ones, _ := network.Mask.Size()
device.routing.mutex.Lock()
device.routing.table.Insert(network.IP, uint(ones), peer)
device.routing.mutex.Unlock()
default:
logError.Println("Invalid UAPI key (peer configuration):", key)
return &IPCError{Code: ipcErrorInvalid}