1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Beginning work noise handshake

This commit is contained in:
Mathias Hall-Andersen 2017-06-23 13:41:59 +02:00
parent 1868d15914
commit 50aeefcb51
7 changed files with 422 additions and 4 deletions

View File

@ -1,12 +1,17 @@
package main package main
import ( import (
"math/rand"
"sync" "sync"
) )
/* TODO: Locking may be a little broad here
*/
type Device struct { type Device struct {
mutex sync.RWMutex mutex sync.RWMutex
peers map[NoisePublicKey]*Peer peers map[NoisePublicKey]*Peer
sessions map[uint32]*Handshake
privateKey NoisePrivateKey privateKey NoisePrivateKey
publicKey NoisePublicKey publicKey NoisePublicKey
fwMark uint32 fwMark uint32
@ -14,6 +19,19 @@ type Device struct {
routingTable RoutingTable routingTable RoutingTable
} }
func (dev *Device) NewID(h *Handshake) uint32 {
dev.mutex.Lock()
defer dev.mutex.Unlock()
for {
id := rand.Uint32()
_, ok := dev.sessions[id]
if !ok {
dev.sessions[id] = h
return id
}
}
}
func (dev *Device) RemovePeer(key NoisePublicKey) { func (dev *Device) RemovePeer(key NoisePublicKey) {
dev.mutex.Lock() dev.mutex.Lock()
defer dev.mutex.Unlock() defer dev.mutex.Unlock()

76
src/kdf_test.go Normal file
View File

@ -0,0 +1,76 @@
package main
import (
"encoding/hex"
"testing"
)
type KDFTest struct {
key string
input string
t0 string
t1 string
t2 string
}
func assertEquals(t *testing.T, a string, b string) {
if a != b {
t.Fatal("expected", a, "=", b)
}
}
func TestKDF(t *testing.T) {
tests := []KDFTest{
{
key: "746573742d6b6579",
input: "746573742d696e707574",
t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633",
t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a",
t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24",
},
{
key: "776972656775617264",
input: "776972656775617264",
t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8",
t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f",
t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160",
},
{
key: "",
input: "",
t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0",
t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e",
t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e",
},
}
for _, test := range tests {
key, _ := hex.DecodeString(test.key)
input, _ := hex.DecodeString(test.input)
t0, t1, t2 := KDF3(key, input)
t0s := hex.EncodeToString(t0[:])
t1s := hex.EncodeToString(t1[:])
t2s := hex.EncodeToString(t2[:])
assertEquals(t, t0s, test.t0)
assertEquals(t, t1s, test.t1)
assertEquals(t, t2s, test.t2)
}
for _, test := range tests {
key, _ := hex.DecodeString(test.key)
input, _ := hex.DecodeString(test.input)
t0, t1 := KDF2(key, input)
t0s := hex.EncodeToString(t0[:])
t1s := hex.EncodeToString(t1[:])
assertEquals(t, t0s, test.t0)
assertEquals(t, t1s, test.t1)
}
for _, test := range tests {
key, _ := hex.DecodeString(test.key)
input, _ := hex.DecodeString(test.input)
t0 := KDF1(key, input)
t0s := hex.EncodeToString(t0[:])
assertEquals(t, t0s, test.t0)
}
}

86
src/noise_helpers.go Normal file
View File

@ -0,0 +1,86 @@
package main
import (
"crypto/hmac"
"crypto/rand"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/curve25519"
"hash"
)
/* KDF related functions.
* HMAC-based Key Derivation Function (HKDF)
* https://tools.ietf.org/html/rfc5869
*/
func HMAC(sum *[blake2s.Size]byte, key []byte, input []byte) {
mac := hmac.New(func() hash.Hash {
h, _ := blake2s.New256(nil)
return h
}, key)
mac.Write(input)
mac.Sum(sum[:0])
}
func KDF1(key []byte, input []byte) (t0 [blake2s.Size]byte) {
HMAC(&t0, key, input)
HMAC(&t0, t0[:], []byte{0x1})
return
}
func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte) {
var prk [blake2s.Size]byte
HMAC(&prk, key, input)
HMAC(&t0, prk[:], []byte{0x1})
HMAC(&t1, prk[:], append(t0[:], 0x2))
return
}
func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte, t2 [blake2s.Size]byte) {
var prk [blake2s.Size]byte
HMAC(&prk, key, input)
HMAC(&t0, prk[:], []byte{0x1})
HMAC(&t1, prk[:], append(t0[:], 0x2))
HMAC(&t2, prk[:], append(t1[:], 0x3))
return
}
/*
*
*/
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return KDF1(c[:], data)
}
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return blake2s.Sum256(append(h[:], data...))
}
/* Curve25519 wrappers
*
* TODO: Rethink this
*/
func newPrivateKey() (sk NoisePrivateKey, err error) {
// clamping: https://cr.yp.to/ecdh.html
_, err = rand.Read(sk[:])
sk[0] &= 248
sk[31] &= 127
sk[31] |= 64
return
}
func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarBaseMult(apk, ask)
return
}
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, apk, ask)
return ss
}

179
src/noise_protocol.go Normal file
View File

@ -0,0 +1,179 @@
package main
import (
"errors"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
"sync"
)
const (
HandshakeInitialCreated = iota
HandshakeInitialConsumed
HandshakeResponseCreated
)
const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
WGLabelMAC1 = "mac1----"
WGLabelCookie = "cookie--"
)
const (
MessageInitalType = 1
MessageResponseType = 2
MessageCookieResponseType = 3
MessageTransportType = 4
)
type MessageInital struct {
Type uint32
Sender uint32
Ephemeral NoisePublicKey
Static [NoisePublicKeySize + poly1305.TagSize]byte
Timestamp [TAI64NSize + poly1305.TagSize]byte
Mac1 [blake2s.Size128]byte
Mac2 [blake2s.Size128]byte
}
type MessageResponse struct {
Type uint32
Sender uint32
Reciever uint32
Ephemeral NoisePublicKey
Empty [poly1305.TagSize]byte
Mac1 [blake2s.Size128]byte
Mac2 [blake2s.Size128]byte
}
type MessageTransport struct {
Type uint32
Reciever uint32
Counter uint64
Content []byte
}
type Handshake struct {
lock sync.Mutex
state int
chainKey [blake2s.Size]byte // chain key
hash [blake2s.Size]byte // hash value
staticStatic NoisePublicKey // precomputed DH(S_i, S_r)
ephemeral NoisePrivateKey // ephemeral secret key
remoteIndex uint32 // index for sending
device *Device
peer *Peer
}
var (
ZeroNonce [chacha20poly1305.NonceSize]byte
InitalChainKey [blake2s.Size]byte
InitalHash [blake2s.Size]byte
)
func init() {
InitalChainKey = blake2s.Sum256([]byte(NoiseConstruction))
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
}
func (h *Handshake) Precompute() {
h.staticStatic = h.device.privateKey.sharedSecret(h.peer.publicKey)
}
func (h *Handshake) ConsumeMessageResponse(msg *MessageResponse) {
}
func (h *Handshake) addHash(data []byte) {
h.hash = addToHash(h.hash, data)
}
func (h *Handshake) addChain(data []byte) {
h.chainKey = addToChainKey(h.chainKey, data)
}
func (h *Handshake) CreateMessageInital() (*MessageInital, error) {
h.lock.Lock()
defer h.lock.Unlock()
// reset handshake
var err error
h.ephemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
h.chainKey = InitalChainKey
h.hash = addToHash(InitalHash, h.device.publicKey[:])
// create ephemeral key
var msg MessageInital
msg.Type = MessageInitalType
msg.Sender = h.device.NewID(h)
msg.Ephemeral = h.ephemeral.publicKey()
h.chainKey = addToChainKey(h.chainKey, msg.Ephemeral[:])
h.hash = addToHash(h.hash, msg.Ephemeral[:])
// encrypt long-term "identity key"
func() {
var key [chacha20poly1305.KeySize]byte
ss := h.ephemeral.sharedSecret(h.peer.publicKey)
h.chainKey, key = KDF2(h.chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], h.device.publicKey[:], nil)
}()
h.addHash(msg.Static[:])
// encrypt timestamp
timestamp := Timestamp()
func() {
var key [chacha20poly1305.KeySize]byte
h.chainKey, key = KDF2(h.chainKey[:], h.staticStatic[:])
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], nil)
}()
h.addHash(msg.Timestamp[:])
h.state = HandshakeInitialCreated
return &msg, nil
}
func (h *Handshake) ConsumeMessageInitial(msg *MessageInital) error {
if msg.Type != MessageInitalType {
panic(errors.New("bug: invalid inital message type"))
}
hash := addToHash(InitalHash, h.device.publicKey[:])
chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
hash = addToHash(hash, msg.Ephemeral[:])
//
ephemeral, err := newPrivateKey()
if err != nil {
return err
}
// update handshake state
h.lock.Lock()
defer h.lock.Unlock()
h.hash = hash
h.chainKey = chainKey
h.remoteIndex = msg.Sender
h.ephemeral = ephemeral
h.state = HandshakeInitialConsumed
return nil
}
func (h *Handshake) CreateMessageResponse() []byte {
return nil
}

38
src/noise_test.go Normal file
View File

@ -0,0 +1,38 @@
package main
import (
"testing"
)
func TestHandshake(t *testing.T) {
var dev1 Device
var dev2 Device
var err error
dev1.privateKey, err = newPrivateKey()
if err != nil {
t.Fatal(err)
}
dev2.privateKey, err = newPrivateKey()
if err != nil {
t.Fatal(err)
}
var peer1 Peer
var peer2 Peer
peer1.publicKey = dev1.privateKey.publicKey()
peer2.publicKey = dev2.privateKey.publicKey()
var handshake1 Handshake
var handshake2 Handshake
handshake1.device = &dev1
handshake2.device = &dev2
handshake1.peer = &peer2
handshake2.peer = &peer1
}

View File

@ -14,8 +14,6 @@ const (
type ( type (
NoisePublicKey [NoisePublicKeySize]byte NoisePublicKey [NoisePublicKeySize]byte
NoisePrivateKey [NoisePrivateKeySize]byte NoisePrivateKey [NoisePrivateKeySize]byte
NoiseSymmetricKey [NoiseSymmetricKeySize]byte
NoiseNonce uint64 // padded to 12-bytes
) )
func loadExactHex(dst []byte, src string) error { func loadExactHex(dst []byte, src string) error {

23
src/tai64.go Normal file
View File

@ -0,0 +1,23 @@
package main
import (
"encoding/binary"
"time"
)
const (
TAI64NBase = uint64(4611686018427387914)
TAI64NSize = 12
)
type TAI64N [TAI64NSize]byte
func Timestamp() TAI64N {
var tai64n TAI64N
now := time.Now()
secs := TAI64NBase + uint64(now.Unix())
nano := uint32(now.UnixNano())
binary.BigEndian.PutUint64(tai64n[:], secs)
binary.BigEndian.PutUint32(tai64n[8:], nano)
return tai64n
}