diff --git a/README.md b/README.md index 2d8826d..24aa6c6 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,29 @@ # wg-quick-go + +[![Build Status](https://gitlab.com/neven-miculinic/wg-quick-go/badges/master/pipeline.svg)](https://gitlab.com/neven-miculinic/wg-quick-go/pipelines) [![GoDoc](https://godoc.org/github.com/nmiculinic/wireguardctrl?status.svg)](https://godoc.org/github.com/nmiculinic/wg-quick-go) [![Go Report Card](https://goreportcard.com/badge/github.com/nmiculinic/wg-quick-go)](https://goreportcard.com/report/github.com/nmiculinic/wg-quick-go) + wg-quick like library in go for embedding # Roadmap -* [ ] full wg-quick feature parity +* [x] full wg-quick feature parity + * [x] PreUp + * [x] PostUp + * [x] PreDown + * [x] PostDown + * [x] DNS + * [x] MTU + * [x] Save --> Use MarshallText interface to save config * [x] Sync +* [x] Up +* [x] Down * [x] MarshallText * [x] UnmarshallText * [x] Minimal test * [ ] Integration tests ((TODO; have some virtual machines/kvm and wreck havoc :) )) -* [ ] Up -* [ ] Down # Caveats * Endpoints DNS MarshallText is unsupported +* Pre/Post Up/Down doesn't support escaped `%i`, that is all `%i` are expanded to interface name. +* SaveConfig in config is only a placeholder (( since there's no reading/writing from files )). Use Unmarshall/Marshall Text to save/load config (( you're responsible for IO)). diff --git a/config.go b/config.go index 0591f12..d41cfde 100644 --- a/config.go +++ b/config.go @@ -41,7 +41,6 @@ type Config struct { SaveConfig bool } - var _ encoding.TextMarshaler = (*Config)(nil) var _ encoding.TextUnmarshaler = (*Config)(nil) @@ -62,7 +61,7 @@ func toSeconds(duration time.Duration) int { } var funcMap = template.FuncMap(map[string]interface{}{ - "wgKey": serializeKey, + "wgKey": serializeKey, "toSeconds": toSeconds, }) @@ -119,11 +118,13 @@ func ParseKey(key string) (wgtypes.Key, error) { } type parseState int + const ( unknown parseState = iota - inter = iota - peer = iota + inter = iota + peer = iota ) + func (cfg *Config) UnmarshalText(text []byte) error { *cfg = Config{} // Zero out the config state := unknown @@ -139,7 +140,7 @@ func (cfg *Config) UnmarshalText(text []byte) error { case "[Peer]": state = peer cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{}) - peerCfg = &cfg.Peers[len(cfg.Peers) - 1] + peerCfg = &cfg.Peers[len(cfg.Peers)-1] default: parts := strings.Split(ln, "=") if len(parts) < 2 { @@ -172,11 +173,11 @@ func parseInterfaceLine(cfg *Config, lhs string, rhs string) error { if err != nil { return fmt.Errorf("%v", err) } - cfg.Address = append(cfg.Address, &net.IPNet{IP:ip, Mask:cidr.Mask}) + cfg.Address = append(cfg.Address, &net.IPNet{IP: ip, Mask: cidr.Mask}) } case "DNS": for _, addr := range strings.Split(rhs, ",") { - ip:= net.ParseIP(strings.TrimSpace(addr)) + ip := net.ParseIP(strings.TrimSpace(addr)) if ip == nil { return fmt.Errorf("cannot parse IP") } @@ -250,7 +251,7 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error { if err != nil { return fmt.Errorf("%v", err) } - peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP:ip, Mask:cidr.Mask}) + peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask}) } case "Endpoint": addr, err := net.ResolveUDPAddr("", rhs) @@ -262,4 +263,4 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error { return fmt.Errorf("unknown directive %s", lhs) } return nil -} \ No newline at end of file +} diff --git a/config_test.go b/config_test.go index 566c069..321838a 100644 --- a/config_test.go +++ b/config_test.go @@ -6,7 +6,7 @@ import ( ) var testConfigs = map[string]string{ -"simple": `[Interface] + "simple": `[Interface] Address = 10.200.100.8/24 DNS = 10.200.100.1 PrivateKey = oK56DE9Ue9zK76rAc8pBl6opph+1v36lm7cXXsQKrQM= @@ -17,7 +17,7 @@ AllowedIPs = 0.0.0.0/0 PresharedKey = /UwcSPg38hW/D9Y3tcS1FOV0K1wuURMbS0sesJEP5ak= Endpoint = 123.12.12.1:51820 `, -"sample-2": `[Interface] + "sample-2": `[Interface] Address = 10.192.122.1/24 Address = 10.10.0.1/16 PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= @@ -36,7 +36,7 @@ AllowedIPs = 10.192.122.4/32, 192.168.0.0/16 PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= AllowedIPs = 10.10.10.230/32 `, -"sample-3":`[Interface] + "sample-3": `[Interface] Address = 10.192.122.1/24 PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= ListenPort = 51820 @@ -50,7 +50,7 @@ AllowedIPs = 0.0.0.0/0 `, } -func TestExampleConfig(t *testing.T) { +func TestExampleConfig(t *testing.T) { c := &Config{} for name, cfg := range testConfigs { t.Run(name, func(t *testing.T) { @@ -62,4 +62,4 @@ func TestExampleConfig(t *testing.T) { assert.Equal(t, cfg, string(tt)) }) } -} \ No newline at end of file +} diff --git a/wg.go b/wg.go index 81ea34c..a207551 100644 --- a/wg.go +++ b/wg.go @@ -1,25 +1,156 @@ package wgquick import ( + "bytes" + "fmt" "github.com/mdlayher/wireguardctrl" "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + "os" + "os/exec" + "strings" "syscall" ) - const ( defaultRoutingTable = 254 ) -// Sync the config to the current setup for given interface -func (cfg *Config) Sync(iface string, logger logrus.FieldLogger) error { +// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface` +func Up(cfg *Config, iface string, logger logrus.FieldLogger) error { log := logger.WithField("iface", iface) + _, err := netlink.LinkByName(iface) + if err == nil { + return os.ErrExist + } + if _, ok := err.(netlink.LinkNotFoundError); !ok { + return err + } + + for _, dns := range cfg.DNS { + if err := execSh("resolvconf -a tun.%i -m 0 -x", iface, log, fmt.Sprintf("nameserver %s\n", dns)); err != nil { + return err + } + } + + if err := execSh(cfg.PreUp, iface, log); err != nil { + return err + } + log.Infoln("applied pre-up command") + if err := Sync(cfg, iface, logger); err != nil { + return err + } + if err := execSh(cfg.PostUp, iface, log); err != nil { + return err + } + log.Infoln("applied post-up command") + return nil +} + +// Down destroyes the wg interface. Mostly equivalent to `wg-quick down iface` +func Down(cfg *Config, iface string, logger logrus.FieldLogger) error { + log := logger.WithField("iface", iface) + link, err := netlink.LinkByName(iface) + if err != nil { + return err + } + + if len(cfg.DNS) > 1 { + if err := execSh("resolvconf -d tun.%s", iface, log); err != nil { + return err + } + } + + if err := execSh(cfg.PreDown, iface, log); err != nil { + return err + } + log.Infoln("applied pre-down command") + if err := netlink.LinkDel(link); err != nil { + return err + } + log.Infoln("link deleted") + if err := execSh(cfg.PostDown, iface, log); err != nil { + return err + } + log.Infoln("applied post-down command") + return nil +} + +func execSh(command string, iface string, log logrus.FieldLogger, stdin ...string) error { + cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface)) + if len(stdin) > 0 { + log = log.WithField("stdin", strings.Join(stdin, "")) + b := &bytes.Buffer{} + for _, ln := range stdin { + if _, err := fmt.Fprint(b, ln); err != nil { + return err + } + } + cmd.Stdin = b + } + out, err := cmd.CombinedOutput() + if err != nil { + log.WithError(err).Errorf("failed to execute %s:\n%s", cmd.Args, out) + return err + } + log.Infof("executed %s:\n%s", cmd.Args, out) + return nil +} + +// Sync the config to the current setup for given interface +func Sync(cfg *Config, iface string, logger logrus.FieldLogger) error { + log := logger.WithField("iface", iface) + + link, err := SyncLink(cfg, iface, log) + if err != nil { + log.WithError(err).Errorln("cannot sync wireguard link") + return err + } + log.Info("synced link") + + if err := SyncWireguardDevice(cfg, link, log); err != nil { + log.WithError(err).Errorln("cannot sync wireguard link") + return err + } + log.Info("synced link") + + if err := SyncAddress(cfg, link, log); err != nil { + log.WithError(err).Errorln("cannot sync addresses") + return err + } + log.Info("synced addresss") + + if err := SyncRoutes(cfg, link, log); err != nil { + log.WithError(err).Errorln("cannot sync routes") + return err + } + log.Info("synced routed") + log.Info("Successfully synced device") + return nil + +} + +// SyncWireguardDevice synces wireguard vpn setting on the given link. It does not set routes/addresses beyond wg internal crypto-key routing +func SyncWireguardDevice(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { + cl, err := wireguardctrl.New() + if err != nil { + log.WithError(err).Errorln("cannot setup wireguard device") + return err + } + if err := cl.ConfigureDevice(link.Attrs().Name, cfg.Config); err != nil { + log.WithError(err).Error("cannot configure device") + return err + } + return nil +} + +// SyncLink synces link state with the config. It does not sync Wireguard settings, just makes sure the device is up and type wireguard +func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link, error) { link, err := netlink.LinkByName(iface) if err != nil { if _, ok := err.(netlink.LinkNotFoundError); !ok { log.WithError(err).Error("cannot read link") - return err + return nil, err } log.Info("link not found, creating") wgLink := &netlink.GenericLink{ @@ -31,48 +162,25 @@ func (cfg *Config) Sync(iface string, logger logrus.FieldLogger) error { } if err := netlink.LinkAdd(wgLink); err != nil { log.WithError(err).Error("cannot create link") - return err + return nil, err } link, err = netlink.LinkByName(iface) if err != nil { log.WithError(err).Error("cannot read link") - return err + return nil, err } } if err := netlink.LinkSetUp(link); err != nil { log.WithError(err).Error("cannot set link up") - return err + return nil, err } log.Info("set device up") - - cl, err := wireguardctrl.New() - if err != nil { - log.Error(err, "cannot setup wireguard device") - return err - } - - if err := cl.ConfigureDevice(iface, cfg.Config); err != nil { - log.WithError(err).Error("cannot configure device") - return err - } - - if err := syncAddress(link, cfg, log); err != nil { - log.Error(err, "cannot sync addresses") - return err - } - - if err := syncRoutes(link, cfg, log); err != nil { - log.Error(err, "cannot sync routes") - return err - } - - log.Info("Successfully setup device") - return nil - + return link, nil } -func syncAddress(link netlink.Link, cfg *Config, log logrus.FieldLogger) error { +// SyncAddress adds/deletes all lind assigned IPV4 addressed as specified in the config +func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { addrs, err := netlink.AddrList(link, syscall.AF_INET) if err != nil { log.Error(err, "cannot read link address") @@ -116,7 +224,8 @@ func syncAddress(link netlink.Link, cfg *Config, log logrus.FieldLogger) error { return nil } -func syncRoutes(link netlink.Link, cfg *Config, log logrus.FieldLogger) error { +// SyncAddress adds/deletes all route assigned IPV4 addressed as specified in the config +func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error { routes, err := netlink.RouteList(link, syscall.AF_INET) if err != nil { log.Error(err, "cannot read existing routes")