diff --git a/.gitignore b/.gitignore index ecfd606..82d5309 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +/build # Binaries for programs and plugins *.exe *.exe~ diff --git a/.idea/watcherTasks.xml b/.idea/watcherTasks.xml new file mode 100644 index 0000000..5243a75 --- /dev/null +++ b/.idea/watcherTasks.xml @@ -0,0 +1,29 @@ + + + + + + + + diff --git a/.idea/wg-quick-go.iml b/.idea/wg-quick-go.iml index 9ef4a9b..6972aef 100644 --- a/.idea/wg-quick-go.iml +++ b/.idea/wg-quick-go.iml @@ -1,5 +1,10 @@ + + + + diff --git a/cmd/wg-quick/main.go b/cmd/wg-quick/main.go index 4a63e2f..ab764c0 100644 --- a/cmd/wg-quick/main.go +++ b/cmd/wg-quick/main.go @@ -10,18 +10,25 @@ import ( ) func printHelp() { - fmt.Println("wg-quick [-iface=wg0] [ up | down ] config_file") + fmt.Print("wg-quick [flags] [ up | down | sync ] [ config_file | interface ]\n\n") + flag.Usage() os.Exit(1) } func main() { flag.String("iface", "", "interface") + verbose := flag.Bool("v", false, "verbose") + protocol := flag.Int("route-protocol", 0, "route protocol to use for our routes") flag.Parse() args := flag.Args() if len(args) != 2 { printHelp() } + if *verbose { + logrus.SetLevel(logrus.DebugLevel) + } + iface := flag.Lookup("iface").Value.String() log := logrus.WithField("iface", iface) @@ -55,6 +62,8 @@ func main() { logrus.WithError(err).Fatalln("cannot parse config file") } + c.RouteProtocol = *protocol + switch args[0] { case "up": if err := wgquick.Up(c, iface, log); err != nil { @@ -64,5 +73,11 @@ func main() { if err := wgquick.Down(c, iface, log); err != nil { logrus.WithError(err).Errorln("cannot down interface") } + case "sync": + if err := wgquick.Sync(c, iface, log); err != nil { + logrus.WithError(err).Errorln("cannot sync interface") + } + default: + printHelp() } } diff --git a/config.go b/config.go index 65f395a..9cbb353 100644 --- a/config.go +++ b/config.go @@ -157,14 +157,14 @@ func (cfg *Config) UnmarshalText(text []byte) error { switch state { case inter: if err := parseInterfaceLine(cfg, lhs, rhs); err != nil { - return fmt.Errorf("[line %d]: %v", no, err) + return fmt.Errorf("[line %d]: %v", no+1, err) } case peer: if err := parsePeerLine(peerCfg, lhs, rhs); err != nil { - return fmt.Errorf("[line %d]: %v", no, err) + return fmt.Errorf("[line %d]: %v", no+1, err) } default: - return fmt.Errorf("cannot parse line %d, unknown state", no) + return fmt.Errorf("[line %d] cannot parse, unknown state", no+1) } } } @@ -254,7 +254,7 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error { for _, addr := range strings.Split(rhs, ",") { ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr)) if err != nil { - return fmt.Errorf("%v", err) + return fmt.Errorf("cannot parse %s: %v", addr, err) } peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask}) } @@ -264,6 +264,13 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error { return err } peerCfg.Endpoint = addr + case "PersistentKeepalive": + t, err := strconv.ParseInt(rhs, 10, 64) + if err != nil { + return err + } + dur := time.Duration(t * int64(time.Second)) + peerCfg.PersistentKeepaliveInterval = &dur default: return fmt.Errorf("unknown directive %s", lhs) } diff --git a/wg.go b/wg.go index 135b677..cbf6aa4 100644 --- a/wg.go +++ b/wg.go @@ -3,17 +3,49 @@ package wgquick import ( "bytes" "fmt" - "github.com/mdlayher/wireguardctrl" - "github.com/sirupsen/logrus" - "github.com/vishvananda/netlink" "os" "os/exec" "strings" "syscall" + + "github.com/mdlayher/wireguardctrl" + "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" ) const ( defaultRoutingTable = 254 + + // From linux/rtnetlink.h + RTPROT_UNSPEC = 0 + RTPROT_REDIRECT = 1 /* Route installed by ICMP redirects; not used by current IPv4 */ + RTPROT_KERNEL = 2 /* Route installed by kernel */ + RTPROT_BOOT = 3 /* Route installed during boot */ + RTPROT_STATIC = 4 /* Route installed by administrator */ + + /* Values of protocol >= RTPROT_STATIC are not interpreted by kernel; + they are just passed from user and back as is. + It will be used by hypothetical multiple routing daemons. + Note that protocol values should be standardized in order to + avoid conflicts. + */ + + RTPROT_GATED = 8 /* Apparently, GateD */ + RTPROT_RA = 9 /* RDISC/ND router advertisements */ + RTPROT_MRT = 10 /* Merit MRT */ + RTPROT_ZEBRA = 11 /* Zebra */ + RTPROT_BIRD = 12 /* BIRD */ + RTPROT_DNROUTED = 13 /* DECnet routing daemon */ + RTPROT_XORP = 14 /* XORP */ + RTPROT_NTK = 15 /* Netsukuku */ + RTPROT_DHCP = 16 /* DHCP client */ + RTPROT_MROUTED = 17 /* Multicast daemon */ + RTPROT_BABEL = 42 /* Babel daemon */ + RTPROT_BGP = 186 /* BGP Routes */ + RTPROT_ISIS = 187 /* ISIS Routes */ + RTPROT_OSPF = 188 /* OSPF Routes */ + RTPROT_RIP = 189 /* RIP Routes */ + RTPROT_EIGRP = 192 /* EIGRP Routes */ ) // Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface` @@ -198,19 +230,19 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { } // nil addr means I've used it - presentAddresses := make(map[string]*netlink.Addr, 0) + presentAddresses := make(map[string]netlink.Addr, 0) for _, addr := range addrs { log.WithFields(map[string]interface{}{ - "addr": addr.IPNet.String(), + "addr": fmt.Sprint(addr.IPNet), "label": addr.Label, - }).Debugln("found existing address") - presentAddresses[addr.IPNet.String()] = &addr + }).Debugf("found existing address: %v", addr) + presentAddresses[addr.IPNet.String()] = addr } for _, addr := range cfg.Address { - log := log.WithField("addr", addr) + log := log.WithField("addr", addr.String()) _, present := presentAddresses[addr.String()] - presentAddresses[addr.String()] = nil // mark as present + presentAddresses[addr.String()] = netlink.Addr{} // mark as present if present { log.Info("address present") continue @@ -226,14 +258,14 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { } for _, addr := range presentAddresses { - if addr == nil { + if addr.IPNet == nil { continue } log := log.WithFields(map[string]interface{}{ "addr": addr.IPNet.String(), "label": addr.Label, }) - if err := netlink.AddrDel(link, addr); err != nil { + if err := netlink.AddrDel(link, &addr); err != nil { log.WithError(err).Error("cannot delete addr") return err } @@ -250,7 +282,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { return err } - presentRoutes := make(map[string]*netlink.Route, 0) + presentRoutes := make(map[string]netlink.Route, 0) for _, r := range routes { log := log.WithFields(map[string]interface{}{ "route": r.Dst.String(), @@ -258,21 +290,26 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { "table": r.Table, "type": r.Type, }) - if r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable) { - presentRoutes[r.Dst.String()] = &r - log.WithField("table", r.Table).Debug("detected existing route") - } else { + log.Debugf("detected existing route: %v", r) + if !(r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable)) { log.Debug("wrong table for route, skipping") + continue } + presentRoutes[r.Dst.String()] = r + log.Debug("added route to consideration") } for _, peer := range cfg.Peers { for _, rt := range peer.AllowedIPs { - _, present := presentRoutes[rt.String()] - presentRoutes[rt.String()] = nil // mark as visited log := log.WithField("route", rt.String()) + route, present := presentRoutes[rt.String()] + presentRoutes[rt.String()] = netlink.Route{} // mark as visited if present { - log.Info("route present") + if route.Dst != nil && route.Protocol != cfg.RouteProtocol { + log.Warnf("route present; proto=%d != defined root proto=%d", route.Protocol, cfg.RouteProtocol) + } else { + log.Info("route present") + } continue } if err := netlink.RouteAdd(&netlink.Route{ @@ -290,7 +327,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { // Clean extra routes for _, rt := range presentRoutes { - if rt == nil { // skip visited routes + if rt.Dst == nil { // skip visited routes continue } log := log.WithFields(map[string]interface{}{ @@ -300,7 +337,12 @@ func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { "type": rt.Type, }) log.Info("extra manual route found") - if err := netlink.RouteDel(rt); err != nil { + // RTPROT_BOOT is default one when other proto isn't defined + if !(rt.Protocol == cfg.RouteProtocol || rt.Protocol == RTPROT_BOOT && cfg.RouteProtocol == 0) { + log.Debug("skipping route deletion, not owned by this daemon") + continue + } + if err := netlink.RouteDel(&rt); err != nil { log.WithError(err).Error("cannot setup route") return err }