update
This commit is contained in:
parent
6b3b3f100f
commit
6bc3601354
@ -3,10 +3,11 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/nmiculinic/wg-quick-go"
|
||||
"github.com/sirupsen/logrus"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"github.com/nmiculinic/wg-quick-go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func printHelp() {
|
||||
@ -19,6 +20,7 @@ func main() {
|
||||
flag.String("iface", "", "interface")
|
||||
verbose := flag.Bool("v", false, "verbose")
|
||||
protocol := flag.Int("route-protocol", 0, "route protocol to use for our routes")
|
||||
metric := flag.Int("route-metric", 0, "route metric to use for our routes")
|
||||
flag.Parse()
|
||||
args := flag.Args()
|
||||
if len(args) != 2 {
|
||||
@ -63,6 +65,7 @@ func main() {
|
||||
}
|
||||
|
||||
c.RouteProtocol = *protocol
|
||||
c.RouteMetric = *metric
|
||||
|
||||
switch args[0] {
|
||||
case "up":
|
||||
|
@ -5,12 +5,13 @@ import (
|
||||
"encoding"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/mdlayher/wireguardctrl/wgtypes"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/mdlayher/wireguardctrl/wgtypes"
|
||||
)
|
||||
|
||||
// Config represents full wg-quick like config structure
|
||||
@ -38,6 +39,9 @@ type Config struct {
|
||||
// RouteProtocol to set on the route. See linux/rtnetlink.h Use value > 4 or default 0
|
||||
RouteProtocol int
|
||||
|
||||
// RouteMetric sets this metric on all managed routes. Lower number means pick this one
|
||||
RouteMetric int
|
||||
|
||||
// Address label to set on the link
|
||||
AddressLabel string
|
||||
|
||||
|
110
wg.go
110
wg.go
@ -286,77 +286,99 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func fillRouteDefaults(rt *netlink.Route) {
|
||||
// fill defaults
|
||||
if rt.Table == 0 {
|
||||
rt.Table = defaultRoutingTable
|
||||
}
|
||||
|
||||
if rt.Protocol == 0 {
|
||||
rt.Protocol = RTPROT_BOOT
|
||||
}
|
||||
|
||||
if rt.Type == 0 {
|
||||
rt.Type = 1 // RTN_UNICAST
|
||||
}
|
||||
}
|
||||
|
||||
// SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config
|
||||
func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error {
|
||||
routes, err := netlink.RouteList(link, syscall.AF_INET)
|
||||
var wantedRoutes = make(map[string][]netlink.Route, len(managedRoutes))
|
||||
presentRoutes, err := netlink.RouteList(link, syscall.AF_INET)
|
||||
if err != nil {
|
||||
log.Error(err, "cannot read existing routes")
|
||||
return err
|
||||
}
|
||||
|
||||
presentRoutes := make(map[string]netlink.Route, 0)
|
||||
for _, r := range routes {
|
||||
log := log.WithFields(map[string]interface{}{
|
||||
"route": r.Dst.String(),
|
||||
"protocol": r.Protocol,
|
||||
"table": r.Table,
|
||||
"type": r.Type,
|
||||
})
|
||||
log.Debugf("detected existing route: %v", r)
|
||||
if !(r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable)) {
|
||||
log.Debug("wrong table for route, skipping")
|
||||
continue
|
||||
}
|
||||
presentRoutes[r.Dst.String()] = r
|
||||
log.Debug("added route to consideration")
|
||||
}
|
||||
|
||||
for _, rt := range managedRoutes {
|
||||
log := log.WithField("route", rt.String())
|
||||
route, present := presentRoutes[rt.String()]
|
||||
presentRoutes[rt.String()] = netlink.Route{} // mark as visited
|
||||
if present {
|
||||
if route.Dst != nil && route.Protocol != cfg.RouteProtocol {
|
||||
log.Warnf("route present; proto=%d != defined root proto=%d", route.Protocol, cfg.RouteProtocol)
|
||||
} else {
|
||||
log.Info("route present")
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := netlink.RouteAdd(&netlink.Route{
|
||||
rt := rt // make copy
|
||||
log.WithField("dst", rt.String()).Debug("managing route")
|
||||
|
||||
nrt := netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: &rt,
|
||||
Table: cfg.Table,
|
||||
Protocol: cfg.RouteProtocol,
|
||||
}); err != nil {
|
||||
log.WithError(err).Error("cannot setup route")
|
||||
return err
|
||||
}
|
||||
log.Info("route added")
|
||||
Priority: cfg.RouteMetric}
|
||||
fillRouteDefaults(&nrt)
|
||||
wantedRoutes[rt.String()] = append(wantedRoutes[rt.String()], nrt)
|
||||
}
|
||||
|
||||
// Clean extra routes
|
||||
for _, rt := range presentRoutes {
|
||||
if rt.Dst == nil { // skip visited routes
|
||||
continue
|
||||
}
|
||||
for _, rtLst := range wantedRoutes {
|
||||
for _, rt := range rtLst {
|
||||
rt := rt // make copy
|
||||
log := log.WithFields(map[string]interface{}{
|
||||
"route": rt.Dst.String(),
|
||||
"protocol": rt.Protocol,
|
||||
"table": rt.Table,
|
||||
"type": rt.Type,
|
||||
"metric": rt.Priority,
|
||||
})
|
||||
log.Info("extra manual route found")
|
||||
// RTPROT_BOOT is default one when other proto isn't defined
|
||||
if !(rt.Protocol == cfg.RouteProtocol || rt.Protocol == RTPROT_BOOT && cfg.RouteProtocol == 0) {
|
||||
if err := netlink.RouteReplace(&rt); err != nil {
|
||||
log.WithError(err).Errorln("cannot add/replace route")
|
||||
return err
|
||||
}
|
||||
log.Infoln("route added/replaced")
|
||||
}
|
||||
}
|
||||
|
||||
checkWanted := func(rt netlink.Route) bool {
|
||||
for _, candidateRt := range wantedRoutes[rt.Dst.String()] {
|
||||
if rt.Equal(candidateRt) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, rt := range presentRoutes {
|
||||
log := log.WithFields(map[string]interface{}{
|
||||
"route": rt.Dst.String(),
|
||||
"protocol": rt.Protocol,
|
||||
"table": rt.Table,
|
||||
"type": rt.Type,
|
||||
"metric": rt.Priority,
|
||||
})
|
||||
if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == defaultRoutingTable)) {
|
||||
log.Debug("wrong table for route, skipping")
|
||||
continue
|
||||
}
|
||||
|
||||
if !(rt.Protocol == cfg.RouteProtocol) {
|
||||
log.Infof("skipping route deletion, not owned by this daemon")
|
||||
continue
|
||||
}
|
||||
|
||||
if checkWanted(rt) {
|
||||
log.Debug("route wanted, skipping deleting")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := netlink.RouteDel(&rt); err != nil {
|
||||
log.WithError(err).Error("cannot delete route")
|
||||
return err
|
||||
}
|
||||
log.Info("route deleted")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user