Added route protocol filtering

This commit is contained in:
Neven Miculinic 2019-03-29 10:48:33 +01:00
parent 8f048ca6d8
commit 9a0ad14ca6
6 changed files with 125 additions and 26 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
/build
# Binaries for programs and plugins
*.exe
*.exe~

29
.idea/watcherTasks.xml Normal file
View File

@ -0,0 +1,29 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectTasksOptions">
<TaskOptions isEnabled="true">
<option name="arguments" value="-w $FilePath$" />
<option name="checkSyntaxErrors" value="true" />
<option name="description" />
<option name="exitCodeBehavior" value="ERROR" />
<option name="fileExtension" value="go" />
<option name="immediateSync" value="false" />
<option name="name" value="goimports" />
<option name="output" value="$FilePath$" />
<option name="outputFilters">
<array />
</option>
<option name="outputFromStdout" value="false" />
<option name="program" value="goimports" />
<option name="runOnExternalChanges" value="false" />
<option name="scopeName" value="Project Files" />
<option name="trackOnlyRoot" value="true" />
<option name="workingDir" value="$ProjectFileDir$" />
<envs>
<env name="GOROOT" value="$GOROOT$" />
<env name="GOPATH" value="$GOPATH$" />
<env name="PATH" value="$GoBinDirs$" />
</envs>
</TaskOptions>
</component>
</project>

View File

@ -1,5 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="Go">
<buildTags>
<option name="os" value="linux" />
</buildTags>
</component>
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />

View File

@ -10,18 +10,25 @@ import (
)
func printHelp() {
fmt.Println("wg-quick [-iface=wg0] [ up | down ] config_file")
fmt.Print("wg-quick [flags] [ up | down | sync ] [ config_file | interface ]\n\n")
flag.Usage()
os.Exit(1)
}
func main() {
flag.String("iface", "", "interface")
verbose := flag.Bool("v", false, "verbose")
protocol := flag.Int("route-protocol", 0, "route protocol to use for our routes")
flag.Parse()
args := flag.Args()
if len(args) != 2 {
printHelp()
}
if *verbose {
logrus.SetLevel(logrus.DebugLevel)
}
iface := flag.Lookup("iface").Value.String()
log := logrus.WithField("iface", iface)
@ -55,6 +62,8 @@ func main() {
logrus.WithError(err).Fatalln("cannot parse config file")
}
c.RouteProtocol = *protocol
switch args[0] {
case "up":
if err := wgquick.Up(c, iface, log); err != nil {
@ -64,5 +73,11 @@ func main() {
if err := wgquick.Down(c, iface, log); err != nil {
logrus.WithError(err).Errorln("cannot down interface")
}
case "sync":
if err := wgquick.Sync(c, iface, log); err != nil {
logrus.WithError(err).Errorln("cannot sync interface")
}
default:
printHelp()
}
}

View File

@ -157,14 +157,14 @@ func (cfg *Config) UnmarshalText(text []byte) error {
switch state {
case inter:
if err := parseInterfaceLine(cfg, lhs, rhs); err != nil {
return fmt.Errorf("[line %d]: %v", no, err)
return fmt.Errorf("[line %d]: %v", no+1, err)
}
case peer:
if err := parsePeerLine(peerCfg, lhs, rhs); err != nil {
return fmt.Errorf("[line %d]: %v", no, err)
return fmt.Errorf("[line %d]: %v", no+1, err)
}
default:
return fmt.Errorf("cannot parse line %d, unknown state", no)
return fmt.Errorf("[line %d] cannot parse, unknown state", no+1)
}
}
}
@ -254,7 +254,7 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
for _, addr := range strings.Split(rhs, ",") {
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
if err != nil {
return fmt.Errorf("%v", err)
return fmt.Errorf("cannot parse %s: %v", addr, err)
}
peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask})
}
@ -264,6 +264,13 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
return err
}
peerCfg.Endpoint = addr
case "PersistentKeepalive":
t, err := strconv.ParseInt(rhs, 10, 64)
if err != nil {
return err
}
dur := time.Duration(t * int64(time.Second))
peerCfg.PersistentKeepaliveInterval = &dur
default:
return fmt.Errorf("unknown directive %s", lhs)
}

84
wg.go
View File

@ -3,17 +3,49 @@ package wgquick
import (
"bytes"
"fmt"
"github.com/mdlayher/wireguardctrl"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"os"
"os/exec"
"strings"
"syscall"
"github.com/mdlayher/wireguardctrl"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
)
const (
defaultRoutingTable = 254
// From linux/rtnetlink.h
RTPROT_UNSPEC = 0
RTPROT_REDIRECT = 1 /* Route installed by ICMP redirects; not used by current IPv4 */
RTPROT_KERNEL = 2 /* Route installed by kernel */
RTPROT_BOOT = 3 /* Route installed during boot */
RTPROT_STATIC = 4 /* Route installed by administrator */
/* Values of protocol >= RTPROT_STATIC are not interpreted by kernel;
they are just passed from user and back as is.
It will be used by hypothetical multiple routing daemons.
Note that protocol values should be standardized in order to
avoid conflicts.
*/
RTPROT_GATED = 8 /* Apparently, GateD */
RTPROT_RA = 9 /* RDISC/ND router advertisements */
RTPROT_MRT = 10 /* Merit MRT */
RTPROT_ZEBRA = 11 /* Zebra */
RTPROT_BIRD = 12 /* BIRD */
RTPROT_DNROUTED = 13 /* DECnet routing daemon */
RTPROT_XORP = 14 /* XORP */
RTPROT_NTK = 15 /* Netsukuku */
RTPROT_DHCP = 16 /* DHCP client */
RTPROT_MROUTED = 17 /* Multicast daemon */
RTPROT_BABEL = 42 /* Babel daemon */
RTPROT_BGP = 186 /* BGP Routes */
RTPROT_ISIS = 187 /* ISIS Routes */
RTPROT_OSPF = 188 /* OSPF Routes */
RTPROT_RIP = 189 /* RIP Routes */
RTPROT_EIGRP = 192 /* EIGRP Routes */
)
// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`
@ -198,19 +230,19 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
}
// nil addr means I've used it
presentAddresses := make(map[string]*netlink.Addr, 0)
presentAddresses := make(map[string]netlink.Addr, 0)
for _, addr := range addrs {
log.WithFields(map[string]interface{}{
"addr": addr.IPNet.String(),
"addr": fmt.Sprint(addr.IPNet),
"label": addr.Label,
}).Debugln("found existing address")
presentAddresses[addr.IPNet.String()] = &addr
}).Debugf("found existing address: %v", addr)
presentAddresses[addr.IPNet.String()] = addr
}
for _, addr := range cfg.Address {
log := log.WithField("addr", addr)
log := log.WithField("addr", addr.String())
_, present := presentAddresses[addr.String()]
presentAddresses[addr.String()] = nil // mark as present
presentAddresses[addr.String()] = netlink.Addr{} // mark as present
if present {
log.Info("address present")
continue
@ -226,14 +258,14 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
}
for _, addr := range presentAddresses {
if addr == nil {
if addr.IPNet == nil {
continue
}
log := log.WithFields(map[string]interface{}{
"addr": addr.IPNet.String(),
"label": addr.Label,
})
if err := netlink.AddrDel(link, addr); err != nil {
if err := netlink.AddrDel(link, &addr); err != nil {
log.WithError(err).Error("cannot delete addr")
return err
}
@ -250,7 +282,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
return err
}
presentRoutes := make(map[string]*netlink.Route, 0)
presentRoutes := make(map[string]netlink.Route, 0)
for _, r := range routes {
log := log.WithFields(map[string]interface{}{
"route": r.Dst.String(),
@ -258,21 +290,26 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
"table": r.Table,
"type": r.Type,
})
if r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable) {
presentRoutes[r.Dst.String()] = &r
log.WithField("table", r.Table).Debug("detected existing route")
} else {
log.Debugf("detected existing route: %v", r)
if !(r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable)) {
log.Debug("wrong table for route, skipping")
continue
}
presentRoutes[r.Dst.String()] = r
log.Debug("added route to consideration")
}
for _, peer := range cfg.Peers {
for _, rt := range peer.AllowedIPs {
_, present := presentRoutes[rt.String()]
presentRoutes[rt.String()] = nil // mark as visited
log := log.WithField("route", rt.String())
route, present := presentRoutes[rt.String()]
presentRoutes[rt.String()] = netlink.Route{} // mark as visited
if present {
log.Info("route present")
if route.Dst != nil && route.Protocol != cfg.RouteProtocol {
log.Warnf("route present; proto=%d != defined root proto=%d", route.Protocol, cfg.RouteProtocol)
} else {
log.Info("route present")
}
continue
}
if err := netlink.RouteAdd(&netlink.Route{
@ -290,7 +327,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
// Clean extra routes
for _, rt := range presentRoutes {
if rt == nil { // skip visited routes
if rt.Dst == nil { // skip visited routes
continue
}
log := log.WithFields(map[string]interface{}{
@ -300,7 +337,12 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
"type": rt.Type,
})
log.Info("extra manual route found")
if err := netlink.RouteDel(rt); err != nil {
// RTPROT_BOOT is default one when other proto isn't defined
if !(rt.Protocol == cfg.RouteProtocol || rt.Protocol == RTPROT_BOOT && cfg.RouteProtocol == 0) {
log.Debug("skipping route deletion, not owned by this daemon")
continue
}
if err := netlink.RouteDel(&rt); err != nil {
log.WithError(err).Error("cannot setup route")
return err
}