1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2025-09-18 20:57:50 +02:00

Revert "device: use wgcfg key types"

More cleanup work of wgcfg to do before bringing this in.

This reverts commit 83ca9b47b6.
This commit is contained in:
David Crawshaw 2020-04-07 15:52:04 +10:00
parent ad256f0b73
commit f6020a2085
10 changed files with 190 additions and 68 deletions

View File

@ -13,7 +13,6 @@ import (
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.zx2c4.com/wireguard/wgcfg"
) )
type CookieChecker struct { type CookieChecker struct {
@ -42,7 +41,7 @@ type CookieGenerator struct {
} }
} }
func (st *CookieChecker) Init(pk wgcfg.Key) { func (st *CookieChecker) Init(pk NoisePublicKey) {
st.Lock() st.Lock()
defer st.Unlock() defer st.Unlock()
@ -172,7 +171,7 @@ func (st *CookieChecker) CreateReply(
return reply, nil return reply, nil
} }
func (st *CookieGenerator) Init(pk wgcfg.Key) { func (st *CookieGenerator) Init(pk NoisePublicKey) {
st.Lock() st.Lock()
defer st.Unlock() defer st.Unlock()

View File

@ -7,8 +7,6 @@ package device
import ( import (
"testing" "testing"
"golang.zx2c4.com/wireguard/wgcfg"
) )
func TestCookieMAC1(t *testing.T) { func TestCookieMAC1(t *testing.T) {
@ -20,11 +18,11 @@ func TestCookieMAC1(t *testing.T) {
checker CookieChecker checker CookieChecker
) )
sk, err := wgcfg.NewPrivateKey() sk, err := newPrivateKey()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
pk := sk.Public() pk := sk.publicKey()
generator.Init(pk) generator.Init(pk)
checker.Init(pk) checker.Init(pk)

View File

@ -17,7 +17,6 @@ import (
"golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgcfg"
) )
type Device struct { type Device struct {
@ -47,13 +46,13 @@ type Device struct {
staticIdentity struct { staticIdentity struct {
sync.RWMutex sync.RWMutex
privateKey wgcfg.PrivateKey privateKey NoisePrivateKey
publicKey wgcfg.Key publicKey NoisePublicKey
} }
peers struct { peers struct {
sync.RWMutex sync.RWMutex
keyMap map[wgcfg.Key]*Peer keyMap map[NoisePublicKey]*Peer
} }
// unprotected / "self-synchronising resources" // unprotected / "self-synchronising resources"
@ -97,7 +96,7 @@ type Device struct {
* *
* Must hold device.peers.Mutex * Must hold device.peers.Mutex
*/ */
func unsafeRemovePeer(device *Device, peer *Peer, key wgcfg.Key) { func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
// stop routing and processing of packets // stop routing and processing of packets
@ -201,13 +200,13 @@ func (device *Device) IsUnderLoad() bool {
return until.After(now) return until.After(now)
} }
func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// lock required resources // lock required resources
device.staticIdentity.Lock() device.staticIdentity.Lock()
defer device.staticIdentity.Unlock() defer device.staticIdentity.Unlock()
if sk.Equal(device.staticIdentity.privateKey) { if sk.Equals(device.staticIdentity.privateKey) {
return nil return nil
} }
@ -222,9 +221,9 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error {
// remove peers with matching public keys // remove peers with matching public keys
publicKey := sk.Public() publicKey := sk.publicKey()
for key, peer := range device.peers.keyMap { for key, peer := range device.peers.keyMap {
if peer.handshake.remoteStatic.Equal(publicKey) { if peer.handshake.remoteStatic.Equals(publicKey) {
unsafeRemovePeer(device, peer, key) unsafeRemovePeer(device, peer, key)
} }
} }
@ -240,7 +239,7 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
handshake := &peer.handshake handshake := &peer.handshake
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(handshake.remoteStatic) handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer) expiredPeers = append(expiredPeers, peer)
} }
@ -270,7 +269,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
} }
device.tun.mtu = int32(mtu) device.tun.mtu = int32(mtu)
device.peers.keyMap = make(map[wgcfg.Key]*Peer) device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init() device.rate.limiter.Init()
device.rate.underLoadUntil.Store(time.Time{}) device.rate.underLoadUntil.Store(time.Time{})
@ -318,14 +317,14 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
return device return device
} }
func (device *Device) LookupPeer(pk wgcfg.Key) *Peer { func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock() device.peers.RLock()
defer device.peers.RUnlock() defer device.peers.RUnlock()
return device.peers.keyMap[pk] return device.peers.keyMap[pk]
} }
func (device *Device) RemovePeer(key wgcfg.Key) { func (device *Device) RemovePeer(key NoisePublicKey) {
device.peers.Lock() device.peers.Lock()
defer device.peers.Unlock() defer device.peers.Unlock()
// stop peer and remove from routing // stop peer and remove from routing
@ -344,7 +343,7 @@ func (device *Device) RemoveAllPeers() {
unsafeRemovePeer(device, peer, key) unsafeRemovePeer(device, peer, key)
} }
device.peers.keyMap = make(map[wgcfg.Key]*Peer) device.peers.keyMap = make(map[NoisePublicKey]*Peer)
} }
func (device *Device) FlushPacketQueues() { func (device *Device) FlushPacketQueues() {

View File

@ -14,7 +14,6 @@ import (
"time" "time"
"golang.zx2c4.com/wireguard/tun/tuntest" "golang.zx2c4.com/wireguard/tun/tuntest"
"golang.zx2c4.com/wireguard/wgcfg"
) )
func TestTwoDevicePing(t *testing.T) { func TestTwoDevicePing(t *testing.T) {
@ -91,7 +90,7 @@ func assertEqual(t *testing.T, a, b []byte) {
} }
func randDevice(t *testing.T) *Device { func randDevice(t *testing.T) *Device {
sk, err := wgcfg.NewPrivateKey() sk, err := newPrivateKey()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -7,10 +7,12 @@ package device
import ( import (
"crypto/hmac" "crypto/hmac"
"crypto/rand"
"crypto/subtle" "crypto/subtle"
"hash" "hash"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/curve25519"
) )
/* KDF related functions. /* KDF related functions.
@ -73,3 +75,28 @@ func setZero(arr []byte) {
arr[i] = 0 arr[i] = 0
} }
} }
func (sk *NoisePrivateKey) clamp() {
sk[0] &= 248
sk[31] = (sk[31] & 127) | 64
}
func newPrivateKey() (sk NoisePrivateKey, err error) {
_, err = rand.Read(sk[:])
sk.clamp()
return
}
func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarBaseMult(apk, ask)
return
}
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk)
return ss
}

View File

@ -15,7 +15,6 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n" "golang.zx2c4.com/wireguard/tai64n"
"golang.zx2c4.com/wireguard/wgcfg"
) )
type handshakeState int type handshakeState int
@ -85,8 +84,8 @@ const (
type MessageInitiation struct { type MessageInitiation struct {
Type uint32 Type uint32
Sender uint32 Sender uint32
Ephemeral wgcfg.Key Ephemeral NoisePublicKey
Static [wgcfg.KeySize + poly1305.TagSize]byte Static [NoisePublicKeySize + poly1305.TagSize]byte
Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
MAC1 [blake2s.Size128]byte MAC1 [blake2s.Size128]byte
MAC2 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte
@ -96,7 +95,7 @@ type MessageResponse struct {
Type uint32 Type uint32
Sender uint32 Sender uint32
Receiver uint32 Receiver uint32
Ephemeral wgcfg.Key Ephemeral NoisePublicKey
Empty [poly1305.TagSize]byte Empty [poly1305.TagSize]byte
MAC1 [blake2s.Size128]byte MAC1 [blake2s.Size128]byte
MAC2 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte
@ -121,13 +120,13 @@ type Handshake struct {
mutex sync.RWMutex mutex sync.RWMutex
hash [blake2s.Size]byte // hash value hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key chainKey [blake2s.Size]byte // chain key
presharedKey wgcfg.SymmetricKey // psk presharedKey NoiseSymmetricKey // psk
localEphemeral wgcfg.PrivateKey // ephemeral secret key localEphemeral NoisePrivateKey // ephemeral secret key
localIndex uint32 // used to clear hash-table localIndex uint32 // used to clear hash-table
remoteIndex uint32 // index for sending remoteIndex uint32 // index for sending
remoteStatic wgcfg.Key // long term key remoteStatic NoisePublicKey // long term key
remoteEphemeral wgcfg.Key // ephemeral public key remoteEphemeral NoisePublicKey // ephemeral public key
precomputedStaticStatic [wgcfg.KeySize]byte // precomputed shared secret precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
lastTimestamp tai64n.Timestamp lastTimestamp tai64n.Timestamp
lastInitiationConsumption time.Time lastInitiationConsumption time.Time
lastSentHandshake time.Time lastSentHandshake time.Time
@ -189,7 +188,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
var err error var err error
handshake.hash = InitialHash handshake.hash = InitialHash
handshake.chainKey = InitialChainKey handshake.chainKey = InitialChainKey
handshake.localEphemeral, err = wgcfg.NewPrivateKey() handshake.localEphemeral, err = newPrivateKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -198,14 +197,14 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
msg := MessageInitiation{ msg := MessageInitiation{
Type: MessageInitiationType, Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.Public(), Ephemeral: handshake.localEphemeral.publicKey(),
} }
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
// encrypt static key // encrypt static key
ss := handshake.localEphemeral.SharedSecret(handshake.remoteStatic) ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if isZero(ss[:]) { if isZero(ss[:]) {
return nil, errZeroECDHResult return nil, errZeroECDHResult
} }
@ -266,9 +265,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// decrypt static key // decrypt static key
var err error var err error
var peerPK wgcfg.Key var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral) ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if isZero(ss[:]) { if isZero(ss[:]) {
return nil return nil
} }
@ -377,18 +376,18 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
// create ephemeral key // create ephemeral key
handshake.localEphemeral, err = wgcfg.NewPrivateKey() handshake.localEphemeral, err = newPrivateKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg.Ephemeral = handshake.localEphemeral.Public() msg.Ephemeral = handshake.localEphemeral.publicKey()
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
func() { func() {
ss := handshake.localEphemeral.SharedSecret(handshake.remoteEphemeral) ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
handshake.mixKey(ss[:]) handshake.mixKey(ss[:])
ss = handshake.localEphemeral.SharedSecret(handshake.remoteStatic) ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
handshake.mixKey(ss[:]) handshake.mixKey(ss[:])
}() }()
@ -458,13 +457,13 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
func() { func() {
ss := handshake.localEphemeral.SharedSecret(msg.Ephemeral) ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:]) mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:]) setZero(ss[:])
}() }()
func() { func() {
ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral) ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:]) mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:]) setZero(ss[:])
}() }()

91
device/noise-types.go Normal file
View File

@ -0,0 +1,91 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"crypto/subtle"
"encoding/hex"
"errors"
"golang.org/x/crypto/chacha20poly1305"
)
const (
NoisePublicKeySize = 32
NoisePrivateKeySize = 32
)
type (
NoisePublicKey [NoisePublicKeySize]byte
NoisePrivateKey [NoisePrivateKeySize]byte
NoiseSymmetricKey [chacha20poly1305.KeySize]byte
NoiseNonce uint64 // padded to 12-bytes
)
func loadExactHex(dst []byte, src string) error {
slice, err := hex.DecodeString(src)
if err != nil {
return err
}
if len(slice) != len(dst) {
return errors.New("hex string does not fit the slice")
}
copy(dst, slice)
return nil
}
func (key NoisePrivateKey) IsZero() bool {
var zero NoisePrivateKey
return key.Equals(zero)
}
func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
}
func (key *NoisePrivateKey) FromHex(src string) (err error) {
err = loadExactHex(key[:], src)
key.clamp()
return
}
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
err = loadExactHex(key[:], src)
if key.IsZero() {
return
}
key.clamp()
return
}
func (key NoisePrivateKey) ToHex() string {
return hex.EncodeToString(key[:])
}
func (key *NoisePublicKey) FromHex(src string) error {
return loadExactHex(key[:], src)
}
func (key NoisePublicKey) ToHex() string {
return hex.EncodeToString(key[:])
}
func (key NoisePublicKey) IsZero() bool {
var zero NoisePublicKey
return key.Equals(zero)
}
func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
}
func (key *NoiseSymmetricKey) FromHex(src string) error {
return loadExactHex(key[:], src)
}
func (key NoiseSymmetricKey) ToHex() string {
return hex.EncodeToString(key[:])
}

View File

@ -11,6 +11,24 @@ import (
"testing" "testing"
) )
func TestCurveWrappers(t *testing.T) {
sk1, err := newPrivateKey()
assertNil(t, err)
sk2, err := newPrivateKey()
assertNil(t, err)
pk1 := sk1.publicKey()
pk2 := sk2.publicKey()
ss1 := sk1.sharedSecret(pk2)
ss2 := sk2.sharedSecret(pk1)
if ss1 != ss2 {
t.Fatal("Failed to compute shared secet")
}
}
func TestNoiseHandshake(t *testing.T) { func TestNoiseHandshake(t *testing.T) {
dev1 := randDevice(t) dev1 := randDevice(t)
dev2 := randDevice(t) dev2 := randDevice(t)
@ -18,14 +36,8 @@ func TestNoiseHandshake(t *testing.T) {
defer dev1.Close() defer dev1.Close()
defer dev2.Close() defer dev2.Close()
peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.Public()) peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
if err != nil { peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
t.Fatal(err)
}
peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.Public())
if err != nil {
t.Fatal(err)
}
assertEqual( assertEqual(
t, t,

View File

@ -14,7 +14,6 @@ import (
"time" "time"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/wgcfg"
) )
const ( const (
@ -77,8 +76,7 @@ type Peer struct {
cookieGenerator CookieGenerator cookieGenerator CookieGenerator
} }
func (device *Device) NewPeer(pk wgcfg.Key) (*Peer, error) { func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
if device.isClosed.Get() { if device.isClosed.Get() {
return nil, errors.New("device closed") return nil, errors.New("device closed")
} }
@ -118,7 +116,7 @@ func (device *Device) NewPeer(pk wgcfg.Key) (*Peer, error) {
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(pk) handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.mutex.Unlock() handshake.mutex.Unlock()

View File

@ -18,7 +18,6 @@ import (
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/wgcfg"
) )
type IPCError struct { type IPCError struct {
@ -55,7 +54,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
// serialize device related values // serialize device related values
if !device.staticIdentity.privateKey.IsZero() { if !device.staticIdentity.privateKey.IsZero() {
send("private_key=" + device.staticIdentity.privateKey.HexString()) send("private_key=" + device.staticIdentity.privateKey.ToHex())
} }
if device.net.port != 0 { if device.net.port != 0 {
@ -72,8 +71,8 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
peer.RLock() peer.RLock()
defer peer.RUnlock() defer peer.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.HexString()) send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.HexString()) send("preshared_key=" + peer.handshake.presharedKey.ToHex())
send("protocol_version=1") send("protocol_version=1")
if peer.endpoint != nil { if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.DstToString()) send("endpoint=" + peer.endpoint.DstToString())
@ -140,7 +139,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
switch key { switch key {
case "private_key": case "private_key":
sk, err := wgcfg.ParsePrivateHexKey(value) var sk NoisePrivateKey
err := sk.FromMaybeZeroHex(value)
if err != nil { if err != nil {
logError.Println("Failed to set private_key:", err) logError.Println("Failed to set private_key:", err)
return &IPCError{ipc.IpcErrorInvalid} return &IPCError{ipc.IpcErrorInvalid}
@ -221,7 +221,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
switch key { switch key {
case "public_key": case "public_key":
publicKey, err := wgcfg.ParseHexKey(value) var publicKey NoisePublicKey
err := publicKey.FromHex(value)
if err != nil { if err != nil {
logError.Println("Failed to get peer by public key:", err) logError.Println("Failed to get peer by public key:", err)
return &IPCError{ipc.IpcErrorInvalid} return &IPCError{ipc.IpcErrorInvalid}
@ -230,7 +231,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
// ignore peer with public key of device // ignore peer with public key of device
device.staticIdentity.RLock() device.staticIdentity.RLock()
dummy = device.staticIdentity.publicKey.Equal(publicKey) dummy = device.staticIdentity.publicKey.Equals(publicKey)
device.staticIdentity.RUnlock() device.staticIdentity.RUnlock()
if dummy { if dummy {
@ -290,8 +291,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
logDebug.Println(peer, "- UAPI: Updating preshared key") logDebug.Println(peer, "- UAPI: Updating preshared key")
peer.handshake.mutex.Lock() peer.handshake.mutex.Lock()
key, err := wgcfg.ParseSymmetricHexKey(value) err := peer.handshake.presharedKey.FromHex(value)
peer.handshake.presharedKey = key
peer.handshake.mutex.Unlock() peer.handshake.mutex.Unlock()
if err != nil { if err != nil {