mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 09:15:14 +01:00
Fixed deadlock in index.go
This commit is contained in:
parent
dd4da93749
commit
c5d7efc246
162
src/config.go
162
src/config.go
@ -8,39 +8,36 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
// #include <errno.h>
|
|
||||||
import "C"
|
|
||||||
|
|
||||||
/* TODO: More fine grained?
|
|
||||||
*/
|
|
||||||
const (
|
const (
|
||||||
ipcErrorNoPeer = C.EPROTO
|
ipcErrorIO = syscall.EIO
|
||||||
ipcErrorNoKeyValue = C.EPROTO
|
ipcErrorNoPeer = syscall.EPROTO
|
||||||
ipcErrorInvalidKey = C.EPROTO
|
ipcErrorNoKeyValue = syscall.EPROTO
|
||||||
ipcErrorInvalidValue = C.EPROTO
|
ipcErrorInvalidKey = syscall.EPROTO
|
||||||
|
ipcErrorInvalidValue = syscall.EPROTO
|
||||||
)
|
)
|
||||||
|
|
||||||
type IPCError struct {
|
type IPCError struct {
|
||||||
Code int
|
Code syscall.Errno
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IPCError) Error() string {
|
func (s *IPCError) Error() string {
|
||||||
return fmt.Sprintf("IPC error: %d", s.Code)
|
return fmt.Sprintf("IPC error: %d", s.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IPCError) ErrorCode() int {
|
func (s *IPCError) ErrorCode() uintptr {
|
||||||
return s.Code
|
return uintptr(s.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
|
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
|
|
||||||
device.mutex.RLock()
|
|
||||||
defer device.mutex.RUnlock()
|
|
||||||
|
|
||||||
// create lines
|
// create lines
|
||||||
|
|
||||||
|
device.mutex.RLock()
|
||||||
|
|
||||||
lines := make([]string, 0, 100)
|
lines := make([]string, 0, 100)
|
||||||
send := func(line string) {
|
send := func(line string) {
|
||||||
lines = append(lines, line)
|
lines = append(lines, line)
|
||||||
@ -63,19 +60,25 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
|
|||||||
}
|
}
|
||||||
send(fmt.Sprintf("tx_bytes=%d", peer.txBytes))
|
send(fmt.Sprintf("tx_bytes=%d", peer.txBytes))
|
||||||
send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes))
|
send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes))
|
||||||
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
|
send(fmt.Sprintf("persistent_keepalive_interval=%d",
|
||||||
|
atomic.LoadUint64(&peer.persistentKeepaliveInterval),
|
||||||
|
))
|
||||||
for _, ip := range device.routingTable.AllowedIPs(peer) {
|
for _, ip := range device.routingTable.AllowedIPs(peer) {
|
||||||
send("allowed_ip=" + ip.String())
|
send("allowed_ip=" + ip.String())
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device.mutex.RUnlock()
|
||||||
|
|
||||||
// send lines
|
// send lines
|
||||||
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
_, err := socket.WriteString(line + "\n")
|
_, err := socket.WriteString(line + "\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return &IPCError{
|
||||||
|
Code: ipcErrorIO,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -83,13 +86,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
logger := device.log.Debug
|
|
||||||
scanner := bufio.NewScanner(socket)
|
scanner := bufio.NewScanner(socket)
|
||||||
|
logError := device.log.Error
|
||||||
|
logDebug := device.log.Debug
|
||||||
|
|
||||||
var peer *Peer
|
var peer *Peer
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
|
|
||||||
// Parse line
|
// parse line
|
||||||
|
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if line == "" {
|
if line == "" {
|
||||||
@ -97,7 +101,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
}
|
}
|
||||||
parts := strings.Split(line, "=")
|
parts := strings.Split(line, "=")
|
||||||
if len(parts) != 2 {
|
if len(parts) != 2 {
|
||||||
device.log.Debug.Println(parts)
|
|
||||||
return &IPCError{Code: ipcErrorNoKeyValue}
|
return &IPCError{Code: ipcErrorNoKeyValue}
|
||||||
}
|
}
|
||||||
key := parts[0]
|
key := parts[0]
|
||||||
@ -105,7 +108,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
|
|
||||||
switch key {
|
switch key {
|
||||||
|
|
||||||
/* Interface configuration */
|
/* interface configuration */
|
||||||
|
|
||||||
case "private_key":
|
case "private_key":
|
||||||
if value == "" {
|
if value == "" {
|
||||||
@ -116,7 +119,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
var sk NoisePrivateKey
|
var sk NoisePrivateKey
|
||||||
err := sk.FromHex(value)
|
err := sk.FromHex(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Println("Failed to set private_key:", err)
|
logError.Println("Failed to set private_key:", err)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalidValue}
|
||||||
}
|
}
|
||||||
device.SetPrivateKey(sk)
|
device.SetPrivateKey(sk)
|
||||||
@ -126,22 +129,26 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
var port int
|
var port int
|
||||||
_, err := fmt.Sscanf(value, "%d", &port)
|
_, err := fmt.Sscanf(value, "%d", &port)
|
||||||
if err != nil || port > (1<<16) || port < 0 {
|
if err != nil || port > (1<<16) || port < 0 {
|
||||||
logger.Println("Failed to set listen_port:", err)
|
logError.Println("Failed to set listen_port:", err)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalidValue}
|
||||||
}
|
}
|
||||||
device.net.mutex.Lock()
|
device.net.mutex.Lock()
|
||||||
device.net.addr.Port = port
|
device.net.addr.Port = port
|
||||||
device.net.conn, err = net.ListenUDP("udp", device.net.addr)
|
device.net.conn, err = net.ListenUDP("udp", device.net.addr)
|
||||||
device.net.mutex.Unlock()
|
device.net.mutex.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
logError.Println("Failed to create UDP listener:", err)
|
||||||
|
return &IPCError{Code: ipcErrorInvalidValue}
|
||||||
|
}
|
||||||
|
|
||||||
case "fwmark":
|
case "fwmark":
|
||||||
logger.Println("FWMark not handled yet")
|
logError.Println("FWMark not handled yet")
|
||||||
|
|
||||||
case "public_key":
|
case "public_key":
|
||||||
var pubKey NoisePublicKey
|
var pubKey NoisePublicKey
|
||||||
err := pubKey.FromHex(value)
|
err := pubKey.FromHex(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Println("Failed to get peer by public_key:", err)
|
logError.Println("Failed to get peer by public_key:", err)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalidValue}
|
||||||
}
|
}
|
||||||
device.mutex.RLock()
|
device.mutex.RLock()
|
||||||
@ -153,22 +160,23 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
peer = device.NewPeer(pubKey)
|
peer = device.NewPeer(pubKey)
|
||||||
}
|
}
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
panic(errors.New("bug: failed to find peer"))
|
panic(errors.New("bug: failed to find / create peer"))
|
||||||
}
|
}
|
||||||
|
|
||||||
case "replace_peers":
|
case "replace_peers":
|
||||||
if value == "true" {
|
if value == "true" {
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
} else {
|
} else {
|
||||||
logger.Println("Failed to set replace_peers, invalid value:", value)
|
logError.Println("Failed to set replace_peers, invalid value:", value)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalidValue}
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
/* Peer configuration */
|
|
||||||
|
/* peer configuration */
|
||||||
|
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
logger.Println("No peer referenced, before peer operation")
|
logError.Println("No peer referenced, before peer operation")
|
||||||
return &IPCError{Code: ipcErrorNoPeer}
|
return &IPCError{Code: ipcErrorNoPeer}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,7 +186,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
device.RemovePeer(peer.handshake.remoteStatic)
|
device.RemovePeer(peer.handshake.remoteStatic)
|
||||||
peer.mutex.Unlock()
|
peer.mutex.Unlock()
|
||||||
logger.Println("Remove peer")
|
logDebug.Println("Removing", peer.String())
|
||||||
peer = nil
|
peer = nil
|
||||||
|
|
||||||
case "preshared_key":
|
case "preshared_key":
|
||||||
@ -188,14 +196,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
return peer.handshake.presharedKey.FromHex(value)
|
return peer.handshake.presharedKey.FromHex(value)
|
||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Println("Failed to set preshared_key:", err)
|
logError.Println("Failed to set preshared_key:", err)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalidValue}
|
||||||
}
|
}
|
||||||
|
|
||||||
case "endpoint":
|
case "endpoint":
|
||||||
addr, err := net.ResolveUDPAddr("udp", value)
|
addr, err := net.ResolveUDPAddr("udp", value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Println("Failed to set endpoint:", value)
|
logError.Println("Failed to set endpoint:", value)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalidValue}
|
||||||
}
|
}
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
@ -205,35 +213,34 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
case "persistent_keepalive_interval":
|
case "persistent_keepalive_interval":
|
||||||
secs, err := strconv.ParseInt(value, 10, 64)
|
secs, err := strconv.ParseInt(value, 10, 64)
|
||||||
if secs < 0 || err != nil {
|
if secs < 0 || err != nil {
|
||||||
logger.Println("Failed to set persistent_keepalive_interval:", err)
|
logError.Println("Failed to set persistent_keepalive_interval:", err)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalidValue}
|
||||||
}
|
}
|
||||||
peer.mutex.Lock()
|
atomic.StoreUint64(
|
||||||
peer.persistentKeepaliveInterval = uint64(secs)
|
&peer.persistentKeepaliveInterval,
|
||||||
peer.mutex.Unlock()
|
uint64(secs),
|
||||||
|
)
|
||||||
|
|
||||||
case "replace_allowed_ips":
|
case "replace_allowed_ips":
|
||||||
if value == "true" {
|
if value == "true" {
|
||||||
device.routingTable.RemovePeer(peer)
|
device.routingTable.RemovePeer(peer)
|
||||||
} else {
|
} else {
|
||||||
logger.Println("Failed to set replace_allowed_ips, invalid value:", value)
|
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalidValue}
|
||||||
}
|
}
|
||||||
|
|
||||||
case "allowed_ip":
|
case "allowed_ip":
|
||||||
_, network, err := net.ParseCIDR(value)
|
_, network, err := net.ParseCIDR(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Println("Failed to set allowed_ip:", err)
|
logError.Println("Failed to set allowed_ip:", err)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalidValue}
|
||||||
}
|
}
|
||||||
ones, _ := network.Mask.Size()
|
ones, _ := network.Mask.Size()
|
||||||
logger.Println(network, ones, network.IP)
|
logError.Println(network, ones, network.IP)
|
||||||
device.routingTable.Insert(network.IP, uint(ones), peer)
|
device.routingTable.Insert(network.IP, uint(ones), peer)
|
||||||
|
|
||||||
/* Invalid key */
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logger.Println("Invalid key:", key)
|
logError.Println("Invalid UAPI key:", key)
|
||||||
return &IPCError{Code: ipcErrorInvalidKey}
|
return &IPCError{Code: ipcErrorInvalidKey}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -244,46 +251,45 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
|
|
||||||
func ipcHandle(device *Device, socket net.Conn) {
|
func ipcHandle(device *Device, socket net.Conn) {
|
||||||
|
|
||||||
func() {
|
defer socket.Close()
|
||||||
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
|
||||||
reader := bufio.NewReader(s)
|
|
||||||
writer := bufio.NewWriter(s)
|
|
||||||
return bufio.NewReadWriter(reader, writer)
|
|
||||||
}(socket)
|
|
||||||
|
|
||||||
defer buffered.Flush()
|
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
||||||
|
reader := bufio.NewReader(s)
|
||||||
|
writer := bufio.NewWriter(s)
|
||||||
|
return bufio.NewReadWriter(reader, writer)
|
||||||
|
}(socket)
|
||||||
|
|
||||||
op, err := buffered.ReadString('\n')
|
defer buffered.Flush()
|
||||||
|
|
||||||
|
op, err := buffered.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch op {
|
||||||
|
|
||||||
|
case "set=1\n":
|
||||||
|
device.log.Debug.Println("Config, set operation")
|
||||||
|
err := ipcSetOperation(device, buffered)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(buffered, "errno=0\n\n")
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
|
||||||
switch op {
|
case "get=1\n":
|
||||||
|
device.log.Debug.Println("Config, get operation")
|
||||||
case "set=1\n":
|
err := ipcGetOperation(device, buffered)
|
||||||
device.log.Debug.Println("Config, set operation")
|
if err != nil {
|
||||||
err := ipcSetOperation(device, buffered)
|
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
|
||||||
if err != nil {
|
} else {
|
||||||
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
|
fmt.Fprintf(buffered, "errno=0\n\n")
|
||||||
} else {
|
|
||||||
fmt.Fprintf(buffered, "errno=0\n\n")
|
|
||||||
}
|
|
||||||
break
|
|
||||||
|
|
||||||
case "get=1\n":
|
|
||||||
device.log.Debug.Println("Config, get operation")
|
|
||||||
err := ipcGetOperation(device, buffered)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(buffered, "errno=1\n\n") // fix
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(buffered, "errno=0\n\n")
|
|
||||||
}
|
|
||||||
break
|
|
||||||
|
|
||||||
default:
|
|
||||||
device.log.Info.Println("Invalid UAPI operation:", op)
|
|
||||||
}
|
}
|
||||||
}()
|
return
|
||||||
|
|
||||||
socket.Close()
|
default:
|
||||||
|
device.log.Error.Println("Invalid UAPI operation:", op)
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -78,7 +78,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
|||||||
defer device.mutex.Unlock()
|
defer device.mutex.Unlock()
|
||||||
|
|
||||||
device.log = NewLogger(logLevel)
|
device.log = NewLogger(logLevel)
|
||||||
// device.mtu = tun.MTU()
|
|
||||||
device.peers = make(map[NoisePublicKey]*Peer)
|
device.peers = make(map[NoisePublicKey]*Peer)
|
||||||
device.indices.Init()
|
device.indices.Init()
|
||||||
device.ratelimiter.Init()
|
device.ratelimiter.Init()
|
||||||
@ -131,12 +130,21 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
|||||||
|
|
||||||
func (device *Device) RoutineMTUUpdater(tun TUNDevice) {
|
func (device *Device) RoutineMTUUpdater(tun TUNDevice) {
|
||||||
logError := device.log.Error
|
logError := device.log.Error
|
||||||
for ; ; time.Sleep(time.Second) {
|
for ; ; time.Sleep(5 * time.Second) {
|
||||||
|
|
||||||
|
// load updated MTU
|
||||||
|
|
||||||
mtu, err := tun.MTU()
|
mtu, err := tun.MTU()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to load updated MTU of device:", err)
|
logError.Println("Failed to load updated MTU of device:", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// upper bound of mtu
|
||||||
|
|
||||||
|
if mtu+MessageTransportSize > MaxMessageSize {
|
||||||
|
mtu = MaxMessageSize - MessageTransportSize
|
||||||
|
}
|
||||||
atomic.StoreInt32(&device.mtu, int32(mtu))
|
atomic.StoreInt32(&device.mtu, int32(mtu))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,8 +6,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
/* Index=0 is reserved for unset indecies
|
/* Index=0 is reserved for unset indecies
|
||||||
*
|
|
||||||
* TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer
|
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@ -72,12 +70,12 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
|
|||||||
|
|
||||||
table.mutex.RLock()
|
table.mutex.RLock()
|
||||||
_, ok := table.table[index]
|
_, ok := table.table[index]
|
||||||
|
table.mutex.RUnlock()
|
||||||
if ok {
|
if ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
table.mutex.RUnlock()
|
|
||||||
|
|
||||||
// replace index
|
// map index to handshake
|
||||||
|
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
_, found := table.table[index]
|
_, found := table.table[index]
|
||||||
|
20
src/main.go
20
src/main.go
@ -17,12 +17,14 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch os.Args[1] {
|
switch os.Args[1] {
|
||||||
|
|
||||||
case "-f", "--foreground":
|
case "-f", "--foreground":
|
||||||
foreground = true
|
foreground = true
|
||||||
if len(os.Args) != 3 {
|
if len(os.Args) != 3 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
interfaceName = os.Args[2]
|
interfaceName = os.Args[2]
|
||||||
|
|
||||||
default:
|
default:
|
||||||
foreground = false
|
foreground = false
|
||||||
if len(os.Args) != 2 {
|
if len(os.Args) != 2 {
|
||||||
@ -48,8 +50,8 @@ func main() {
|
|||||||
// open TUN device
|
// open TUN device
|
||||||
|
|
||||||
tun, err := CreateTUN(interfaceName)
|
tun, err := CreateTUN(interfaceName)
|
||||||
log.Println(tun, err)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Println("Failed to create tun device:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,11 +71,15 @@ func main() {
|
|||||||
}
|
}
|
||||||
defer uapi.Close()
|
defer uapi.Close()
|
||||||
|
|
||||||
for {
|
go func() {
|
||||||
conn, err := uapi.Accept()
|
for {
|
||||||
if err != nil {
|
conn, err := uapi.Accept()
|
||||||
logError.Fatal("accept error:", err)
|
if err != nil {
|
||||||
|
logError.Fatal("UAPI accept error:", err)
|
||||||
|
}
|
||||||
|
go ipcHandle(device, conn)
|
||||||
}
|
}
|
||||||
go ipcHandle(device, conn)
|
}()
|
||||||
}
|
|
||||||
|
device.Wait()
|
||||||
}
|
}
|
||||||
|
@ -459,7 +459,8 @@ func (peer *Peer) NewKeyPair() *KeyPair {
|
|||||||
|
|
||||||
// remap index
|
// remap index
|
||||||
|
|
||||||
peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
|
indices := &peer.device.indices
|
||||||
|
indices.Insert(handshake.localIndex, IndexTableEntry{
|
||||||
peer: peer,
|
peer: peer,
|
||||||
keyPair: keyPair,
|
keyPair: keyPair,
|
||||||
handshake: nil,
|
handshake: nil,
|
||||||
@ -476,7 +477,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
|
|||||||
if kp.previous != nil {
|
if kp.previous != nil {
|
||||||
kp.previous.send = nil
|
kp.previous.send = nil
|
||||||
kp.previous.receive = nil
|
kp.previous.receive = nil
|
||||||
peer.device.indices.Delete(kp.previous.localIndex)
|
indices.Delete(kp.previous.localIndex)
|
||||||
}
|
}
|
||||||
kp.previous = kp.current
|
kp.previous = kp.current
|
||||||
kp.current = keyPair
|
kp.current = keyPair
|
||||||
|
@ -212,18 +212,18 @@ func (device *Device) RoutineReceiveIncomming() {
|
|||||||
// add to peer queue
|
// add to peer queue
|
||||||
|
|
||||||
peer := value.peer
|
peer := value.peer
|
||||||
work := &QueueInboundElement{
|
elem := &QueueInboundElement{
|
||||||
packet: packet,
|
packet: packet,
|
||||||
buffer: buffer,
|
buffer: buffer,
|
||||||
keyPair: keyPair,
|
keyPair: keyPair,
|
||||||
dropped: AtomicFalse,
|
dropped: AtomicFalse,
|
||||||
}
|
}
|
||||||
work.mutex.Lock()
|
elem.mutex.Lock()
|
||||||
|
|
||||||
// add to decryption queues
|
// add to decryption queues
|
||||||
|
|
||||||
device.addToInboundQueue(device.queue.decryption, work)
|
device.addToInboundQueue(device.queue.decryption, elem)
|
||||||
device.addToInboundQueue(peer.queue.inbound, work)
|
device.addToInboundQueue(peer.queue.inbound, elem)
|
||||||
buffer = nil
|
buffer = nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
81
src/send.go
81
src/send.go
@ -270,50 +270,65 @@ func (peer *Peer) RoutineNonce() {
|
|||||||
* Obs. One instance per core
|
* Obs. One instance per core
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineEncryption() {
|
func (device *Device) RoutineEncryption() {
|
||||||
|
|
||||||
|
var elem *QueueOutboundElement
|
||||||
var nonce [chacha20poly1305.NonceSize]byte
|
var nonce [chacha20poly1305.NonceSize]byte
|
||||||
for work := range device.queue.encryption {
|
|
||||||
|
logDebug := device.log.Debug
|
||||||
|
logDebug.Println("Routine, encryption worker, started")
|
||||||
|
|
||||||
|
for {
|
||||||
|
|
||||||
|
// fetch next element
|
||||||
|
|
||||||
|
select {
|
||||||
|
case elem = <-device.queue.encryption:
|
||||||
|
case <-device.signal.stop:
|
||||||
|
logDebug.Println("Routine, encryption worker, stopped")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// check if dropped
|
// check if dropped
|
||||||
|
|
||||||
if work.IsDropped() {
|
if elem.IsDropped() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// populate header fields
|
// populate header fields
|
||||||
|
|
||||||
header := work.buffer[:MessageTransportHeaderSize]
|
header := elem.buffer[:MessageTransportHeaderSize]
|
||||||
|
|
||||||
fieldType := header[0:4]
|
fieldType := header[0:4]
|
||||||
fieldReceiver := header[4:8]
|
fieldReceiver := header[4:8]
|
||||||
fieldNonce := header[8:16]
|
fieldNonce := header[8:16]
|
||||||
|
|
||||||
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
||||||
binary.LittleEndian.PutUint32(fieldReceiver, work.keyPair.remoteIndex)
|
binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex)
|
||||||
binary.LittleEndian.PutUint64(fieldNonce, work.nonce)
|
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
||||||
|
|
||||||
// pad content to MTU size
|
// pad content to MTU size
|
||||||
|
|
||||||
mtu := int(atomic.LoadInt32(&device.mtu))
|
mtu := int(atomic.LoadInt32(&device.mtu))
|
||||||
for i := len(work.packet); i < mtu; i++ {
|
for i := len(elem.packet); i < mtu; i++ {
|
||||||
work.packet = append(work.packet, 0)
|
elem.packet = append(elem.packet, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// encrypt content
|
// encrypt content
|
||||||
|
|
||||||
binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
|
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
|
||||||
work.packet = work.keyPair.send.Seal(
|
elem.packet = elem.keyPair.send.Seal(
|
||||||
work.packet[:0],
|
elem.packet[:0],
|
||||||
nonce[:],
|
nonce[:],
|
||||||
work.packet,
|
elem.packet,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
length := MessageTransportHeaderSize + len(work.packet)
|
length := MessageTransportHeaderSize + len(elem.packet)
|
||||||
work.packet = work.buffer[:length]
|
elem.packet = elem.buffer[:length]
|
||||||
work.mutex.Unlock()
|
elem.mutex.Unlock()
|
||||||
|
|
||||||
// refresh key if necessary
|
// refresh key if necessary
|
||||||
|
|
||||||
work.peer.KeepKeyFreshSending()
|
elem.peer.KeepKeyFreshSending()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -334,49 +349,43 @@ func (peer *Peer) RoutineSequentialSender() {
|
|||||||
logDebug.Println("Routine, sequential sender, stopped for", peer.String())
|
logDebug.Println("Routine, sequential sender, stopped for", peer.String())
|
||||||
return
|
return
|
||||||
|
|
||||||
case work := <-peer.queue.outbound:
|
case elem := <-peer.queue.outbound:
|
||||||
work.mutex.Lock()
|
elem.mutex.Lock()
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
|
if elem.IsDropped() {
|
||||||
// return buffer to pool after processing
|
|
||||||
|
|
||||||
defer device.PutMessageBuffer(work.buffer)
|
|
||||||
if work.IsDropped() {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// send to endpoint
|
// get endpoint and connection
|
||||||
|
|
||||||
peer.mutex.RLock()
|
peer.mutex.RLock()
|
||||||
defer peer.mutex.RUnlock()
|
endpoint := peer.endpoint
|
||||||
|
peer.mutex.RUnlock()
|
||||||
if peer.endpoint == nil {
|
if endpoint == nil {
|
||||||
logDebug.Println("No endpoint for", peer.String())
|
logDebug.Println("No endpoint for", peer.String())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
device.net.mutex.RLock()
|
device.net.mutex.RLock()
|
||||||
defer device.net.mutex.RUnlock()
|
conn := device.net.conn
|
||||||
|
device.net.mutex.RUnlock()
|
||||||
if device.net.conn == nil {
|
if conn == nil {
|
||||||
logDebug.Println("No source for device")
|
logDebug.Println("No source for device")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// send message and return buffer to pool
|
// send message and refresh keys
|
||||||
|
|
||||||
_, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint)
|
_, err := conn.WriteToUDP(elem.packet, endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
atomic.AddUint64(&peer.txBytes, uint64(len(elem.packet)))
|
||||||
atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
|
|
||||||
|
|
||||||
// reset keep-alive
|
|
||||||
|
|
||||||
peer.TimerResetKeepalive()
|
peer.TimerResetKeepalive()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -138,6 +138,7 @@ func (peer *Peer) BeginHandshakeInitiation() (*QueueOutboundElement, error) {
|
|||||||
|
|
||||||
func (peer *Peer) RoutineTimerHandler() {
|
func (peer *Peer) RoutineTimerHandler() {
|
||||||
device := peer.device
|
device := peer.device
|
||||||
|
indices := &device.indices
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
logDebug.Println("Routine, timer handler, started for peer", peer.String())
|
logDebug.Println("Routine, timer handler, started for peer", peer.String())
|
||||||
@ -170,29 +171,42 @@ func (peer *Peer) RoutineTimerHandler() {
|
|||||||
|
|
||||||
logDebug.Println("Clearing all key material for", peer.String())
|
logDebug.Println("Clearing all key material for", peer.String())
|
||||||
|
|
||||||
// zero out key pairs
|
kp := &peer.keyPairs
|
||||||
|
kp.mutex.Lock()
|
||||||
|
|
||||||
func() {
|
hs := &peer.handshake
|
||||||
kp := &peer.keyPairs
|
hs.mutex.Lock()
|
||||||
kp.mutex.Lock()
|
|
||||||
// best we can do is wait for GC :( ?
|
// unmap local indecies
|
||||||
kp.current = nil
|
|
||||||
kp.previous = nil
|
indices.mutex.Lock()
|
||||||
kp.next = nil
|
if kp.previous != nil {
|
||||||
kp.mutex.Unlock()
|
delete(indices.table, kp.previous.localIndex)
|
||||||
}()
|
}
|
||||||
|
if kp.current != nil {
|
||||||
|
delete(indices.table, kp.current.localIndex)
|
||||||
|
}
|
||||||
|
if kp.next != nil {
|
||||||
|
delete(indices.table, kp.next.localIndex)
|
||||||
|
}
|
||||||
|
delete(indices.table, hs.localIndex)
|
||||||
|
indices.mutex.Unlock()
|
||||||
|
|
||||||
|
// zero out key pairs (TODO: better than wait for GC)
|
||||||
|
|
||||||
|
kp.current = nil
|
||||||
|
kp.previous = nil
|
||||||
|
kp.next = nil
|
||||||
|
kp.mutex.Unlock()
|
||||||
|
|
||||||
// zero out handshake
|
// zero out handshake
|
||||||
|
|
||||||
func() {
|
hs.localIndex = 0
|
||||||
hs := &peer.handshake
|
hs.localEphemeral = NoisePrivateKey{}
|
||||||
hs.mutex.Lock()
|
hs.remoteEphemeral = NoisePublicKey{}
|
||||||
hs.localEphemeral = NoisePrivateKey{}
|
hs.chainKey = [blake2s.Size]byte{}
|
||||||
hs.remoteEphemeral = NoisePublicKey{}
|
hs.hash = [blake2s.Size]byte{}
|
||||||
hs.chainKey = [blake2s.Size]byte{}
|
hs.mutex.Unlock()
|
||||||
hs.hash = [blake2s.Size]byte{}
|
|
||||||
hs.mutex.Unlock()
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user