This commit is contained in:
Neven Miculinic 2019-03-29 12:40:15 +01:00
parent 6b3b3f100f
commit 6bc3601354
3 changed files with 75 additions and 46 deletions

View File

@ -3,10 +3,11 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"github.com/nmiculinic/wg-quick-go"
"github.com/sirupsen/logrus"
"io/ioutil" "io/ioutil"
"os" "os"
"github.com/nmiculinic/wg-quick-go"
"github.com/sirupsen/logrus"
) )
func printHelp() { func printHelp() {
@ -19,6 +20,7 @@ func main() {
flag.String("iface", "", "interface") flag.String("iface", "", "interface")
verbose := flag.Bool("v", false, "verbose") verbose := flag.Bool("v", false, "verbose")
protocol := flag.Int("route-protocol", 0, "route protocol to use for our routes") protocol := flag.Int("route-protocol", 0, "route protocol to use for our routes")
metric := flag.Int("route-metric", 0, "route metric to use for our routes")
flag.Parse() flag.Parse()
args := flag.Args() args := flag.Args()
if len(args) != 2 { if len(args) != 2 {
@ -63,6 +65,7 @@ func main() {
} }
c.RouteProtocol = *protocol c.RouteProtocol = *protocol
c.RouteMetric = *metric
switch args[0] { switch args[0] {
case "up": case "up":

View File

@ -5,12 +5,13 @@ import (
"encoding" "encoding"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/mdlayher/wireguardctrl/wgtypes"
"net" "net"
"strconv" "strconv"
"strings" "strings"
"text/template" "text/template"
"time" "time"
"github.com/mdlayher/wireguardctrl/wgtypes"
) )
// Config represents full wg-quick like config structure // Config represents full wg-quick like config structure
@ -38,6 +39,9 @@ type Config struct {
// RouteProtocol to set on the route. See linux/rtnetlink.h Use value > 4 or default 0 // RouteProtocol to set on the route. See linux/rtnetlink.h Use value > 4 or default 0
RouteProtocol int RouteProtocol int
// RouteMetric sets this metric on all managed routes. Lower number means pick this one
RouteMetric int
// Address label to set on the link // Address label to set on the link
AddressLabel string AddressLabel string

108
wg.go
View File

@ -286,77 +286,99 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
return nil return nil
} }
func fillRouteDefaults(rt *netlink.Route) {
// fill defaults
if rt.Table == 0 {
rt.Table = defaultRoutingTable
}
if rt.Protocol == 0 {
rt.Protocol = RTPROT_BOOT
}
if rt.Type == 0 {
rt.Type = 1 // RTN_UNICAST
}
}
// SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config // SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config
func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error { func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error {
routes, err := netlink.RouteList(link, syscall.AF_INET) var wantedRoutes = make(map[string][]netlink.Route, len(managedRoutes))
presentRoutes, err := netlink.RouteList(link, syscall.AF_INET)
if err != nil { if err != nil {
log.Error(err, "cannot read existing routes") log.Error(err, "cannot read existing routes")
return err return err
} }
presentRoutes := make(map[string]netlink.Route, 0)
for _, r := range routes {
log := log.WithFields(map[string]interface{}{
"route": r.Dst.String(),
"protocol": r.Protocol,
"table": r.Table,
"type": r.Type,
})
log.Debugf("detected existing route: %v", r)
if !(r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable)) {
log.Debug("wrong table for route, skipping")
continue
}
presentRoutes[r.Dst.String()] = r
log.Debug("added route to consideration")
}
for _, rt := range managedRoutes { for _, rt := range managedRoutes {
log := log.WithField("route", rt.String()) rt := rt // make copy
route, present := presentRoutes[rt.String()] log.WithField("dst", rt.String()).Debug("managing route")
presentRoutes[rt.String()] = netlink.Route{} // mark as visited
if present { nrt := netlink.Route{
if route.Dst != nil && route.Protocol != cfg.RouteProtocol {
log.Warnf("route present; proto=%d != defined root proto=%d", route.Protocol, cfg.RouteProtocol)
} else {
log.Info("route present")
}
continue
}
if err := netlink.RouteAdd(&netlink.Route{
LinkIndex: link.Attrs().Index, LinkIndex: link.Attrs().Index,
Dst: &rt, Dst: &rt,
Table: cfg.Table, Table: cfg.Table,
Protocol: cfg.RouteProtocol, Protocol: cfg.RouteProtocol,
}); err != nil { Priority: cfg.RouteMetric}
log.WithError(err).Error("cannot setup route") fillRouteDefaults(&nrt)
return err wantedRoutes[rt.String()] = append(wantedRoutes[rt.String()], nrt)
}
log.Info("route added")
} }
// Clean extra routes for _, rtLst := range wantedRoutes {
for _, rt := range presentRoutes { for _, rt := range rtLst {
if rt.Dst == nil { // skip visited routes rt := rt // make copy
continue log := log.WithFields(map[string]interface{}{
"route": rt.Dst.String(),
"protocol": rt.Protocol,
"table": rt.Table,
"type": rt.Type,
"metric": rt.Priority,
})
if err := netlink.RouteReplace(&rt); err != nil {
log.WithError(err).Errorln("cannot add/replace route")
return err
}
log.Infoln("route added/replaced")
} }
}
checkWanted := func(rt netlink.Route) bool {
for _, candidateRt := range wantedRoutes[rt.Dst.String()] {
if rt.Equal(candidateRt) {
return true
}
}
return false
}
for _, rt := range presentRoutes {
log := log.WithFields(map[string]interface{}{ log := log.WithFields(map[string]interface{}{
"route": rt.Dst.String(), "route": rt.Dst.String(),
"protocol": rt.Protocol, "protocol": rt.Protocol,
"table": rt.Table, "table": rt.Table,
"type": rt.Type, "type": rt.Type,
"metric": rt.Priority,
}) })
log.Info("extra manual route found") if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == defaultRoutingTable)) {
// RTPROT_BOOT is default one when other proto isn't defined log.Debug("wrong table for route, skipping")
if !(rt.Protocol == cfg.RouteProtocol || rt.Protocol == RTPROT_BOOT && cfg.RouteProtocol == 0) { continue
}
if !(rt.Protocol == cfg.RouteProtocol) {
log.Infof("skipping route deletion, not owned by this daemon") log.Infof("skipping route deletion, not owned by this daemon")
continue continue
} }
if checkWanted(rt) {
log.Debug("route wanted, skipping deleting")
continue
}
if err := netlink.RouteDel(&rt); err != nil { if err := netlink.RouteDel(&rt); err != nil {
log.WithError(err).Error("cannot delete route") log.WithError(err).Error("cannot delete route")
return err return err
} }
log.Info("route deleted") log.Info("route deleted")
} }
return nil return nil
} }