1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Fixed deadlock in index.go

This commit is contained in:
Mathias Hall-Andersen 2017-07-17 16:16:18 +02:00
parent dd4da93749
commit c5d7efc246
8 changed files with 194 additions and 152 deletions

View File

@ -8,39 +8,36 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"syscall"
) )
// #include <errno.h>
import "C"
/* TODO: More fine grained?
*/
const ( const (
ipcErrorNoPeer = C.EPROTO ipcErrorIO = syscall.EIO
ipcErrorNoKeyValue = C.EPROTO ipcErrorNoPeer = syscall.EPROTO
ipcErrorInvalidKey = C.EPROTO ipcErrorNoKeyValue = syscall.EPROTO
ipcErrorInvalidValue = C.EPROTO ipcErrorInvalidKey = syscall.EPROTO
ipcErrorInvalidValue = syscall.EPROTO
) )
type IPCError struct { type IPCError struct {
Code int Code syscall.Errno
} }
func (s *IPCError) Error() string { func (s *IPCError) Error() string {
return fmt.Sprintf("IPC error: %d", s.Code) return fmt.Sprintf("IPC error: %d", s.Code)
} }
func (s *IPCError) ErrorCode() int { func (s *IPCError) ErrorCode() uintptr {
return s.Code return uintptr(s.Code)
} }
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.mutex.RLock()
defer device.mutex.RUnlock()
// create lines // create lines
device.mutex.RLock()
lines := make([]string, 0, 100) lines := make([]string, 0, 100)
send := func(line string) { send := func(line string) {
lines = append(lines, line) lines = append(lines, line)
@ -63,19 +60,25 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
} }
send(fmt.Sprintf("tx_bytes=%d", peer.txBytes)) send(fmt.Sprintf("tx_bytes=%d", peer.txBytes))
send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes)) send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes))
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) send(fmt.Sprintf("persistent_keepalive_interval=%d",
atomic.LoadUint64(&peer.persistentKeepaliveInterval),
))
for _, ip := range device.routingTable.AllowedIPs(peer) { for _, ip := range device.routingTable.AllowedIPs(peer) {
send("allowed_ip=" + ip.String()) send("allowed_ip=" + ip.String())
} }
}() }()
} }
device.mutex.RUnlock()
// send lines // send lines
for _, line := range lines { for _, line := range lines {
_, err := socket.WriteString(line + "\n") _, err := socket.WriteString(line + "\n")
if err != nil { if err != nil {
return err return &IPCError{
Code: ipcErrorIO,
}
} }
} }
@ -83,13 +86,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
} }
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logger := device.log.Debug
scanner := bufio.NewScanner(socket) scanner := bufio.NewScanner(socket)
logError := device.log.Error
logDebug := device.log.Debug
var peer *Peer var peer *Peer
for scanner.Scan() { for scanner.Scan() {
// Parse line // parse line
line := scanner.Text() line := scanner.Text()
if line == "" { if line == "" {
@ -97,7 +101,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
parts := strings.Split(line, "=") parts := strings.Split(line, "=")
if len(parts) != 2 { if len(parts) != 2 {
device.log.Debug.Println(parts)
return &IPCError{Code: ipcErrorNoKeyValue} return &IPCError{Code: ipcErrorNoKeyValue}
} }
key := parts[0] key := parts[0]
@ -105,7 +108,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
switch key { switch key {
/* Interface configuration */ /* interface configuration */
case "private_key": case "private_key":
if value == "" { if value == "" {
@ -116,7 +119,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
var sk NoisePrivateKey var sk NoisePrivateKey
err := sk.FromHex(value) err := sk.FromHex(value)
if err != nil { if err != nil {
logger.Println("Failed to set private_key:", err) logError.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
device.SetPrivateKey(sk) device.SetPrivateKey(sk)
@ -126,22 +129,26 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
var port int var port int
_, err := fmt.Sscanf(value, "%d", &port) _, err := fmt.Sscanf(value, "%d", &port)
if err != nil || port > (1<<16) || port < 0 { if err != nil || port > (1<<16) || port < 0 {
logger.Println("Failed to set listen_port:", err) logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
device.net.mutex.Lock() device.net.mutex.Lock()
device.net.addr.Port = port device.net.addr.Port = port
device.net.conn, err = net.ListenUDP("udp", device.net.addr) device.net.conn, err = net.ListenUDP("udp", device.net.addr)
device.net.mutex.Unlock() device.net.mutex.Unlock()
if err != nil {
logError.Println("Failed to create UDP listener:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
case "fwmark": case "fwmark":
logger.Println("FWMark not handled yet") logError.Println("FWMark not handled yet")
case "public_key": case "public_key":
var pubKey NoisePublicKey var pubKey NoisePublicKey
err := pubKey.FromHex(value) err := pubKey.FromHex(value)
if err != nil { if err != nil {
logger.Println("Failed to get peer by public_key:", err) logError.Println("Failed to get peer by public_key:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
device.mutex.RLock() device.mutex.RLock()
@ -153,22 +160,23 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
peer = device.NewPeer(pubKey) peer = device.NewPeer(pubKey)
} }
if peer == nil { if peer == nil {
panic(errors.New("bug: failed to find peer")) panic(errors.New("bug: failed to find / create peer"))
} }
case "replace_peers": case "replace_peers":
if value == "true" { if value == "true" {
device.RemoveAllPeers() device.RemoveAllPeers()
} else { } else {
logger.Println("Failed to set replace_peers, invalid value:", value) logError.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
default: default:
/* Peer configuration */
/* peer configuration */
if peer == nil { if peer == nil {
logger.Println("No peer referenced, before peer operation") logError.Println("No peer referenced, before peer operation")
return &IPCError{Code: ipcErrorNoPeer} return &IPCError{Code: ipcErrorNoPeer}
} }
@ -178,7 +186,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
peer.mutex.Lock() peer.mutex.Lock()
device.RemovePeer(peer.handshake.remoteStatic) device.RemovePeer(peer.handshake.remoteStatic)
peer.mutex.Unlock() peer.mutex.Unlock()
logger.Println("Remove peer") logDebug.Println("Removing", peer.String())
peer = nil peer = nil
case "preshared_key": case "preshared_key":
@ -188,14 +196,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return peer.handshake.presharedKey.FromHex(value) return peer.handshake.presharedKey.FromHex(value)
}() }()
if err != nil { if err != nil {
logger.Println("Failed to set preshared_key:", err) logError.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
case "endpoint": case "endpoint":
addr, err := net.ResolveUDPAddr("udp", value) addr, err := net.ResolveUDPAddr("udp", value)
if err != nil { if err != nil {
logger.Println("Failed to set endpoint:", value) logError.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
peer.mutex.Lock() peer.mutex.Lock()
@ -205,35 +213,34 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
secs, err := strconv.ParseInt(value, 10, 64) secs, err := strconv.ParseInt(value, 10, 64)
if secs < 0 || err != nil { if secs < 0 || err != nil {
logger.Println("Failed to set persistent_keepalive_interval:", err) logError.Println("Failed to set persistent_keepalive_interval:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
peer.mutex.Lock() atomic.StoreUint64(
peer.persistentKeepaliveInterval = uint64(secs) &peer.persistentKeepaliveInterval,
peer.mutex.Unlock() uint64(secs),
)
case "replace_allowed_ips": case "replace_allowed_ips":
if value == "true" { if value == "true" {
device.routingTable.RemovePeer(peer) device.routingTable.RemovePeer(peer)
} else { } else {
logger.Println("Failed to set replace_allowed_ips, invalid value:", value) logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
case "allowed_ip": case "allowed_ip":
_, network, err := net.ParseCIDR(value) _, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {
logger.Println("Failed to set allowed_ip:", err) logError.Println("Failed to set allowed_ip:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
ones, _ := network.Mask.Size() ones, _ := network.Mask.Size()
logger.Println(network, ones, network.IP) logError.Println(network, ones, network.IP)
device.routingTable.Insert(network.IP, uint(ones), peer) device.routingTable.Insert(network.IP, uint(ones), peer)
/* Invalid key */
default: default:
logger.Println("Invalid key:", key) logError.Println("Invalid UAPI key:", key)
return &IPCError{Code: ipcErrorInvalidKey} return &IPCError{Code: ipcErrorInvalidKey}
} }
} }
@ -244,7 +251,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
func ipcHandle(device *Device, socket net.Conn) { func ipcHandle(device *Device, socket net.Conn) {
func() { defer socket.Close()
buffered := func(s io.ReadWriter) *bufio.ReadWriter { buffered := func(s io.ReadWriter) *bufio.ReadWriter {
reader := bufio.NewReader(s) reader := bufio.NewReader(s)
writer := bufio.NewWriter(s) writer := bufio.NewWriter(s)
@ -268,22 +276,20 @@ func ipcHandle(device *Device, socket net.Conn) {
} else { } else {
fmt.Fprintf(buffered, "errno=0\n\n") fmt.Fprintf(buffered, "errno=0\n\n")
} }
break return
case "get=1\n": case "get=1\n":
device.log.Debug.Println("Config, get operation") device.log.Debug.Println("Config, get operation")
err := ipcGetOperation(device, buffered) err := ipcGetOperation(device, buffered)
if err != nil { if err != nil {
fmt.Fprintf(buffered, "errno=1\n\n") // fix fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
} else { } else {
fmt.Fprintf(buffered, "errno=0\n\n") fmt.Fprintf(buffered, "errno=0\n\n")
} }
break return
default: default:
device.log.Info.Println("Invalid UAPI operation:", op) device.log.Error.Println("Invalid UAPI operation:", op)
}
}()
socket.Close() }
} }

View File

@ -78,7 +78,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
defer device.mutex.Unlock() defer device.mutex.Unlock()
device.log = NewLogger(logLevel) device.log = NewLogger(logLevel)
// device.mtu = tun.MTU()
device.peers = make(map[NoisePublicKey]*Peer) device.peers = make(map[NoisePublicKey]*Peer)
device.indices.Init() device.indices.Init()
device.ratelimiter.Init() device.ratelimiter.Init()
@ -131,12 +130,21 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
func (device *Device) RoutineMTUUpdater(tun TUNDevice) { func (device *Device) RoutineMTUUpdater(tun TUNDevice) {
logError := device.log.Error logError := device.log.Error
for ; ; time.Sleep(time.Second) { for ; ; time.Sleep(5 * time.Second) {
// load updated MTU
mtu, err := tun.MTU() mtu, err := tun.MTU()
if err != nil { if err != nil {
logError.Println("Failed to load updated MTU of device:", err) logError.Println("Failed to load updated MTU of device:", err)
continue continue
} }
// upper bound of mtu
if mtu+MessageTransportSize > MaxMessageSize {
mtu = MaxMessageSize - MessageTransportSize
}
atomic.StoreInt32(&device.mtu, int32(mtu)) atomic.StoreInt32(&device.mtu, int32(mtu))
} }
} }

View File

@ -6,8 +6,6 @@ import (
) )
/* Index=0 is reserved for unset indecies /* Index=0 is reserved for unset indecies
*
* TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer
* *
*/ */
@ -72,12 +70,12 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
table.mutex.RLock() table.mutex.RLock()
_, ok := table.table[index] _, ok := table.table[index]
table.mutex.RUnlock()
if ok { if ok {
continue continue
} }
table.mutex.RUnlock()
// replace index // map index to handshake
table.mutex.Lock() table.mutex.Lock()
_, found := table.table[index] _, found := table.table[index]

View File

@ -17,12 +17,14 @@ func main() {
} }
switch os.Args[1] { switch os.Args[1] {
case "-f", "--foreground": case "-f", "--foreground":
foreground = true foreground = true
if len(os.Args) != 3 { if len(os.Args) != 3 {
return return
} }
interfaceName = os.Args[2] interfaceName = os.Args[2]
default: default:
foreground = false foreground = false
if len(os.Args) != 2 { if len(os.Args) != 2 {
@ -48,8 +50,8 @@ func main() {
// open TUN device // open TUN device
tun, err := CreateTUN(interfaceName) tun, err := CreateTUN(interfaceName)
log.Println(tun, err)
if err != nil { if err != nil {
log.Println("Failed to create tun device:", err)
return return
} }
@ -69,11 +71,15 @@ func main() {
} }
defer uapi.Close() defer uapi.Close()
go func() {
for { for {
conn, err := uapi.Accept() conn, err := uapi.Accept()
if err != nil { if err != nil {
logError.Fatal("accept error:", err) logError.Fatal("UAPI accept error:", err)
} }
go ipcHandle(device, conn) go ipcHandle(device, conn)
} }
}()
device.Wait()
} }

View File

@ -459,7 +459,8 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// remap index // remap index
peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{ indices := &peer.device.indices
indices.Insert(handshake.localIndex, IndexTableEntry{
peer: peer, peer: peer,
keyPair: keyPair, keyPair: keyPair,
handshake: nil, handshake: nil,
@ -476,7 +477,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
if kp.previous != nil { if kp.previous != nil {
kp.previous.send = nil kp.previous.send = nil
kp.previous.receive = nil kp.previous.receive = nil
peer.device.indices.Delete(kp.previous.localIndex) indices.Delete(kp.previous.localIndex)
} }
kp.previous = kp.current kp.previous = kp.current
kp.current = keyPair kp.current = keyPair

View File

@ -212,18 +212,18 @@ func (device *Device) RoutineReceiveIncomming() {
// add to peer queue // add to peer queue
peer := value.peer peer := value.peer
work := &QueueInboundElement{ elem := &QueueInboundElement{
packet: packet, packet: packet,
buffer: buffer, buffer: buffer,
keyPair: keyPair, keyPair: keyPair,
dropped: AtomicFalse, dropped: AtomicFalse,
} }
work.mutex.Lock() elem.mutex.Lock()
// add to decryption queues // add to decryption queues
device.addToInboundQueue(device.queue.decryption, work) device.addToInboundQueue(device.queue.decryption, elem)
device.addToInboundQueue(peer.queue.inbound, work) device.addToInboundQueue(peer.queue.inbound, elem)
buffer = nil buffer = nil
default: default:

View File

@ -270,50 +270,65 @@ func (peer *Peer) RoutineNonce() {
* Obs. One instance per core * Obs. One instance per core
*/ */
func (device *Device) RoutineEncryption() { func (device *Device) RoutineEncryption() {
var elem *QueueOutboundElement
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
for work := range device.queue.encryption {
logDebug := device.log.Debug
logDebug.Println("Routine, encryption worker, started")
for {
// fetch next element
select {
case elem = <-device.queue.encryption:
case <-device.signal.stop:
logDebug.Println("Routine, encryption worker, stopped")
return
}
// check if dropped // check if dropped
if work.IsDropped() { if elem.IsDropped() {
continue continue
} }
// populate header fields // populate header fields
header := work.buffer[:MessageTransportHeaderSize] header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4] fieldType := header[0:4]
fieldReceiver := header[4:8] fieldReceiver := header[4:8]
fieldNonce := header[8:16] fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType) binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, work.keyPair.remoteIndex) binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, work.nonce) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to MTU size // pad content to MTU size
mtu := int(atomic.LoadInt32(&device.mtu)) mtu := int(atomic.LoadInt32(&device.mtu))
for i := len(work.packet); i < mtu; i++ { for i := len(elem.packet); i < mtu; i++ {
work.packet = append(work.packet, 0) elem.packet = append(elem.packet, 0)
} }
// encrypt content // encrypt content
binary.LittleEndian.PutUint64(nonce[4:], work.nonce) binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
work.packet = work.keyPair.send.Seal( elem.packet = elem.keyPair.send.Seal(
work.packet[:0], elem.packet[:0],
nonce[:], nonce[:],
work.packet, elem.packet,
nil, nil,
) )
length := MessageTransportHeaderSize + len(work.packet) length := MessageTransportHeaderSize + len(elem.packet)
work.packet = work.buffer[:length] elem.packet = elem.buffer[:length]
work.mutex.Unlock() elem.mutex.Unlock()
// refresh key if necessary // refresh key if necessary
work.peer.KeepKeyFreshSending() elem.peer.KeepKeyFreshSending()
} }
} }
@ -334,49 +349,43 @@ func (peer *Peer) RoutineSequentialSender() {
logDebug.Println("Routine, sequential sender, stopped for", peer.String()) logDebug.Println("Routine, sequential sender, stopped for", peer.String())
return return
case work := <-peer.queue.outbound: case elem := <-peer.queue.outbound:
work.mutex.Lock() elem.mutex.Lock()
func() { func() {
if elem.IsDropped() {
// return buffer to pool after processing
defer device.PutMessageBuffer(work.buffer)
if work.IsDropped() {
return return
} }
// send to endpoint // get endpoint and connection
peer.mutex.RLock() peer.mutex.RLock()
defer peer.mutex.RUnlock() endpoint := peer.endpoint
peer.mutex.RUnlock()
if peer.endpoint == nil { if endpoint == nil {
logDebug.Println("No endpoint for", peer.String()) logDebug.Println("No endpoint for", peer.String())
return return
} }
device.net.mutex.RLock() device.net.mutex.RLock()
defer device.net.mutex.RUnlock() conn := device.net.conn
device.net.mutex.RUnlock()
if device.net.conn == nil { if conn == nil {
logDebug.Println("No source for device") logDebug.Println("No source for device")
return return
} }
// send message and return buffer to pool // send message and refresh keys
_, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint) _, err := conn.WriteToUDP(elem.packet, endpoint)
if err != nil { if err != nil {
return return
} }
atomic.AddUint64(&peer.txBytes, uint64(len(elem.packet)))
atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
// reset keep-alive
peer.TimerResetKeepalive() peer.TimerResetKeepalive()
}() }()
device.PutMessageBuffer(elem.buffer)
} }
} }
} }

View File

@ -138,6 +138,7 @@ func (peer *Peer) BeginHandshakeInitiation() (*QueueOutboundElement, error) {
func (peer *Peer) RoutineTimerHandler() { func (peer *Peer) RoutineTimerHandler() {
device := peer.device device := peer.device
indices := &device.indices
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String()) logDebug.Println("Routine, timer handler, started for peer", peer.String())
@ -170,29 +171,42 @@ func (peer *Peer) RoutineTimerHandler() {
logDebug.Println("Clearing all key material for", peer.String()) logDebug.Println("Clearing all key material for", peer.String())
// zero out key pairs
func() {
kp := &peer.keyPairs kp := &peer.keyPairs
kp.mutex.Lock() kp.mutex.Lock()
// best we can do is wait for GC :( ?
hs := &peer.handshake
hs.mutex.Lock()
// unmap local indecies
indices.mutex.Lock()
if kp.previous != nil {
delete(indices.table, kp.previous.localIndex)
}
if kp.current != nil {
delete(indices.table, kp.current.localIndex)
}
if kp.next != nil {
delete(indices.table, kp.next.localIndex)
}
delete(indices.table, hs.localIndex)
indices.mutex.Unlock()
// zero out key pairs (TODO: better than wait for GC)
kp.current = nil kp.current = nil
kp.previous = nil kp.previous = nil
kp.next = nil kp.next = nil
kp.mutex.Unlock() kp.mutex.Unlock()
}()
// zero out handshake // zero out handshake
func() { hs.localIndex = 0
hs := &peer.handshake
hs.mutex.Lock()
hs.localEphemeral = NoisePrivateKey{} hs.localEphemeral = NoisePrivateKey{}
hs.remoteEphemeral = NoisePublicKey{} hs.remoteEphemeral = NoisePublicKey{}
hs.chainKey = [blake2s.Size]byte{} hs.chainKey = [blake2s.Size]byte{}
hs.hash = [blake2s.Size]byte{} hs.hash = [blake2s.Size]byte{}
hs.mutex.Unlock() hs.mutex.Unlock()
}()
} }
} }
} }