diff --git a/Makefile b/Makefile index 8c48afb..1fa76c4 100644 --- a/Makefile +++ b/Makefile @@ -52,4 +52,4 @@ test: .PHONY: lint lint: - golangci-lint run --enable-all --disable gomnd --disable godox --timeout 5m + golangci-lint run --enable-all --disable gomnd --disable godox --disable exhaustivestruct --timeout 5m diff --git a/config.go b/config.go index b1c7815..eaa6ecd 100644 --- a/config.go +++ b/config.go @@ -1,3 +1,4 @@ +// nolint: gochecknoglobals, gomnd, goerr113 package wgquick import ( @@ -14,23 +15,32 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// Config represents full wg-quick like config structure +// Config represents full wg-quick like config structure. type Config struct { wgtypes.Config - // Address list of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned to the interface. May be specified multiple times. + // Address list of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned + // to the interface. May be specified multiple times. Address []net.IPNet - // list of IP (v4 or v6) addresses to be set as the interface’s DNS servers. May be specified multiple times. Upon bringing the interface up, this runs ‘resolvconf -a tun.INTERFACE -m 0 -x‘ and upon bringing it down, this runs ‘resolvconf -d tun.INTERFACE‘. If these particular invocations of resolvconf(8) are undesirable, the PostUp and PostDown keys below may be used instead. + // list of IP (v4 or v6) addresses to be set as the interface’s DNS servers. May be specified multiple times. + // Upon bringing the interface up, this runs ‘resolvconf -a tun.INTERFACE -m 0 -x‘ and upon bringing it down, + // this runs ‘resolvconf -d tun.INTERFACE‘. If these particular invocations of resolvconf(8) are undesirable, + // the PostUp and PostDown keys below may be used instead. DNS []net.IP - // MTU is automatically determined from the endpoint addresses or the system default route, which is usually a sane choice. However, to manually specify an MTU to override this automatic discovery, this value may be specified explicitly. + // MTU is automatically determined from the endpoint addresses or the system default route, + // which is usually a sane choice. However, to manually specify an MTU to override this automatic discovery, + // this value may be specified explicitly. MTU int // Table — Controls the routing table to which routes are added. Table int - // PreUp, PostUp, PreDown, PostDown — script snippets which will be executed by bash(1) before/after setting up/tearing down the interface, most commonly used to configure custom DNS options or firewall rules. The special string ‘%i’ is expanded to INTERFACE. Each one may be specified multiple times, in which case the commands are executed in order. + // PreUp, PostUp, PreDown, PostDown — script snippets which will be executed by bash(1) + // before/after setting up/tearing down the interface, most commonly used to configure + // custom DNS options or firewall rules. The special string ‘%i’ is expanded to INTERFACE. + // Each one may be specified multiple times, in which case the commands are executed in order. PreUp string PostUp string PreDown string @@ -50,14 +60,17 @@ type Config struct { SaveConfig bool } -var _ encoding.TextMarshaler = (*Config)(nil) -var _ encoding.TextUnmarshaler = (*Config)(nil) +var ( + _ encoding.TextMarshaler = (*Config)(nil) + _ encoding.TextUnmarshaler = (*Config)(nil) +) func (cfg *Config) String() string { b, err := cfg.MarshalText() if err != nil { panic(err) } + return string(b) } @@ -83,11 +96,13 @@ var cfgTemplate = template.Must( func (cfg *Config) MarshalText() (text []byte, err error) { buff := &bytes.Buffer{} if err := cfgTemplate.Execute(buff, cfg); err != nil { - return nil, err + return nil, fmt.Errorf("%w", err) } + return buff.Bytes(), nil } +// nolint: lll const wgtypeTemplateSpec = `[Interface] {{- range .Address }} Address = {{ . }} @@ -115,14 +130,17 @@ AllowedIPs = {{ range $i, $el := .AllowedIPs }}{{if $i}}, {{ end }}{{ $el }}{{ e {{- end }} ` -// ParseKey parses the base64 encoded wireguard private key +// ParseKey parses the base64 encoded wireguard private key. func ParseKey(key string) (wgtypes.Key, error) { var pkey wgtypes.Key + pkeySlice, err := base64.StdEncoding.DecodeString(key) if err != nil { - return pkey, err + return pkey, fmt.Errorf("%w", err) } - copy(pkey[:], pkeySlice[:]) + + copy(pkey[:], pkeySlice) + return pkey, nil } @@ -137,17 +155,21 @@ const ( func (cfg *Config) UnmarshalText(text []byte) error { *cfg = Config{} // Zero out the config state := unknown + var peerCfg *wgtypes.PeerConfig + for no, line := range strings.Split(string(text), "\n") { ln := strings.TrimSpace(line) if len(ln) == 0 || ln[0] == '#' { continue } + switch ln { case "[Interface]": state = inter case "[Peer]": state = peer + cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{}) peerCfg = &cfg.Peers[len(cfg.Peers)-1] default: @@ -155,33 +177,40 @@ func (cfg *Config) UnmarshalText(text []byte) error { if len(parts) < 2 { return fmt.Errorf("cannot parse line %d, missing =", no) } + lhs := strings.TrimSpace(parts[0]) rhs := strings.TrimSpace(strings.Join(parts[1:], "=")) switch state { case inter: if err := parseInterfaceLine(cfg, lhs, rhs); err != nil { - return fmt.Errorf("[line %d]: %v", no+1, err) + return fmt.Errorf("[line %d]: %w", no+1, err) } case peer: if err := parsePeerLine(peerCfg, lhs, rhs); err != nil { - return fmt.Errorf("[line %d]: %v", no+1, err) + return fmt.Errorf("[line %d]: %w", no+1, err) } - default: + case unknown: return fmt.Errorf("[line %d] cannot parse, unknown state", no+1) + default: + return fmt.Errorf("switch couldnt find state") } } } + return nil } + +// nolint: funlen func parseInterfaceLine(cfg *Config, lhs string, rhs string) error { switch lhs { case "Address": for _, addr := range strings.Split(rhs, ",") { ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr)) if err != nil { - return err + return fmt.Errorf("%w", err) } + cfg.Address = append(cfg.Address, net.IPNet{IP: ip, Mask: cidr.Mask}) } case "DNS": @@ -190,25 +219,29 @@ func parseInterfaceLine(cfg *Config, lhs string, rhs string) error { if ip == nil { return fmt.Errorf("cannot parse IP") } + cfg.DNS = append(cfg.DNS, ip) } case "MTU": mtu, err := strconv.ParseInt(rhs, 10, 64) if err != nil { - return err + return fmt.Errorf("%w", err) } + cfg.MTU = int(mtu) case "Table": tbl, err := strconv.ParseInt(rhs, 10, 64) if err != nil { - return err + return fmt.Errorf("%w", err) } + cfg.Table = int(tbl) case "ListenPort": portI64, err := strconv.ParseInt(rhs, 10, 64) if err != nil { - return err + return fmt.Errorf("%w", err) } + port := int(portI64) cfg.ListenPort = &port case "PreUp": @@ -222,18 +255,21 @@ func parseInterfaceLine(cfg *Config, lhs string, rhs string) error { case "SaveConfig": save, err := strconv.ParseBool(rhs) if err != nil { - return err + return fmt.Errorf("%w", err) } + cfg.SaveConfig = save case "PrivateKey": key, err := ParseKey(rhs) if err != nil { - return fmt.Errorf("cannot decode key %v", err) + return fmt.Errorf("cannot decode key %w", err) } + cfg.PrivateKey = &key default: return fmt.Errorf("unknown directive %s", lhs) } + return nil } @@ -242,41 +278,48 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error { case "PublicKey": key, err := ParseKey(rhs) if err != nil { - return fmt.Errorf("cannot decode key %v", err) + return fmt.Errorf("cannot decode key %w", err) } + peerCfg.PublicKey = key case "PresharedKey": key, err := ParseKey(rhs) if err != nil { - return fmt.Errorf("cannot decode key %v", err) + return fmt.Errorf("cannot decode key %w", err) } + if peerCfg.PresharedKey != nil { - return fmt.Errorf("preshared key already defined %v", err) + return fmt.Errorf("preshared key already defined %w", err) } + peerCfg.PresharedKey = &key case "AllowedIPs": for _, addr := range strings.Split(rhs, ",") { ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr)) if err != nil { - return fmt.Errorf("cannot parse %s: %v", addr, err) + return fmt.Errorf("cannot parse %s: %w", addr, err) } + peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask}) } case "Endpoint": addr, err := net.ResolveUDPAddr("", rhs) if err != nil { - return err + return fmt.Errorf("%w", err) } + peerCfg.Endpoint = addr case "PersistentKeepalive": t, err := strconv.ParseInt(rhs, 10, 64) if err != nil { - return err + return fmt.Errorf("%w", err) } + dur := time.Duration(t * int64(time.Second)) peerCfg.PersistentKeepaliveInterval = &dur default: return fmt.Errorf("unknown directive %s", lhs) } + return nil } diff --git a/config_test.go b/config_test.go index 7f012ba..d4788b3 100644 --- a/config_test.go +++ b/config_test.go @@ -1,3 +1,4 @@ +// nolint: scopelint, gochecknoglobals, paralleltest, testpackage package wgquick import ( @@ -54,6 +55,7 @@ PersistentKeepalive = 25 func TestExampleConfig(t *testing.T) { c := &Config{} + for name, cfg := range testConfigs { t.Run(name, func(t *testing.T) { err := c.UnmarshalText([]byte(cfg)) diff --git a/wg.go b/wg.go index 9f2252d..76d4b2a 100644 --- a/wg.go +++ b/wg.go @@ -2,6 +2,7 @@ package wgquick import ( "bytes" + "errors" "fmt" "net" "os" @@ -31,20 +32,24 @@ func wgGo(iface string) error { } cmd := wgo.Command(iface) - cmd.Start() + if err := cmd.Start(); err != nil { + return fmt.Errorf("could not start wireguard-go: %w", err) + } return nil } -// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface` +// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`. func Up(cfg *Config, iface string, logger logrus.FieldLogger) error { log := logger.WithField("iface", iface) + _, err := netlink.LinkByName(iface) if err == nil { return os.ErrExist } - if _, ok := err.(netlink.LinkNotFoundError); !ok { - return err + + if errors.As(err, &netlink.LinkNotFoundError{}) { + return fmt.Errorf("%w", err) } for _, dns := range cfg.DNS { @@ -57,8 +62,10 @@ func Up(cfg *Config, iface string, logger logrus.FieldLogger) error { if err := execSh(cfg.PreUp, iface, log); err != nil { return err } + log.Infoln("applied pre-up command") } + if err := Sync(cfg, iface, logger); err != nil { return err } @@ -67,17 +74,20 @@ func Up(cfg *Config, iface string, logger logrus.FieldLogger) error { if err := execSh(cfg.PostUp, iface, log); err != nil { return err } + log.Infoln("applied post-up command") } + return nil } -// Down destroys the wg interface. Mostly equivalent to `wg-quick down iface` +// Down destroys the wg interface. Mostly equivalent to `wg-quick down iface`. func Down(cfg *Config, iface string, logger logrus.FieldLogger) error { log := logger.WithField("iface", iface) + link, err := netlink.LinkByName(iface) if err != nil { - return err + return fmt.Errorf("%w", err) } if len(cfg.DNS) > 1 { @@ -90,108 +100,138 @@ func Down(cfg *Config, iface string, logger logrus.FieldLogger) error { if err := execSh(cfg.PreDown, iface, log); err != nil { return err } + log.Infoln("applied pre-down command") } if err := netlink.LinkDel(link); err != nil { - return err + return fmt.Errorf("%w", err) } + log.Infoln("link deleted") + if cfg.PostDown != "" { if err := execSh(cfg.PostDown, iface, log); err != nil { return err } + log.Infoln("applied post-down command") } + return nil } func execSh(command string, iface string, log logrus.FieldLogger, stdin ...string) error { - cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface)) + cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface)) // nolint: gosec + if len(stdin) > 0 { log = log.WithField("stdin", strings.Join(stdin, "")) b := &bytes.Buffer{} + for _, ln := range stdin { if _, err := fmt.Fprint(b, ln); err != nil { - return err + return fmt.Errorf("%w", err) } } + cmd.Stdin = b } + out, err := cmd.CombinedOutput() if err != nil { log.WithError(err).Errorf("failed to execute %s:\n%s", cmd.Args, out) - return err + + return fmt.Errorf("%w", err) } + log.Infof("executed %s:\n%s", cmd.Args, out) + return nil } -// Sync the config to the current setup for given interface +// Sync the config to the current setup for given interface. // It perform 4 operations: -// * SyncLink --> makes sure link is up and type wireguard -// * SyncWireguardDevice --> configures allowedIP & other wireguard specific settings -// * SyncAddress --> synces linux addresses bounded to this interface -// * SyncRoutes --> synces all allowedIP routes to route to this interface +// * SyncLink --> makes sure link is up and type wireguard. +// * SyncWireguardDevice --> configures allowedIP & other wireguard specific settings. +// * SyncAddress --> synces linux addresses bounded to this interface. +// * SyncRoutes --> synces all allowedIP routes to route to this interface. func Sync(cfg *Config, iface string, logger logrus.FieldLogger) error { log := logger.WithField("iface", iface) link, err := SyncLink(cfg, iface, log) if err != nil { log.WithError(err).Errorln("cannot sync wireguard link") - return err + + return fmt.Errorf("%w", err) } + log.Info("synced link") if err := SyncWireguardDevice(cfg, link, log); err != nil { log.WithError(err).Errorln("cannot sync wireguard link") + return err } + log.Info("synced link") if err := SyncAddress(cfg, link, log); err != nil { log.WithError(err).Errorln("cannot sync addresses") - return err + + return fmt.Errorf("%w", err) } + log.Info("synced addresss") - var managedRoutes []net.IPNet + managedRoutes := make([]net.IPNet, 0) + for _, peer := range cfg.Peers { - for _, rt := range peer.AllowedIPs { - managedRoutes = append(managedRoutes, rt) - } + managedRoutes = append(managedRoutes, peer.AllowedIPs...) } + if err := SyncRoutes(cfg, link, managedRoutes, log); err != nil { log.WithError(err).Errorln("cannot sync routes") - return err + + return fmt.Errorf("%w", err) } + log.Info("synced routed") log.Info("Successfully synced device") + return nil } -// SyncWireguardDevice synces wireguard vpn setting on the given link. It does not set routes/addresses beyond wg internal crypto-key routing, only handles wireguard specific settings +// SyncWireguardDevice synces wireguard vpn setting on the given link. +// It does not set routes/addresses beyond wg internal crypto-key routing, only handles wireguard specific settings. func SyncWireguardDevice(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { cl, err := wgctrl.New() if err != nil { log.WithError(err).Errorln("cannot setup wireguard device") - return err + + return fmt.Errorf("%w", err) } + if err := cl.ConfigureDevice(link.Attrs().Name, cfg.Config); err != nil { log.WithError(err).Error("cannot configure device") - return err + + return fmt.Errorf("%w", err) } + return nil } -// SyncLink synces link state with the config. It does not sync Wireguard settings, just makes sure the device is up and type wireguard +// SyncLink synces link state with the config. +// It does not sync Wireguard settings, just makes sure the device is up and type wireguard. func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link, error) { link, err := netlink.LinkByName(iface) + // nolint: nestif if err != nil { - if _, ok := err.(netlink.LinkNotFoundError); !ok { + if errors.As(err, &netlink.LinkNotFoundError{}) { log.WithError(err).Error("cannot read link") - return nil, err + + return nil, fmt.Errorf("%w", err) } + log.Info("link not found, creating") wgLink := &netlink.GenericLink{ @@ -204,9 +244,11 @@ func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link, if err := netlink.LinkAdd(wgLink); err != nil { log.WithError(err).Errorf("cannot create link: %s", err.Error()) log.Info("trying to use embedded wireguard-go...") + if err := wgGo(iface); err != nil { log.WithError(err).Errorf("cannot create link through wireguard-go: %s", err.Error()) - return nil, fmt.Errorf("cannot create link") + + return nil, fmt.Errorf("cannot create link: %w", err) } } @@ -216,32 +258,41 @@ func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link, link, err = netlink.LinkByName(iface) if err != nil { log.WithError(err).Error("cannot read link") - return nil, err + + return nil, fmt.Errorf("%w", err) } } + if err := netlink.LinkSetUp(link); err != nil { log.WithError(err).Error("cannot set link up") - return nil, err + + return nil, fmt.Errorf("%w", err) } + log.Info("set device up") + return link, nil } // SyncAddress adds/deletes all lind assigned IPV4 addressed as specified in the config +// nolint: funlen, gosec, scopelint func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { addrs, err := netlink.AddrList(link, syscall.AF_INET) if err != nil { log.Error(err, "cannot read link address") - return err + + return fmt.Errorf("%w", err) } // nil addr means I've used it - presentAddresses := make(map[string]netlink.Addr, 0) + presentAddresses := make(map[string]netlink.Addr) + for _, addr := range addrs { log.WithFields(map[string]interface{}{ "addr": fmt.Sprint(addr.IPNet), "label": addr.Label, }).Debugf("found existing address: %v", addr) + presentAddresses[addr.IPNet.String()] = addr } @@ -249,19 +300,24 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { log := log.WithField("addr", addr.String()) _, present := presentAddresses[addr.String()] presentAddresses[addr.String()] = netlink.Addr{} // mark as present + if present { log.Info("address present") + continue } + if err := netlink.AddrAdd(link, &netlink.Addr{ IPNet: &addr, Label: cfg.AddressLabel, }); err != nil { - if err != syscall.EEXIST { + if errors.Is(err, syscall.EEXIST) { log.WithError(err).Error("cannot add addr") - return err + + return fmt.Errorf("%w", err) } } + log.Info("address added") } @@ -269,16 +325,21 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { if addr.IPNet == nil { continue } + log := log.WithFields(map[string]interface{}{ "addr": addr.IPNet.String(), "label": addr.Label, }) + if err := netlink.AddrDel(link, &addr); err != nil { log.WithError(err).Error("cannot delete addr") - return err + + return fmt.Errorf("%w", err) } + log.Info("addr deleted") } + return nil } @@ -298,13 +359,17 @@ func fillRouteDefaults(rt *netlink.Route) { } // SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config +// nolint: funlen, gosec, scopelint func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error { wantedRoutes := make(map[string][]netlink.Route, len(managedRoutes)) + presentRoutes, err := netlink.RouteList(link, syscall.AF_INET) if err != nil { log.Error(err, "cannot read existing routes") - return err + + return fmt.Errorf("%w", err) } + for _, rt := range managedRoutes { rt := rt // make copy log.WithField("dst", rt.String()).Debug("managing route") @@ -330,10 +395,13 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l "type": rt.Type, "metric": rt.Priority, }) + if err := netlink.RouteReplace(&rt); err != nil { log.WithError(err).Errorln("cannot add/replace route") - return err + + return fmt.Errorf("%w", err) } + log.Infoln("route added/replaced") } } @@ -344,6 +412,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l return true } } + return false } @@ -355,25 +424,31 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l "type": rt.Type, "metric": rt.Priority, }) + if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == unix.RT_CLASS_MAIN)) { log.Debug("wrong table for route, skipping") + continue } if !(rt.Protocol == cfg.RouteProtocol) { log.Infof("skipping route deletion, not owned by this daemon") + continue } if checkWanted(rt) { log.Debug("route wanted, skipping deleting") + continue } if err := netlink.RouteDel(&rt); err != nil { log.WithError(err).Error("cannot delete route") - return err + + return fmt.Errorf("%w", err) } + log.Info("route deleted") }