1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00
wireguard-go/noise-protocol.go
2018-05-10 16:08:03 +02:00

588 lines
13 KiB
Go

/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2015-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
package main
import (
"./tai64n"
"errors"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
"sync"
"time"
)
const (
HandshakeZeroed = iota
HandshakeInitiationCreated
HandshakeInitiationConsumed
HandshakeResponseCreated
HandshakeResponseConsumed
)
const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
WGLabelMAC1 = "mac1----"
WGLabelCookie = "cookie--"
)
const (
MessageInitiationType = 1
MessageResponseType = 2
MessageCookieReplyType = 3
MessageTransportType = 4
)
const (
MessageInitiationSize = 148 // size of handshake initation message
MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
)
const (
MessageTransportOffsetReceiver = 4
MessageTransportOffsetCounter = 8
MessageTransportOffsetContent = 16
)
/* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder
* we can treat these as a 32-bit unsigned int (for now)
*
*/
type MessageInitiation struct {
Type uint32
Sender uint32
Ephemeral NoisePublicKey
Static [NoisePublicKeySize + poly1305.TagSize]byte
Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
MAC1 [blake2s.Size128]byte
MAC2 [blake2s.Size128]byte
}
type MessageResponse struct {
Type uint32
Sender uint32
Receiver uint32
Ephemeral NoisePublicKey
Empty [poly1305.TagSize]byte
MAC1 [blake2s.Size128]byte
MAC2 [blake2s.Size128]byte
}
type MessageTransport struct {
Type uint32
Receiver uint32
Counter uint64
Content []byte
}
type MessageCookieReply struct {
Type uint32
Receiver uint32
Nonce [24]byte
Cookie [blake2s.Size128 + poly1305.TagSize]byte
}
type Handshake struct {
state int
mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
presharedKey NoiseSymmetricKey // psk
localEphemeral NoisePrivateKey // ephemeral secret key
localIndex uint32 // used to clear hash-table
remoteIndex uint32 // index for sending
remoteStatic NoisePublicKey // long term key
remoteEphemeral NoisePublicKey // ephemeral public key
precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
lastTimestamp tai64n.Timestamp
lastInitiationConsumption time.Time
}
var (
InitialChainKey [blake2s.Size]byte
InitialHash [blake2s.Size]byte
ZeroNonce [chacha20poly1305.NonceSize]byte
)
func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data)
}
func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
hsh, _ := blake2s.New256(nil)
hsh.Write(h[:])
hsh.Write(data)
hsh.Sum(dst[:0])
hsh.Reset()
}
func (h *Handshake) Clear() {
setZero(h.localEphemeral[:])
setZero(h.remoteEphemeral[:])
setZero(h.chainKey[:])
setZero(h.hash[:])
h.localIndex = 0
h.state = HandshakeZeroed
}
func (h *Handshake) mixHash(data []byte) {
mixHash(&h.hash, &h.hash, data)
}
func (h *Handshake) mixKey(data []byte) {
mixKey(&h.chainKey, &h.chainKey, data)
}
/* Do basic precomputations
*/
func init() {
InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
device.noise.mutex.RLock()
defer device.noise.mutex.RUnlock()
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("Static shared secret is zero")
}
// create ephemeral key
var err error
handshake.hash = InitialHash
handshake.chainKey = InitialChainKey
handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
// assign index
device.indices.Delete(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil {
return nil, err
}
handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
}
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
func() {
var key [chacha20poly1305.KeySize]byte
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
ss[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:])
}()
handshake.mixHash(msg.Static[:])
// encrypt timestamp
timestamp := tai64n.Now()
func() {
var key [chacha20poly1305.KeySize]byte
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}()
handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated
return &msg, nil
}
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
var (
hash [blake2s.Size]byte
chainKey [blake2s.Size]byte
)
if msg.Type != MessageInitiationType {
return nil
}
device.noise.mutex.RLock()
defer device.noise.mutex.RUnlock()
mixHash(&hash, &InitialHash, device.noise.publicKey[:])
mixHash(&hash, &hash, msg.Ephemeral[:])
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
var err error
var peerPK NoisePublicKey
func() {
var key [chacha20poly1305.KeySize]byte
ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
}()
if err != nil {
return nil
}
mixHash(&hash, &hash, msg.Static[:])
// lookup peer
peer := device.LookupPeer(peerPK)
if peer == nil {
return nil
}
handshake := &peer.handshake
if isZero(handshake.precomputedStaticStatic[:]) {
return nil
}
// verify identity
var timestamp tai64n.Timestamp
var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock()
KDF2(
&chainKey,
&key,
chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil {
handshake.mutex.RUnlock()
return nil
}
mixHash(&hash, &hash, msg.Timestamp[:])
// protect against replay & flood
var ok bool
ok = timestamp.After(handshake.lastTimestamp)
ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate
handshake.mutex.RUnlock()
if !ok {
return nil
}
// update handshake state
handshake.mutex.Lock()
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp
handshake.lastInitiationConsumption = time.Now()
handshake.state = HandshakeInitiationConsumed
handshake.mutex.Unlock()
return peer
}
func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitiationConsumed {
return nil, errors.New("handshake initation must be consumed first")
}
// assign index
var err error
device.indices.Delete(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil {
return nil, err
}
var msg MessageResponse
msg.Type = MessageResponseType
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
// create ephemeral key
handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
msg.Ephemeral = handshake.localEphemeral.publicKey()
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
func() {
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
handshake.mixKey(ss[:])
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
handshake.mixKey(ss[:])
}()
// add preshared key (psk)
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
KDF3(
&handshake.chainKey,
&tau,
&key,
handshake.chainKey[:],
handshake.presharedKey[:],
)
handshake.mixHash(tau[:])
func() {
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
handshake.mixHash(msg.Empty[:])
}()
handshake.state = HandshakeResponseCreated
return &msg, nil
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
if msg.Type != MessageResponseType {
return nil
}
// lookup handshake by reciever
lookup := device.indices.Lookup(msg.Receiver)
handshake := lookup.handshake
if handshake == nil {
return nil
}
var (
hash [blake2s.Size]byte
chainKey [blake2s.Size]byte
)
ok := func() bool {
// lock handshake state
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
if handshake.state != HandshakeInitiationCreated {
return false
}
// lock private key for reading
device.noise.mutex.RLock()
defer device.noise.mutex.RUnlock()
// finish 3-way DH
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
func() {
ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
// add preshared key (psk)
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
KDF3(
&chainKey,
&tau,
&key,
chainKey[:],
handshake.presharedKey[:],
)
mixHash(&hash, &hash, tau[:])
// authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
device.log.Debug.Println("failed to open")
return false
}
mixHash(&hash, &hash, msg.Empty[:])
return true
}()
if !ok {
return nil
}
// update handshake state
handshake.mutex.Lock()
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed
handshake.mutex.Unlock()
setZero(hash[:])
setZero(chainKey[:])
return lookup.peer
}
/* Derives a new key-pair from the current handshake state
*
*/
func (peer *Peer) NewKeypair() *Keypair {
device := peer.device
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
// derive keys
var isInitiator bool
var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == HandshakeResponseConsumed {
KDF2(
&sendKey,
&recvKey,
handshake.chainKey[:],
nil,
)
isInitiator = true
} else if handshake.state == HandshakeResponseCreated {
KDF2(
&recvKey,
&sendKey,
handshake.chainKey[:],
nil,
)
isInitiator = false
} else {
return nil
}
// zero handshake
setZero(handshake.chainKey[:])
setZero(handshake.localEphemeral[:])
peer.handshake.state = HandshakeZeroed
// create AEAD instances
keyPair := new(Keypair)
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
setZero(sendKey[:])
setZero(recvKey[:])
keyPair.created = time.Now()
keyPair.sendNonce = 0
keyPair.replayFilter.Init()
keyPair.isInitiator = isInitiator
keyPair.localIndex = peer.handshake.localIndex
keyPair.remoteIndex = peer.handshake.remoteIndex
// remap index
device.indices.Insert(
handshake.localIndex,
IndexTableEntry{
peer: peer,
keyPair: keyPair,
handshake: nil,
},
)
handshake.localIndex = 0
// rotate key pairs
kp := &peer.keyPairs
kp.mutex.Lock()
peer.timersSessionDerived()
previous := kp.previous
next := kp.next
current := kp.current
if isInitiator {
if next != nil {
kp.next = nil
kp.previous = next
device.DeleteKeypair(current)
} else {
kp.previous = current
}
device.DeleteKeypair(previous)
kp.current = keyPair
} else {
kp.next = keyPair
device.DeleteKeypair(next)
kp.previous = nil
device.DeleteKeypair(previous)
}
kp.mutex.Unlock()
return keyPair
}