1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Begin implementation of outbound work queue

This commit is contained in:
Mathias Hall-Andersen 2017-06-26 22:07:29 +02:00
parent 9d806d3853
commit eb75ff430d
6 changed files with 181 additions and 84 deletions

View File

@ -2,11 +2,14 @@ package main
import ( import (
"log" "log"
"net"
"sync" "sync"
) )
type Device struct { type Device struct {
mtu int mtu int
source *net.UDPAddr // UDP source address
conn *net.UDPConn // UDP "connection"
mutex sync.RWMutex mutex sync.RWMutex
peers map[NoisePublicKey]*Peer peers map[NoisePublicKey]*Peer
indices IndexTable indices IndexTable

View File

@ -11,10 +11,15 @@ import (
* *
*/ */
type IndexTableEntry struct {
peer *Peer
handshake *Handshake
keyPair *KeyPair
}
type IndexTable struct { type IndexTable struct {
mutex sync.RWMutex mutex sync.RWMutex
keypairs map[uint32]*KeyPair table map[uint32]IndexTableEntry
handshakes map[uint32]*Peer
} }
func randUint32() (uint32, error) { func randUint32() (uint32, error) {
@ -32,52 +37,66 @@ func randUint32() (uint32, error) {
func (table *IndexTable) Init() { func (table *IndexTable) Init() {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() table.table = make(map[uint32]IndexTableEntry)
table.keypairs = make(map[uint32]*KeyPair) table.mutex.Unlock()
table.handshakes = make(map[uint32]*Peer) }
func (table *IndexTable) ClearIndex(index uint32) {
if index == 0 {
return
}
table.mutex.Lock()
delete(table.table, index)
table.mutex.Unlock()
}
func (table *IndexTable) Insert(key uint32, value IndexTableEntry) {
table.mutex.Lock()
table.table[key] = value
table.mutex.Unlock()
} }
func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
table.mutex.Lock()
defer table.mutex.Unlock()
for { for {
// generate random index // generate random index
id, err := randUint32() index, err := randUint32()
if err != nil { if err != nil {
return id, err return index, err
} }
if id == 0 { if index == 0 {
continue continue
} }
// check if index used // check if index used
_, ok := table.keypairs[id] table.mutex.RLock()
_, ok := table.table[index]
if ok { if ok {
continue continue
} }
_, ok = table.handshakes[id] table.mutex.RUnlock()
if ok {
// replace index
table.mutex.Lock()
_, found := table.table[index]
if found {
table.mutex.Unlock()
continue continue
} }
table.table[index] = IndexTableEntry{
// clean old index peer: peer,
handshake: &peer.handshake,
delete(table.handshakes, peer.handshake.localIndex) keyPair: nil,
table.handshakes[id] = peer }
return id, nil table.mutex.Unlock()
return index, nil
} }
} }
func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair { func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
return table.keypairs[id] return table.table[id]
}
func (table *IndexTable) LookupHandshake(id uint32) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.handshakes[id]
} }

View File

@ -16,6 +16,18 @@ type KeyPairs struct {
mutex sync.RWMutex mutex sync.RWMutex
current *KeyPair current *KeyPair
previous *KeyPair previous *KeyPair
next *KeyPair next *KeyPair // not yet "confirmed by transport"
newKeyPair chan bool newKeyPair chan bool // signals when "current" has been updated
}
func (kp *KeyPairs) Init() {
kp.mutex.Lock()
kp.newKeyPair = make(chan bool, 5)
kp.mutex.Unlock()
}
func (kp *KeyPairs) Current() *KeyPair {
kp.mutex.RLock()
defer kp.mutex.RUnlock()
return kp.current
} }

View File

@ -120,13 +120,15 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err return nil, err
} }
device.indices.ClearIndex(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
// assign index // assign index
var msg MessageInitiation var msg MessageInitiation
msg.Type = MessageInitiationType msg.Type = MessageInitiationType
msg.Ephemeral = handshake.localEphemeral.publicKey() msg.Ephemeral = handshake.localEphemeral.publicKey()
handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil { if err != nil {
return nil, err return nil, err
@ -249,6 +251,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
// assign index // assign index
var err error var err error
device.indices.ClearIndex(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer) handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil { if err != nil {
return nil, err return nil, err
@ -299,11 +302,12 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// lookup handshake by reciever // lookup handshake by reciever
peer := device.indices.LookupHandshake(msg.Reciever) lookup := device.indices.Lookup(msg.Reciever)
if peer == nil { handshake := lookup.handshake
if handshake == nil {
return nil return nil
} }
handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitiationCreated { if handshake.state != HandshakeInitiationCreated {
@ -345,7 +349,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed handshake.state = HandshakeResponseConsumed
return peer return lookup.peer
} }
func (peer *Peer) NewKeyPair() *KeyPair { func (peer *Peer) NewKeyPair() *KeyPair {
@ -355,13 +359,16 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// derive keys // derive keys
var isInitiator bool
var sendKey [chacha20poly1305.KeySize]byte var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == HandshakeResponseConsumed { if handshake.state == HandshakeResponseConsumed {
sendKey, recvKey = KDF2(handshake.chainKey[:], nil) sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
isInitiator = true
} else if handshake.state == HandshakeResponseCreated { } else if handshake.state == HandshakeResponseCreated {
recvKey, sendKey = KDF2(handshake.chainKey[:], nil) recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
isInitiator = false
} else { } else {
return nil return nil
} }
@ -369,16 +376,40 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// create AEAD instances // create AEAD instances
var keyPair KeyPair var keyPair KeyPair
keyPair.send, _ = chacha20poly1305.New(sendKey[:]) keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.recv, _ = chacha20poly1305.New(recvKey[:]) keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0 keyPair.sendNonce = 0
keyPair.recvNonce = 0 keyPair.recvNonce = 0
// remap index
peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
peer: peer,
keyPair: &keyPair,
handshake: nil,
})
handshake.localIndex = 0
// rotate key pairs
func() {
kp := &peer.keyPairs
kp.mutex.Lock()
defer kp.mutex.Unlock()
if isInitiator {
kp.previous = peer.keyPairs.current
kp.current = &keyPair
kp.newKeyPair <- true
} else {
kp.next = &keyPair
}
}()
// zero handshake // zero handshake
handshake.chainKey = [blake2s.Size]byte{} handshake.chainKey = [blake2s.Size]byte{}
handshake.localEphemeral = NoisePrivateKey{} handshake.localEphemeral = NoisePrivateKey{}
peer.handshake.state = HandshakeZeroed peer.handshake.state = HandshakeZeroed
return &keyPair return &keyPair
} }

View File

@ -14,8 +14,7 @@ const (
type Peer struct { type Peer struct {
mutex sync.RWMutex mutex sync.RWMutex
endpointIP net.IP // endpoint *net.UDPAddr
endpointPort uint16 //
persistentKeepaliveInterval time.Duration // 0 = disabled persistentKeepaliveInterval time.Duration // 0 = disabled
keyPairs KeyPairs keyPairs KeyPairs
handshake Handshake handshake Handshake
@ -35,6 +34,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
peer.mutex.Lock() peer.mutex.Lock()
peer.device = device peer.device = device
peer.keyPairs.Init()
peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize) peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
// map public key // map public key

View File

@ -1,9 +1,11 @@
package main package main
import ( import (
"encoding/binary"
"golang.org/x/crypto/chacha20poly1305"
"net" "net"
"sync" "sync"
"sync/atomic" "time"
) )
/* Handles outbound flow /* Handles outbound flow
@ -70,85 +72,115 @@ func (device *Device) SendPacket(packet []byte) {
* *
* TODO: avoid dynamic allocation of work queue elements * TODO: avoid dynamic allocation of work queue elements
*/ */
func (peer *Peer) ConsumeOutboundPackets() { func (peer *Peer) RoutineOutboundNonceWorker() {
var packet []byte
var keyPair *KeyPair
var flushTimer time.Timer
for { for {
// wait for key pair
keyPair := func() *KeyPair { // wait for packet
peer.keyPairs.mutex.RLock()
defer peer.keyPairs.mutex.RUnlock() if packet == nil {
return peer.keyPairs.current packet = <-peer.queueOutboundRouting
}()
if keyPair == nil {
if len(peer.queueOutboundRouting) > 0 {
// TODO: start handshake
<-peer.keyPairs.newKeyPair
}
continue
} }
// assign packets key pair // wait for key pair
for {
for keyPair == nil {
flushTimer.Reset(time.Second * 10)
// TODO: Handshake or NOP
select { select {
case <-peer.keyPairs.newKeyPair: case <-peer.keyPairs.newKeyPair:
default: keyPair = peer.keyPairs.Current()
case <-peer.keyPairs.newKeyPair: continue
case packet := <-peer.queueOutboundRouting: case <-flushTimer.C:
size := len(peer.queueOutboundRouting)
for i := 0; i < size; i += 1 {
<-peer.queueOutboundRouting
}
packet = nil
}
break
}
// create new work element // process current packet
work := new(OutboundWorkQueueElement) if packet != nil {
work.wg.Add(1)
work.keyPair = keyPair
work.packet = packet
work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
peer.queueOutbound <- work // create work element
// drop packets until there is room work := new(OutboundWorkQueueElement)
work.wg.Add(1)
work.keyPair = keyPair
work.packet = packet
work.nonce = keyPair.sendNonce
packet = nil
peer.queueOutbound <- work
keyPair.sendNonce += 1
// drop packets until there is space
func() {
for { for {
select { select {
case peer.device.queueWorkOutbound <- work: case peer.device.queueWorkOutbound <- work:
break return
default: default:
drop := <-peer.device.queueWorkOutbound drop := <-peer.device.queueWorkOutbound
drop.packet = nil drop.packet = nil
drop.wg.Done() drop.wg.Done()
} }
} }
} }()
} }
} }
} }
/* Go routine
*
* sequentially reads packets from queue and sends to endpoint
*
*/
func (peer *Peer) RoutineSequential() { func (peer *Peer) RoutineSequential() {
for work := range peer.queueOutbound { for work := range peer.queueOutbound {
work.wg.Wait() work.wg.Wait()
// check if dropped ("ghost packet")
if work.packet == nil { if work.packet == nil {
continue continue
} }
//
} }
} }
func (device *Device) EncryptionWorker() { func (device *Device) RoutineEncryptionWorker() {
for { var nonce [chacha20poly1305.NonceSize]byte
work := <-device.queueWorkOutbound for work := range device.queueWorkOutbound {
// pad packet
func() { padding := device.mtu - len(work.packet)
defer work.wg.Done() if padding < 0 {
work.packet = nil
work.wg.Done()
}
for n := 0; n < padding; n += 1 {
work.packet = append(work.packet, 0)
}
// pad packet // encrypt
padding := device.mtu - len(work.packet)
if padding < 0 {
work.packet = nil
return
}
for n := 0; n < padding; n += 1 {
work.packet = append(work.packet, 0) // TODO: gotta be a faster way
}
// binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
work.packet = work.keyPair.send.Seal(
}() work.packet[:0],
nonce[:],
work.packet,
nil,
)
work.wg.Done()
} }
} }