Added route protocol filtering

This commit is contained in:
Neven Miculinic 2019-03-29 10:48:33 +01:00
parent 8f048ca6d8
commit 9a0ad14ca6
6 changed files with 125 additions and 26 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
/build
# Binaries for programs and plugins # Binaries for programs and plugins
*.exe *.exe
*.exe~ *.exe~

29
.idea/watcherTasks.xml Normal file
View File

@ -0,0 +1,29 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectTasksOptions">
<TaskOptions isEnabled="true">
<option name="arguments" value="-w $FilePath$" />
<option name="checkSyntaxErrors" value="true" />
<option name="description" />
<option name="exitCodeBehavior" value="ERROR" />
<option name="fileExtension" value="go" />
<option name="immediateSync" value="false" />
<option name="name" value="goimports" />
<option name="output" value="$FilePath$" />
<option name="outputFilters">
<array />
</option>
<option name="outputFromStdout" value="false" />
<option name="program" value="goimports" />
<option name="runOnExternalChanges" value="false" />
<option name="scopeName" value="Project Files" />
<option name="trackOnlyRoot" value="true" />
<option name="workingDir" value="$ProjectFileDir$" />
<envs>
<env name="GOROOT" value="$GOROOT$" />
<env name="GOPATH" value="$GOPATH$" />
<env name="PATH" value="$GoBinDirs$" />
</envs>
</TaskOptions>
</component>
</project>

View File

@ -1,5 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4"> <module type="WEB_MODULE" version="4">
<component name="Go">
<buildTags>
<option name="os" value="linux" />
</buildTags>
</component>
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" /> <orderEntry type="inheritedJdk" />

View File

@ -10,18 +10,25 @@ import (
) )
func printHelp() { func printHelp() {
fmt.Println("wg-quick [-iface=wg0] [ up | down ] config_file") fmt.Print("wg-quick [flags] [ up | down | sync ] [ config_file | interface ]\n\n")
flag.Usage()
os.Exit(1) os.Exit(1)
} }
func main() { func main() {
flag.String("iface", "", "interface") flag.String("iface", "", "interface")
verbose := flag.Bool("v", false, "verbose")
protocol := flag.Int("route-protocol", 0, "route protocol to use for our routes")
flag.Parse() flag.Parse()
args := flag.Args() args := flag.Args()
if len(args) != 2 { if len(args) != 2 {
printHelp() printHelp()
} }
if *verbose {
logrus.SetLevel(logrus.DebugLevel)
}
iface := flag.Lookup("iface").Value.String() iface := flag.Lookup("iface").Value.String()
log := logrus.WithField("iface", iface) log := logrus.WithField("iface", iface)
@ -55,6 +62,8 @@ func main() {
logrus.WithError(err).Fatalln("cannot parse config file") logrus.WithError(err).Fatalln("cannot parse config file")
} }
c.RouteProtocol = *protocol
switch args[0] { switch args[0] {
case "up": case "up":
if err := wgquick.Up(c, iface, log); err != nil { if err := wgquick.Up(c, iface, log); err != nil {
@ -64,5 +73,11 @@ func main() {
if err := wgquick.Down(c, iface, log); err != nil { if err := wgquick.Down(c, iface, log); err != nil {
logrus.WithError(err).Errorln("cannot down interface") logrus.WithError(err).Errorln("cannot down interface")
} }
case "sync":
if err := wgquick.Sync(c, iface, log); err != nil {
logrus.WithError(err).Errorln("cannot sync interface")
}
default:
printHelp()
} }
} }

View File

@ -157,14 +157,14 @@ func (cfg *Config) UnmarshalText(text []byte) error {
switch state { switch state {
case inter: case inter:
if err := parseInterfaceLine(cfg, lhs, rhs); err != nil { if err := parseInterfaceLine(cfg, lhs, rhs); err != nil {
return fmt.Errorf("[line %d]: %v", no, err) return fmt.Errorf("[line %d]: %v", no+1, err)
} }
case peer: case peer:
if err := parsePeerLine(peerCfg, lhs, rhs); err != nil { if err := parsePeerLine(peerCfg, lhs, rhs); err != nil {
return fmt.Errorf("[line %d]: %v", no, err) return fmt.Errorf("[line %d]: %v", no+1, err)
} }
default: default:
return fmt.Errorf("cannot parse line %d, unknown state", no) return fmt.Errorf("[line %d] cannot parse, unknown state", no+1)
} }
} }
} }
@ -254,7 +254,7 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
for _, addr := range strings.Split(rhs, ",") { for _, addr := range strings.Split(rhs, ",") {
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr)) ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
if err != nil { if err != nil {
return fmt.Errorf("%v", err) return fmt.Errorf("cannot parse %s: %v", addr, err)
} }
peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask}) peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask})
} }
@ -264,6 +264,13 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
return err return err
} }
peerCfg.Endpoint = addr peerCfg.Endpoint = addr
case "PersistentKeepalive":
t, err := strconv.ParseInt(rhs, 10, 64)
if err != nil {
return err
}
dur := time.Duration(t * int64(time.Second))
peerCfg.PersistentKeepaliveInterval = &dur
default: default:
return fmt.Errorf("unknown directive %s", lhs) return fmt.Errorf("unknown directive %s", lhs)
} }

84
wg.go
View File

@ -3,17 +3,49 @@ package wgquick
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/mdlayher/wireguardctrl"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"syscall" "syscall"
"github.com/mdlayher/wireguardctrl"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
) )
const ( const (
defaultRoutingTable = 254 defaultRoutingTable = 254
// From linux/rtnetlink.h
RTPROT_UNSPEC = 0
RTPROT_REDIRECT = 1 /* Route installed by ICMP redirects; not used by current IPv4 */
RTPROT_KERNEL = 2 /* Route installed by kernel */
RTPROT_BOOT = 3 /* Route installed during boot */
RTPROT_STATIC = 4 /* Route installed by administrator */
/* Values of protocol >= RTPROT_STATIC are not interpreted by kernel;
they are just passed from user and back as is.
It will be used by hypothetical multiple routing daemons.
Note that protocol values should be standardized in order to
avoid conflicts.
*/
RTPROT_GATED = 8 /* Apparently, GateD */
RTPROT_RA = 9 /* RDISC/ND router advertisements */
RTPROT_MRT = 10 /* Merit MRT */
RTPROT_ZEBRA = 11 /* Zebra */
RTPROT_BIRD = 12 /* BIRD */
RTPROT_DNROUTED = 13 /* DECnet routing daemon */
RTPROT_XORP = 14 /* XORP */
RTPROT_NTK = 15 /* Netsukuku */
RTPROT_DHCP = 16 /* DHCP client */
RTPROT_MROUTED = 17 /* Multicast daemon */
RTPROT_BABEL = 42 /* Babel daemon */
RTPROT_BGP = 186 /* BGP Routes */
RTPROT_ISIS = 187 /* ISIS Routes */
RTPROT_OSPF = 188 /* OSPF Routes */
RTPROT_RIP = 189 /* RIP Routes */
RTPROT_EIGRP = 192 /* EIGRP Routes */
) )
// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface` // Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`
@ -198,19 +230,19 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
} }
// nil addr means I've used it // nil addr means I've used it
presentAddresses := make(map[string]*netlink.Addr, 0) presentAddresses := make(map[string]netlink.Addr, 0)
for _, addr := range addrs { for _, addr := range addrs {
log.WithFields(map[string]interface{}{ log.WithFields(map[string]interface{}{
"addr": addr.IPNet.String(), "addr": fmt.Sprint(addr.IPNet),
"label": addr.Label, "label": addr.Label,
}).Debugln("found existing address") }).Debugf("found existing address: %v", addr)
presentAddresses[addr.IPNet.String()] = &addr presentAddresses[addr.IPNet.String()] = addr
} }
for _, addr := range cfg.Address { for _, addr := range cfg.Address {
log := log.WithField("addr", addr) log := log.WithField("addr", addr.String())
_, present := presentAddresses[addr.String()] _, present := presentAddresses[addr.String()]
presentAddresses[addr.String()] = nil // mark as present presentAddresses[addr.String()] = netlink.Addr{} // mark as present
if present { if present {
log.Info("address present") log.Info("address present")
continue continue
@ -226,14 +258,14 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
} }
for _, addr := range presentAddresses { for _, addr := range presentAddresses {
if addr == nil { if addr.IPNet == nil {
continue continue
} }
log := log.WithFields(map[string]interface{}{ log := log.WithFields(map[string]interface{}{
"addr": addr.IPNet.String(), "addr": addr.IPNet.String(),
"label": addr.Label, "label": addr.Label,
}) })
if err := netlink.AddrDel(link, addr); err != nil { if err := netlink.AddrDel(link, &addr); err != nil {
log.WithError(err).Error("cannot delete addr") log.WithError(err).Error("cannot delete addr")
return err return err
} }
@ -250,7 +282,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
return err return err
} }
presentRoutes := make(map[string]*netlink.Route, 0) presentRoutes := make(map[string]netlink.Route, 0)
for _, r := range routes { for _, r := range routes {
log := log.WithFields(map[string]interface{}{ log := log.WithFields(map[string]interface{}{
"route": r.Dst.String(), "route": r.Dst.String(),
@ -258,21 +290,26 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
"table": r.Table, "table": r.Table,
"type": r.Type, "type": r.Type,
}) })
if r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable) { log.Debugf("detected existing route: %v", r)
presentRoutes[r.Dst.String()] = &r if !(r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable)) {
log.WithField("table", r.Table).Debug("detected existing route")
} else {
log.Debug("wrong table for route, skipping") log.Debug("wrong table for route, skipping")
continue
} }
presentRoutes[r.Dst.String()] = r
log.Debug("added route to consideration")
} }
for _, peer := range cfg.Peers { for _, peer := range cfg.Peers {
for _, rt := range peer.AllowedIPs { for _, rt := range peer.AllowedIPs {
_, present := presentRoutes[rt.String()]
presentRoutes[rt.String()] = nil // mark as visited
log := log.WithField("route", rt.String()) log := log.WithField("route", rt.String())
route, present := presentRoutes[rt.String()]
presentRoutes[rt.String()] = netlink.Route{} // mark as visited
if present { if present {
log.Info("route present") if route.Dst != nil && route.Protocol != cfg.RouteProtocol {
log.Warnf("route present; proto=%d != defined root proto=%d", route.Protocol, cfg.RouteProtocol)
} else {
log.Info("route present")
}
continue continue
} }
if err := netlink.RouteAdd(&netlink.Route{ if err := netlink.RouteAdd(&netlink.Route{
@ -290,7 +327,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
// Clean extra routes // Clean extra routes
for _, rt := range presentRoutes { for _, rt := range presentRoutes {
if rt == nil { // skip visited routes if rt.Dst == nil { // skip visited routes
continue continue
} }
log := log.WithFields(map[string]interface{}{ log := log.WithFields(map[string]interface{}{
@ -300,7 +337,12 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
"type": rt.Type, "type": rt.Type,
}) })
log.Info("extra manual route found") log.Info("extra manual route found")
if err := netlink.RouteDel(rt); err != nil { // RTPROT_BOOT is default one when other proto isn't defined
if !(rt.Protocol == cfg.RouteProtocol || rt.Protocol == RTPROT_BOOT && cfg.RouteProtocol == 0) {
log.Debug("skipping route deletion, not owned by this daemon")
continue
}
if err := netlink.RouteDel(&rt); err != nil {
log.WithError(err).Error("cannot setup route") log.WithError(err).Error("cannot setup route")
return err return err
} }