happy linting

This commit is contained in:
Marvin Steadfast 2021-01-21 10:25:04 +01:00
parent d9764025be
commit a5695600e2
4 changed files with 186 additions and 66 deletions

View File

@ -52,4 +52,4 @@ test:
.PHONY: lint .PHONY: lint
lint: lint:
golangci-lint run --enable-all --disable gomnd --disable godox --timeout 5m golangci-lint run --enable-all --disable gomnd --disable godox --disable exhaustivestruct --timeout 5m

View File

@ -1,3 +1,4 @@
// nolint: gochecknoglobals, gomnd, goerr113
package wgquick package wgquick
import ( import (
@ -14,23 +15,32 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// Config represents full wg-quick like config structure // Config represents full wg-quick like config structure.
type Config struct { type Config struct {
wgtypes.Config wgtypes.Config
// Address list of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned to the interface. May be specified multiple times. // Address list of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned
// to the interface. May be specified multiple times.
Address []net.IPNet Address []net.IPNet
// list of IP (v4 or v6) addresses to be set as the interfaces DNS servers. May be specified multiple times. Upon bringing the interface up, this runs resolvconf -a tun.INTERFACE -m 0 -x and upon bringing it down, this runs resolvconf -d tun.INTERFACE. If these particular invocations of resolvconf(8) are undesirable, the PostUp and PostDown keys below may be used instead. // list of IP (v4 or v6) addresses to be set as the interfaces DNS servers. May be specified multiple times.
// Upon bringing the interface up, this runs resolvconf -a tun.INTERFACE -m 0 -x and upon bringing it down,
// this runs resolvconf -d tun.INTERFACE. If these particular invocations of resolvconf(8) are undesirable,
// the PostUp and PostDown keys below may be used instead.
DNS []net.IP DNS []net.IP
// MTU is automatically determined from the endpoint addresses or the system default route, which is usually a sane choice. However, to manually specify an MTU to override this automatic discovery, this value may be specified explicitly. // MTU is automatically determined from the endpoint addresses or the system default route,
// which is usually a sane choice. However, to manually specify an MTU to override this automatic discovery,
// this value may be specified explicitly.
MTU int MTU int
// Table — Controls the routing table to which routes are added. // Table — Controls the routing table to which routes are added.
Table int Table int
// PreUp, PostUp, PreDown, PostDown — script snippets which will be executed by bash(1) before/after setting up/tearing down the interface, most commonly used to configure custom DNS options or firewall rules. The special string %i is expanded to INTERFACE. Each one may be specified multiple times, in which case the commands are executed in order. // PreUp, PostUp, PreDown, PostDown — script snippets which will be executed by bash(1)
// before/after setting up/tearing down the interface, most commonly used to configure
// custom DNS options or firewall rules. The special string %i is expanded to INTERFACE.
// Each one may be specified multiple times, in which case the commands are executed in order.
PreUp string PreUp string
PostUp string PostUp string
PreDown string PreDown string
@ -50,14 +60,17 @@ type Config struct {
SaveConfig bool SaveConfig bool
} }
var _ encoding.TextMarshaler = (*Config)(nil) var (
var _ encoding.TextUnmarshaler = (*Config)(nil) _ encoding.TextMarshaler = (*Config)(nil)
_ encoding.TextUnmarshaler = (*Config)(nil)
)
func (cfg *Config) String() string { func (cfg *Config) String() string {
b, err := cfg.MarshalText() b, err := cfg.MarshalText()
if err != nil { if err != nil {
panic(err) panic(err)
} }
return string(b) return string(b)
} }
@ -83,11 +96,13 @@ var cfgTemplate = template.Must(
func (cfg *Config) MarshalText() (text []byte, err error) { func (cfg *Config) MarshalText() (text []byte, err error) {
buff := &bytes.Buffer{} buff := &bytes.Buffer{}
if err := cfgTemplate.Execute(buff, cfg); err != nil { if err := cfgTemplate.Execute(buff, cfg); err != nil {
return nil, err return nil, fmt.Errorf("%w", err)
} }
return buff.Bytes(), nil return buff.Bytes(), nil
} }
// nolint: lll
const wgtypeTemplateSpec = `[Interface] const wgtypeTemplateSpec = `[Interface]
{{- range .Address }} {{- range .Address }}
Address = {{ . }} Address = {{ . }}
@ -115,14 +130,17 @@ AllowedIPs = {{ range $i, $el := .AllowedIPs }}{{if $i}}, {{ end }}{{ $el }}{{ e
{{- end }} {{- end }}
` `
// ParseKey parses the base64 encoded wireguard private key // ParseKey parses the base64 encoded wireguard private key.
func ParseKey(key string) (wgtypes.Key, error) { func ParseKey(key string) (wgtypes.Key, error) {
var pkey wgtypes.Key var pkey wgtypes.Key
pkeySlice, err := base64.StdEncoding.DecodeString(key) pkeySlice, err := base64.StdEncoding.DecodeString(key)
if err != nil { if err != nil {
return pkey, err return pkey, fmt.Errorf("%w", err)
} }
copy(pkey[:], pkeySlice[:])
copy(pkey[:], pkeySlice)
return pkey, nil return pkey, nil
} }
@ -137,17 +155,21 @@ const (
func (cfg *Config) UnmarshalText(text []byte) error { func (cfg *Config) UnmarshalText(text []byte) error {
*cfg = Config{} // Zero out the config *cfg = Config{} // Zero out the config
state := unknown state := unknown
var peerCfg *wgtypes.PeerConfig var peerCfg *wgtypes.PeerConfig
for no, line := range strings.Split(string(text), "\n") { for no, line := range strings.Split(string(text), "\n") {
ln := strings.TrimSpace(line) ln := strings.TrimSpace(line)
if len(ln) == 0 || ln[0] == '#' { if len(ln) == 0 || ln[0] == '#' {
continue continue
} }
switch ln { switch ln {
case "[Interface]": case "[Interface]":
state = inter state = inter
case "[Peer]": case "[Peer]":
state = peer state = peer
cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{}) cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{})
peerCfg = &cfg.Peers[len(cfg.Peers)-1] peerCfg = &cfg.Peers[len(cfg.Peers)-1]
default: default:
@ -155,33 +177,40 @@ func (cfg *Config) UnmarshalText(text []byte) error {
if len(parts) < 2 { if len(parts) < 2 {
return fmt.Errorf("cannot parse line %d, missing =", no) return fmt.Errorf("cannot parse line %d, missing =", no)
} }
lhs := strings.TrimSpace(parts[0]) lhs := strings.TrimSpace(parts[0])
rhs := strings.TrimSpace(strings.Join(parts[1:], "=")) rhs := strings.TrimSpace(strings.Join(parts[1:], "="))
switch state { switch state {
case inter: case inter:
if err := parseInterfaceLine(cfg, lhs, rhs); err != nil { if err := parseInterfaceLine(cfg, lhs, rhs); err != nil {
return fmt.Errorf("[line %d]: %v", no+1, err) return fmt.Errorf("[line %d]: %w", no+1, err)
} }
case peer: case peer:
if err := parsePeerLine(peerCfg, lhs, rhs); err != nil { if err := parsePeerLine(peerCfg, lhs, rhs); err != nil {
return fmt.Errorf("[line %d]: %v", no+1, err) return fmt.Errorf("[line %d]: %w", no+1, err)
} }
default: case unknown:
return fmt.Errorf("[line %d] cannot parse, unknown state", no+1) return fmt.Errorf("[line %d] cannot parse, unknown state", no+1)
default:
return fmt.Errorf("switch couldnt find state")
} }
} }
} }
return nil return nil
} }
// nolint: funlen
func parseInterfaceLine(cfg *Config, lhs string, rhs string) error { func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
switch lhs { switch lhs {
case "Address": case "Address":
for _, addr := range strings.Split(rhs, ",") { for _, addr := range strings.Split(rhs, ",") {
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr)) ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
if err != nil { if err != nil {
return err return fmt.Errorf("%w", err)
} }
cfg.Address = append(cfg.Address, net.IPNet{IP: ip, Mask: cidr.Mask}) cfg.Address = append(cfg.Address, net.IPNet{IP: ip, Mask: cidr.Mask})
} }
case "DNS": case "DNS":
@ -190,25 +219,29 @@ func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
if ip == nil { if ip == nil {
return fmt.Errorf("cannot parse IP") return fmt.Errorf("cannot parse IP")
} }
cfg.DNS = append(cfg.DNS, ip) cfg.DNS = append(cfg.DNS, ip)
} }
case "MTU": case "MTU":
mtu, err := strconv.ParseInt(rhs, 10, 64) mtu, err := strconv.ParseInt(rhs, 10, 64)
if err != nil { if err != nil {
return err return fmt.Errorf("%w", err)
} }
cfg.MTU = int(mtu) cfg.MTU = int(mtu)
case "Table": case "Table":
tbl, err := strconv.ParseInt(rhs, 10, 64) tbl, err := strconv.ParseInt(rhs, 10, 64)
if err != nil { if err != nil {
return err return fmt.Errorf("%w", err)
} }
cfg.Table = int(tbl) cfg.Table = int(tbl)
case "ListenPort": case "ListenPort":
portI64, err := strconv.ParseInt(rhs, 10, 64) portI64, err := strconv.ParseInt(rhs, 10, 64)
if err != nil { if err != nil {
return err return fmt.Errorf("%w", err)
} }
port := int(portI64) port := int(portI64)
cfg.ListenPort = &port cfg.ListenPort = &port
case "PreUp": case "PreUp":
@ -222,18 +255,21 @@ func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
case "SaveConfig": case "SaveConfig":
save, err := strconv.ParseBool(rhs) save, err := strconv.ParseBool(rhs)
if err != nil { if err != nil {
return err return fmt.Errorf("%w", err)
} }
cfg.SaveConfig = save cfg.SaveConfig = save
case "PrivateKey": case "PrivateKey":
key, err := ParseKey(rhs) key, err := ParseKey(rhs)
if err != nil { if err != nil {
return fmt.Errorf("cannot decode key %v", err) return fmt.Errorf("cannot decode key %w", err)
} }
cfg.PrivateKey = &key cfg.PrivateKey = &key
default: default:
return fmt.Errorf("unknown directive %s", lhs) return fmt.Errorf("unknown directive %s", lhs)
} }
return nil return nil
} }
@ -242,41 +278,48 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
case "PublicKey": case "PublicKey":
key, err := ParseKey(rhs) key, err := ParseKey(rhs)
if err != nil { if err != nil {
return fmt.Errorf("cannot decode key %v", err) return fmt.Errorf("cannot decode key %w", err)
} }
peerCfg.PublicKey = key peerCfg.PublicKey = key
case "PresharedKey": case "PresharedKey":
key, err := ParseKey(rhs) key, err := ParseKey(rhs)
if err != nil { if err != nil {
return fmt.Errorf("cannot decode key %v", err) return fmt.Errorf("cannot decode key %w", err)
} }
if peerCfg.PresharedKey != nil { if peerCfg.PresharedKey != nil {
return fmt.Errorf("preshared key already defined %v", err) return fmt.Errorf("preshared key already defined %w", err)
} }
peerCfg.PresharedKey = &key peerCfg.PresharedKey = &key
case "AllowedIPs": case "AllowedIPs":
for _, addr := range strings.Split(rhs, ",") { for _, addr := range strings.Split(rhs, ",") {
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr)) ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
if err != nil { if err != nil {
return fmt.Errorf("cannot parse %s: %v", addr, err) return fmt.Errorf("cannot parse %s: %w", addr, err)
} }
peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask}) peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask})
} }
case "Endpoint": case "Endpoint":
addr, err := net.ResolveUDPAddr("", rhs) addr, err := net.ResolveUDPAddr("", rhs)
if err != nil { if err != nil {
return err return fmt.Errorf("%w", err)
} }
peerCfg.Endpoint = addr peerCfg.Endpoint = addr
case "PersistentKeepalive": case "PersistentKeepalive":
t, err := strconv.ParseInt(rhs, 10, 64) t, err := strconv.ParseInt(rhs, 10, 64)
if err != nil { if err != nil {
return err return fmt.Errorf("%w", err)
} }
dur := time.Duration(t * int64(time.Second)) dur := time.Duration(t * int64(time.Second))
peerCfg.PersistentKeepaliveInterval = &dur peerCfg.PersistentKeepaliveInterval = &dur
default: default:
return fmt.Errorf("unknown directive %s", lhs) return fmt.Errorf("unknown directive %s", lhs)
} }
return nil return nil
} }

View File

@ -1,3 +1,4 @@
// nolint: scopelint, gochecknoglobals, paralleltest, testpackage
package wgquick package wgquick
import ( import (
@ -54,6 +55,7 @@ PersistentKeepalive = 25
func TestExampleConfig(t *testing.T) { func TestExampleConfig(t *testing.T) {
c := &Config{} c := &Config{}
for name, cfg := range testConfigs { for name, cfg := range testConfigs {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
err := c.UnmarshalText([]byte(cfg)) err := c.UnmarshalText([]byte(cfg))

153
wg.go
View File

@ -2,6 +2,7 @@ package wgquick
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"net" "net"
"os" "os"
@ -31,20 +32,24 @@ func wgGo(iface string) error {
} }
cmd := wgo.Command(iface) cmd := wgo.Command(iface)
cmd.Start() if err := cmd.Start(); err != nil {
return fmt.Errorf("could not start wireguard-go: %w", err)
}
return nil return nil
} }
// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface` // Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`.
func Up(cfg *Config, iface string, logger logrus.FieldLogger) error { func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface) log := logger.WithField("iface", iface)
_, err := netlink.LinkByName(iface) _, err := netlink.LinkByName(iface)
if err == nil { if err == nil {
return os.ErrExist return os.ErrExist
} }
if _, ok := err.(netlink.LinkNotFoundError); !ok {
return err if errors.As(err, &netlink.LinkNotFoundError{}) {
return fmt.Errorf("%w", err)
} }
for _, dns := range cfg.DNS { for _, dns := range cfg.DNS {
@ -57,8 +62,10 @@ func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
if err := execSh(cfg.PreUp, iface, log); err != nil { if err := execSh(cfg.PreUp, iface, log); err != nil {
return err return err
} }
log.Infoln("applied pre-up command") log.Infoln("applied pre-up command")
} }
if err := Sync(cfg, iface, logger); err != nil { if err := Sync(cfg, iface, logger); err != nil {
return err return err
} }
@ -67,17 +74,20 @@ func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
if err := execSh(cfg.PostUp, iface, log); err != nil { if err := execSh(cfg.PostUp, iface, log); err != nil {
return err return err
} }
log.Infoln("applied post-up command") log.Infoln("applied post-up command")
} }
return nil return nil
} }
// Down destroys the wg interface. Mostly equivalent to `wg-quick down iface` // Down destroys the wg interface. Mostly equivalent to `wg-quick down iface`.
func Down(cfg *Config, iface string, logger logrus.FieldLogger) error { func Down(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface) log := logger.WithField("iface", iface)
link, err := netlink.LinkByName(iface) link, err := netlink.LinkByName(iface)
if err != nil { if err != nil {
return err return fmt.Errorf("%w", err)
} }
if len(cfg.DNS) > 1 { if len(cfg.DNS) > 1 {
@ -90,108 +100,138 @@ func Down(cfg *Config, iface string, logger logrus.FieldLogger) error {
if err := execSh(cfg.PreDown, iface, log); err != nil { if err := execSh(cfg.PreDown, iface, log); err != nil {
return err return err
} }
log.Infoln("applied pre-down command") log.Infoln("applied pre-down command")
} }
if err := netlink.LinkDel(link); err != nil { if err := netlink.LinkDel(link); err != nil {
return err return fmt.Errorf("%w", err)
} }
log.Infoln("link deleted") log.Infoln("link deleted")
if cfg.PostDown != "" { if cfg.PostDown != "" {
if err := execSh(cfg.PostDown, iface, log); err != nil { if err := execSh(cfg.PostDown, iface, log); err != nil {
return err return err
} }
log.Infoln("applied post-down command") log.Infoln("applied post-down command")
} }
return nil return nil
} }
func execSh(command string, iface string, log logrus.FieldLogger, stdin ...string) error { func execSh(command string, iface string, log logrus.FieldLogger, stdin ...string) error {
cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface)) cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface)) // nolint: gosec
if len(stdin) > 0 { if len(stdin) > 0 {
log = log.WithField("stdin", strings.Join(stdin, "")) log = log.WithField("stdin", strings.Join(stdin, ""))
b := &bytes.Buffer{} b := &bytes.Buffer{}
for _, ln := range stdin { for _, ln := range stdin {
if _, err := fmt.Fprint(b, ln); err != nil { if _, err := fmt.Fprint(b, ln); err != nil {
return err return fmt.Errorf("%w", err)
} }
} }
cmd.Stdin = b cmd.Stdin = b
} }
out, err := cmd.CombinedOutput() out, err := cmd.CombinedOutput()
if err != nil { if err != nil {
log.WithError(err).Errorf("failed to execute %s:\n%s", cmd.Args, out) log.WithError(err).Errorf("failed to execute %s:\n%s", cmd.Args, out)
return err
return fmt.Errorf("%w", err)
} }
log.Infof("executed %s:\n%s", cmd.Args, out) log.Infof("executed %s:\n%s", cmd.Args, out)
return nil return nil
} }
// Sync the config to the current setup for given interface // Sync the config to the current setup for given interface.
// It perform 4 operations: // It perform 4 operations:
// * SyncLink --> makes sure link is up and type wireguard // * SyncLink --> makes sure link is up and type wireguard.
// * SyncWireguardDevice --> configures allowedIP & other wireguard specific settings // * SyncWireguardDevice --> configures allowedIP & other wireguard specific settings.
// * SyncAddress --> synces linux addresses bounded to this interface // * SyncAddress --> synces linux addresses bounded to this interface.
// * SyncRoutes --> synces all allowedIP routes to route to this interface // * SyncRoutes --> synces all allowedIP routes to route to this interface.
func Sync(cfg *Config, iface string, logger logrus.FieldLogger) error { func Sync(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface) log := logger.WithField("iface", iface)
link, err := SyncLink(cfg, iface, log) link, err := SyncLink(cfg, iface, log)
if err != nil { if err != nil {
log.WithError(err).Errorln("cannot sync wireguard link") log.WithError(err).Errorln("cannot sync wireguard link")
return err
return fmt.Errorf("%w", err)
} }
log.Info("synced link") log.Info("synced link")
if err := SyncWireguardDevice(cfg, link, log); err != nil { if err := SyncWireguardDevice(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync wireguard link") log.WithError(err).Errorln("cannot sync wireguard link")
return err return err
} }
log.Info("synced link") log.Info("synced link")
if err := SyncAddress(cfg, link, log); err != nil { if err := SyncAddress(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync addresses") log.WithError(err).Errorln("cannot sync addresses")
return err
return fmt.Errorf("%w", err)
} }
log.Info("synced addresss") log.Info("synced addresss")
var managedRoutes []net.IPNet managedRoutes := make([]net.IPNet, 0)
for _, peer := range cfg.Peers { for _, peer := range cfg.Peers {
for _, rt := range peer.AllowedIPs { managedRoutes = append(managedRoutes, peer.AllowedIPs...)
managedRoutes = append(managedRoutes, rt)
}
} }
if err := SyncRoutes(cfg, link, managedRoutes, log); err != nil { if err := SyncRoutes(cfg, link, managedRoutes, log); err != nil {
log.WithError(err).Errorln("cannot sync routes") log.WithError(err).Errorln("cannot sync routes")
return err
return fmt.Errorf("%w", err)
} }
log.Info("synced routed") log.Info("synced routed")
log.Info("Successfully synced device") log.Info("Successfully synced device")
return nil return nil
} }
// SyncWireguardDevice synces wireguard vpn setting on the given link. It does not set routes/addresses beyond wg internal crypto-key routing, only handles wireguard specific settings // SyncWireguardDevice synces wireguard vpn setting on the given link.
// It does not set routes/addresses beyond wg internal crypto-key routing, only handles wireguard specific settings.
func SyncWireguardDevice(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { func SyncWireguardDevice(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
cl, err := wgctrl.New() cl, err := wgctrl.New()
if err != nil { if err != nil {
log.WithError(err).Errorln("cannot setup wireguard device") log.WithError(err).Errorln("cannot setup wireguard device")
return err
return fmt.Errorf("%w", err)
} }
if err := cl.ConfigureDevice(link.Attrs().Name, cfg.Config); err != nil { if err := cl.ConfigureDevice(link.Attrs().Name, cfg.Config); err != nil {
log.WithError(err).Error("cannot configure device") log.WithError(err).Error("cannot configure device")
return err
return fmt.Errorf("%w", err)
} }
return nil return nil
} }
// SyncLink synces link state with the config. It does not sync Wireguard settings, just makes sure the device is up and type wireguard // SyncLink synces link state with the config.
// It does not sync Wireguard settings, just makes sure the device is up and type wireguard.
func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link, error) { func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link, error) {
link, err := netlink.LinkByName(iface) link, err := netlink.LinkByName(iface)
// nolint: nestif
if err != nil { if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); !ok { if errors.As(err, &netlink.LinkNotFoundError{}) {
log.WithError(err).Error("cannot read link") log.WithError(err).Error("cannot read link")
return nil, err
return nil, fmt.Errorf("%w", err)
} }
log.Info("link not found, creating") log.Info("link not found, creating")
wgLink := &netlink.GenericLink{ wgLink := &netlink.GenericLink{
@ -204,9 +244,11 @@ func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link,
if err := netlink.LinkAdd(wgLink); err != nil { if err := netlink.LinkAdd(wgLink); err != nil {
log.WithError(err).Errorf("cannot create link: %s", err.Error()) log.WithError(err).Errorf("cannot create link: %s", err.Error())
log.Info("trying to use embedded wireguard-go...") log.Info("trying to use embedded wireguard-go...")
if err := wgGo(iface); err != nil { if err := wgGo(iface); err != nil {
log.WithError(err).Errorf("cannot create link through wireguard-go: %s", err.Error()) log.WithError(err).Errorf("cannot create link through wireguard-go: %s", err.Error())
return nil, fmt.Errorf("cannot create link")
return nil, fmt.Errorf("cannot create link: %w", err)
} }
} }
@ -216,32 +258,41 @@ func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link,
link, err = netlink.LinkByName(iface) link, err = netlink.LinkByName(iface)
if err != nil { if err != nil {
log.WithError(err).Error("cannot read link") log.WithError(err).Error("cannot read link")
return nil, err
return nil, fmt.Errorf("%w", err)
} }
} }
if err := netlink.LinkSetUp(link); err != nil { if err := netlink.LinkSetUp(link); err != nil {
log.WithError(err).Error("cannot set link up") log.WithError(err).Error("cannot set link up")
return nil, err
return nil, fmt.Errorf("%w", err)
} }
log.Info("set device up") log.Info("set device up")
return link, nil return link, nil
} }
// SyncAddress adds/deletes all lind assigned IPV4 addressed as specified in the config // SyncAddress adds/deletes all lind assigned IPV4 addressed as specified in the config
// nolint: funlen, gosec, scopelint
func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
addrs, err := netlink.AddrList(link, syscall.AF_INET) addrs, err := netlink.AddrList(link, syscall.AF_INET)
if err != nil { if err != nil {
log.Error(err, "cannot read link address") log.Error(err, "cannot read link address")
return err
return fmt.Errorf("%w", err)
} }
// nil addr means I've used it // nil addr means I've used it
presentAddresses := make(map[string]netlink.Addr, 0) presentAddresses := make(map[string]netlink.Addr)
for _, addr := range addrs { for _, addr := range addrs {
log.WithFields(map[string]interface{}{ log.WithFields(map[string]interface{}{
"addr": fmt.Sprint(addr.IPNet), "addr": fmt.Sprint(addr.IPNet),
"label": addr.Label, "label": addr.Label,
}).Debugf("found existing address: %v", addr) }).Debugf("found existing address: %v", addr)
presentAddresses[addr.IPNet.String()] = addr presentAddresses[addr.IPNet.String()] = addr
} }
@ -249,19 +300,24 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
log := log.WithField("addr", addr.String()) log := log.WithField("addr", addr.String())
_, present := presentAddresses[addr.String()] _, present := presentAddresses[addr.String()]
presentAddresses[addr.String()] = netlink.Addr{} // mark as present presentAddresses[addr.String()] = netlink.Addr{} // mark as present
if present { if present {
log.Info("address present") log.Info("address present")
continue continue
} }
if err := netlink.AddrAdd(link, &netlink.Addr{ if err := netlink.AddrAdd(link, &netlink.Addr{
IPNet: &addr, IPNet: &addr,
Label: cfg.AddressLabel, Label: cfg.AddressLabel,
}); err != nil { }); err != nil {
if err != syscall.EEXIST { if errors.Is(err, syscall.EEXIST) {
log.WithError(err).Error("cannot add addr") log.WithError(err).Error("cannot add addr")
return err
return fmt.Errorf("%w", err)
} }
} }
log.Info("address added") log.Info("address added")
} }
@ -269,16 +325,21 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
if addr.IPNet == nil { if addr.IPNet == nil {
continue continue
} }
log := log.WithFields(map[string]interface{}{ log := log.WithFields(map[string]interface{}{
"addr": addr.IPNet.String(), "addr": addr.IPNet.String(),
"label": addr.Label, "label": addr.Label,
}) })
if err := netlink.AddrDel(link, &addr); err != nil { if err := netlink.AddrDel(link, &addr); err != nil {
log.WithError(err).Error("cannot delete addr") log.WithError(err).Error("cannot delete addr")
return err
return fmt.Errorf("%w", err)
} }
log.Info("addr deleted") log.Info("addr deleted")
} }
return nil return nil
} }
@ -298,13 +359,17 @@ func fillRouteDefaults(rt *netlink.Route) {
} }
// SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config // SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config
// nolint: funlen, gosec, scopelint
func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error { func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error {
wantedRoutes := make(map[string][]netlink.Route, len(managedRoutes)) wantedRoutes := make(map[string][]netlink.Route, len(managedRoutes))
presentRoutes, err := netlink.RouteList(link, syscall.AF_INET) presentRoutes, err := netlink.RouteList(link, syscall.AF_INET)
if err != nil { if err != nil {
log.Error(err, "cannot read existing routes") log.Error(err, "cannot read existing routes")
return err
return fmt.Errorf("%w", err)
} }
for _, rt := range managedRoutes { for _, rt := range managedRoutes {
rt := rt // make copy rt := rt // make copy
log.WithField("dst", rt.String()).Debug("managing route") log.WithField("dst", rt.String()).Debug("managing route")
@ -330,10 +395,13 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l
"type": rt.Type, "type": rt.Type,
"metric": rt.Priority, "metric": rt.Priority,
}) })
if err := netlink.RouteReplace(&rt); err != nil { if err := netlink.RouteReplace(&rt); err != nil {
log.WithError(err).Errorln("cannot add/replace route") log.WithError(err).Errorln("cannot add/replace route")
return err
return fmt.Errorf("%w", err)
} }
log.Infoln("route added/replaced") log.Infoln("route added/replaced")
} }
} }
@ -344,6 +412,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l
return true return true
} }
} }
return false return false
} }
@ -355,25 +424,31 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l
"type": rt.Type, "type": rt.Type,
"metric": rt.Priority, "metric": rt.Priority,
}) })
if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == unix.RT_CLASS_MAIN)) { if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == unix.RT_CLASS_MAIN)) {
log.Debug("wrong table for route, skipping") log.Debug("wrong table for route, skipping")
continue continue
} }
if !(rt.Protocol == cfg.RouteProtocol) { if !(rt.Protocol == cfg.RouteProtocol) {
log.Infof("skipping route deletion, not owned by this daemon") log.Infof("skipping route deletion, not owned by this daemon")
continue continue
} }
if checkWanted(rt) { if checkWanted(rt) {
log.Debug("route wanted, skipping deleting") log.Debug("route wanted, skipping deleting")
continue continue
} }
if err := netlink.RouteDel(&rt); err != nil { if err := netlink.RouteDel(&rt); err != nil {
log.WithError(err).Error("cannot delete route") log.WithError(err).Error("cannot delete route")
return err
return fmt.Errorf("%w", err)
} }
log.Info("route deleted") log.Info("route deleted")
} }