1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Begin work on outbound packet flow

This commit is contained in:
Mathias Hall-Andersen 2017-06-26 13:14:02 +02:00
parent cf3a5130d3
commit 9d806d3853
9 changed files with 319 additions and 99 deletions

39
src/cookie.go Normal file
View File

@ -0,0 +1,39 @@
package main
import (
"errors"
"golang.org/x/crypto/blake2s"
)
func CalculateCookie(peer *Peer, msg []byte) {
size := len(msg)
if size < blake2s.Size128*2 {
panic(errors.New("bug: message too short"))
}
startMac1 := size - (blake2s.Size128 * 2)
startMac2 := size - blake2s.Size128
mac1 := msg[startMac1 : startMac1+blake2s.Size128]
mac2 := msg[startMac2 : startMac2+blake2s.Size128]
peer.mutex.RLock()
defer peer.mutex.RUnlock()
// set mac1
func() {
mac, _ := blake2s.New128(peer.macKey[:])
mac.Write(msg[:startMac1])
mac.Sum(mac1[:0])
}()
// set mac2
if peer.cookie != nil {
mac, _ := blake2s.New128(peer.cookie)
mac.Write(msg[:startMac2])
mac.Sum(mac2[:0])
}
}

View File

@ -1,18 +1,22 @@
package main package main
import ( import (
"log"
"sync" "sync"
) )
type Device struct { type Device struct {
mutex sync.RWMutex mtu int
peers map[NoisePublicKey]*Peer mutex sync.RWMutex
indices IndexTable peers map[NoisePublicKey]*Peer
privateKey NoisePrivateKey indices IndexTable
publicKey NoisePublicKey privateKey NoisePrivateKey
fwMark uint32 publicKey NoisePublicKey
listenPort uint16 fwMark uint32
routingTable RoutingTable listenPort uint16
routingTable RoutingTable
logger log.Logger
queueWorkOutbound chan *OutboundWorkQueueElement
} }
func (device *Device) SetPrivateKey(sk NoisePrivateKey) { func (device *Device) SetPrivateKey(sk NoisePrivateKey) {

View File

@ -2,11 +2,20 @@ package main
import ( import (
"crypto/cipher" "crypto/cipher"
"sync"
) )
type KeyPair struct { type KeyPair struct {
recv cipher.AEAD recv cipher.AEAD
recvNonce NoiseNonce recvNonce uint64
send cipher.AEAD send cipher.AEAD
sendNonce NoiseNonce sendNonce uint64
}
type KeyPairs struct {
mutex sync.RWMutex
current *KeyPair
previous *KeyPair
next *KeyPair
newKeyPair chan bool
} }

View File

@ -1,6 +1,8 @@
package main package main
import "fmt" import (
"fmt"
)
func main() { func main() {
fd, err := CreateTUN("test0") fd, err := CreateTUN("test0")
@ -8,9 +10,9 @@ func main() {
queue := make(chan []byte, 1000) queue := make(chan []byte, 1000)
var device Device // var device Device
go OutgoingRoutingWorker(&device, queue) // go OutgoingRoutingWorker(&device, queue)
for { for {
tmp := make([]byte, 1<<16) tmp := make([]byte, 1<<16)

View File

@ -9,9 +9,9 @@ import (
) )
const ( const (
HandshakeReset = iota HandshakeZeroed = iota
HandshakeInitialCreated HandshakeInitiationCreated
HandshakeInitialConsumed HandshakeInitiationConsumed
HandshakeResponseCreated HandshakeResponseCreated
HandshakeResponseConsumed HandshakeResponseConsumed
) )
@ -24,13 +24,19 @@ const (
) )
const ( const (
MessageInitalType = 1 MessageInitiationType = 1
MessageResponseType = 2 MessageResponseType = 2
MessageCookieResponseType = 3 MessageCookieResponseType = 3
MessageTransportType = 4 MessageTransportType = 4
) )
type MessageInital struct { /* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder
* we can treat these as a 32-bit int
*
*/
type MessageInitiation struct {
Type uint32 Type uint32
Sender uint32 Sender uint32
Ephemeral NoisePublicKey Ephemeral NoisePublicKey
@ -73,9 +79,9 @@ type Handshake struct {
} }
var ( var (
ZeroNonce [chacha20poly1305.NonceSize]byte
InitalChainKey [blake2s.Size]byte InitalChainKey [blake2s.Size]byte
InitalHash [blake2s.Size]byte InitalHash [blake2s.Size]byte
ZeroNonce [chacha20poly1305.NonceSize]byte
) )
func init() { func init() {
@ -83,23 +89,23 @@ func init() {
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...)) InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
} }
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return KDF1(c[:], data) return KDF1(c[:], data)
} }
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return blake2s.Sum256(append(h[:], data...)) return blake2s.Sum256(append(h[:], data...))
} }
func (h *Handshake) addToHash(data []byte) { func (h *Handshake) mixHash(data []byte) {
h.hash = addToHash(h.hash, data) h.hash = mixHash(h.hash, data)
} }
func (h *Handshake) addToChainKey(data []byte) { func (h *Handshake) mixKey(data []byte) {
h.chainKey = addToChainKey(h.chainKey, data) h.chainKey = mixKey(h.chainKey, data)
} }
func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
@ -108,7 +114,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
var err error var err error
handshake.chainKey = InitalChainKey handshake.chainKey = InitalChainKey
handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:]) handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:])
handshake.localEphemeral, err = newPrivateKey() handshake.localEphemeral, err = newPrivateKey()
if err != nil { if err != nil {
return nil, err return nil, err
@ -116,9 +122,9 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
// assign index // assign index
var msg MessageInital var msg MessageInitiation
msg.Type = MessageInitalType msg.Type = MessageInitiationType
msg.Ephemeral = handshake.localEphemeral.publicKey() msg.Ephemeral = handshake.localEphemeral.publicKey()
handshake.localIndex, err = device.indices.NewIndex(peer) handshake.localIndex, err = device.indices.NewIndex(peer)
@ -127,10 +133,10 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
} }
msg.Sender = handshake.localIndex msg.Sender = handshake.localIndex
handshake.addToChainKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.addToHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
// encrypt identity key // encrypt static key
func() { func() {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
@ -139,7 +145,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
}() }()
handshake.addToHash(msg.Static[:]) handshake.mixHash(msg.Static[:])
// encrypt timestamp // encrypt timestamp
@ -154,22 +160,22 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}() }()
handshake.addToHash(msg.Timestamp[:]) handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitialCreated handshake.state = HandshakeInitiationCreated
return &msg, nil return &msg, nil
} }
func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if msg.Type != MessageInitalType { if msg.Type != MessageInitiationType {
panic(errors.New("bug: invalid inital message type")) return nil
} }
hash := addToHash(InitalHash, device.publicKey[:]) hash := mixHash(InitalHash, device.publicKey[:])
hash = addToHash(hash, msg.Ephemeral[:]) hash = mixHash(hash, msg.Ephemeral[:])
chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:]) chainKey := mixKey(InitalChainKey, msg.Ephemeral[:])
// decrypt identity key // decrypt static key
var err error var err error
var peerPK NoisePublicKey var peerPK NoisePublicKey
@ -183,7 +189,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
if err != nil { if err != nil {
return nil return nil
} }
hash = addToHash(hash, msg.Static[:]) hash = mixHash(hash, msg.Static[:])
// find peer // find peer
@ -210,7 +216,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
if err != nil { if err != nil {
return nil return nil
} }
hash = addToHash(hash, msg.Timestamp[:]) hash = mixHash(hash, msg.Timestamp[:])
// check for replay attack // check for replay attack
@ -218,7 +224,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
return nil return nil
} }
// check for flood attack // TODO: check for flood attack
// update handshake state // update handshake state
@ -227,7 +233,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp handshake.lastTimestamp = timestamp
handshake.state = HandshakeInitialConsumed handshake.state = HandshakeInitiationConsumed
return peer return peer
} }
@ -236,8 +242,8 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitialConsumed { if handshake.state != HandshakeInitiationConsumed {
panic(errors.New("bug: handshake initation must be consumed first")) return nil, errors.New("handshake initation must be consumed first")
} }
// assign index // assign index
@ -260,13 +266,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
return nil, err return nil, err
} }
msg.Ephemeral = handshake.localEphemeral.publicKey() msg.Ephemeral = handshake.localEphemeral.publicKey()
handshake.addToHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
func() { func() {
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
handshake.addToChainKey(ss[:]) handshake.mixKey(ss[:])
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
handshake.addToChainKey(ss[:]) handshake.mixKey(ss[:])
}() }()
// add preshared key (psk) // add preshared key (psk)
@ -274,12 +280,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
var tau [blake2s.Size]byte var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:]) handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
handshake.addToHash(tau[:]) handshake.mixHash(tau[:])
func() { func() {
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
handshake.addToHash(msg.Empty[:]) handshake.mixHash(msg.Empty[:])
}() }()
handshake.state = HandshakeResponseCreated handshake.state = HandshakeResponseCreated
@ -288,7 +294,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
if msg.Type != MessageResponseType { if msg.Type != MessageResponseType {
panic(errors.New("bug: invalid message type")) return nil
} }
// lookup handshake by reciever // lookup handshake by reciever
@ -300,20 +306,20 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitialCreated { if handshake.state != HandshakeInitiationCreated {
return nil return nil
} }
// finish 3-way DH // finish 3-way DH
hash := addToHash(handshake.hash, msg.Ephemeral[:]) hash := mixHash(handshake.hash, msg.Ephemeral[:])
chainKey := handshake.chainKey chainKey := handshake.chainKey
func() { func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
chainKey = addToChainKey(chainKey, ss[:]) chainKey = mixKey(chainKey, ss[:])
ss = device.privateKey.sharedSecret(msg.Ephemeral) ss = device.privateKey.sharedSecret(msg.Ephemeral)
chainKey = addToChainKey(chainKey, ss[:]) chainKey = mixKey(chainKey, ss[:])
}() }()
// add preshared key (psk) // add preshared key (psk)
@ -321,7 +327,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
var tau [blake2s.Size]byte var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
hash = addToHash(hash, tau[:]) hash = mixHash(hash, tau[:])
// authenticate // authenticate
@ -330,7 +336,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
if err != nil { if err != nil {
return nil return nil
} }
hash = addToHash(hash, msg.Empty[:]) hash = mixHash(hash, msg.Empty[:])
// update handshake state // update handshake state
@ -368,7 +374,11 @@ func (peer *Peer) NewKeyPair() *KeyPair {
keyPair.sendNonce = 0 keyPair.sendNonce = 0
keyPair.recvNonce = 0 keyPair.recvNonce = 0
peer.handshake.state = HandshakeReset // zero handshake
handshake.chainKey = [blake2s.Size]byte{}
handshake.localEphemeral = NoisePrivateKey{}
peer.handshake.state = HandshakeZeroed
return &keyPair return &keyPair
} }

View File

@ -67,13 +67,13 @@ func TestNoiseHandshake(t *testing.T) {
t.Log("exchange initiation message") t.Log("exchange initiation message")
msg1, err := dev1.CreateMessageInitial(peer2) msg1, err := dev1.CreateMessageInitiation(peer2)
assertNil(t, err) assertNil(t, err)
packet := make([]byte, 0, 256) packet := make([]byte, 0, 256)
writer := bytes.NewBuffer(packet) writer := bytes.NewBuffer(packet)
err = binary.Write(writer, binary.LittleEndian, msg1) err = binary.Write(writer, binary.LittleEndian, msg1)
peer := dev2.ConsumeMessageInitial(msg1) peer := dev2.ConsumeMessageInitiation(msg1)
if peer == nil { if peer == nil {
t.Fatal("handshake failed at initiation message") t.Fatal("handshake failed at initiation message")
} }

View File

@ -1,39 +1,64 @@
package main package main
import ( import (
"errors"
"golang.org/x/crypto/blake2s"
"net" "net"
"sync" "sync"
"time" "time"
) )
const (
OutboundQueueSize = 64
)
type Peer struct { type Peer struct {
mutex sync.RWMutex mutex sync.RWMutex
endpointIP net.IP // endpointIP net.IP //
endpointPort uint16 // endpointPort uint16 //
persistentKeepaliveInterval time.Duration // 0 = disabled persistentKeepaliveInterval time.Duration // 0 = disabled
keyPairs KeyPairs
handshake Handshake handshake Handshake
device *Device device *Device
macKey [blake2s.Size]byte // Hash(Label-Mac1 || publicKey)
cookie []byte // cookie
cookieExpire time.Time
queueInbound chan []byte
queueOutbound chan *OutboundWorkQueueElement
queueOutboundRouting chan []byte
} }
func (device *Device) NewPeer(pk NoisePublicKey) *Peer { func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
var peer Peer var peer Peer
// map public key // create peer
device.mutex.Lock()
device.peers[pk] = &peer
device.mutex.Unlock()
// precompute
peer.mutex.Lock() peer.mutex.Lock()
peer.device = device peer.device = device
func(h *Handshake) { peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
h.mutex.Lock()
h.remoteStatic = pk // map public key
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
h.mutex.Unlock() device.mutex.Lock()
}(&peer.handshake) _, ok := device.peers[pk]
if ok {
panic(errors.New("bug: adding existing peer"))
}
device.peers[pk] = &peer
device.mutex.Unlock()
// precompute DH
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.remoteStatic = pk
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
// compute mac key
peer.macKey = blake2s.Sum256(append([]byte(WGLabelMAC1[:]), handshake.remoteStatic[:]...))
handshake.mutex.Unlock()
peer.mutex.Unlock() peer.mutex.Unlock()
return &peer return &peer

View File

@ -2,7 +2,6 @@ package main
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
) )
@ -52,25 +51,3 @@ func (table *RoutingTable) LookupIPv6(address []byte) *Peer {
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
return table.IPv6.Lookup(address) return table.IPv6.Lookup(address)
} }
func OutgoingRoutingWorker(device *Device, queue chan []byte) {
for {
packet := <-queue
switch packet[0] >> 4 {
case IPv4version:
dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer := device.routingTable.LookupIPv4(dst)
fmt.Println("IPv4", peer)
case IPv6version:
dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer := device.routingTable.LookupIPv6(dst)
fmt.Println("IPv6", peer)
default:
// todo: log
fmt.Println("Unknown IP version")
}
}
}

154
src/send.go Normal file
View File

@ -0,0 +1,154 @@
package main
import (
"net"
"sync"
"sync/atomic"
)
/* Handles outbound flow
*
* 1. TUN queue
* 2. Routing
* 3. Per peer queuing
* 4. (work queuing)
*
*/
type OutboundWorkQueueElement struct {
wg sync.WaitGroup
packet []byte
nonce uint64
keyPair *KeyPair
}
func (device *Device) SendPacket(packet []byte) {
// lookup peer
var peer *Peer
switch packet[0] >> 4 {
case IPv4version:
dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst)
case IPv6version:
dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst)
default:
device.logger.Println("unknown IP version")
return
}
if peer == nil {
return
}
// insert into peer queue
for {
select {
case peer.queueOutboundRouting <- packet:
default:
select {
case <-peer.queueOutboundRouting:
default:
}
continue
}
break
}
}
/* Go routine
*
*
* 1. waits for handshake.
* 2. assigns key pair & nonce
* 3. inserts to working queue
*
* TODO: avoid dynamic allocation of work queue elements
*/
func (peer *Peer) ConsumeOutboundPackets() {
for {
// wait for key pair
keyPair := func() *KeyPair {
peer.keyPairs.mutex.RLock()
defer peer.keyPairs.mutex.RUnlock()
return peer.keyPairs.current
}()
if keyPair == nil {
if len(peer.queueOutboundRouting) > 0 {
// TODO: start handshake
<-peer.keyPairs.newKeyPair
}
continue
}
// assign packets key pair
for {
select {
case <-peer.keyPairs.newKeyPair:
default:
case <-peer.keyPairs.newKeyPair:
case packet := <-peer.queueOutboundRouting:
// create new work element
work := new(OutboundWorkQueueElement)
work.wg.Add(1)
work.keyPair = keyPair
work.packet = packet
work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
peer.queueOutbound <- work
// drop packets until there is room
for {
select {
case peer.device.queueWorkOutbound <- work:
break
default:
drop := <-peer.device.queueWorkOutbound
drop.packet = nil
drop.wg.Done()
}
}
}
}
}
}
func (peer *Peer) RoutineSequential() {
for work := range peer.queueOutbound {
work.wg.Wait()
if work.packet == nil {
continue
}
}
}
func (device *Device) EncryptionWorker() {
for {
work := <-device.queueWorkOutbound
func() {
defer work.wg.Done()
// pad packet
padding := device.mtu - len(work.packet)
if padding < 0 {
work.packet = nil
return
}
for n := 0; n < padding; n += 1 {
work.packet = append(work.packet, 0) // TODO: gotta be a faster way
}
//
}()
}
}