wg-quicker/wg.go
2019-03-28 14:22:30 +01:00

292 lines
7.4 KiB
Go

package wgquick
import (
"bytes"
"fmt"
"github.com/mdlayher/wireguardctrl"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"os"
"os/exec"
"strings"
"syscall"
)
const (
defaultRoutingTable = 254
)
// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`
func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface)
_, err := netlink.LinkByName(iface)
if err == nil {
return os.ErrExist
}
if _, ok := err.(netlink.LinkNotFoundError); !ok {
return err
}
for _, dns := range cfg.DNS {
if err := execSh("resolvconf -a tun.%i -m 0 -x", iface, log, fmt.Sprintf("nameserver %s\n", dns)); err != nil {
return err
}
}
if cfg.PreUp != "" {
if err := execSh(cfg.PreUp, iface, log); err != nil {
return err
}
log.Infoln("applied pre-up command")
}
if err := Sync(cfg, iface, logger); err != nil {
return err
}
if cfg.PostUp != "" {
if err := execSh(cfg.PostUp, iface, log); err != nil {
return err
}
log.Infoln("applied post-up command")
}
return nil
}
// Down destroys the wg interface. Mostly equivalent to `wg-quick down iface`
func Down(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface)
link, err := netlink.LinkByName(iface)
if err != nil {
return err
}
if len(cfg.DNS) > 1 {
if err := execSh("resolvconf -d tun.%s", iface, log); err != nil {
return err
}
}
if cfg.PreDown != "" {
if err := execSh(cfg.PreDown, iface, log); err != nil {
return err
}
log.Infoln("applied pre-down command")
}
if err := netlink.LinkDel(link); err != nil {
return err
}
log.Infoln("link deleted")
if cfg.PostDown != "" {
if err := execSh(cfg.PostDown, iface, log); err != nil {
return err
}
log.Infoln("applied post-down command")
}
return nil
}
func execSh(command string, iface string, log logrus.FieldLogger, stdin ...string) error {
cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface))
if len(stdin) > 0 {
log = log.WithField("stdin", strings.Join(stdin, ""))
b := &bytes.Buffer{}
for _, ln := range stdin {
if _, err := fmt.Fprint(b, ln); err != nil {
return err
}
}
cmd.Stdin = b
}
out, err := cmd.CombinedOutput()
if err != nil {
log.WithError(err).Errorf("failed to execute %s:\n%s", cmd.Args, out)
return err
}
log.Infof("executed %s:\n%s", cmd.Args, out)
return nil
}
// Sync the config to the current setup for given interface
func Sync(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface)
link, err := SyncLink(cfg, iface, log)
if err != nil {
log.WithError(err).Errorln("cannot sync wireguard link")
return err
}
log.Info("synced link")
if err := SyncWireguardDevice(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync wireguard link")
return err
}
log.Info("synced link")
if err := SyncAddress(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync addresses")
return err
}
log.Info("synced addresss")
if err := SyncRoutes(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync routes")
return err
}
log.Info("synced routed")
log.Info("Successfully synced device")
return nil
}
// SyncWireguardDevice synces wireguard vpn setting on the given link. It does not set routes/addresses beyond wg internal crypto-key routing
func SyncWireguardDevice(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
cl, err := wireguardctrl.New()
if err != nil {
log.WithError(err).Errorln("cannot setup wireguard device")
return err
}
if err := cl.ConfigureDevice(link.Attrs().Name, cfg.Config); err != nil {
log.WithError(err).Error("cannot configure device")
return err
}
return nil
}
// SyncLink synces link state with the config. It does not sync Wireguard settings, just makes sure the device is up and type wireguard
func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link, error) {
link, err := netlink.LinkByName(iface)
if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); !ok {
log.WithError(err).Error("cannot read link")
return nil, err
}
log.Info("link not found, creating")
wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{
Name: iface,
MTU: cfg.MTU,
},
LinkType: "wireguard",
}
if err := netlink.LinkAdd(wgLink); err != nil {
log.WithError(err).Error("cannot create link")
return nil, err
}
link, err = netlink.LinkByName(iface)
if err != nil {
log.WithError(err).Error("cannot read link")
return nil, err
}
}
if err := netlink.LinkSetUp(link); err != nil {
log.WithError(err).Error("cannot set link up")
return nil, err
}
log.Info("set device up")
return link, nil
}
// SyncAddress adds/deletes all lind assigned IPV4 addressed as specified in the config
func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
addrs, err := netlink.AddrList(link, syscall.AF_INET)
if err != nil {
log.Error(err, "cannot read link address")
return err
}
// nil addr means I've used it
presentAddresses := make(map[string]*netlink.Addr, 0)
for _, addr := range addrs {
presentAddresses[addr.IPNet.String()] = &addr
}
for _, addr := range cfg.Address {
log := log.WithField("addr", addr)
_, present := presentAddresses[addr.String()]
presentAddresses[addr.String()] = nil // mark as present
if present {
log.Info("address present")
continue
}
if err := netlink.AddrAdd(link, &netlink.Addr{
IPNet: &addr,
}); err != nil {
log.WithError(err).Error("cannot add addr")
return err
}
log.Info("address added")
}
for _, addr := range presentAddresses {
if addr == nil {
continue
}
log := log.WithField("addr", addr.IPNet.String())
if err := netlink.AddrDel(link, addr); err != nil {
log.WithError(err).Error("cannot delete addr")
return err
}
log.Info("addr deleted")
}
return nil
}
// SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config
func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
routes, err := netlink.RouteList(link, syscall.AF_INET)
if err != nil {
log.Error(err, "cannot read existing routes")
return err
}
presentRoutes := make(map[string]*netlink.Route, 0)
for _, r := range routes {
log := log.WithField("route", r.Dst.String())
if r.Table == cfg.Table || (cfg.Table == 0 && r.Table == defaultRoutingTable) {
presentRoutes[r.Dst.String()] = &r
log.WithField("table", r.Table).Debug("detected existing route")
} else {
log.Debug("wrong table for route, skipping")
}
}
for _, peer := range cfg.Peers {
for _, rt := range peer.AllowedIPs {
_, present := presentRoutes[rt.String()]
presentRoutes[rt.String()] = nil // mark as visited
log := log.WithField("route", rt.String())
if present {
log.Info("route present")
continue
}
if err := netlink.RouteAdd(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: &rt,
Table: cfg.Table,
}); err != nil {
log.WithError(err).Error("cannot setup route")
return err
}
log.Info("route added")
}
}
// Clean extra routes
for _, rt := range presentRoutes {
if rt == nil { // skip visited routes
continue
}
log := log.WithField("route", rt.Dst.String())
log.Info("extra manual route found")
if err := netlink.RouteDel(rt); err != nil {
log.WithError(err).Error("cannot setup route")
return err
}
log.Info("route deleted")
}
return nil
}