Better address/route deletion code
This commit is contained in:
parent
a4b7e8df32
commit
db19e30e58
53
wg.go
53
wg.go
@ -4,11 +4,14 @@ import (
|
|||||||
"github.com/mdlayher/wireguardctrl"
|
"github.com/mdlayher/wireguardctrl"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"net"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultRoutingTable = 254
|
||||||
|
)
|
||||||
|
|
||||||
// Sync the config to the current setup for given interface
|
// Sync the config to the current setup for given interface
|
||||||
func (cfg *Config) Sync(iface string, logger logrus.FieldLogger) error {
|
func (cfg *Config) Sync(iface string, logger logrus.FieldLogger) error {
|
||||||
log := logger.WithField("iface", iface)
|
log := logger.WithField("iface", iface)
|
||||||
@ -76,20 +79,20 @@ func syncAddress(link netlink.Link, cfg *Config, log logrus.FieldLogger) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
presentAddresses := make(map[string]int, 0)
|
// nil addr means I've used it
|
||||||
|
presentAddresses := make(map[string]*netlink.Addr, 0)
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
presentAddresses[addr.IPNet.String()] = 1
|
presentAddresses[addr.IPNet.String()] = &addr
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, addr := range cfg.Address {
|
for _, addr := range cfg.Address {
|
||||||
log := log.WithField("addr", addr)
|
log := log.WithField("addr", addr)
|
||||||
_, present := presentAddresses[addr.String()]
|
_, present := presentAddresses[addr.String()]
|
||||||
presentAddresses[addr.String()] = 2
|
presentAddresses[addr.String()] = nil // mark as present
|
||||||
if present {
|
if present {
|
||||||
log.Info("address present")
|
log.Info("address present")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.AddrAdd(link, &netlink.Addr{
|
if err := netlink.AddrAdd(link, &netlink.Addr{
|
||||||
IPNet: addr,
|
IPNet: addr,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@ -99,21 +102,17 @@ func syncAddress(link netlink.Link, cfg *Config, log logrus.FieldLogger) error {
|
|||||||
log.Info("address added")
|
log.Info("address added")
|
||||||
}
|
}
|
||||||
|
|
||||||
for addr, p := range presentAddresses {
|
for _, addr := range presentAddresses {
|
||||||
log := log.WithField("addr", addr)
|
if addr == nil {
|
||||||
if p < 2 {
|
continue
|
||||||
nlAddr, err := netlink.ParseAddr(addr)
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Error("cannot parse del addr")
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
if err := netlink.AddrAdd(link, nlAddr); err != nil {
|
log := log.WithField("addr", addr.IPNet.String())
|
||||||
|
if err := netlink.AddrDel(link, addr); err != nil {
|
||||||
log.WithError(err).Error("cannot delete addr")
|
log.WithError(err).Error("cannot delete addr")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Info("addr deleted")
|
log.Info("addr deleted")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -124,11 +123,11 @@ func syncRoutes(link netlink.Link, cfg *Config, log logrus.FieldLogger) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
presentRoutes := make(map[string]int, 0)
|
presentRoutes := make(map[string]*netlink.Route, 0)
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
log := log.WithField("route", r.Dst.String())
|
log := log.WithField("route", r.Dst.String())
|
||||||
if r.Table == cfg.Table || cfg.Table == 0 {
|
if r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable) {
|
||||||
presentRoutes[r.Dst.String()] = 1
|
presentRoutes[r.Dst.String()] = &r
|
||||||
log.WithField("table", r.Table).Debug("detected existing route")
|
log.WithField("table", r.Table).Debug("detected existing route")
|
||||||
} else {
|
} else {
|
||||||
log.Debug("wrong table for route, skipping")
|
log.Debug("wrong table for route, skipping")
|
||||||
@ -138,7 +137,7 @@ func syncRoutes(link netlink.Link, cfg *Config, log logrus.FieldLogger) error {
|
|||||||
for _, peer := range cfg.Peers {
|
for _, peer := range cfg.Peers {
|
||||||
for _, rt := range peer.AllowedIPs {
|
for _, rt := range peer.AllowedIPs {
|
||||||
_, present := presentRoutes[rt.String()]
|
_, present := presentRoutes[rt.String()]
|
||||||
presentRoutes[rt.String()] = 2
|
presentRoutes[rt.String()] = nil // mark as visited
|
||||||
log := log.WithField("route", rt.String())
|
log := log.WithField("route", rt.String())
|
||||||
if present {
|
if present {
|
||||||
log.Info("route present")
|
log.Info("route present")
|
||||||
@ -157,25 +156,17 @@ func syncRoutes(link netlink.Link, cfg *Config, log logrus.FieldLogger) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Clean extra routes
|
// Clean extra routes
|
||||||
for rtStr, p := range presentRoutes {
|
for _, rt := range presentRoutes {
|
||||||
_, rt, err := net.ParseCIDR(rtStr)
|
if rt == nil { // skip visited routes
|
||||||
log := log.WithField("route", rt.String())
|
continue
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Error("cannot parse route")
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
if p < 2 {
|
log := log.WithField("route", rt.Dst.String())
|
||||||
log.Info("extra manual route found")
|
log.Info("extra manual route found")
|
||||||
if err := netlink.RouteDel(&netlink.Route{
|
if err := netlink.RouteDel(rt); err != nil {
|
||||||
LinkIndex: link.Attrs().Index,
|
|
||||||
Dst: rt,
|
|
||||||
Table: cfg.Table,
|
|
||||||
}); err != nil {
|
|
||||||
log.WithError(err).Error("cannot setup route")
|
log.WithError(err).Error("cannot setup route")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Info("route deleted")
|
log.Info("route deleted")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user