mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
Beginning work noise handshake
This commit is contained in:
parent
1868d15914
commit
50aeefcb51
@ -1,12 +1,17 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/* TODO: Locking may be a little broad here
|
||||||
|
*/
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
peers map[NoisePublicKey]*Peer
|
peers map[NoisePublicKey]*Peer
|
||||||
|
sessions map[uint32]*Handshake
|
||||||
privateKey NoisePrivateKey
|
privateKey NoisePrivateKey
|
||||||
publicKey NoisePublicKey
|
publicKey NoisePublicKey
|
||||||
fwMark uint32
|
fwMark uint32
|
||||||
@ -14,6 +19,19 @@ type Device struct {
|
|||||||
routingTable RoutingTable
|
routingTable RoutingTable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dev *Device) NewID(h *Handshake) uint32 {
|
||||||
|
dev.mutex.Lock()
|
||||||
|
defer dev.mutex.Unlock()
|
||||||
|
for {
|
||||||
|
id := rand.Uint32()
|
||||||
|
_, ok := dev.sessions[id]
|
||||||
|
if !ok {
|
||||||
|
dev.sessions[id] = h
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (dev *Device) RemovePeer(key NoisePublicKey) {
|
func (dev *Device) RemovePeer(key NoisePublicKey) {
|
||||||
dev.mutex.Lock()
|
dev.mutex.Lock()
|
||||||
defer dev.mutex.Unlock()
|
defer dev.mutex.Unlock()
|
||||||
|
76
src/kdf_test.go
Normal file
76
src/kdf_test.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KDFTest struct {
|
||||||
|
key string
|
||||||
|
input string
|
||||||
|
t0 string
|
||||||
|
t1 string
|
||||||
|
t2 string
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEquals(t *testing.T, a string, b string) {
|
||||||
|
if a != b {
|
||||||
|
t.Fatal("expected", a, "=", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKDF(t *testing.T) {
|
||||||
|
tests := []KDFTest{
|
||||||
|
{
|
||||||
|
key: "746573742d6b6579",
|
||||||
|
input: "746573742d696e707574",
|
||||||
|
t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633",
|
||||||
|
t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a",
|
||||||
|
t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: "776972656775617264",
|
||||||
|
input: "776972656775617264",
|
||||||
|
t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8",
|
||||||
|
t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f",
|
||||||
|
t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: "",
|
||||||
|
input: "",
|
||||||
|
t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0",
|
||||||
|
t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e",
|
||||||
|
t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
key, _ := hex.DecodeString(test.key)
|
||||||
|
input, _ := hex.DecodeString(test.input)
|
||||||
|
t0, t1, t2 := KDF3(key, input)
|
||||||
|
t0s := hex.EncodeToString(t0[:])
|
||||||
|
t1s := hex.EncodeToString(t1[:])
|
||||||
|
t2s := hex.EncodeToString(t2[:])
|
||||||
|
assertEquals(t, t0s, test.t0)
|
||||||
|
assertEquals(t, t1s, test.t1)
|
||||||
|
assertEquals(t, t2s, test.t2)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
key, _ := hex.DecodeString(test.key)
|
||||||
|
input, _ := hex.DecodeString(test.input)
|
||||||
|
t0, t1 := KDF2(key, input)
|
||||||
|
t0s := hex.EncodeToString(t0[:])
|
||||||
|
t1s := hex.EncodeToString(t1[:])
|
||||||
|
assertEquals(t, t0s, test.t0)
|
||||||
|
assertEquals(t, t1s, test.t1)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
key, _ := hex.DecodeString(test.key)
|
||||||
|
input, _ := hex.DecodeString(test.input)
|
||||||
|
t0 := KDF1(key, input)
|
||||||
|
t0s := hex.EncodeToString(t0[:])
|
||||||
|
assertEquals(t, t0s, test.t0)
|
||||||
|
}
|
||||||
|
}
|
86
src/noise_helpers.go
Normal file
86
src/noise_helpers.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"golang.org/x/crypto/blake2s"
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
"hash"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* KDF related functions.
|
||||||
|
* HMAC-based Key Derivation Function (HKDF)
|
||||||
|
* https://tools.ietf.org/html/rfc5869
|
||||||
|
*/
|
||||||
|
|
||||||
|
func HMAC(sum *[blake2s.Size]byte, key []byte, input []byte) {
|
||||||
|
mac := hmac.New(func() hash.Hash {
|
||||||
|
h, _ := blake2s.New256(nil)
|
||||||
|
return h
|
||||||
|
}, key)
|
||||||
|
mac.Write(input)
|
||||||
|
mac.Sum(sum[:0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func KDF1(key []byte, input []byte) (t0 [blake2s.Size]byte) {
|
||||||
|
HMAC(&t0, key, input)
|
||||||
|
HMAC(&t0, t0[:], []byte{0x1})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte) {
|
||||||
|
var prk [blake2s.Size]byte
|
||||||
|
HMAC(&prk, key, input)
|
||||||
|
HMAC(&t0, prk[:], []byte{0x1})
|
||||||
|
HMAC(&t1, prk[:], append(t0[:], 0x2))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte, t2 [blake2s.Size]byte) {
|
||||||
|
var prk [blake2s.Size]byte
|
||||||
|
HMAC(&prk, key, input)
|
||||||
|
HMAC(&t0, prk[:], []byte{0x1})
|
||||||
|
HMAC(&t1, prk[:], append(t0[:], 0x2))
|
||||||
|
HMAC(&t2, prk[:], append(t1[:], 0x3))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
||||||
|
return KDF1(c[:], data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
||||||
|
return blake2s.Sum256(append(h[:], data...))
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Curve25519 wrappers
|
||||||
|
*
|
||||||
|
* TODO: Rethink this
|
||||||
|
*/
|
||||||
|
|
||||||
|
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
||||||
|
// clamping: https://cr.yp.to/ecdh.html
|
||||||
|
_, err = rand.Read(sk[:])
|
||||||
|
sk[0] &= 248
|
||||||
|
sk[31] &= 127
|
||||||
|
sk[31] |= 64
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
|
||||||
|
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||||
|
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||||
|
curve25519.ScalarBaseMult(apk, ask)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
|
||||||
|
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||||
|
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||||
|
curve25519.ScalarMult(&ss, apk, ask)
|
||||||
|
return ss
|
||||||
|
}
|
179
src/noise_protocol.go
Normal file
179
src/noise_protocol.go
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"golang.org/x/crypto/blake2s"
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"golang.org/x/crypto/poly1305"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
HandshakeInitialCreated = iota
|
||||||
|
HandshakeInitialConsumed
|
||||||
|
HandshakeResponseCreated
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
||||||
|
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
||||||
|
WGLabelMAC1 = "mac1----"
|
||||||
|
WGLabelCookie = "cookie--"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MessageInitalType = 1
|
||||||
|
MessageResponseType = 2
|
||||||
|
MessageCookieResponseType = 3
|
||||||
|
MessageTransportType = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
type MessageInital struct {
|
||||||
|
Type uint32
|
||||||
|
Sender uint32
|
||||||
|
Ephemeral NoisePublicKey
|
||||||
|
Static [NoisePublicKeySize + poly1305.TagSize]byte
|
||||||
|
Timestamp [TAI64NSize + poly1305.TagSize]byte
|
||||||
|
Mac1 [blake2s.Size128]byte
|
||||||
|
Mac2 [blake2s.Size128]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageResponse struct {
|
||||||
|
Type uint32
|
||||||
|
Sender uint32
|
||||||
|
Reciever uint32
|
||||||
|
Ephemeral NoisePublicKey
|
||||||
|
Empty [poly1305.TagSize]byte
|
||||||
|
Mac1 [blake2s.Size128]byte
|
||||||
|
Mac2 [blake2s.Size128]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageTransport struct {
|
||||||
|
Type uint32
|
||||||
|
Reciever uint32
|
||||||
|
Counter uint64
|
||||||
|
Content []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type Handshake struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
state int
|
||||||
|
chainKey [blake2s.Size]byte // chain key
|
||||||
|
hash [blake2s.Size]byte // hash value
|
||||||
|
staticStatic NoisePublicKey // precomputed DH(S_i, S_r)
|
||||||
|
ephemeral NoisePrivateKey // ephemeral secret key
|
||||||
|
remoteIndex uint32 // index for sending
|
||||||
|
device *Device
|
||||||
|
peer *Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ZeroNonce [chacha20poly1305.NonceSize]byte
|
||||||
|
InitalChainKey [blake2s.Size]byte
|
||||||
|
InitalHash [blake2s.Size]byte
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
InitalChainKey = blake2s.Sum256([]byte(NoiseConstruction))
|
||||||
|
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handshake) Precompute() {
|
||||||
|
h.staticStatic = h.device.privateKey.sharedSecret(h.peer.publicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handshake) ConsumeMessageResponse(msg *MessageResponse) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handshake) addHash(data []byte) {
|
||||||
|
h.hash = addToHash(h.hash, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handshake) addChain(data []byte) {
|
||||||
|
h.chainKey = addToChainKey(h.chainKey, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handshake) CreateMessageInital() (*MessageInital, error) {
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
|
// reset handshake
|
||||||
|
|
||||||
|
var err error
|
||||||
|
h.ephemeral, err = newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h.chainKey = InitalChainKey
|
||||||
|
h.hash = addToHash(InitalHash, h.device.publicKey[:])
|
||||||
|
|
||||||
|
// create ephemeral key
|
||||||
|
|
||||||
|
var msg MessageInital
|
||||||
|
msg.Type = MessageInitalType
|
||||||
|
msg.Sender = h.device.NewID(h)
|
||||||
|
msg.Ephemeral = h.ephemeral.publicKey()
|
||||||
|
h.chainKey = addToChainKey(h.chainKey, msg.Ephemeral[:])
|
||||||
|
h.hash = addToHash(h.hash, msg.Ephemeral[:])
|
||||||
|
|
||||||
|
// encrypt long-term "identity key"
|
||||||
|
|
||||||
|
func() {
|
||||||
|
var key [chacha20poly1305.KeySize]byte
|
||||||
|
ss := h.ephemeral.sharedSecret(h.peer.publicKey)
|
||||||
|
h.chainKey, key = KDF2(h.chainKey[:], ss[:])
|
||||||
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
|
aead.Seal(msg.Static[:0], ZeroNonce[:], h.device.publicKey[:], nil)
|
||||||
|
}()
|
||||||
|
h.addHash(msg.Static[:])
|
||||||
|
|
||||||
|
// encrypt timestamp
|
||||||
|
|
||||||
|
timestamp := Timestamp()
|
||||||
|
func() {
|
||||||
|
var key [chacha20poly1305.KeySize]byte
|
||||||
|
h.chainKey, key = KDF2(h.chainKey[:], h.staticStatic[:])
|
||||||
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
|
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], nil)
|
||||||
|
}()
|
||||||
|
h.addHash(msg.Timestamp[:])
|
||||||
|
h.state = HandshakeInitialCreated
|
||||||
|
return &msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handshake) ConsumeMessageInitial(msg *MessageInital) error {
|
||||||
|
if msg.Type != MessageInitalType {
|
||||||
|
panic(errors.New("bug: invalid inital message type"))
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := addToHash(InitalHash, h.device.publicKey[:])
|
||||||
|
chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
|
||||||
|
hash = addToHash(hash, msg.Ephemeral[:])
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
|
ephemeral, err := newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// update handshake state
|
||||||
|
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
|
h.hash = hash
|
||||||
|
h.chainKey = chainKey
|
||||||
|
h.remoteIndex = msg.Sender
|
||||||
|
h.ephemeral = ephemeral
|
||||||
|
h.state = HandshakeInitialConsumed
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handshake) CreateMessageResponse() []byte {
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
38
src/noise_test.go
Normal file
38
src/noise_test.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandshake(t *testing.T) {
|
||||||
|
var dev1 Device
|
||||||
|
var dev2 Device
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
dev1.privateKey, err = newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dev2.privateKey, err = newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var peer1 Peer
|
||||||
|
var peer2 Peer
|
||||||
|
|
||||||
|
peer1.publicKey = dev1.privateKey.publicKey()
|
||||||
|
peer2.publicKey = dev2.privateKey.publicKey()
|
||||||
|
|
||||||
|
var handshake1 Handshake
|
||||||
|
var handshake2 Handshake
|
||||||
|
|
||||||
|
handshake1.device = &dev1
|
||||||
|
handshake2.device = &dev2
|
||||||
|
|
||||||
|
handshake1.peer = &peer2
|
||||||
|
handshake2.peer = &peer1
|
||||||
|
|
||||||
|
}
|
@ -14,8 +14,6 @@ const (
|
|||||||
type (
|
type (
|
||||||
NoisePublicKey [NoisePublicKeySize]byte
|
NoisePublicKey [NoisePublicKeySize]byte
|
||||||
NoisePrivateKey [NoisePrivateKeySize]byte
|
NoisePrivateKey [NoisePrivateKeySize]byte
|
||||||
NoiseSymmetricKey [NoiseSymmetricKeySize]byte
|
|
||||||
NoiseNonce uint64 // padded to 12-bytes
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadExactHex(dst []byte, src string) error {
|
func loadExactHex(dst []byte, src string) error {
|
23
src/tai64.go
Normal file
23
src/tai64.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TAI64NBase = uint64(4611686018427387914)
|
||||||
|
TAI64NSize = 12
|
||||||
|
)
|
||||||
|
|
||||||
|
type TAI64N [TAI64NSize]byte
|
||||||
|
|
||||||
|
func Timestamp() TAI64N {
|
||||||
|
var tai64n TAI64N
|
||||||
|
now := time.Now()
|
||||||
|
secs := TAI64NBase + uint64(now.Unix())
|
||||||
|
nano := uint32(now.UnixNano())
|
||||||
|
binary.BigEndian.PutUint64(tai64n[:], secs)
|
||||||
|
binary.BigEndian.PutUint32(tai64n[8:], nano)
|
||||||
|
return tai64n
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user