This commit is contained in:
Neven Miculinic 2019-03-29 12:40:15 +01:00
parent 6b3b3f100f
commit 6bc3601354
3 changed files with 75 additions and 46 deletions

View File

@ -3,10 +3,11 @@ package main
import (
"flag"
"fmt"
"github.com/nmiculinic/wg-quick-go"
"github.com/sirupsen/logrus"
"io/ioutil"
"os"
"github.com/nmiculinic/wg-quick-go"
"github.com/sirupsen/logrus"
)
func printHelp() {
@ -19,6 +20,7 @@ func main() {
flag.String("iface", "", "interface")
verbose := flag.Bool("v", false, "verbose")
protocol := flag.Int("route-protocol", 0, "route protocol to use for our routes")
metric := flag.Int("route-metric", 0, "route metric to use for our routes")
flag.Parse()
args := flag.Args()
if len(args) != 2 {
@ -63,6 +65,7 @@ func main() {
}
c.RouteProtocol = *protocol
c.RouteMetric = *metric
switch args[0] {
case "up":

View File

@ -5,12 +5,13 @@ import (
"encoding"
"encoding/base64"
"fmt"
"github.com/mdlayher/wireguardctrl/wgtypes"
"net"
"strconv"
"strings"
"text/template"
"time"
"github.com/mdlayher/wireguardctrl/wgtypes"
)
// Config represents full wg-quick like config structure
@ -38,6 +39,9 @@ type Config struct {
// RouteProtocol to set on the route. See linux/rtnetlink.h Use value > 4 or default 0
RouteProtocol int
// RouteMetric sets this metric on all managed routes. Lower number means pick this one
RouteMetric int
// Address label to set on the link
AddressLabel string

110
wg.go
View File

@ -286,77 +286,99 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
return nil
}
func fillRouteDefaults(rt *netlink.Route) {
// fill defaults
if rt.Table == 0 {
rt.Table = defaultRoutingTable
}
if rt.Protocol == 0 {
rt.Protocol = RTPROT_BOOT
}
if rt.Type == 0 {
rt.Type = 1 // RTN_UNICAST
}
}
// SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config
func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error {
routes, err := netlink.RouteList(link, syscall.AF_INET)
var wantedRoutes = make(map[string][]netlink.Route, len(managedRoutes))
presentRoutes, err := netlink.RouteList(link, syscall.AF_INET)
if err != nil {
log.Error(err, "cannot read existing routes")
return err
}
presentRoutes := make(map[string]netlink.Route, 0)
for _, r := range routes {
log := log.WithFields(map[string]interface{}{
"route": r.Dst.String(),
"protocol": r.Protocol,
"table": r.Table,
"type": r.Type,
})
log.Debugf("detected existing route: %v", r)
if !(r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable)) {
log.Debug("wrong table for route, skipping")
continue
}
presentRoutes[r.Dst.String()] = r
log.Debug("added route to consideration")
}
for _, rt := range managedRoutes {
log := log.WithField("route", rt.String())
route, present := presentRoutes[rt.String()]
presentRoutes[rt.String()] = netlink.Route{} // mark as visited
if present {
if route.Dst != nil && route.Protocol != cfg.RouteProtocol {
log.Warnf("route present; proto=%d != defined root proto=%d", route.Protocol, cfg.RouteProtocol)
} else {
log.Info("route present")
}
continue
}
if err := netlink.RouteAdd(&netlink.Route{
rt := rt // make copy
log.WithField("dst", rt.String()).Debug("managing route")
nrt := netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: &rt,
Table: cfg.Table,
Protocol: cfg.RouteProtocol,
}); err != nil {
log.WithError(err).Error("cannot setup route")
return err
}
log.Info("route added")
Priority: cfg.RouteMetric}
fillRouteDefaults(&nrt)
wantedRoutes[rt.String()] = append(wantedRoutes[rt.String()], nrt)
}
// Clean extra routes
for _, rt := range presentRoutes {
if rt.Dst == nil { // skip visited routes
continue
}
for _, rtLst := range wantedRoutes {
for _, rt := range rtLst {
rt := rt // make copy
log := log.WithFields(map[string]interface{}{
"route": rt.Dst.String(),
"protocol": rt.Protocol,
"table": rt.Table,
"type": rt.Type,
"metric": rt.Priority,
})
log.Info("extra manual route found")
// RTPROT_BOOT is default one when other proto isn't defined
if !(rt.Protocol == cfg.RouteProtocol || rt.Protocol == RTPROT_BOOT && cfg.RouteProtocol == 0) {
if err := netlink.RouteReplace(&rt); err != nil {
log.WithError(err).Errorln("cannot add/replace route")
return err
}
log.Infoln("route added/replaced")
}
}
checkWanted := func(rt netlink.Route) bool {
for _, candidateRt := range wantedRoutes[rt.Dst.String()] {
if rt.Equal(candidateRt) {
return true
}
}
return false
}
for _, rt := range presentRoutes {
log := log.WithFields(map[string]interface{}{
"route": rt.Dst.String(),
"protocol": rt.Protocol,
"table": rt.Table,
"type": rt.Type,
"metric": rt.Priority,
})
if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == defaultRoutingTable)) {
log.Debug("wrong table for route, skipping")
continue
}
if !(rt.Protocol == cfg.RouteProtocol) {
log.Infof("skipping route deletion, not owned by this daemon")
continue
}
if checkWanted(rt) {
log.Debug("route wanted, skipping deleting")
continue
}
if err := netlink.RouteDel(&rt); err != nil {
log.WithError(err).Error("cannot delete route")
return err
}
log.Info("route deleted")
}
return nil
}