Better address/route deletion code

This commit is contained in:
Neven Miculinic 2019-03-27 17:40:06 +01:00
parent a4b7e8df32
commit db19e30e58

53
wg.go
View File

@ -4,11 +4,14 @@ import (
"github.com/mdlayher/wireguardctrl" "github.com/mdlayher/wireguardctrl"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"net"
"syscall" "syscall"
) )
const (
defaultRoutingTable = 254
)
// Sync the config to the current setup for given interface // Sync the config to the current setup for given interface
func (cfg *Config) Sync(iface string, logger logrus.FieldLogger) error { func (cfg *Config) Sync(iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface) log := logger.WithField("iface", iface)
@ -76,20 +79,20 @@ func syncAddress(link netlink.Link, cfg *Config, log logrus.FieldLogger) error {
return err return err
} }
presentAddresses := make(map[string]int, 0) // nil addr means I've used it
presentAddresses := make(map[string]*netlink.Addr, 0)
for _, addr := range addrs { for _, addr := range addrs {
presentAddresses[addr.IPNet.String()] = 1 presentAddresses[addr.IPNet.String()] = &addr
} }
for _, addr := range cfg.Address { for _, addr := range cfg.Address {
log := log.WithField("addr", addr) log := log.WithField("addr", addr)
_, present := presentAddresses[addr.String()] _, present := presentAddresses[addr.String()]
presentAddresses[addr.String()] = 2 presentAddresses[addr.String()] = nil // mark as present
if present { if present {
log.Info("address present") log.Info("address present")
continue continue
} }
if err := netlink.AddrAdd(link, &netlink.Addr{ if err := netlink.AddrAdd(link, &netlink.Addr{
IPNet: addr, IPNet: addr,
}); err != nil { }); err != nil {
@ -99,21 +102,17 @@ func syncAddress(link netlink.Link, cfg *Config, log logrus.FieldLogger) error {
log.Info("address added") log.Info("address added")
} }
for addr, p := range presentAddresses { for _, addr := range presentAddresses {
log := log.WithField("addr", addr) if addr == nil {
if p < 2 { continue
nlAddr, err := netlink.ParseAddr(addr)
if err != nil {
log.WithError(err).Error("cannot parse del addr")
return err
} }
if err := netlink.AddrAdd(link, nlAddr); err != nil { log := log.WithField("addr", addr.IPNet.String())
if err := netlink.AddrDel(link, addr); err != nil {
log.WithError(err).Error("cannot delete addr") log.WithError(err).Error("cannot delete addr")
return err return err
} }
log.Info("addr deleted") log.Info("addr deleted")
} }
}
return nil return nil
} }
@ -124,11 +123,11 @@ func syncRoutes(link netlink.Link, cfg *Config, log logrus.FieldLogger) error {
return err return err
} }
presentRoutes := make(map[string]int, 0) presentRoutes := make(map[string]*netlink.Route, 0)
for _, r := range routes { for _, r := range routes {
log := log.WithField("route", r.Dst.String()) log := log.WithField("route", r.Dst.String())
if r.Table == cfg.Table || cfg.Table == 0 { if r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable) {
presentRoutes[r.Dst.String()] = 1 presentRoutes[r.Dst.String()] = &r
log.WithField("table", r.Table).Debug("detected existing route") log.WithField("table", r.Table).Debug("detected existing route")
} else { } else {
log.Debug("wrong table for route, skipping") log.Debug("wrong table for route, skipping")
@ -138,7 +137,7 @@ func syncRoutes(link netlink.Link, cfg *Config, log logrus.FieldLogger) error {
for _, peer := range cfg.Peers { for _, peer := range cfg.Peers {
for _, rt := range peer.AllowedIPs { for _, rt := range peer.AllowedIPs {
_, present := presentRoutes[rt.String()] _, present := presentRoutes[rt.String()]
presentRoutes[rt.String()] = 2 presentRoutes[rt.String()] = nil // mark as visited
log := log.WithField("route", rt.String()) log := log.WithField("route", rt.String())
if present { if present {
log.Info("route present") log.Info("route present")
@ -157,25 +156,17 @@ func syncRoutes(link netlink.Link, cfg *Config, log logrus.FieldLogger) error {
} }
// Clean extra routes // Clean extra routes
for rtStr, p := range presentRoutes { for _, rt := range presentRoutes {
_, rt, err := net.ParseCIDR(rtStr) if rt == nil { // skip visited routes
log := log.WithField("route", rt.String()) continue
if err != nil {
log.WithError(err).Error("cannot parse route")
return err
} }
if p < 2 { log := log.WithField("route", rt.Dst.String())
log.Info("extra manual route found") log.Info("extra manual route found")
if err := netlink.RouteDel(&netlink.Route{ if err := netlink.RouteDel(rt); err != nil {
LinkIndex: link.Attrs().Index,
Dst: rt,
Table: cfg.Table,
}); err != nil {
log.WithError(err).Error("cannot setup route") log.WithError(err).Error("cannot setup route")
return err return err
} }
log.Info("route deleted") log.Info("route deleted")
} }
}
return nil return nil
} }