mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 09:15:14 +01:00
wgcfg: rename Key to PublicKey
A few minor review cleanups while here (e.g. remove unused LessThan). Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
parent
75a41b24ad
commit
1ecbb3313c
@ -23,7 +23,7 @@ type Config struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
PublicKey Key
|
PublicKey PublicKey
|
||||||
PresharedKey SymmetricKey
|
PresharedKey SymmetricKey
|
||||||
AllowedIPs []CIDR
|
AllowedIPs []CIDR
|
||||||
Endpoints []Endpoint
|
Endpoints []Endpoint
|
||||||
|
106
wgcfg/key.go
106
wgcfg/key.go
@ -2,7 +2,7 @@ package wgcfg
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
cryptorand "crypto/rand"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
@ -16,32 +16,22 @@ import (
|
|||||||
|
|
||||||
const KeySize = 32
|
const KeySize = 32
|
||||||
|
|
||||||
// Key is curve25519 key.
|
// PublicKey is curve25519 key.
|
||||||
// It is used by WireGuard to represent public and preshared keys.
|
// It is used by WireGuard to represent public and preshared keys.
|
||||||
type Key [KeySize]byte
|
type PublicKey [KeySize]byte
|
||||||
|
|
||||||
// NewPresharedKey generates a new random key.
|
func ParseKey(b64 string) (*PublicKey, error) { return parseKeyBase64(base64.StdEncoding, b64) }
|
||||||
func NewPresharedKey() (*Key, error) {
|
|
||||||
var k [KeySize]byte
|
|
||||||
_, err := rand.Read(k[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return (*Key)(&k), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseKey(b64 string) (*Key, error) { return parseKeyBase64(base64.StdEncoding, b64) }
|
func ParseHexKey(s string) (PublicKey, error) {
|
||||||
|
|
||||||
func ParseHexKey(s string) (Key, error) {
|
|
||||||
b, err := hex.DecodeString(s)
|
b, err := hex.DecodeString(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Key{}, &ParseError{"invalid hex key: " + err.Error(), s}
|
return PublicKey{}, &ParseError{"invalid hex key: " + err.Error(), s}
|
||||||
}
|
}
|
||||||
if len(b) != KeySize {
|
if len(b) != KeySize {
|
||||||
return Key{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s}
|
return PublicKey{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s}
|
||||||
}
|
}
|
||||||
|
|
||||||
var key Key
|
var key PublicKey
|
||||||
copy(key[:], b)
|
copy(key[:], b)
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
@ -62,31 +52,22 @@ func ParsePrivateHexKey(v string) (PrivateKey, error) {
|
|||||||
return pk, nil
|
return pk, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k Key) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) }
|
func (k PublicKey) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) }
|
||||||
func (k Key) String() string { return "pub:" + k.Base64()[:8] }
|
func (k PublicKey) String() string { return k.ShortString() }
|
||||||
func (k Key) HexString() string { return hex.EncodeToString(k[:]) }
|
func (k PublicKey) HexString() string { return hex.EncodeToString(k[:]) }
|
||||||
func (k Key) Equal(k2 Key) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 }
|
func (k PublicKey) Equal(k2 PublicKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 }
|
||||||
|
|
||||||
func (k *Key) ShortString() string {
|
func (k *PublicKey) ShortString() string {
|
||||||
if k.IsZero() {
|
long := k.Base64()
|
||||||
return "[empty]"
|
return "[" + long[0:5] + "]"
|
||||||
}
|
|
||||||
long := k.String()
|
|
||||||
if len(long) < 10 {
|
|
||||||
return "invalid"
|
|
||||||
}
|
|
||||||
return "[" + long[0:4] + "…" + long[len(long)-5:len(long)-1] + "]"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *Key) IsZero() bool {
|
func (k PublicKey) IsZero() bool {
|
||||||
if k == nil {
|
var zeros PublicKey
|
||||||
return true
|
|
||||||
}
|
|
||||||
var zeros Key
|
|
||||||
return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1
|
return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *Key) MarshalJSON() ([]byte, error) {
|
func (k *PublicKey) MarshalJSON() ([]byte, error) {
|
||||||
if k == nil {
|
if k == nil {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
}
|
}
|
||||||
@ -95,47 +76,35 @@ func (k *Key) MarshalJSON() ([]byte, error) {
|
|||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *Key) UnmarshalJSON(b []byte) error {
|
func (k *PublicKey) UnmarshalJSON(b []byte) error {
|
||||||
if k == nil {
|
if k == nil {
|
||||||
return errors.New("wgcfg.Key: UnmarshalJSON on nil pointer")
|
return errors.New("wgcfg.PublicKey: UnmarshalJSON on nil pointer")
|
||||||
}
|
}
|
||||||
if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' {
|
if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' {
|
||||||
return errors.New("wgcfg.Key: UnmarshalJSON not given a string")
|
return errors.New("wgcfg.PublicKey: UnmarshalJSON not given a string")
|
||||||
}
|
}
|
||||||
b = b[1 : len(b)-1]
|
b = b[1 : len(b)-1]
|
||||||
key, err := ParseHexKey(string(b))
|
key, err := ParseHexKey(string(b))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("wgcfg.Key: UnmarshalJSON: %v", err)
|
return fmt.Errorf("wgcfg.PublicKey: UnmarshalJSON: %v", err)
|
||||||
}
|
}
|
||||||
copy(k[:], key[:])
|
copy(k[:], key[:])
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Key) LessThan(b *Key) bool {
|
|
||||||
for i := range a {
|
|
||||||
if a[i] < b[i] {
|
|
||||||
return true
|
|
||||||
} else if a[i] > b[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// PrivateKey is curve25519 key.
|
// PrivateKey is curve25519 key.
|
||||||
// It is used by WireGuard to represent private keys.
|
// It is used by WireGuard to represent private keys.
|
||||||
type PrivateKey [KeySize]byte
|
type PrivateKey [KeySize]byte
|
||||||
|
|
||||||
// NewPrivateKey generates a new curve25519 secret key.
|
// NewPrivateKey generates a new curve25519 secret key.
|
||||||
// It conforms to the format described on https://cr.yp.to/ecdh.html.
|
// It conforms to the format described on https://cr.yp.to/ecdh.html.
|
||||||
func NewPrivateKey() (PrivateKey, error) {
|
func NewPrivateKey() (pk PrivateKey, err error) {
|
||||||
k, err := NewPresharedKey()
|
_, err = cryptorand.Read(pk[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return PrivateKey{}, err
|
return PrivateKey{}, err
|
||||||
}
|
}
|
||||||
k[0] &= 248
|
pk.clamp()
|
||||||
k[31] = (k[31] & 127) | 64
|
return pk, nil
|
||||||
return (PrivateKey)(*k), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParsePrivateKey(b64 string) (*PrivateKey, error) {
|
func ParsePrivateKey(b64 string) (*PrivateKey, error) {
|
||||||
@ -147,9 +116,9 @@ func (k *PrivateKey) String() string { return base64.StdEncoding.Encod
|
|||||||
func (k *PrivateKey) HexString() string { return hex.EncodeToString(k[:]) }
|
func (k *PrivateKey) HexString() string { return hex.EncodeToString(k[:]) }
|
||||||
func (k *PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 }
|
func (k *PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 }
|
||||||
|
|
||||||
func (k *PrivateKey) IsZero() bool {
|
func (k PrivateKey) IsZero() bool {
|
||||||
pk := Key(*k)
|
var zeros PrivateKey
|
||||||
return pk.IsZero()
|
return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *PrivateKey) clamp() {
|
func (k *PrivateKey) clamp() {
|
||||||
@ -158,14 +127,13 @@ func (k *PrivateKey) clamp() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Public computes the public key matching this curve25519 secret key.
|
// Public computes the public key matching this curve25519 secret key.
|
||||||
func (k *PrivateKey) Public() Key {
|
func (k PrivateKey) Public() PublicKey {
|
||||||
pk := Key(*k)
|
if k.IsZero() {
|
||||||
if pk.IsZero() {
|
panic("wgcfg: tried to generate public key for a zero key")
|
||||||
panic("Tried to generate emptyPrivateKey.Public()")
|
|
||||||
}
|
}
|
||||||
var p [KeySize]byte
|
var p [KeySize]byte
|
||||||
curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(k))
|
curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(&k))
|
||||||
return (Key)(p)
|
return (PublicKey)(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k PrivateKey) MarshalText() ([]byte, error) {
|
func (k PrivateKey) MarshalText() ([]byte, error) {
|
||||||
@ -188,14 +156,14 @@ func (k *PrivateKey) UnmarshalText(b []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k PrivateKey) SharedSecret(pub Key) (ss [KeySize]byte) {
|
func (k PrivateKey) SharedSecret(pub PublicKey) (ss [KeySize]byte) {
|
||||||
apk := (*[KeySize]byte)(&pub)
|
apk := (*[KeySize]byte)(&pub)
|
||||||
ask := (*[KeySize]byte)(&k)
|
ask := (*[KeySize]byte)(&k)
|
||||||
curve25519.ScalarMult(&ss, ask, apk)
|
curve25519.ScalarMult(&ss, ask, apk)
|
||||||
return ss
|
return ss
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) {
|
func parseKeyBase64(enc *base64.Encoding, s string) (*PublicKey, error) {
|
||||||
k, err := enc.DecodeString(s)
|
k, err := enc.DecodeString(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &ParseError{"Invalid key: " + err.Error(), s}
|
return nil, &ParseError{"Invalid key: " + err.Error(), s}
|
||||||
@ -203,7 +171,7 @@ func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) {
|
|||||||
if len(k) != KeySize {
|
if len(k) != KeySize {
|
||||||
return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
|
return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
|
||||||
}
|
}
|
||||||
var key Key
|
var key PublicKey
|
||||||
copy(key[:], k)
|
copy(key[:], k)
|
||||||
return &key, nil
|
return &key, nil
|
||||||
}
|
}
|
||||||
|
@ -6,10 +6,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestKeyBasics(t *testing.T) {
|
func TestKeyBasics(t *testing.T) {
|
||||||
k1, err := NewPresharedKey()
|
pk1, err := NewPrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
k1 := pk1.Public()
|
||||||
|
|
||||||
b, err := k1.MarshalJSON()
|
b, err := k1.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -18,7 +19,7 @@ func TestKeyBasics(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("JSON round-trip", func(t *testing.T) {
|
t.Run("JSON round-trip", func(t *testing.T) {
|
||||||
// should preserve the keys
|
// should preserve the keys
|
||||||
k2 := new(Key)
|
k2 := new(PublicKey)
|
||||||
if err := k2.UnmarshalJSON(b); err != nil {
|
if err := k2.UnmarshalJSON(b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -39,10 +40,11 @@ func TestKeyBasics(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("second key", func(t *testing.T) {
|
t.Run("second key", func(t *testing.T) {
|
||||||
// A second call to NewPresharedKey should make a new key.
|
// A second call to NewPresharedKey should make a new key.
|
||||||
k3, err := NewPresharedKey()
|
pk3, err := NewPrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
k3 := pk3.Public()
|
||||||
if bytes.Equal(k1[:], k3[:]) {
|
if bytes.Equal(k1[:], k3[:]) {
|
||||||
t.Fatalf("k1 %v == k3 %v", k1[:], k3[:])
|
t.Fatalf("k1 %v == k3 %v", k1[:], k3[:])
|
||||||
}
|
}
|
||||||
@ -52,6 +54,7 @@ func TestKeyBasics(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrivateKeyBasics(t *testing.T) {
|
func TestPrivateKeyBasics(t *testing.T) {
|
||||||
pri, err := NewPrivateKey()
|
pri, err := NewPrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -81,7 +84,7 @@ func TestPrivateKeyBasics(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("JSON incompatible with Key", func(t *testing.T) {
|
t.Run("JSON incompatible with Key", func(t *testing.T) {
|
||||||
k2 := new(Key)
|
k2 := new(PublicKey)
|
||||||
if err := k2.UnmarshalJSON(b); err == nil {
|
if err := k2.UnmarshalJSON(b); err == nil {
|
||||||
t.Fatalf("successfully decoded private key as key")
|
t.Fatalf("successfully decoded private key as key")
|
||||||
}
|
}
|
||||||
|
@ -100,7 +100,7 @@ func parsePersistentKeepalive(s string) (uint16, error) {
|
|||||||
return uint16(m), nil
|
return uint16(m), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseKeyHex(s string) (*Key, error) {
|
func parseKeyHex(s string) (*PublicKey, error) {
|
||||||
k, err := hex.DecodeString(s)
|
k, err := hex.DecodeString(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &ParseError{"Invalid key: " + err.Error(), s}
|
return nil, &ParseError{"Invalid key: " + err.Error(), s}
|
||||||
@ -108,7 +108,7 @@ func parseKeyHex(s string) (*Key, error) {
|
|||||||
if len(k) != KeySize {
|
if len(k) != KeySize {
|
||||||
return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
|
return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
|
||||||
}
|
}
|
||||||
var key Key
|
var key PublicKey
|
||||||
copy(key[:], k)
|
copy(key[:], k)
|
||||||
return &key, nil
|
return &key, nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user