mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
Begin work on outbound packet flow
This commit is contained in:
parent
cf3a5130d3
commit
9d806d3853
39
src/cookie.go
Normal file
39
src/cookie.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"golang.org/x/crypto/blake2s"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CalculateCookie(peer *Peer, msg []byte) {
|
||||||
|
size := len(msg)
|
||||||
|
|
||||||
|
if size < blake2s.Size128*2 {
|
||||||
|
panic(errors.New("bug: message too short"))
|
||||||
|
}
|
||||||
|
|
||||||
|
startMac1 := size - (blake2s.Size128 * 2)
|
||||||
|
startMac2 := size - blake2s.Size128
|
||||||
|
|
||||||
|
mac1 := msg[startMac1 : startMac1+blake2s.Size128]
|
||||||
|
mac2 := msg[startMac2 : startMac2+blake2s.Size128]
|
||||||
|
|
||||||
|
peer.mutex.RLock()
|
||||||
|
defer peer.mutex.RUnlock()
|
||||||
|
|
||||||
|
// set mac1
|
||||||
|
|
||||||
|
func() {
|
||||||
|
mac, _ := blake2s.New128(peer.macKey[:])
|
||||||
|
mac.Write(msg[:startMac1])
|
||||||
|
mac.Sum(mac1[:0])
|
||||||
|
}()
|
||||||
|
|
||||||
|
// set mac2
|
||||||
|
|
||||||
|
if peer.cookie != nil {
|
||||||
|
mac, _ := blake2s.New128(peer.cookie)
|
||||||
|
mac.Write(msg[:startMac2])
|
||||||
|
mac.Sum(mac2[:0])
|
||||||
|
}
|
||||||
|
}
|
@ -1,18 +1,22 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
mutex sync.RWMutex
|
mtu int
|
||||||
peers map[NoisePublicKey]*Peer
|
mutex sync.RWMutex
|
||||||
indices IndexTable
|
peers map[NoisePublicKey]*Peer
|
||||||
privateKey NoisePrivateKey
|
indices IndexTable
|
||||||
publicKey NoisePublicKey
|
privateKey NoisePrivateKey
|
||||||
fwMark uint32
|
publicKey NoisePublicKey
|
||||||
listenPort uint16
|
fwMark uint32
|
||||||
routingTable RoutingTable
|
listenPort uint16
|
||||||
|
routingTable RoutingTable
|
||||||
|
logger log.Logger
|
||||||
|
queueWorkOutbound chan *OutboundWorkQueueElement
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
|
func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
|
||||||
|
@ -2,11 +2,20 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type KeyPair struct {
|
type KeyPair struct {
|
||||||
recv cipher.AEAD
|
recv cipher.AEAD
|
||||||
recvNonce NoiseNonce
|
recvNonce uint64
|
||||||
send cipher.AEAD
|
send cipher.AEAD
|
||||||
sendNonce NoiseNonce
|
sendNonce uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyPairs struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
current *KeyPair
|
||||||
|
previous *KeyPair
|
||||||
|
next *KeyPair
|
||||||
|
newKeyPair chan bool
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
fd, err := CreateTUN("test0")
|
fd, err := CreateTUN("test0")
|
||||||
@ -8,9 +10,9 @@ func main() {
|
|||||||
|
|
||||||
queue := make(chan []byte, 1000)
|
queue := make(chan []byte, 1000)
|
||||||
|
|
||||||
var device Device
|
// var device Device
|
||||||
|
|
||||||
go OutgoingRoutingWorker(&device, queue)
|
// go OutgoingRoutingWorker(&device, queue)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
tmp := make([]byte, 1<<16)
|
tmp := make([]byte, 1<<16)
|
||||||
|
@ -9,9 +9,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HandshakeReset = iota
|
HandshakeZeroed = iota
|
||||||
HandshakeInitialCreated
|
HandshakeInitiationCreated
|
||||||
HandshakeInitialConsumed
|
HandshakeInitiationConsumed
|
||||||
HandshakeResponseCreated
|
HandshakeResponseCreated
|
||||||
HandshakeResponseConsumed
|
HandshakeResponseConsumed
|
||||||
)
|
)
|
||||||
@ -24,13 +24,19 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MessageInitalType = 1
|
MessageInitiationType = 1
|
||||||
MessageResponseType = 2
|
MessageResponseType = 2
|
||||||
MessageCookieResponseType = 3
|
MessageCookieResponseType = 3
|
||||||
MessageTransportType = 4
|
MessageTransportType = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
type MessageInital struct {
|
/* Type is an 8-bit field, followed by 3 nul bytes,
|
||||||
|
* by marshalling the messages in little-endian byteorder
|
||||||
|
* we can treat these as a 32-bit int
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
type MessageInitiation struct {
|
||||||
Type uint32
|
Type uint32
|
||||||
Sender uint32
|
Sender uint32
|
||||||
Ephemeral NoisePublicKey
|
Ephemeral NoisePublicKey
|
||||||
@ -73,9 +79,9 @@ type Handshake struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ZeroNonce [chacha20poly1305.NonceSize]byte
|
|
||||||
InitalChainKey [blake2s.Size]byte
|
InitalChainKey [blake2s.Size]byte
|
||||||
InitalHash [blake2s.Size]byte
|
InitalHash [blake2s.Size]byte
|
||||||
|
ZeroNonce [chacha20poly1305.NonceSize]byte
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -83,23 +89,23 @@ func init() {
|
|||||||
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
|
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
||||||
return KDF1(c[:], data)
|
return KDF1(c[:], data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
||||||
return blake2s.Sum256(append(h[:], data...))
|
return blake2s.Sum256(append(h[:], data...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshake) addToHash(data []byte) {
|
func (h *Handshake) mixHash(data []byte) {
|
||||||
h.hash = addToHash(h.hash, data)
|
h.hash = mixHash(h.hash, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshake) addToChainKey(data []byte) {
|
func (h *Handshake) mixKey(data []byte) {
|
||||||
h.chainKey = addToChainKey(h.chainKey, data)
|
h.chainKey = mixKey(h.chainKey, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
defer handshake.mutex.Unlock()
|
defer handshake.mutex.Unlock()
|
||||||
@ -108,7 +114,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
handshake.chainKey = InitalChainKey
|
handshake.chainKey = InitalChainKey
|
||||||
handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:])
|
handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:])
|
||||||
handshake.localEphemeral, err = newPrivateKey()
|
handshake.localEphemeral, err = newPrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -116,9 +122,9 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
|||||||
|
|
||||||
// assign index
|
// assign index
|
||||||
|
|
||||||
var msg MessageInital
|
var msg MessageInitiation
|
||||||
|
|
||||||
msg.Type = MessageInitalType
|
msg.Type = MessageInitiationType
|
||||||
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
||||||
handshake.localIndex, err = device.indices.NewIndex(peer)
|
handshake.localIndex, err = device.indices.NewIndex(peer)
|
||||||
|
|
||||||
@ -127,10 +133,10 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
msg.Sender = handshake.localIndex
|
msg.Sender = handshake.localIndex
|
||||||
handshake.addToChainKey(msg.Ephemeral[:])
|
handshake.mixKey(msg.Ephemeral[:])
|
||||||
handshake.addToHash(msg.Ephemeral[:])
|
handshake.mixHash(msg.Ephemeral[:])
|
||||||
|
|
||||||
// encrypt identity key
|
// encrypt static key
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
@ -139,7 +145,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
|||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
|
aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
|
||||||
}()
|
}()
|
||||||
handshake.addToHash(msg.Static[:])
|
handshake.mixHash(msg.Static[:])
|
||||||
|
|
||||||
// encrypt timestamp
|
// encrypt timestamp
|
||||||
|
|
||||||
@ -154,22 +160,22 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
|||||||
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
handshake.addToHash(msg.Timestamp[:])
|
handshake.mixHash(msg.Timestamp[:])
|
||||||
handshake.state = HandshakeInitialCreated
|
handshake.state = HandshakeInitiationCreated
|
||||||
|
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
|
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
if msg.Type != MessageInitalType {
|
if msg.Type != MessageInitiationType {
|
||||||
panic(errors.New("bug: invalid inital message type"))
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
hash := addToHash(InitalHash, device.publicKey[:])
|
hash := mixHash(InitalHash, device.publicKey[:])
|
||||||
hash = addToHash(hash, msg.Ephemeral[:])
|
hash = mixHash(hash, msg.Ephemeral[:])
|
||||||
chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
|
chainKey := mixKey(InitalChainKey, msg.Ephemeral[:])
|
||||||
|
|
||||||
// decrypt identity key
|
// decrypt static key
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var peerPK NoisePublicKey
|
var peerPK NoisePublicKey
|
||||||
@ -183,7 +189,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
hash = addToHash(hash, msg.Static[:])
|
hash = mixHash(hash, msg.Static[:])
|
||||||
|
|
||||||
// find peer
|
// find peer
|
||||||
|
|
||||||
@ -210,7 +216,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
hash = addToHash(hash, msg.Timestamp[:])
|
hash = mixHash(hash, msg.Timestamp[:])
|
||||||
|
|
||||||
// check for replay attack
|
// check for replay attack
|
||||||
|
|
||||||
@ -218,7 +224,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// check for flood attack
|
// TODO: check for flood attack
|
||||||
|
|
||||||
// update handshake state
|
// update handshake state
|
||||||
|
|
||||||
@ -227,7 +233,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
|
|||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.remoteEphemeral = msg.Ephemeral
|
handshake.remoteEphemeral = msg.Ephemeral
|
||||||
handshake.lastTimestamp = timestamp
|
handshake.lastTimestamp = timestamp
|
||||||
handshake.state = HandshakeInitialConsumed
|
handshake.state = HandshakeInitiationConsumed
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,8 +242,8 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
defer handshake.mutex.Unlock()
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
if handshake.state != HandshakeInitialConsumed {
|
if handshake.state != HandshakeInitiationConsumed {
|
||||||
panic(errors.New("bug: handshake initation must be consumed first"))
|
return nil, errors.New("handshake initation must be consumed first")
|
||||||
}
|
}
|
||||||
|
|
||||||
// assign index
|
// assign index
|
||||||
@ -260,13 +266,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
||||||
handshake.addToHash(msg.Ephemeral[:])
|
handshake.mixHash(msg.Ephemeral[:])
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||||
handshake.addToChainKey(ss[:])
|
handshake.mixKey(ss[:])
|
||||||
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
handshake.addToChainKey(ss[:])
|
handshake.mixKey(ss[:])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// add preshared key (psk)
|
// add preshared key (psk)
|
||||||
@ -274,12 +280,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
var tau [blake2s.Size]byte
|
var tau [blake2s.Size]byte
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
|
handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
|
||||||
handshake.addToHash(tau[:])
|
handshake.mixHash(tau[:])
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
||||||
handshake.addToHash(msg.Empty[:])
|
handshake.mixHash(msg.Empty[:])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
handshake.state = HandshakeResponseCreated
|
handshake.state = HandshakeResponseCreated
|
||||||
@ -288,7 +294,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
|
|
||||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
if msg.Type != MessageResponseType {
|
if msg.Type != MessageResponseType {
|
||||||
panic(errors.New("bug: invalid message type"))
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookup handshake by reciever
|
// lookup handshake by reciever
|
||||||
@ -300,20 +306,20 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
defer handshake.mutex.Unlock()
|
defer handshake.mutex.Unlock()
|
||||||
if handshake.state != HandshakeInitialCreated {
|
if handshake.state != HandshakeInitiationCreated {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// finish 3-way DH
|
// finish 3-way DH
|
||||||
|
|
||||||
hash := addToHash(handshake.hash, msg.Ephemeral[:])
|
hash := mixHash(handshake.hash, msg.Ephemeral[:])
|
||||||
chainKey := handshake.chainKey
|
chainKey := handshake.chainKey
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||||
chainKey = addToChainKey(chainKey, ss[:])
|
chainKey = mixKey(chainKey, ss[:])
|
||||||
ss = device.privateKey.sharedSecret(msg.Ephemeral)
|
ss = device.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
chainKey = addToChainKey(chainKey, ss[:])
|
chainKey = mixKey(chainKey, ss[:])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// add preshared key (psk)
|
// add preshared key (psk)
|
||||||
@ -321,7 +327,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
var tau [blake2s.Size]byte
|
var tau [blake2s.Size]byte
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
|
chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
|
||||||
hash = addToHash(hash, tau[:])
|
hash = mixHash(hash, tau[:])
|
||||||
|
|
||||||
// authenticate
|
// authenticate
|
||||||
|
|
||||||
@ -330,7 +336,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
hash = addToHash(hash, msg.Empty[:])
|
hash = mixHash(hash, msg.Empty[:])
|
||||||
|
|
||||||
// update handshake state
|
// update handshake state
|
||||||
|
|
||||||
@ -368,7 +374,11 @@ func (peer *Peer) NewKeyPair() *KeyPair {
|
|||||||
keyPair.sendNonce = 0
|
keyPair.sendNonce = 0
|
||||||
keyPair.recvNonce = 0
|
keyPair.recvNonce = 0
|
||||||
|
|
||||||
peer.handshake.state = HandshakeReset
|
// zero handshake
|
||||||
|
|
||||||
|
handshake.chainKey = [blake2s.Size]byte{}
|
||||||
|
handshake.localEphemeral = NoisePrivateKey{}
|
||||||
|
peer.handshake.state = HandshakeZeroed
|
||||||
|
|
||||||
return &keyPair
|
return &keyPair
|
||||||
}
|
}
|
||||||
|
@ -67,13 +67,13 @@ func TestNoiseHandshake(t *testing.T) {
|
|||||||
|
|
||||||
t.Log("exchange initiation message")
|
t.Log("exchange initiation message")
|
||||||
|
|
||||||
msg1, err := dev1.CreateMessageInitial(peer2)
|
msg1, err := dev1.CreateMessageInitiation(peer2)
|
||||||
assertNil(t, err)
|
assertNil(t, err)
|
||||||
|
|
||||||
packet := make([]byte, 0, 256)
|
packet := make([]byte, 0, 256)
|
||||||
writer := bytes.NewBuffer(packet)
|
writer := bytes.NewBuffer(packet)
|
||||||
err = binary.Write(writer, binary.LittleEndian, msg1)
|
err = binary.Write(writer, binary.LittleEndian, msg1)
|
||||||
peer := dev2.ConsumeMessageInitial(msg1)
|
peer := dev2.ConsumeMessageInitiation(msg1)
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
t.Fatal("handshake failed at initiation message")
|
t.Fatal("handshake failed at initiation message")
|
||||||
}
|
}
|
||||||
|
51
src/peer.go
51
src/peer.go
@ -1,39 +1,64 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"golang.org/x/crypto/blake2s"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
OutboundQueueSize = 64
|
||||||
|
)
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
endpointIP net.IP //
|
endpointIP net.IP //
|
||||||
endpointPort uint16 //
|
endpointPort uint16 //
|
||||||
persistentKeepaliveInterval time.Duration // 0 = disabled
|
persistentKeepaliveInterval time.Duration // 0 = disabled
|
||||||
|
keyPairs KeyPairs
|
||||||
handshake Handshake
|
handshake Handshake
|
||||||
device *Device
|
device *Device
|
||||||
|
macKey [blake2s.Size]byte // Hash(Label-Mac1 || publicKey)
|
||||||
|
cookie []byte // cookie
|
||||||
|
cookieExpire time.Time
|
||||||
|
queueInbound chan []byte
|
||||||
|
queueOutbound chan *OutboundWorkQueueElement
|
||||||
|
queueOutboundRouting chan []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
|
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
|
||||||
var peer Peer
|
var peer Peer
|
||||||
|
|
||||||
// map public key
|
// create peer
|
||||||
|
|
||||||
device.mutex.Lock()
|
|
||||||
device.peers[pk] = &peer
|
|
||||||
device.mutex.Unlock()
|
|
||||||
|
|
||||||
// precompute
|
|
||||||
|
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
peer.device = device
|
peer.device = device
|
||||||
func(h *Handshake) {
|
peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
|
||||||
h.mutex.Lock()
|
|
||||||
h.remoteStatic = pk
|
// map public key
|
||||||
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
|
|
||||||
h.mutex.Unlock()
|
device.mutex.Lock()
|
||||||
}(&peer.handshake)
|
_, ok := device.peers[pk]
|
||||||
|
if ok {
|
||||||
|
panic(errors.New("bug: adding existing peer"))
|
||||||
|
}
|
||||||
|
device.peers[pk] = &peer
|
||||||
|
device.mutex.Unlock()
|
||||||
|
|
||||||
|
// precompute DH
|
||||||
|
|
||||||
|
handshake := &peer.handshake
|
||||||
|
handshake.mutex.Lock()
|
||||||
|
handshake.remoteStatic = pk
|
||||||
|
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
|
||||||
|
|
||||||
|
// compute mac key
|
||||||
|
|
||||||
|
peer.macKey = blake2s.Sum256(append([]byte(WGLabelMAC1[:]), handshake.remoteStatic[:]...))
|
||||||
|
|
||||||
|
handshake.mutex.Unlock()
|
||||||
peer.mutex.Unlock()
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
return &peer
|
return &peer
|
||||||
|
@ -2,7 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
@ -52,25 +51,3 @@ func (table *RoutingTable) LookupIPv6(address []byte) *Peer {
|
|||||||
defer table.mutex.RUnlock()
|
defer table.mutex.RUnlock()
|
||||||
return table.IPv6.Lookup(address)
|
return table.IPv6.Lookup(address)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OutgoingRoutingWorker(device *Device, queue chan []byte) {
|
|
||||||
for {
|
|
||||||
packet := <-queue
|
|
||||||
switch packet[0] >> 4 {
|
|
||||||
|
|
||||||
case IPv4version:
|
|
||||||
dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
|
||||||
peer := device.routingTable.LookupIPv4(dst)
|
|
||||||
fmt.Println("IPv4", peer)
|
|
||||||
|
|
||||||
case IPv6version:
|
|
||||||
dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
|
||||||
peer := device.routingTable.LookupIPv6(dst)
|
|
||||||
fmt.Println("IPv6", peer)
|
|
||||||
|
|
||||||
default:
|
|
||||||
// todo: log
|
|
||||||
fmt.Println("Unknown IP version")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
154
src/send.go
Normal file
154
src/send.go
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Handles outbound flow
|
||||||
|
*
|
||||||
|
* 1. TUN queue
|
||||||
|
* 2. Routing
|
||||||
|
* 3. Per peer queuing
|
||||||
|
* 4. (work queuing)
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
type OutboundWorkQueueElement struct {
|
||||||
|
wg sync.WaitGroup
|
||||||
|
packet []byte
|
||||||
|
nonce uint64
|
||||||
|
keyPair *KeyPair
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) SendPacket(packet []byte) {
|
||||||
|
|
||||||
|
// lookup peer
|
||||||
|
|
||||||
|
var peer *Peer
|
||||||
|
switch packet[0] >> 4 {
|
||||||
|
case IPv4version:
|
||||||
|
dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||||
|
peer = device.routingTable.LookupIPv4(dst)
|
||||||
|
|
||||||
|
case IPv6version:
|
||||||
|
dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||||
|
peer = device.routingTable.LookupIPv6(dst)
|
||||||
|
|
||||||
|
default:
|
||||||
|
device.logger.Println("unknown IP version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert into peer queue
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case peer.queueOutboundRouting <- packet:
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case <-peer.queueOutboundRouting:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Go routine
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* 1. waits for handshake.
|
||||||
|
* 2. assigns key pair & nonce
|
||||||
|
* 3. inserts to working queue
|
||||||
|
*
|
||||||
|
* TODO: avoid dynamic allocation of work queue elements
|
||||||
|
*/
|
||||||
|
func (peer *Peer) ConsumeOutboundPackets() {
|
||||||
|
for {
|
||||||
|
// wait for key pair
|
||||||
|
keyPair := func() *KeyPair {
|
||||||
|
peer.keyPairs.mutex.RLock()
|
||||||
|
defer peer.keyPairs.mutex.RUnlock()
|
||||||
|
return peer.keyPairs.current
|
||||||
|
}()
|
||||||
|
if keyPair == nil {
|
||||||
|
if len(peer.queueOutboundRouting) > 0 {
|
||||||
|
// TODO: start handshake
|
||||||
|
<-peer.keyPairs.newKeyPair
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign packets key pair
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-peer.keyPairs.newKeyPair:
|
||||||
|
default:
|
||||||
|
case <-peer.keyPairs.newKeyPair:
|
||||||
|
case packet := <-peer.queueOutboundRouting:
|
||||||
|
|
||||||
|
// create new work element
|
||||||
|
|
||||||
|
work := new(OutboundWorkQueueElement)
|
||||||
|
work.wg.Add(1)
|
||||||
|
work.keyPair = keyPair
|
||||||
|
work.packet = packet
|
||||||
|
work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
|
||||||
|
|
||||||
|
peer.queueOutbound <- work
|
||||||
|
|
||||||
|
// drop packets until there is room
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case peer.device.queueWorkOutbound <- work:
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
drop := <-peer.device.queueWorkOutbound
|
||||||
|
drop.packet = nil
|
||||||
|
drop.wg.Done()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) RoutineSequential() {
|
||||||
|
for work := range peer.queueOutbound {
|
||||||
|
work.wg.Wait()
|
||||||
|
if work.packet == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) EncryptionWorker() {
|
||||||
|
for {
|
||||||
|
work := <-device.queueWorkOutbound
|
||||||
|
|
||||||
|
func() {
|
||||||
|
defer work.wg.Done()
|
||||||
|
|
||||||
|
// pad packet
|
||||||
|
padding := device.mtu - len(work.packet)
|
||||||
|
if padding < 0 {
|
||||||
|
work.packet = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for n := 0; n < padding; n += 1 {
|
||||||
|
work.packet = append(work.packet, 0) // TODO: gotta be a faster way
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user