happy linting

This commit is contained in:
Marvin Steadfast 2021-01-21 10:25:04 +01:00
parent d9764025be
commit a5695600e2
4 changed files with 186 additions and 66 deletions

View File

@ -52,4 +52,4 @@ test:
.PHONY: lint
lint:
golangci-lint run --enable-all --disable gomnd --disable godox --timeout 5m
golangci-lint run --enable-all --disable gomnd --disable godox --disable exhaustivestruct --timeout 5m

View File

@ -1,3 +1,4 @@
// nolint: gochecknoglobals, gomnd, goerr113
package wgquick
import (
@ -14,23 +15,32 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// Config represents full wg-quick like config structure
// Config represents full wg-quick like config structure.
type Config struct {
wgtypes.Config
// Address list of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned to the interface. May be specified multiple times.
// Address list of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned
// to the interface. May be specified multiple times.
Address []net.IPNet
// list of IP (v4 or v6) addresses to be set as the interfaces DNS servers. May be specified multiple times. Upon bringing the interface up, this runs resolvconf -a tun.INTERFACE -m 0 -x and upon bringing it down, this runs resolvconf -d tun.INTERFACE. If these particular invocations of resolvconf(8) are undesirable, the PostUp and PostDown keys below may be used instead.
// list of IP (v4 or v6) addresses to be set as the interfaces DNS servers. May be specified multiple times.
// Upon bringing the interface up, this runs resolvconf -a tun.INTERFACE -m 0 -x and upon bringing it down,
// this runs resolvconf -d tun.INTERFACE. If these particular invocations of resolvconf(8) are undesirable,
// the PostUp and PostDown keys below may be used instead.
DNS []net.IP
// MTU is automatically determined from the endpoint addresses or the system default route, which is usually a sane choice. However, to manually specify an MTU to override this automatic discovery, this value may be specified explicitly.
// MTU is automatically determined from the endpoint addresses or the system default route,
// which is usually a sane choice. However, to manually specify an MTU to override this automatic discovery,
// this value may be specified explicitly.
MTU int
// Table — Controls the routing table to which routes are added.
Table int
// PreUp, PostUp, PreDown, PostDown — script snippets which will be executed by bash(1) before/after setting up/tearing down the interface, most commonly used to configure custom DNS options or firewall rules. The special string %i is expanded to INTERFACE. Each one may be specified multiple times, in which case the commands are executed in order.
// PreUp, PostUp, PreDown, PostDown — script snippets which will be executed by bash(1)
// before/after setting up/tearing down the interface, most commonly used to configure
// custom DNS options or firewall rules. The special string %i is expanded to INTERFACE.
// Each one may be specified multiple times, in which case the commands are executed in order.
PreUp string
PostUp string
PreDown string
@ -50,14 +60,17 @@ type Config struct {
SaveConfig bool
}
var _ encoding.TextMarshaler = (*Config)(nil)
var _ encoding.TextUnmarshaler = (*Config)(nil)
var (
_ encoding.TextMarshaler = (*Config)(nil)
_ encoding.TextUnmarshaler = (*Config)(nil)
)
func (cfg *Config) String() string {
b, err := cfg.MarshalText()
if err != nil {
panic(err)
}
return string(b)
}
@ -83,11 +96,13 @@ var cfgTemplate = template.Must(
func (cfg *Config) MarshalText() (text []byte, err error) {
buff := &bytes.Buffer{}
if err := cfgTemplate.Execute(buff, cfg); err != nil {
return nil, err
return nil, fmt.Errorf("%w", err)
}
return buff.Bytes(), nil
}
// nolint: lll
const wgtypeTemplateSpec = `[Interface]
{{- range .Address }}
Address = {{ . }}
@ -115,14 +130,17 @@ AllowedIPs = {{ range $i, $el := .AllowedIPs }}{{if $i}}, {{ end }}{{ $el }}{{ e
{{- end }}
`
// ParseKey parses the base64 encoded wireguard private key
// ParseKey parses the base64 encoded wireguard private key.
func ParseKey(key string) (wgtypes.Key, error) {
var pkey wgtypes.Key
pkeySlice, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return pkey, err
return pkey, fmt.Errorf("%w", err)
}
copy(pkey[:], pkeySlice[:])
copy(pkey[:], pkeySlice)
return pkey, nil
}
@ -137,17 +155,21 @@ const (
func (cfg *Config) UnmarshalText(text []byte) error {
*cfg = Config{} // Zero out the config
state := unknown
var peerCfg *wgtypes.PeerConfig
for no, line := range strings.Split(string(text), "\n") {
ln := strings.TrimSpace(line)
if len(ln) == 0 || ln[0] == '#' {
continue
}
switch ln {
case "[Interface]":
state = inter
case "[Peer]":
state = peer
cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{})
peerCfg = &cfg.Peers[len(cfg.Peers)-1]
default:
@ -155,33 +177,40 @@ func (cfg *Config) UnmarshalText(text []byte) error {
if len(parts) < 2 {
return fmt.Errorf("cannot parse line %d, missing =", no)
}
lhs := strings.TrimSpace(parts[0])
rhs := strings.TrimSpace(strings.Join(parts[1:], "="))
switch state {
case inter:
if err := parseInterfaceLine(cfg, lhs, rhs); err != nil {
return fmt.Errorf("[line %d]: %v", no+1, err)
return fmt.Errorf("[line %d]: %w", no+1, err)
}
case peer:
if err := parsePeerLine(peerCfg, lhs, rhs); err != nil {
return fmt.Errorf("[line %d]: %v", no+1, err)
return fmt.Errorf("[line %d]: %w", no+1, err)
}
default:
case unknown:
return fmt.Errorf("[line %d] cannot parse, unknown state", no+1)
default:
return fmt.Errorf("switch couldnt find state")
}
}
}
return nil
}
// nolint: funlen
func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
switch lhs {
case "Address":
for _, addr := range strings.Split(rhs, ",") {
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
if err != nil {
return err
return fmt.Errorf("%w", err)
}
cfg.Address = append(cfg.Address, net.IPNet{IP: ip, Mask: cidr.Mask})
}
case "DNS":
@ -190,25 +219,29 @@ func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
if ip == nil {
return fmt.Errorf("cannot parse IP")
}
cfg.DNS = append(cfg.DNS, ip)
}
case "MTU":
mtu, err := strconv.ParseInt(rhs, 10, 64)
if err != nil {
return err
return fmt.Errorf("%w", err)
}
cfg.MTU = int(mtu)
case "Table":
tbl, err := strconv.ParseInt(rhs, 10, 64)
if err != nil {
return err
return fmt.Errorf("%w", err)
}
cfg.Table = int(tbl)
case "ListenPort":
portI64, err := strconv.ParseInt(rhs, 10, 64)
if err != nil {
return err
return fmt.Errorf("%w", err)
}
port := int(portI64)
cfg.ListenPort = &port
case "PreUp":
@ -222,18 +255,21 @@ func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
case "SaveConfig":
save, err := strconv.ParseBool(rhs)
if err != nil {
return err
return fmt.Errorf("%w", err)
}
cfg.SaveConfig = save
case "PrivateKey":
key, err := ParseKey(rhs)
if err != nil {
return fmt.Errorf("cannot decode key %v", err)
return fmt.Errorf("cannot decode key %w", err)
}
cfg.PrivateKey = &key
default:
return fmt.Errorf("unknown directive %s", lhs)
}
return nil
}
@ -242,41 +278,48 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
case "PublicKey":
key, err := ParseKey(rhs)
if err != nil {
return fmt.Errorf("cannot decode key %v", err)
return fmt.Errorf("cannot decode key %w", err)
}
peerCfg.PublicKey = key
case "PresharedKey":
key, err := ParseKey(rhs)
if err != nil {
return fmt.Errorf("cannot decode key %v", err)
return fmt.Errorf("cannot decode key %w", err)
}
if peerCfg.PresharedKey != nil {
return fmt.Errorf("preshared key already defined %v", err)
return fmt.Errorf("preshared key already defined %w", err)
}
peerCfg.PresharedKey = &key
case "AllowedIPs":
for _, addr := range strings.Split(rhs, ",") {
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
if err != nil {
return fmt.Errorf("cannot parse %s: %v", addr, err)
return fmt.Errorf("cannot parse %s: %w", addr, err)
}
peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask})
}
case "Endpoint":
addr, err := net.ResolveUDPAddr("", rhs)
if err != nil {
return err
return fmt.Errorf("%w", err)
}
peerCfg.Endpoint = addr
case "PersistentKeepalive":
t, err := strconv.ParseInt(rhs, 10, 64)
if err != nil {
return err
return fmt.Errorf("%w", err)
}
dur := time.Duration(t * int64(time.Second))
peerCfg.PersistentKeepaliveInterval = &dur
default:
return fmt.Errorf("unknown directive %s", lhs)
}
return nil
}

View File

@ -1,3 +1,4 @@
// nolint: scopelint, gochecknoglobals, paralleltest, testpackage
package wgquick
import (
@ -54,6 +55,7 @@ PersistentKeepalive = 25
func TestExampleConfig(t *testing.T) {
c := &Config{}
for name, cfg := range testConfigs {
t.Run(name, func(t *testing.T) {
err := c.UnmarshalText([]byte(cfg))

153
wg.go
View File

@ -2,6 +2,7 @@ package wgquick
import (
"bytes"
"errors"
"fmt"
"net"
"os"
@ -31,20 +32,24 @@ func wgGo(iface string) error {
}
cmd := wgo.Command(iface)
cmd.Start()
if err := cmd.Start(); err != nil {
return fmt.Errorf("could not start wireguard-go: %w", err)
}
return nil
}
// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`
// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`.
func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface)
_, err := netlink.LinkByName(iface)
if err == nil {
return os.ErrExist
}
if _, ok := err.(netlink.LinkNotFoundError); !ok {
return err
if errors.As(err, &netlink.LinkNotFoundError{}) {
return fmt.Errorf("%w", err)
}
for _, dns := range cfg.DNS {
@ -57,8 +62,10 @@ func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
if err := execSh(cfg.PreUp, iface, log); err != nil {
return err
}
log.Infoln("applied pre-up command")
}
if err := Sync(cfg, iface, logger); err != nil {
return err
}
@ -67,17 +74,20 @@ func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
if err := execSh(cfg.PostUp, iface, log); err != nil {
return err
}
log.Infoln("applied post-up command")
}
return nil
}
// Down destroys the wg interface. Mostly equivalent to `wg-quick down iface`
// Down destroys the wg interface. Mostly equivalent to `wg-quick down iface`.
func Down(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface)
link, err := netlink.LinkByName(iface)
if err != nil {
return err
return fmt.Errorf("%w", err)
}
if len(cfg.DNS) > 1 {
@ -90,108 +100,138 @@ func Down(cfg *Config, iface string, logger logrus.FieldLogger) error {
if err := execSh(cfg.PreDown, iface, log); err != nil {
return err
}
log.Infoln("applied pre-down command")
}
if err := netlink.LinkDel(link); err != nil {
return err
return fmt.Errorf("%w", err)
}
log.Infoln("link deleted")
if cfg.PostDown != "" {
if err := execSh(cfg.PostDown, iface, log); err != nil {
return err
}
log.Infoln("applied post-down command")
}
return nil
}
func execSh(command string, iface string, log logrus.FieldLogger, stdin ...string) error {
cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface))
cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface)) // nolint: gosec
if len(stdin) > 0 {
log = log.WithField("stdin", strings.Join(stdin, ""))
b := &bytes.Buffer{}
for _, ln := range stdin {
if _, err := fmt.Fprint(b, ln); err != nil {
return err
return fmt.Errorf("%w", err)
}
}
cmd.Stdin = b
}
out, err := cmd.CombinedOutput()
if err != nil {
log.WithError(err).Errorf("failed to execute %s:\n%s", cmd.Args, out)
return err
return fmt.Errorf("%w", err)
}
log.Infof("executed %s:\n%s", cmd.Args, out)
return nil
}
// Sync the config to the current setup for given interface
// Sync the config to the current setup for given interface.
// It perform 4 operations:
// * SyncLink --> makes sure link is up and type wireguard
// * SyncWireguardDevice --> configures allowedIP & other wireguard specific settings
// * SyncAddress --> synces linux addresses bounded to this interface
// * SyncRoutes --> synces all allowedIP routes to route to this interface
// * SyncLink --> makes sure link is up and type wireguard.
// * SyncWireguardDevice --> configures allowedIP & other wireguard specific settings.
// * SyncAddress --> synces linux addresses bounded to this interface.
// * SyncRoutes --> synces all allowedIP routes to route to this interface.
func Sync(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface)
link, err := SyncLink(cfg, iface, log)
if err != nil {
log.WithError(err).Errorln("cannot sync wireguard link")
return err
return fmt.Errorf("%w", err)
}
log.Info("synced link")
if err := SyncWireguardDevice(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync wireguard link")
return err
}
log.Info("synced link")
if err := SyncAddress(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync addresses")
return err
return fmt.Errorf("%w", err)
}
log.Info("synced addresss")
var managedRoutes []net.IPNet
managedRoutes := make([]net.IPNet, 0)
for _, peer := range cfg.Peers {
for _, rt := range peer.AllowedIPs {
managedRoutes = append(managedRoutes, rt)
}
managedRoutes = append(managedRoutes, peer.AllowedIPs...)
}
if err := SyncRoutes(cfg, link, managedRoutes, log); err != nil {
log.WithError(err).Errorln("cannot sync routes")
return err
return fmt.Errorf("%w", err)
}
log.Info("synced routed")
log.Info("Successfully synced device")
return nil
}
// SyncWireguardDevice synces wireguard vpn setting on the given link. It does not set routes/addresses beyond wg internal crypto-key routing, only handles wireguard specific settings
// SyncWireguardDevice synces wireguard vpn setting on the given link.
// It does not set routes/addresses beyond wg internal crypto-key routing, only handles wireguard specific settings.
func SyncWireguardDevice(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
cl, err := wgctrl.New()
if err != nil {
log.WithError(err).Errorln("cannot setup wireguard device")
return err
return fmt.Errorf("%w", err)
}
if err := cl.ConfigureDevice(link.Attrs().Name, cfg.Config); err != nil {
log.WithError(err).Error("cannot configure device")
return err
return fmt.Errorf("%w", err)
}
return nil
}
// SyncLink synces link state with the config. It does not sync Wireguard settings, just makes sure the device is up and type wireguard
// SyncLink synces link state with the config.
// It does not sync Wireguard settings, just makes sure the device is up and type wireguard.
func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link, error) {
link, err := netlink.LinkByName(iface)
// nolint: nestif
if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); !ok {
if errors.As(err, &netlink.LinkNotFoundError{}) {
log.WithError(err).Error("cannot read link")
return nil, err
return nil, fmt.Errorf("%w", err)
}
log.Info("link not found, creating")
wgLink := &netlink.GenericLink{
@ -204,9 +244,11 @@ func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link,
if err := netlink.LinkAdd(wgLink); err != nil {
log.WithError(err).Errorf("cannot create link: %s", err.Error())
log.Info("trying to use embedded wireguard-go...")
if err := wgGo(iface); err != nil {
log.WithError(err).Errorf("cannot create link through wireguard-go: %s", err.Error())
return nil, fmt.Errorf("cannot create link")
return nil, fmt.Errorf("cannot create link: %w", err)
}
}
@ -216,32 +258,41 @@ func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link,
link, err = netlink.LinkByName(iface)
if err != nil {
log.WithError(err).Error("cannot read link")
return nil, err
return nil, fmt.Errorf("%w", err)
}
}
if err := netlink.LinkSetUp(link); err != nil {
log.WithError(err).Error("cannot set link up")
return nil, err
return nil, fmt.Errorf("%w", err)
}
log.Info("set device up")
return link, nil
}
// SyncAddress adds/deletes all lind assigned IPV4 addressed as specified in the config
// nolint: funlen, gosec, scopelint
func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
addrs, err := netlink.AddrList(link, syscall.AF_INET)
if err != nil {
log.Error(err, "cannot read link address")
return err
return fmt.Errorf("%w", err)
}
// nil addr means I've used it
presentAddresses := make(map[string]netlink.Addr, 0)
presentAddresses := make(map[string]netlink.Addr)
for _, addr := range addrs {
log.WithFields(map[string]interface{}{
"addr": fmt.Sprint(addr.IPNet),
"label": addr.Label,
}).Debugf("found existing address: %v", addr)
presentAddresses[addr.IPNet.String()] = addr
}
@ -249,19 +300,24 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
log := log.WithField("addr", addr.String())
_, present := presentAddresses[addr.String()]
presentAddresses[addr.String()] = netlink.Addr{} // mark as present
if present {
log.Info("address present")
continue
}
if err := netlink.AddrAdd(link, &netlink.Addr{
IPNet: &addr,
Label: cfg.AddressLabel,
}); err != nil {
if err != syscall.EEXIST {
if errors.Is(err, syscall.EEXIST) {
log.WithError(err).Error("cannot add addr")
return err
return fmt.Errorf("%w", err)
}
}
log.Info("address added")
}
@ -269,16 +325,21 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
if addr.IPNet == nil {
continue
}
log := log.WithFields(map[string]interface{}{
"addr": addr.IPNet.String(),
"label": addr.Label,
})
if err := netlink.AddrDel(link, &addr); err != nil {
log.WithError(err).Error("cannot delete addr")
return err
return fmt.Errorf("%w", err)
}
log.Info("addr deleted")
}
return nil
}
@ -298,13 +359,17 @@ func fillRouteDefaults(rt *netlink.Route) {
}
// SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config
// nolint: funlen, gosec, scopelint
func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error {
wantedRoutes := make(map[string][]netlink.Route, len(managedRoutes))
presentRoutes, err := netlink.RouteList(link, syscall.AF_INET)
if err != nil {
log.Error(err, "cannot read existing routes")
return err
return fmt.Errorf("%w", err)
}
for _, rt := range managedRoutes {
rt := rt // make copy
log.WithField("dst", rt.String()).Debug("managing route")
@ -330,10 +395,13 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l
"type": rt.Type,
"metric": rt.Priority,
})
if err := netlink.RouteReplace(&rt); err != nil {
log.WithError(err).Errorln("cannot add/replace route")
return err
return fmt.Errorf("%w", err)
}
log.Infoln("route added/replaced")
}
}
@ -344,6 +412,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l
return true
}
}
return false
}
@ -355,25 +424,31 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l
"type": rt.Type,
"metric": rt.Priority,
})
if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == unix.RT_CLASS_MAIN)) {
log.Debug("wrong table for route, skipping")
continue
}
if !(rt.Protocol == cfg.RouteProtocol) {
log.Infof("skipping route deletion, not owned by this daemon")
continue
}
if checkWanted(rt) {
log.Debug("route wanted, skipping deleting")
continue
}
if err := netlink.RouteDel(&rt); err != nil {
log.WithError(err).Error("cannot delete route")
return err
return fmt.Errorf("%w", err)
}
log.Info("route deleted")
}