wg-quicker/wgquick/wg.go

469 lines
11 KiB
Go
Raw Normal View History

2021-03-08 09:32:33 +01:00
// nolint: errorlint, cyclop
2019-03-27 17:26:09 +01:00
package wgquick
2019-03-26 13:44:38 +01:00
import (
"bytes"
2021-01-21 10:25:04 +01:00
"errors"
"fmt"
"net"
"os"
"os/exec"
"strings"
2019-03-27 17:26:09 +01:00
"syscall"
"time"
2019-03-29 10:48:33 +01:00
"github.com/getlantern/byteexec"
2019-03-29 10:48:33 +01:00
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"go.xsfx.dev/wg-quicker/assets"
"golang.org/x/sys/unix"
2019-10-30 18:04:10 +01:00
"golang.zx2c4.com/wireguard/wgctrl"
2019-03-27 17:40:06 +01:00
)
2021-01-25 11:06:27 +01:00
// userspace runs a embedded wireguard-go for interface creation.
func userspace(iface string) error {
wgob, err := assets.Asset("third_party/wireguard-go/wireguard-go")
if err != nil {
return fmt.Errorf("cannot get wireguard-go: %w", err)
}
wgo, err := byteexec.New(wgob, "wireguard-go")
if err != nil {
return fmt.Errorf("unable to build byteexec for wireguard-go: %w", err)
}
cmd := wgo.Command(iface)
2021-01-21 10:25:04 +01:00
if err := cmd.Start(); err != nil {
return fmt.Errorf("could not start wireguard-go: %w", err)
}
return nil
}
2021-01-21 10:25:04 +01:00
// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`.
2021-01-25 11:06:27 +01:00
func Up(cfg *Config, iface string, uspace bool, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface)
2021-01-21 10:25:04 +01:00
_, err := netlink.LinkByName(iface)
if err == nil {
return os.ErrExist
}
2021-01-21 10:25:04 +01:00
2021-01-21 11:54:26 +01:00
if _, ok := err.(netlink.LinkNotFoundError); !ok {
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
}
for _, dns := range cfg.DNS {
if err := execSh("resolvconf -a tun.%i -m 0 -x", iface, log, fmt.Sprintf("nameserver %s\n", dns)); err != nil {
return err
}
}
2019-03-28 13:58:11 +01:00
if cfg.PreUp != "" {
if err := execSh(cfg.PreUp, iface, log); err != nil {
return err
}
2021-01-21 10:25:04 +01:00
2019-03-28 13:58:11 +01:00
log.Infoln("applied pre-up command")
}
2021-01-21 10:25:04 +01:00
2021-01-25 11:06:27 +01:00
if err := Sync(cfg, iface, uspace, logger); err != nil {
return err
}
2019-03-28 13:58:11 +01:00
if cfg.PostUp != "" {
if err := execSh(cfg.PostUp, iface, log); err != nil {
return err
}
2021-01-21 10:25:04 +01:00
2019-03-28 13:58:11 +01:00
log.Infoln("applied post-up command")
}
2021-01-21 10:25:04 +01:00
return nil
}
2021-01-21 10:25:04 +01:00
// Down destroys the wg interface. Mostly equivalent to `wg-quick down iface`.
func Down(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface)
2021-01-21 10:25:04 +01:00
link, err := netlink.LinkByName(iface)
if err != nil {
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
}
if len(cfg.DNS) > 1 {
if err := execSh("resolvconf -d tun.%s", iface, log); err != nil {
return err
}
}
2019-03-28 13:58:11 +01:00
if cfg.PreDown != "" {
if err := execSh(cfg.PreDown, iface, log); err != nil {
return err
}
2021-01-21 10:25:04 +01:00
2019-03-28 13:58:11 +01:00
log.Infoln("applied pre-down command")
}
2019-03-28 13:58:11 +01:00
if err := netlink.LinkDel(link); err != nil {
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
}
2021-01-21 10:25:04 +01:00
log.Infoln("link deleted")
2021-01-21 10:25:04 +01:00
2019-03-28 13:58:11 +01:00
if cfg.PostDown != "" {
if err := execSh(cfg.PostDown, iface, log); err != nil {
return err
}
2021-01-21 10:25:04 +01:00
2019-03-28 13:58:11 +01:00
log.Infoln("applied post-down command")
}
2021-01-21 10:25:04 +01:00
return nil
}
func execSh(command string, iface string, log logrus.FieldLogger, stdin ...string) error {
2021-01-21 10:25:04 +01:00
cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface)) // nolint: gosec
if len(stdin) > 0 {
log = log.WithField("stdin", strings.Join(stdin, ""))
b := &bytes.Buffer{}
2021-01-21 10:25:04 +01:00
for _, ln := range stdin {
if _, err := fmt.Fprint(b, ln); err != nil {
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
}
}
2021-01-21 10:25:04 +01:00
cmd.Stdin = b
}
2021-01-21 10:25:04 +01:00
out, err := cmd.CombinedOutput()
if err != nil {
log.WithError(err).Errorf("failed to execute %s:\n%s", cmd.Args, out)
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
}
2021-01-21 10:25:04 +01:00
log.Infof("executed %s:\n%s", cmd.Args, out)
2021-01-21 10:25:04 +01:00
return nil
}
2021-01-21 10:25:04 +01:00
// Sync the config to the current setup for given interface.
// It perform 4 operations:
2021-01-21 10:25:04 +01:00
// * SyncLink --> makes sure link is up and type wireguard.
// * SyncWireguardDevice --> configures allowedIP & other wireguard specific settings.
// * SyncAddress --> synces linux addresses bounded to this interface.
// * SyncRoutes --> synces all allowedIP routes to route to this interface.
2021-01-25 11:06:27 +01:00
func Sync(cfg *Config, iface string, uspace bool, logger logrus.FieldLogger) error {
2019-03-26 13:57:33 +01:00
log := logger.WithField("iface", iface)
2021-01-25 11:06:27 +01:00
link, err := SyncLink(cfg, iface, uspace, log)
if err != nil {
log.WithError(err).Errorln("cannot sync wireguard link")
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
}
2021-01-21 10:25:04 +01:00
log.Info("synced link")
if err := SyncWireguardDevice(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync wireguard link")
2021-01-21 10:25:04 +01:00
return err
}
2021-01-21 10:25:04 +01:00
log.Info("synced link")
if err := SyncAddress(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync addresses")
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
}
2021-01-21 10:25:04 +01:00
log.Info("synced addresss")
2021-01-21 10:25:04 +01:00
managedRoutes := make([]net.IPNet, 0)
for _, peer := range cfg.Peers {
2021-01-21 10:25:04 +01:00
managedRoutes = append(managedRoutes, peer.AllowedIPs...)
}
2021-01-21 10:25:04 +01:00
if err := SyncRoutes(cfg, link, managedRoutes, log); err != nil {
log.WithError(err).Errorln("cannot sync routes")
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
}
2021-01-21 10:25:04 +01:00
log.Info("synced routed")
log.Info("Successfully synced device")
2021-01-21 10:25:04 +01:00
return nil
}
2021-01-21 10:25:04 +01:00
// SyncWireguardDevice synces wireguard vpn setting on the given link.
// It does not set routes/addresses beyond wg internal crypto-key routing, only handles wireguard specific settings.
func SyncWireguardDevice(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
2019-10-30 18:04:10 +01:00
cl, err := wgctrl.New()
if err != nil {
log.WithError(err).Errorln("cannot setup wireguard device")
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
}
2021-01-21 10:25:04 +01:00
if err := cl.ConfigureDevice(link.Attrs().Name, cfg.Config); err != nil {
log.WithError(err).Error("cannot configure device")
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
}
2021-01-21 10:25:04 +01:00
return nil
}
2021-01-21 10:25:04 +01:00
// SyncLink synces link state with the config.
// It does not sync Wireguard settings, just makes sure the device is up and type wireguard.
2021-01-25 11:06:27 +01:00
func SyncLink(cfg *Config, iface string, uspace bool, log logrus.FieldLogger) (netlink.Link, error) {
2019-03-26 13:44:38 +01:00
link, err := netlink.LinkByName(iface)
2021-01-21 10:25:04 +01:00
// nolint: nestif
2019-03-26 13:44:38 +01:00
if err != nil {
2021-01-21 11:54:26 +01:00
if _, ok := err.(netlink.LinkNotFoundError); !ok {
2019-03-26 13:57:33 +01:00
log.WithError(err).Error("cannot read link")
2021-01-21 10:25:04 +01:00
return nil, fmt.Errorf("%w", err)
2019-03-26 13:44:38 +01:00
}
2021-01-21 10:25:04 +01:00
2019-03-26 13:44:38 +01:00
log.Info("link not found, creating")
2019-03-26 13:44:38 +01:00
wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{
Name: iface,
2019-03-26 13:58:26 +01:00
MTU: cfg.MTU,
2019-03-26 13:44:38 +01:00
},
LinkType: "wireguard",
}
2021-01-25 11:06:27 +01:00
if uspace {
log.Info("enforcing embedded wireguard-go")
2021-01-21 10:25:04 +01:00
2021-01-25 11:06:27 +01:00
if err := userspace(iface); err != nil {
2021-01-20 15:24:38 +01:00
log.WithError(err).Errorf("cannot create link through wireguard-go: %s", err.Error())
2021-01-21 10:25:04 +01:00
return nil, fmt.Errorf("cannot create link: %w", err)
}
2021-01-25 11:06:27 +01:00
} else if !uspace {
if err := netlink.LinkAdd(wgLink); err != nil {
log.WithError(err).Errorf("cannot create link: %s", err.Error())
log.Info("trying to use embedded wireguard-go")
2021-01-25 11:06:27 +01:00
if err := userspace(iface); err != nil {
log.WithError(err).Errorf("cannot create link through wireguard-go: %s", err.Error())
return nil, fmt.Errorf("cannot create link: %w", err)
}
}
2019-03-26 13:44:38 +01:00
}
// Needs some sleeping to wait for interface creating.
time.Sleep(1 * time.Second)
2019-03-26 13:44:38 +01:00
link, err = netlink.LinkByName(iface)
if err != nil {
2019-03-26 13:57:33 +01:00
log.WithError(err).Error("cannot read link")
2021-01-21 10:25:04 +01:00
return nil, fmt.Errorf("%w", err)
2019-03-26 13:44:38 +01:00
}
}
2021-01-21 10:25:04 +01:00
2019-03-26 13:44:38 +01:00
if err := netlink.LinkSetUp(link); err != nil {
2019-03-26 13:57:33 +01:00
log.WithError(err).Error("cannot set link up")
2021-01-21 10:25:04 +01:00
return nil, fmt.Errorf("%w", err)
2019-03-26 13:44:38 +01:00
}
2021-01-21 10:25:04 +01:00
2019-03-26 13:57:33 +01:00
log.Info("set device up")
2021-01-21 10:25:04 +01:00
return link, nil
2019-03-26 13:44:38 +01:00
}
// SyncAddress adds/deletes all lind assigned IPV4 addressed as specified in the config
2021-01-21 10:25:04 +01:00
// nolint: funlen, gosec, scopelint
func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
2019-03-26 13:44:38 +01:00
addrs, err := netlink.AddrList(link, syscall.AF_INET)
if err != nil {
log.Error(err, "cannot read link address")
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
2019-03-26 13:44:38 +01:00
}
2019-03-27 17:40:06 +01:00
// nil addr means I've used it
2021-01-21 10:25:04 +01:00
presentAddresses := make(map[string]netlink.Addr)
2019-03-26 13:44:38 +01:00
for _, addr := range addrs {
log.WithFields(map[string]interface{}{
2019-03-29 10:48:33 +01:00
"addr": fmt.Sprint(addr.IPNet),
"label": addr.Label,
2019-03-29 10:48:33 +01:00
}).Debugf("found existing address: %v", addr)
2021-01-21 10:25:04 +01:00
2019-03-29 10:48:33 +01:00
presentAddresses[addr.IPNet.String()] = addr
2019-03-26 13:44:38 +01:00
}
for _, addr := range cfg.Address {
2019-03-29 10:48:33 +01:00
log := log.WithField("addr", addr.String())
2019-03-26 13:44:38 +01:00
_, present := presentAddresses[addr.String()]
2019-03-29 10:48:33 +01:00
presentAddresses[addr.String()] = netlink.Addr{} // mark as present
2021-01-21 10:25:04 +01:00
2019-03-26 13:44:38 +01:00
if present {
2019-03-26 13:57:33 +01:00
log.Info("address present")
2021-01-21 10:25:04 +01:00
2019-03-26 13:44:38 +01:00
continue
}
2021-01-21 10:25:04 +01:00
2019-03-26 13:44:38 +01:00
if err := netlink.AddrAdd(link, &netlink.Addr{
2019-03-28 14:22:30 +01:00
IPNet: &addr,
Label: cfg.AddressLabel,
2019-03-26 13:44:38 +01:00
}); err != nil {
2021-01-21 10:25:04 +01:00
if errors.Is(err, syscall.EEXIST) {
log.WithError(err).Error("cannot add addr")
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
}
2019-03-26 13:44:38 +01:00
}
2021-01-21 10:25:04 +01:00
2019-03-26 13:57:33 +01:00
log.Info("address added")
2019-03-26 13:44:38 +01:00
}
2019-03-27 17:40:06 +01:00
for _, addr := range presentAddresses {
2019-03-29 10:48:33 +01:00
if addr.IPNet == nil {
2019-03-27 17:40:06 +01:00
continue
}
2021-01-21 10:25:04 +01:00
log := log.WithFields(map[string]interface{}{
"addr": addr.IPNet.String(),
"label": addr.Label,
})
2021-01-21 10:25:04 +01:00
2019-03-29 10:48:33 +01:00
if err := netlink.AddrDel(link, &addr); err != nil {
2019-03-27 17:40:06 +01:00
log.WithError(err).Error("cannot delete addr")
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
2019-03-26 13:44:38 +01:00
}
2021-01-21 10:25:04 +01:00
2019-03-27 17:40:06 +01:00
log.Info("addr deleted")
2019-03-26 13:44:38 +01:00
}
2021-01-21 10:25:04 +01:00
2019-03-26 13:44:38 +01:00
return nil
}
2019-03-29 12:40:15 +01:00
func fillRouteDefaults(rt *netlink.Route) {
// fill defaults
if rt.Table == 0 {
rt.Table = unix.RT_CLASS_MAIN
2019-03-29 12:40:15 +01:00
}
if rt.Protocol == 0 {
rt.Protocol = unix.RTPROT_BOOT
2019-03-29 12:40:15 +01:00
}
if rt.Type == 0 {
rt.Type = unix.RTN_UNICAST
2019-03-29 12:40:15 +01:00
}
}
2019-03-28 12:43:34 +01:00
// SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config
2021-01-21 10:25:04 +01:00
// nolint: funlen, gosec, scopelint
func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error {
wantedRoutes := make(map[string][]netlink.Route, len(managedRoutes))
2021-01-21 10:25:04 +01:00
2019-03-29 12:40:15 +01:00
presentRoutes, err := netlink.RouteList(link, syscall.AF_INET)
2019-03-26 13:44:38 +01:00
if err != nil {
log.Error(err, "cannot read existing routes")
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
2019-03-26 13:44:38 +01:00
}
2021-01-21 10:25:04 +01:00
for _, rt := range managedRoutes {
2019-03-29 12:40:15 +01:00
rt := rt // make copy
log.WithField("dst", rt.String()).Debug("managing route")
nrt := netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: &rt,
Table: cfg.Table,
Protocol: cfg.RouteProtocol,
Priority: cfg.RouteMetric,
}
2019-03-29 12:40:15 +01:00
fillRouteDefaults(&nrt)
wantedRoutes[rt.String()] = append(wantedRoutes[rt.String()], nrt)
}
for _, rtLst := range wantedRoutes {
for _, rt := range rtLst {
rt := rt // make copy
log := log.WithFields(map[string]interface{}{
"route": rt.Dst.String(),
"protocol": rt.Protocol,
"table": rt.Table,
"type": rt.Type,
"metric": rt.Priority,
})
2021-01-21 10:25:04 +01:00
2019-03-29 12:40:15 +01:00
if err := netlink.RouteReplace(&rt); err != nil {
log.WithError(err).Errorln("cannot add/replace route")
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
2019-03-29 12:40:15 +01:00
}
2021-01-21 10:25:04 +01:00
2019-03-29 12:40:15 +01:00
log.Infoln("route added/replaced")
2019-03-26 13:44:38 +01:00
}
}
2019-03-29 12:40:15 +01:00
checkWanted := func(rt netlink.Route) bool {
for _, candidateRt := range wantedRoutes[rt.Dst.String()] {
if rt.Equal(candidateRt) {
return true
}
2019-03-26 13:44:38 +01:00
}
2021-01-21 10:25:04 +01:00
2019-03-29 12:40:15 +01:00
return false
}
for _, rt := range presentRoutes {
log := log.WithFields(map[string]interface{}{
"route": rt.Dst.String(),
"protocol": rt.Protocol,
"table": rt.Table,
"type": rt.Type,
2019-03-29 12:40:15 +01:00
"metric": rt.Priority,
})
2021-01-21 10:25:04 +01:00
if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == unix.RT_CLASS_MAIN)) {
2019-03-29 12:40:15 +01:00
log.Debug("wrong table for route, skipping")
2021-01-21 10:25:04 +01:00
2019-03-29 12:40:15 +01:00
continue
}
if !(rt.Protocol == cfg.RouteProtocol) {
2019-03-29 10:50:17 +01:00
log.Infof("skipping route deletion, not owned by this daemon")
2021-01-21 10:25:04 +01:00
2019-03-29 10:48:33 +01:00
continue
}
2019-03-29 12:40:15 +01:00
if checkWanted(rt) {
log.Debug("route wanted, skipping deleting")
2021-01-21 10:25:04 +01:00
2019-03-29 12:40:15 +01:00
continue
}
2019-03-29 10:48:33 +01:00
if err := netlink.RouteDel(&rt); err != nil {
2019-03-29 10:50:17 +01:00
log.WithError(err).Error("cannot delete route")
2021-01-21 10:25:04 +01:00
return fmt.Errorf("%w", err)
2019-03-26 13:44:38 +01:00
}
2021-01-21 10:25:04 +01:00
2019-03-27 17:40:06 +01:00
log.Info("route deleted")
2019-03-26 13:44:38 +01:00
}
2019-03-29 12:40:15 +01:00
2019-03-26 13:44:38 +01:00
return nil
}