diff --git a/cmd/wg-quick/main.go b/cmd/wg-quick/main.go index ab764c0..aa9cece 100644 --- a/cmd/wg-quick/main.go +++ b/cmd/wg-quick/main.go @@ -3,10 +3,11 @@ package main import ( "flag" "fmt" - "github.com/nmiculinic/wg-quick-go" - "github.com/sirupsen/logrus" "io/ioutil" "os" + + "github.com/nmiculinic/wg-quick-go" + "github.com/sirupsen/logrus" ) func printHelp() { @@ -19,6 +20,7 @@ func main() { flag.String("iface", "", "interface") verbose := flag.Bool("v", false, "verbose") protocol := flag.Int("route-protocol", 0, "route protocol to use for our routes") + metric := flag.Int("route-metric", 0, "route metric to use for our routes") flag.Parse() args := flag.Args() if len(args) != 2 { @@ -63,6 +65,7 @@ func main() { } c.RouteProtocol = *protocol + c.RouteMetric = *metric switch args[0] { case "up": diff --git a/config.go b/config.go index 9cbb353..f156036 100644 --- a/config.go +++ b/config.go @@ -5,12 +5,13 @@ import ( "encoding" "encoding/base64" "fmt" - "github.com/mdlayher/wireguardctrl/wgtypes" "net" "strconv" "strings" "text/template" "time" + + "github.com/mdlayher/wireguardctrl/wgtypes" ) // Config represents full wg-quick like config structure @@ -38,6 +39,9 @@ type Config struct { // RouteProtocol to set on the route. See linux/rtnetlink.h Use value > 4 or default 0 RouteProtocol int + // RouteMetric sets this metric on all managed routes. Lower number means pick this one + RouteMetric int + // Address label to set on the link AddressLabel string diff --git a/wg.go b/wg.go index d5e4d04..7e0a315 100644 --- a/wg.go +++ b/wg.go @@ -286,77 +286,99 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { return nil } +func fillRouteDefaults(rt *netlink.Route) { + // fill defaults + if rt.Table == 0 { + rt.Table = defaultRoutingTable + } + + if rt.Protocol == 0 { + rt.Protocol = RTPROT_BOOT + } + + if rt.Type == 0 { + rt.Type = 1 // RTN_UNICAST + } +} + // SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error { - routes, err := netlink.RouteList(link, syscall.AF_INET) + var wantedRoutes = make(map[string][]netlink.Route, len(managedRoutes)) + presentRoutes, err := netlink.RouteList(link, syscall.AF_INET) if err != nil { log.Error(err, "cannot read existing routes") return err } - - presentRoutes := make(map[string]netlink.Route, 0) - for _, r := range routes { - log := log.WithFields(map[string]interface{}{ - "route": r.Dst.String(), - "protocol": r.Protocol, - "table": r.Table, - "type": r.Type, - }) - log.Debugf("detected existing route: %v", r) - if !(r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable)) { - log.Debug("wrong table for route, skipping") - continue - } - presentRoutes[r.Dst.String()] = r - log.Debug("added route to consideration") - } - for _, rt := range managedRoutes { - log := log.WithField("route", rt.String()) - route, present := presentRoutes[rt.String()] - presentRoutes[rt.String()] = netlink.Route{} // mark as visited - if present { - if route.Dst != nil && route.Protocol != cfg.RouteProtocol { - log.Warnf("route present; proto=%d != defined root proto=%d", route.Protocol, cfg.RouteProtocol) - } else { - log.Info("route present") - } - continue - } - if err := netlink.RouteAdd(&netlink.Route{ + rt := rt // make copy + log.WithField("dst", rt.String()).Debug("managing route") + + nrt := netlink.Route{ LinkIndex: link.Attrs().Index, Dst: &rt, Table: cfg.Table, Protocol: cfg.RouteProtocol, - }); err != nil { - log.WithError(err).Error("cannot setup route") - return err - } - log.Info("route added") + Priority: cfg.RouteMetric} + fillRouteDefaults(&nrt) + wantedRoutes[rt.String()] = append(wantedRoutes[rt.String()], nrt) } - // Clean extra routes - for _, rt := range presentRoutes { - if rt.Dst == nil { // skip visited routes - continue + for _, rtLst := range wantedRoutes { + for _, rt := range rtLst { + rt := rt // make copy + log := log.WithFields(map[string]interface{}{ + "route": rt.Dst.String(), + "protocol": rt.Protocol, + "table": rt.Table, + "type": rt.Type, + "metric": rt.Priority, + }) + if err := netlink.RouteReplace(&rt); err != nil { + log.WithError(err).Errorln("cannot add/replace route") + return err + } + log.Infoln("route added/replaced") } + } + + checkWanted := func(rt netlink.Route) bool { + for _, candidateRt := range wantedRoutes[rt.Dst.String()] { + if rt.Equal(candidateRt) { + return true + } + } + return false + } + + for _, rt := range presentRoutes { log := log.WithFields(map[string]interface{}{ "route": rt.Dst.String(), "protocol": rt.Protocol, "table": rt.Table, "type": rt.Type, + "metric": rt.Priority, }) - log.Info("extra manual route found") - // RTPROT_BOOT is default one when other proto isn't defined - if !(rt.Protocol == cfg.RouteProtocol || rt.Protocol == RTPROT_BOOT && cfg.RouteProtocol == 0) { + if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == defaultRoutingTable)) { + log.Debug("wrong table for route, skipping") + continue + } + + if !(rt.Protocol == cfg.RouteProtocol) { log.Infof("skipping route deletion, not owned by this daemon") continue } + + if checkWanted(rt) { + log.Debug("route wanted, skipping deleting") + continue + } + if err := netlink.RouteDel(&rt); err != nil { log.WithError(err).Error("cannot delete route") return err } log.Info("route deleted") } + return nil }