diff --git a/wg.go b/wg.go index 6300968..81ea34c 100644 --- a/wg.go +++ b/wg.go @@ -4,11 +4,14 @@ import ( "github.com/mdlayher/wireguardctrl" "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" - "net" "syscall" ) +const ( + defaultRoutingTable = 254 +) + // Sync the config to the current setup for given interface func (cfg *Config) Sync(iface string, logger logrus.FieldLogger) error { log := logger.WithField("iface", iface) @@ -76,20 +79,20 @@ func syncAddress(link netlink.Link, cfg *Config, log logrus.FieldLogger) error { return err } - presentAddresses := make(map[string]int, 0) + // nil addr means I've used it + presentAddresses := make(map[string]*netlink.Addr, 0) for _, addr := range addrs { - presentAddresses[addr.IPNet.String()] = 1 + presentAddresses[addr.IPNet.String()] = &addr } for _, addr := range cfg.Address { log := log.WithField("addr", addr) _, present := presentAddresses[addr.String()] - presentAddresses[addr.String()] = 2 + presentAddresses[addr.String()] = nil // mark as present if present { log.Info("address present") continue } - if err := netlink.AddrAdd(link, &netlink.Addr{ IPNet: addr, }); err != nil { @@ -99,20 +102,16 @@ func syncAddress(link netlink.Link, cfg *Config, log logrus.FieldLogger) error { log.Info("address added") } - for addr, p := range presentAddresses { - log := log.WithField("addr", addr) - if p < 2 { - nlAddr, err := netlink.ParseAddr(addr) - if err != nil { - log.WithError(err).Error("cannot parse del addr") - return err - } - if err := netlink.AddrAdd(link, nlAddr); err != nil { - log.WithError(err).Error("cannot delete addr") - return err - } - log.Info("addr deleted") + for _, addr := range presentAddresses { + if addr == nil { + continue } + log := log.WithField("addr", addr.IPNet.String()) + if err := netlink.AddrDel(link, addr); err != nil { + log.WithError(err).Error("cannot delete addr") + return err + } + log.Info("addr deleted") } return nil } @@ -124,11 +123,11 @@ func syncRoutes(link netlink.Link, cfg *Config, log logrus.FieldLogger) error { return err } - presentRoutes := make(map[string]int, 0) + presentRoutes := make(map[string]*netlink.Route, 0) for _, r := range routes { log := log.WithField("route", r.Dst.String()) - if r.Table == cfg.Table || cfg.Table == 0 { - presentRoutes[r.Dst.String()] = 1 + if r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable) { + presentRoutes[r.Dst.String()] = &r log.WithField("table", r.Table).Debug("detected existing route") } else { log.Debug("wrong table for route, skipping") @@ -138,7 +137,7 @@ func syncRoutes(link netlink.Link, cfg *Config, log logrus.FieldLogger) error { for _, peer := range cfg.Peers { for _, rt := range peer.AllowedIPs { _, present := presentRoutes[rt.String()] - presentRoutes[rt.String()] = 2 + presentRoutes[rt.String()] = nil // mark as visited log := log.WithField("route", rt.String()) if present { log.Info("route present") @@ -157,25 +156,17 @@ func syncRoutes(link netlink.Link, cfg *Config, log logrus.FieldLogger) error { } // Clean extra routes - for rtStr, p := range presentRoutes { - _, rt, err := net.ParseCIDR(rtStr) - log := log.WithField("route", rt.String()) - if err != nil { - log.WithError(err).Error("cannot parse route") + for _, rt := range presentRoutes { + if rt == nil { // skip visited routes + continue + } + log := log.WithField("route", rt.Dst.String()) + log.Info("extra manual route found") + if err := netlink.RouteDel(rt); err != nil { + log.WithError(err).Error("cannot setup route") return err } - if p < 2 { - log.Info("extra manual route found") - if err := netlink.RouteDel(&netlink.Route{ - LinkIndex: link.Attrs().Index, - Dst: rt, - Table: cfg.Table, - }); err != nil { - log.WithError(err).Error("cannot setup route") - return err - } - log.Info("route deleted") - } + log.Info("route deleted") } return nil }