Added route protocol filtering
This commit is contained in:
parent
8f048ca6d8
commit
9a0ad14ca6
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
/build
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
|
29
.idea/watcherTasks.xml
Normal file
29
.idea/watcherTasks.xml
Normal file
@ -0,0 +1,29 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectTasksOptions">
|
||||
<TaskOptions isEnabled="true">
|
||||
<option name="arguments" value="-w $FilePath$" />
|
||||
<option name="checkSyntaxErrors" value="true" />
|
||||
<option name="description" />
|
||||
<option name="exitCodeBehavior" value="ERROR" />
|
||||
<option name="fileExtension" value="go" />
|
||||
<option name="immediateSync" value="false" />
|
||||
<option name="name" value="goimports" />
|
||||
<option name="output" value="$FilePath$" />
|
||||
<option name="outputFilters">
|
||||
<array />
|
||||
</option>
|
||||
<option name="outputFromStdout" value="false" />
|
||||
<option name="program" value="goimports" />
|
||||
<option name="runOnExternalChanges" value="false" />
|
||||
<option name="scopeName" value="Project Files" />
|
||||
<option name="trackOnlyRoot" value="true" />
|
||||
<option name="workingDir" value="$ProjectFileDir$" />
|
||||
<envs>
|
||||
<env name="GOROOT" value="$GOROOT$" />
|
||||
<env name="GOPATH" value="$GOPATH$" />
|
||||
<env name="PATH" value="$GoBinDirs$" />
|
||||
</envs>
|
||||
</TaskOptions>
|
||||
</component>
|
||||
</project>
|
@ -1,5 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="WEB_MODULE" version="4">
|
||||
<component name="Go">
|
||||
<buildTags>
|
||||
<option name="os" value="linux" />
|
||||
</buildTags>
|
||||
</component>
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
|
@ -10,18 +10,25 @@ import (
|
||||
)
|
||||
|
||||
func printHelp() {
|
||||
fmt.Println("wg-quick [-iface=wg0] [ up | down ] config_file")
|
||||
fmt.Print("wg-quick [flags] [ up | down | sync ] [ config_file | interface ]\n\n")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.String("iface", "", "interface")
|
||||
verbose := flag.Bool("v", false, "verbose")
|
||||
protocol := flag.Int("route-protocol", 0, "route protocol to use for our routes")
|
||||
flag.Parse()
|
||||
args := flag.Args()
|
||||
if len(args) != 2 {
|
||||
printHelp()
|
||||
}
|
||||
|
||||
if *verbose {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
}
|
||||
|
||||
iface := flag.Lookup("iface").Value.String()
|
||||
log := logrus.WithField("iface", iface)
|
||||
|
||||
@ -55,6 +62,8 @@ func main() {
|
||||
logrus.WithError(err).Fatalln("cannot parse config file")
|
||||
}
|
||||
|
||||
c.RouteProtocol = *protocol
|
||||
|
||||
switch args[0] {
|
||||
case "up":
|
||||
if err := wgquick.Up(c, iface, log); err != nil {
|
||||
@ -64,5 +73,11 @@ func main() {
|
||||
if err := wgquick.Down(c, iface, log); err != nil {
|
||||
logrus.WithError(err).Errorln("cannot down interface")
|
||||
}
|
||||
case "sync":
|
||||
if err := wgquick.Sync(c, iface, log); err != nil {
|
||||
logrus.WithError(err).Errorln("cannot sync interface")
|
||||
}
|
||||
default:
|
||||
printHelp()
|
||||
}
|
||||
}
|
||||
|
15
config.go
15
config.go
@ -157,14 +157,14 @@ func (cfg *Config) UnmarshalText(text []byte) error {
|
||||
switch state {
|
||||
case inter:
|
||||
if err := parseInterfaceLine(cfg, lhs, rhs); err != nil {
|
||||
return fmt.Errorf("[line %d]: %v", no, err)
|
||||
return fmt.Errorf("[line %d]: %v", no+1, err)
|
||||
}
|
||||
case peer:
|
||||
if err := parsePeerLine(peerCfg, lhs, rhs); err != nil {
|
||||
return fmt.Errorf("[line %d]: %v", no, err)
|
||||
return fmt.Errorf("[line %d]: %v", no+1, err)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("cannot parse line %d, unknown state", no)
|
||||
return fmt.Errorf("[line %d] cannot parse, unknown state", no+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -254,7 +254,7 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
|
||||
for _, addr := range strings.Split(rhs, ",") {
|
||||
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v", err)
|
||||
return fmt.Errorf("cannot parse %s: %v", addr, err)
|
||||
}
|
||||
peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask})
|
||||
}
|
||||
@ -264,6 +264,13 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
|
||||
return err
|
||||
}
|
||||
peerCfg.Endpoint = addr
|
||||
case "PersistentKeepalive":
|
||||
t, err := strconv.ParseInt(rhs, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dur := time.Duration(t * int64(time.Second))
|
||||
peerCfg.PersistentKeepaliveInterval = &dur
|
||||
default:
|
||||
return fmt.Errorf("unknown directive %s", lhs)
|
||||
}
|
||||
|
84
wg.go
84
wg.go
@ -3,17 +3,49 @@ package wgquick
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/mdlayher/wireguardctrl"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/mdlayher/wireguardctrl"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultRoutingTable = 254
|
||||
|
||||
// From linux/rtnetlink.h
|
||||
RTPROT_UNSPEC = 0
|
||||
RTPROT_REDIRECT = 1 /* Route installed by ICMP redirects; not used by current IPv4 */
|
||||
RTPROT_KERNEL = 2 /* Route installed by kernel */
|
||||
RTPROT_BOOT = 3 /* Route installed during boot */
|
||||
RTPROT_STATIC = 4 /* Route installed by administrator */
|
||||
|
||||
/* Values of protocol >= RTPROT_STATIC are not interpreted by kernel;
|
||||
they are just passed from user and back as is.
|
||||
It will be used by hypothetical multiple routing daemons.
|
||||
Note that protocol values should be standardized in order to
|
||||
avoid conflicts.
|
||||
*/
|
||||
|
||||
RTPROT_GATED = 8 /* Apparently, GateD */
|
||||
RTPROT_RA = 9 /* RDISC/ND router advertisements */
|
||||
RTPROT_MRT = 10 /* Merit MRT */
|
||||
RTPROT_ZEBRA = 11 /* Zebra */
|
||||
RTPROT_BIRD = 12 /* BIRD */
|
||||
RTPROT_DNROUTED = 13 /* DECnet routing daemon */
|
||||
RTPROT_XORP = 14 /* XORP */
|
||||
RTPROT_NTK = 15 /* Netsukuku */
|
||||
RTPROT_DHCP = 16 /* DHCP client */
|
||||
RTPROT_MROUTED = 17 /* Multicast daemon */
|
||||
RTPROT_BABEL = 42 /* Babel daemon */
|
||||
RTPROT_BGP = 186 /* BGP Routes */
|
||||
RTPROT_ISIS = 187 /* ISIS Routes */
|
||||
RTPROT_OSPF = 188 /* OSPF Routes */
|
||||
RTPROT_RIP = 189 /* RIP Routes */
|
||||
RTPROT_EIGRP = 192 /* EIGRP Routes */
|
||||
)
|
||||
|
||||
// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`
|
||||
@ -198,19 +230,19 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
||||
}
|
||||
|
||||
// nil addr means I've used it
|
||||
presentAddresses := make(map[string]*netlink.Addr, 0)
|
||||
presentAddresses := make(map[string]netlink.Addr, 0)
|
||||
for _, addr := range addrs {
|
||||
log.WithFields(map[string]interface{}{
|
||||
"addr": addr.IPNet.String(),
|
||||
"addr": fmt.Sprint(addr.IPNet),
|
||||
"label": addr.Label,
|
||||
}).Debugln("found existing address")
|
||||
presentAddresses[addr.IPNet.String()] = &addr
|
||||
}).Debugf("found existing address: %v", addr)
|
||||
presentAddresses[addr.IPNet.String()] = addr
|
||||
}
|
||||
|
||||
for _, addr := range cfg.Address {
|
||||
log := log.WithField("addr", addr)
|
||||
log := log.WithField("addr", addr.String())
|
||||
_, present := presentAddresses[addr.String()]
|
||||
presentAddresses[addr.String()] = nil // mark as present
|
||||
presentAddresses[addr.String()] = netlink.Addr{} // mark as present
|
||||
if present {
|
||||
log.Info("address present")
|
||||
continue
|
||||
@ -226,14 +258,14 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
||||
}
|
||||
|
||||
for _, addr := range presentAddresses {
|
||||
if addr == nil {
|
||||
if addr.IPNet == nil {
|
||||
continue
|
||||
}
|
||||
log := log.WithFields(map[string]interface{}{
|
||||
"addr": addr.IPNet.String(),
|
||||
"label": addr.Label,
|
||||
})
|
||||
if err := netlink.AddrDel(link, addr); err != nil {
|
||||
if err := netlink.AddrDel(link, &addr); err != nil {
|
||||
log.WithError(err).Error("cannot delete addr")
|
||||
return err
|
||||
}
|
||||
@ -250,7 +282,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
||||
return err
|
||||
}
|
||||
|
||||
presentRoutes := make(map[string]*netlink.Route, 0)
|
||||
presentRoutes := make(map[string]netlink.Route, 0)
|
||||
for _, r := range routes {
|
||||
log := log.WithFields(map[string]interface{}{
|
||||
"route": r.Dst.String(),
|
||||
@ -258,21 +290,26 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
||||
"table": r.Table,
|
||||
"type": r.Type,
|
||||
})
|
||||
if r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable) {
|
||||
presentRoutes[r.Dst.String()] = &r
|
||||
log.WithField("table", r.Table).Debug("detected existing route")
|
||||
} else {
|
||||
log.Debugf("detected existing route: %v", r)
|
||||
if !(r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable)) {
|
||||
log.Debug("wrong table for route, skipping")
|
||||
continue
|
||||
}
|
||||
presentRoutes[r.Dst.String()] = r
|
||||
log.Debug("added route to consideration")
|
||||
}
|
||||
|
||||
for _, peer := range cfg.Peers {
|
||||
for _, rt := range peer.AllowedIPs {
|
||||
_, present := presentRoutes[rt.String()]
|
||||
presentRoutes[rt.String()] = nil // mark as visited
|
||||
log := log.WithField("route", rt.String())
|
||||
route, present := presentRoutes[rt.String()]
|
||||
presentRoutes[rt.String()] = netlink.Route{} // mark as visited
|
||||
if present {
|
||||
log.Info("route present")
|
||||
if route.Dst != nil && route.Protocol != cfg.RouteProtocol {
|
||||
log.Warnf("route present; proto=%d != defined root proto=%d", route.Protocol, cfg.RouteProtocol)
|
||||
} else {
|
||||
log.Info("route present")
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := netlink.RouteAdd(&netlink.Route{
|
||||
@ -290,7 +327,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
||||
|
||||
// Clean extra routes
|
||||
for _, rt := range presentRoutes {
|
||||
if rt == nil { // skip visited routes
|
||||
if rt.Dst == nil { // skip visited routes
|
||||
continue
|
||||
}
|
||||
log := log.WithFields(map[string]interface{}{
|
||||
@ -300,7 +337,12 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
||||
"type": rt.Type,
|
||||
})
|
||||
log.Info("extra manual route found")
|
||||
if err := netlink.RouteDel(rt); err != nil {
|
||||
// RTPROT_BOOT is default one when other proto isn't defined
|
||||
if !(rt.Protocol == cfg.RouteProtocol || rt.Protocol == RTPROT_BOOT && cfg.RouteProtocol == 0) {
|
||||
log.Debug("skipping route deletion, not owned by this daemon")
|
||||
continue
|
||||
}
|
||||
if err := netlink.RouteDel(&rt); err != nil {
|
||||
log.WithError(err).Error("cannot setup route")
|
||||
return err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user