update
This commit is contained in:
parent
6b3b3f100f
commit
6bc3601354
@ -3,10 +3,11 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/nmiculinic/wg-quick-go"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/nmiculinic/wg-quick-go"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func printHelp() {
|
func printHelp() {
|
||||||
@ -19,6 +20,7 @@ func main() {
|
|||||||
flag.String("iface", "", "interface")
|
flag.String("iface", "", "interface")
|
||||||
verbose := flag.Bool("v", false, "verbose")
|
verbose := flag.Bool("v", false, "verbose")
|
||||||
protocol := flag.Int("route-protocol", 0, "route protocol to use for our routes")
|
protocol := flag.Int("route-protocol", 0, "route protocol to use for our routes")
|
||||||
|
metric := flag.Int("route-metric", 0, "route metric to use for our routes")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
args := flag.Args()
|
args := flag.Args()
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
@ -63,6 +65,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.RouteProtocol = *protocol
|
c.RouteProtocol = *protocol
|
||||||
|
c.RouteMetric = *metric
|
||||||
|
|
||||||
switch args[0] {
|
switch args[0] {
|
||||||
case "up":
|
case "up":
|
||||||
|
@ -5,12 +5,13 @@ import (
|
|||||||
"encoding"
|
"encoding"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/mdlayher/wireguardctrl/wgtypes"
|
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mdlayher/wireguardctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config represents full wg-quick like config structure
|
// Config represents full wg-quick like config structure
|
||||||
@ -38,6 +39,9 @@ type Config struct {
|
|||||||
// RouteProtocol to set on the route. See linux/rtnetlink.h Use value > 4 or default 0
|
// RouteProtocol to set on the route. See linux/rtnetlink.h Use value > 4 or default 0
|
||||||
RouteProtocol int
|
RouteProtocol int
|
||||||
|
|
||||||
|
// RouteMetric sets this metric on all managed routes. Lower number means pick this one
|
||||||
|
RouteMetric int
|
||||||
|
|
||||||
// Address label to set on the link
|
// Address label to set on the link
|
||||||
AddressLabel string
|
AddressLabel string
|
||||||
|
|
||||||
|
108
wg.go
108
wg.go
@ -286,77 +286,99 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func fillRouteDefaults(rt *netlink.Route) {
|
||||||
|
// fill defaults
|
||||||
|
if rt.Table == 0 {
|
||||||
|
rt.Table = defaultRoutingTable
|
||||||
|
}
|
||||||
|
|
||||||
|
if rt.Protocol == 0 {
|
||||||
|
rt.Protocol = RTPROT_BOOT
|
||||||
|
}
|
||||||
|
|
||||||
|
if rt.Type == 0 {
|
||||||
|
rt.Type = 1 // RTN_UNICAST
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config
|
// SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config
|
||||||
func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error {
|
func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error {
|
||||||
routes, err := netlink.RouteList(link, syscall.AF_INET)
|
var wantedRoutes = make(map[string][]netlink.Route, len(managedRoutes))
|
||||||
|
presentRoutes, err := netlink.RouteList(link, syscall.AF_INET)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err, "cannot read existing routes")
|
log.Error(err, "cannot read existing routes")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
presentRoutes := make(map[string]netlink.Route, 0)
|
|
||||||
for _, r := range routes {
|
|
||||||
log := log.WithFields(map[string]interface{}{
|
|
||||||
"route": r.Dst.String(),
|
|
||||||
"protocol": r.Protocol,
|
|
||||||
"table": r.Table,
|
|
||||||
"type": r.Type,
|
|
||||||
})
|
|
||||||
log.Debugf("detected existing route: %v", r)
|
|
||||||
if !(r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable)) {
|
|
||||||
log.Debug("wrong table for route, skipping")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
presentRoutes[r.Dst.String()] = r
|
|
||||||
log.Debug("added route to consideration")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rt := range managedRoutes {
|
for _, rt := range managedRoutes {
|
||||||
log := log.WithField("route", rt.String())
|
rt := rt // make copy
|
||||||
route, present := presentRoutes[rt.String()]
|
log.WithField("dst", rt.String()).Debug("managing route")
|
||||||
presentRoutes[rt.String()] = netlink.Route{} // mark as visited
|
|
||||||
if present {
|
nrt := netlink.Route{
|
||||||
if route.Dst != nil && route.Protocol != cfg.RouteProtocol {
|
|
||||||
log.Warnf("route present; proto=%d != defined root proto=%d", route.Protocol, cfg.RouteProtocol)
|
|
||||||
} else {
|
|
||||||
log.Info("route present")
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := netlink.RouteAdd(&netlink.Route{
|
|
||||||
LinkIndex: link.Attrs().Index,
|
LinkIndex: link.Attrs().Index,
|
||||||
Dst: &rt,
|
Dst: &rt,
|
||||||
Table: cfg.Table,
|
Table: cfg.Table,
|
||||||
Protocol: cfg.RouteProtocol,
|
Protocol: cfg.RouteProtocol,
|
||||||
}); err != nil {
|
Priority: cfg.RouteMetric}
|
||||||
log.WithError(err).Error("cannot setup route")
|
fillRouteDefaults(&nrt)
|
||||||
return err
|
wantedRoutes[rt.String()] = append(wantedRoutes[rt.String()], nrt)
|
||||||
}
|
|
||||||
log.Info("route added")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean extra routes
|
for _, rtLst := range wantedRoutes {
|
||||||
for _, rt := range presentRoutes {
|
for _, rt := range rtLst {
|
||||||
if rt.Dst == nil { // skip visited routes
|
rt := rt // make copy
|
||||||
continue
|
log := log.WithFields(map[string]interface{}{
|
||||||
|
"route": rt.Dst.String(),
|
||||||
|
"protocol": rt.Protocol,
|
||||||
|
"table": rt.Table,
|
||||||
|
"type": rt.Type,
|
||||||
|
"metric": rt.Priority,
|
||||||
|
})
|
||||||
|
if err := netlink.RouteReplace(&rt); err != nil {
|
||||||
|
log.WithError(err).Errorln("cannot add/replace route")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Infoln("route added/replaced")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
checkWanted := func(rt netlink.Route) bool {
|
||||||
|
for _, candidateRt := range wantedRoutes[rt.Dst.String()] {
|
||||||
|
if rt.Equal(candidateRt) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rt := range presentRoutes {
|
||||||
log := log.WithFields(map[string]interface{}{
|
log := log.WithFields(map[string]interface{}{
|
||||||
"route": rt.Dst.String(),
|
"route": rt.Dst.String(),
|
||||||
"protocol": rt.Protocol,
|
"protocol": rt.Protocol,
|
||||||
"table": rt.Table,
|
"table": rt.Table,
|
||||||
"type": rt.Type,
|
"type": rt.Type,
|
||||||
|
"metric": rt.Priority,
|
||||||
})
|
})
|
||||||
log.Info("extra manual route found")
|
if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == defaultRoutingTable)) {
|
||||||
// RTPROT_BOOT is default one when other proto isn't defined
|
log.Debug("wrong table for route, skipping")
|
||||||
if !(rt.Protocol == cfg.RouteProtocol || rt.Protocol == RTPROT_BOOT && cfg.RouteProtocol == 0) {
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !(rt.Protocol == cfg.RouteProtocol) {
|
||||||
log.Infof("skipping route deletion, not owned by this daemon")
|
log.Infof("skipping route deletion, not owned by this daemon")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if checkWanted(rt) {
|
||||||
|
log.Debug("route wanted, skipping deleting")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if err := netlink.RouteDel(&rt); err != nil {
|
if err := netlink.RouteDel(&rt); err != nil {
|
||||||
log.WithError(err).Error("cannot delete route")
|
log.WithError(err).Error("cannot delete route")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Info("route deleted")
|
log.Info("route deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user