1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Begin implementation of outbound work queue

This commit is contained in:
Mathias Hall-Andersen 2017-06-26 22:07:29 +02:00
parent 9d806d3853
commit eb75ff430d
6 changed files with 181 additions and 84 deletions

View File

@ -2,11 +2,14 @@ package main
import (
"log"
"net"
"sync"
)
type Device struct {
mtu int
source *net.UDPAddr // UDP source address
conn *net.UDPConn // UDP "connection"
mutex sync.RWMutex
peers map[NoisePublicKey]*Peer
indices IndexTable

View File

@ -11,10 +11,15 @@ import (
*
*/
type IndexTableEntry struct {
peer *Peer
handshake *Handshake
keyPair *KeyPair
}
type IndexTable struct {
mutex sync.RWMutex
keypairs map[uint32]*KeyPair
handshakes map[uint32]*Peer
mutex sync.RWMutex
table map[uint32]IndexTableEntry
}
func randUint32() (uint32, error) {
@ -32,52 +37,66 @@ func randUint32() (uint32, error) {
func (table *IndexTable) Init() {
table.mutex.Lock()
defer table.mutex.Unlock()
table.keypairs = make(map[uint32]*KeyPair)
table.handshakes = make(map[uint32]*Peer)
table.table = make(map[uint32]IndexTableEntry)
table.mutex.Unlock()
}
func (table *IndexTable) ClearIndex(index uint32) {
if index == 0 {
return
}
table.mutex.Lock()
delete(table.table, index)
table.mutex.Unlock()
}
func (table *IndexTable) Insert(key uint32, value IndexTableEntry) {
table.mutex.Lock()
table.table[key] = value
table.mutex.Unlock()
}
func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
table.mutex.Lock()
defer table.mutex.Unlock()
for {
// generate random index
id, err := randUint32()
index, err := randUint32()
if err != nil {
return id, err
return index, err
}
if id == 0 {
if index == 0 {
continue
}
// check if index used
_, ok := table.keypairs[id]
table.mutex.RLock()
_, ok := table.table[index]
if ok {
continue
}
_, ok = table.handshakes[id]
if ok {
table.mutex.RUnlock()
// replace index
table.mutex.Lock()
_, found := table.table[index]
if found {
table.mutex.Unlock()
continue
}
// clean old index
delete(table.handshakes, peer.handshake.localIndex)
table.handshakes[id] = peer
return id, nil
table.table[index] = IndexTableEntry{
peer: peer,
handshake: &peer.handshake,
keyPair: nil,
}
table.mutex.Unlock()
return index, nil
}
}
func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.keypairs[id]
}
func (table *IndexTable) LookupHandshake(id uint32) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.handshakes[id]
return table.table[id]
}

View File

@ -16,6 +16,18 @@ type KeyPairs struct {
mutex sync.RWMutex
current *KeyPair
previous *KeyPair
next *KeyPair
newKeyPair chan bool
next *KeyPair // not yet "confirmed by transport"
newKeyPair chan bool // signals when "current" has been updated
}
func (kp *KeyPairs) Init() {
kp.mutex.Lock()
kp.newKeyPair = make(chan bool, 5)
kp.mutex.Unlock()
}
func (kp *KeyPairs) Current() *KeyPair {
kp.mutex.RLock()
defer kp.mutex.RUnlock()
return kp.current
}

View File

@ -120,13 +120,15 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err
}
device.indices.ClearIndex(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
// assign index
var msg MessageInitiation
msg.Type = MessageInitiationType
msg.Ephemeral = handshake.localEphemeral.publicKey()
handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil {
return nil, err
@ -249,6 +251,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
// assign index
var err error
device.indices.ClearIndex(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil {
return nil, err
@ -299,11 +302,12 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// lookup handshake by reciever
peer := device.indices.LookupHandshake(msg.Reciever)
if peer == nil {
lookup := device.indices.Lookup(msg.Reciever)
handshake := lookup.handshake
if handshake == nil {
return nil
}
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitiationCreated {
@ -345,7 +349,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed
return peer
return lookup.peer
}
func (peer *Peer) NewKeyPair() *KeyPair {
@ -355,13 +359,16 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// derive keys
var isInitiator bool
var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == HandshakeResponseConsumed {
sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
isInitiator = true
} else if handshake.state == HandshakeResponseCreated {
recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
isInitiator = false
} else {
return nil
}
@ -369,16 +376,40 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// create AEAD instances
var keyPair KeyPair
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0
keyPair.recvNonce = 0
// remap index
peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
peer: peer,
keyPair: &keyPair,
handshake: nil,
})
handshake.localIndex = 0
// rotate key pairs
func() {
kp := &peer.keyPairs
kp.mutex.Lock()
defer kp.mutex.Unlock()
if isInitiator {
kp.previous = peer.keyPairs.current
kp.current = &keyPair
kp.newKeyPair <- true
} else {
kp.next = &keyPair
}
}()
// zero handshake
handshake.chainKey = [blake2s.Size]byte{}
handshake.localEphemeral = NoisePrivateKey{}
peer.handshake.state = HandshakeZeroed
return &keyPair
}

View File

@ -14,8 +14,7 @@ const (
type Peer struct {
mutex sync.RWMutex
endpointIP net.IP //
endpointPort uint16 //
endpoint *net.UDPAddr
persistentKeepaliveInterval time.Duration // 0 = disabled
keyPairs KeyPairs
handshake Handshake
@ -35,6 +34,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
peer.mutex.Lock()
peer.device = device
peer.keyPairs.Init()
peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
// map public key

View File

@ -1,9 +1,11 @@
package main
import (
"encoding/binary"
"golang.org/x/crypto/chacha20poly1305"
"net"
"sync"
"sync/atomic"
"time"
)
/* Handles outbound flow
@ -70,85 +72,115 @@ func (device *Device) SendPacket(packet []byte) {
*
* TODO: avoid dynamic allocation of work queue elements
*/
func (peer *Peer) ConsumeOutboundPackets() {
func (peer *Peer) RoutineOutboundNonceWorker() {
var packet []byte
var keyPair *KeyPair
var flushTimer time.Timer
for {
// wait for key pair
keyPair := func() *KeyPair {
peer.keyPairs.mutex.RLock()
defer peer.keyPairs.mutex.RUnlock()
return peer.keyPairs.current
}()
if keyPair == nil {
if len(peer.queueOutboundRouting) > 0 {
// TODO: start handshake
<-peer.keyPairs.newKeyPair
}
continue
// wait for packet
if packet == nil {
packet = <-peer.queueOutboundRouting
}
// assign packets key pair
for {
// wait for key pair
for keyPair == nil {
flushTimer.Reset(time.Second * 10)
// TODO: Handshake or NOP
select {
case <-peer.keyPairs.newKeyPair:
default:
case <-peer.keyPairs.newKeyPair:
case packet := <-peer.queueOutboundRouting:
keyPair = peer.keyPairs.Current()
continue
case <-flushTimer.C:
size := len(peer.queueOutboundRouting)
for i := 0; i < size; i += 1 {
<-peer.queueOutboundRouting
}
packet = nil
}
break
}
// create new work element
// process current packet
work := new(OutboundWorkQueueElement)
work.wg.Add(1)
work.keyPair = keyPair
work.packet = packet
work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
if packet != nil {
peer.queueOutbound <- work
// create work element
// drop packets until there is room
work := new(OutboundWorkQueueElement)
work.wg.Add(1)
work.keyPair = keyPair
work.packet = packet
work.nonce = keyPair.sendNonce
packet = nil
peer.queueOutbound <- work
keyPair.sendNonce += 1
// drop packets until there is space
func() {
for {
select {
case peer.device.queueWorkOutbound <- work:
break
return
default:
drop := <-peer.device.queueWorkOutbound
drop.packet = nil
drop.wg.Done()
}
}
}
}()
}
}
}
/* Go routine
*
* sequentially reads packets from queue and sends to endpoint
*
*/
func (peer *Peer) RoutineSequential() {
for work := range peer.queueOutbound {
work.wg.Wait()
// check if dropped ("ghost packet")
if work.packet == nil {
continue
}
//
}
}
func (device *Device) EncryptionWorker() {
for {
work := <-device.queueWorkOutbound
func (device *Device) RoutineEncryptionWorker() {
var nonce [chacha20poly1305.NonceSize]byte
for work := range device.queueWorkOutbound {
// pad packet
func() {
defer work.wg.Done()
padding := device.mtu - len(work.packet)
if padding < 0 {
work.packet = nil
work.wg.Done()
}
for n := 0; n < padding; n += 1 {
work.packet = append(work.packet, 0)
}
// pad packet
padding := device.mtu - len(work.packet)
if padding < 0 {
work.packet = nil
return
}
for n := 0; n < padding; n += 1 {
work.packet = append(work.packet, 0) // TODO: gotta be a faster way
}
// encrypt
//
}()
binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
work.packet = work.keyPair.send.Seal(
work.packet[:0],
nonce[:],
work.packet,
nil,
)
work.wg.Done()
}
}