wg-quicker/wgquick/config.go
Marvin Steadfast 3b0b3b5a82
Some checks failed
continuous-integration/drone/push Build is failing
continuous-integration/drone/tag Build is failing
happy linting
2021-03-08 09:32:33 +01:00

326 lines
8.2 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// nolint: gochecknoglobals, gomnd, goerr113, cyclop
package wgquick
import (
"bytes"
"encoding"
"encoding/base64"
"fmt"
"net"
"strconv"
"strings"
"text/template"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// Config represents full wg-quick like config structure.
type Config struct {
wgtypes.Config
// Address list of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned
// to the interface. May be specified multiple times.
Address []net.IPNet
// list of IP (v4 or v6) addresses to be set as the interfaces DNS servers. May be specified multiple times.
// Upon bringing the interface up, this runs resolvconf -a tun.INTERFACE -m 0 -x and upon bringing it down,
// this runs resolvconf -d tun.INTERFACE. If these particular invocations of resolvconf(8) are undesirable,
// the PostUp and PostDown keys below may be used instead.
DNS []net.IP
// MTU is automatically determined from the endpoint addresses or the system default route,
// which is usually a sane choice. However, to manually specify an MTU to override this automatic discovery,
// this value may be specified explicitly.
MTU int
// Table — Controls the routing table to which routes are added.
Table int
// PreUp, PostUp, PreDown, PostDown — script snippets which will be executed by bash(1)
// before/after setting up/tearing down the interface, most commonly used to configure
// custom DNS options or firewall rules. The special string %i is expanded to INTERFACE.
// Each one may be specified multiple times, in which case the commands are executed in order.
PreUp string
PostUp string
PreDown string
PostDown string
// RouteProtocol to set on the route. See linux/rtnetlink.h Use value > 4 or default 0
RouteProtocol int
// RouteMetric sets this metric on all managed routes. Lower number means pick this one
RouteMetric int
// Address label to set on the link
AddressLabel string
// SaveConfig — if set to true, the configuration is saved from the current state of the interface upon shutdown.
// Currently unsupported
SaveConfig bool
}
var (
_ encoding.TextMarshaler = (*Config)(nil)
_ encoding.TextUnmarshaler = (*Config)(nil)
)
func (cfg *Config) String() string {
b, err := cfg.MarshalText()
if err != nil {
panic(err)
}
return string(b)
}
func serializeKey(key *wgtypes.Key) string {
return base64.StdEncoding.EncodeToString(key[:])
}
func toSeconds(duration time.Duration) int {
return int(duration / time.Second)
}
var funcMap = template.FuncMap(map[string]interface{}{
"wgKey": serializeKey,
"toSeconds": toSeconds,
})
var cfgTemplate = template.Must(
template.
New("wg-cfg").
Funcs(funcMap).
Parse(wgtypeTemplateSpec))
func (cfg *Config) MarshalText() (text []byte, err error) {
buff := &bytes.Buffer{}
if err := cfgTemplate.Execute(buff, cfg); err != nil {
return nil, fmt.Errorf("%w", err)
}
return buff.Bytes(), nil
}
// nolint: lll
const wgtypeTemplateSpec = `[Interface]
{{- range .Address }}
Address = {{ . }}
{{- end }}
{{- range .DNS }}
DNS = {{ . }}
{{- end }}
PrivateKey = {{ .PrivateKey | wgKey }}
{{- if .ListenPort }}{{ "\n" }}ListenPort = {{ .ListenPort }}{{ end }}
{{- if .MTU }}{{ "\n" }}MTU = {{ .MTU }}{{ end }}
{{- if .Table }}{{ "\n" }}Table = {{ .Table }}{{ end }}
{{- if .PreUp }}{{ "\n" }}PreUp = {{ .PreUp }}{{ end }}
{{- if .PostUp }}{{ "\n" }}PostUp = {{ .PostUp }}{{ end }}
{{- if .PreDown }}{{ "\n" }}PreDown = {{ .PreDown }}{{ end }}
{{- if .PostDown }}{{ "\n" }}PostDown = {{ .PostDown }}{{ end }}
{{- if .SaveConfig }}{{ "\n" }}SaveConfig = {{ .SaveConfig }}{{ end }}
{{- range .Peers }}
{{- "\n" }}
[Peer]
PublicKey = {{ .PublicKey | wgKey }}
AllowedIPs = {{ range $i, $el := .AllowedIPs }}{{if $i}}, {{ end }}{{ $el }}{{ end }}
{{- if .PresharedKey }}{{ "\n" }}PresharedKey = {{ .PresharedKey }}{{ end }}
{{- if .PersistentKeepaliveInterval }}{{ "\n" }}PersistentKeepalive = {{ .PersistentKeepaliveInterval | toSeconds }}{{ end }}
{{- if .Endpoint }}{{ "\n" }}Endpoint = {{ .Endpoint }}{{ end }}
{{- end }}
`
// ParseKey parses the base64 encoded wireguard private key.
func ParseKey(key string) (wgtypes.Key, error) {
var pkey wgtypes.Key
pkeySlice, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return pkey, fmt.Errorf("%w", err)
}
copy(pkey[:], pkeySlice)
return pkey, nil
}
type parseState int
const (
unknown parseState = iota
inter = iota
peer = iota
)
func (cfg *Config) UnmarshalText(text []byte) error {
*cfg = Config{} // Zero out the config
state := unknown
var peerCfg *wgtypes.PeerConfig
for no, line := range strings.Split(string(text), "\n") {
ln := strings.TrimSpace(line)
if len(ln) == 0 || ln[0] == '#' {
continue
}
switch ln {
case "[Interface]":
state = inter
case "[Peer]":
state = peer
cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{})
peerCfg = &cfg.Peers[len(cfg.Peers)-1]
default:
parts := strings.Split(ln, "=")
if len(parts) < 2 {
return fmt.Errorf("cannot parse line %d, missing =", no)
}
lhs := strings.TrimSpace(parts[0])
rhs := strings.TrimSpace(strings.Join(parts[1:], "="))
switch state {
case inter:
if err := parseInterfaceLine(cfg, lhs, rhs); err != nil {
return fmt.Errorf("[line %d]: %w", no+1, err)
}
case peer:
if err := parsePeerLine(peerCfg, lhs, rhs); err != nil {
return fmt.Errorf("[line %d]: %w", no+1, err)
}
case unknown:
return fmt.Errorf("[line %d] cannot parse, unknown state", no+1)
default:
return fmt.Errorf("switch couldnt find state")
}
}
}
return nil
}
// nolint: funlen
func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
switch lhs {
case "Address":
for _, addr := range strings.Split(rhs, ",") {
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
if err != nil {
return fmt.Errorf("%w", err)
}
cfg.Address = append(cfg.Address, net.IPNet{IP: ip, Mask: cidr.Mask})
}
case "DNS":
for _, addr := range strings.Split(rhs, ",") {
ip := net.ParseIP(strings.TrimSpace(addr))
if ip == nil {
return fmt.Errorf("cannot parse IP")
}
cfg.DNS = append(cfg.DNS, ip)
}
case "MTU":
mtu, err := strconv.ParseInt(rhs, 10, 64)
if err != nil {
return fmt.Errorf("%w", err)
}
cfg.MTU = int(mtu)
case "Table":
tbl, err := strconv.ParseInt(rhs, 10, 64)
if err != nil {
return fmt.Errorf("%w", err)
}
cfg.Table = int(tbl)
case "ListenPort":
portI64, err := strconv.ParseInt(rhs, 10, 64)
if err != nil {
return fmt.Errorf("%w", err)
}
port := int(portI64)
cfg.ListenPort = &port
case "PreUp":
cfg.PreUp = rhs
case "PostUp":
cfg.PostUp = rhs
case "PreDown":
cfg.PreDown = rhs
case "PostDown":
cfg.PostDown = rhs
case "SaveConfig":
save, err := strconv.ParseBool(rhs)
if err != nil {
return fmt.Errorf("%w", err)
}
cfg.SaveConfig = save
case "PrivateKey":
key, err := ParseKey(rhs)
if err != nil {
return fmt.Errorf("cannot decode key %w", err)
}
cfg.PrivateKey = &key
default:
return fmt.Errorf("unknown directive %s", lhs)
}
return nil
}
func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
switch lhs {
case "PublicKey":
key, err := ParseKey(rhs)
if err != nil {
return fmt.Errorf("cannot decode key %w", err)
}
peerCfg.PublicKey = key
case "PresharedKey":
key, err := ParseKey(rhs)
if err != nil {
return fmt.Errorf("cannot decode key %w", err)
}
if peerCfg.PresharedKey != nil {
return fmt.Errorf("preshared key already defined %w", err)
}
peerCfg.PresharedKey = &key
case "AllowedIPs":
for _, addr := range strings.Split(rhs, ",") {
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
if err != nil {
return fmt.Errorf("cannot parse %s: %w", addr, err)
}
peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask})
}
case "Endpoint":
addr, err := net.ResolveUDPAddr("", rhs)
if err != nil {
return fmt.Errorf("%w", err)
}
peerCfg.Endpoint = addr
case "PersistentKeepalive":
t, err := strconv.ParseInt(rhs, 10, 64)
if err != nil {
return fmt.Errorf("%w", err)
}
dur := time.Duration(t * int64(time.Second))
peerCfg.PersistentKeepaliveInterval = &dur
default:
return fmt.Errorf("unknown directive %s", lhs)
}
return nil
}