gofmt + various feature implementations

This commit is contained in:
Neven Miculinic 2019-03-28 12:38:22 +01:00
parent 940aafb98d
commit f0b2e9f85a
4 changed files with 173 additions and 51 deletions

View File

@ -1,17 +1,29 @@
# wg-quick-go # wg-quick-go
[![Build Status](https://gitlab.com/neven-miculinic/wg-quick-go/badges/master/pipeline.svg)](https://gitlab.com/neven-miculinic/wg-quick-go/pipelines) [![GoDoc](https://godoc.org/github.com/nmiculinic/wireguardctrl?status.svg)](https://godoc.org/github.com/nmiculinic/wg-quick-go) [![Go Report Card](https://goreportcard.com/badge/github.com/nmiculinic/wg-quick-go)](https://goreportcard.com/report/github.com/nmiculinic/wg-quick-go)
wg-quick like library in go for embedding wg-quick like library in go for embedding
# Roadmap # Roadmap
* [ ] full wg-quick feature parity * [x] full wg-quick feature parity
* [x] PreUp
* [x] PostUp
* [x] PreDown
* [x] PostDown
* [x] DNS
* [x] MTU
* [x] Save --> Use MarshallText interface to save config
* [x] Sync * [x] Sync
* [x] Up
* [x] Down
* [x] MarshallText * [x] MarshallText
* [x] UnmarshallText * [x] UnmarshallText
* [x] Minimal test * [x] Minimal test
* [ ] Integration tests ((TODO; have some virtual machines/kvm and wreck havoc :) )) * [ ] Integration tests ((TODO; have some virtual machines/kvm and wreck havoc :) ))
* [ ] Up
* [ ] Down
# Caveats # Caveats
* Endpoints DNS MarshallText is unsupported * Endpoints DNS MarshallText is unsupported
* Pre/Post Up/Down doesn't support escaped `%i`, that is all `%i` are expanded to interface name.
* SaveConfig in config is only a placeholder (( since there's no reading/writing from files )). Use Unmarshall/Marshall Text to save/load config (( you're responsible for IO)).

View File

@ -41,7 +41,6 @@ type Config struct {
SaveConfig bool SaveConfig bool
} }
var _ encoding.TextMarshaler = (*Config)(nil) var _ encoding.TextMarshaler = (*Config)(nil)
var _ encoding.TextUnmarshaler = (*Config)(nil) var _ encoding.TextUnmarshaler = (*Config)(nil)
@ -62,7 +61,7 @@ func toSeconds(duration time.Duration) int {
} }
var funcMap = template.FuncMap(map[string]interface{}{ var funcMap = template.FuncMap(map[string]interface{}{
"wgKey": serializeKey, "wgKey": serializeKey,
"toSeconds": toSeconds, "toSeconds": toSeconds,
}) })
@ -119,11 +118,13 @@ func ParseKey(key string) (wgtypes.Key, error) {
} }
type parseState int type parseState int
const ( const (
unknown parseState = iota unknown parseState = iota
inter = iota inter = iota
peer = iota peer = iota
) )
func (cfg *Config) UnmarshalText(text []byte) error { func (cfg *Config) UnmarshalText(text []byte) error {
*cfg = Config{} // Zero out the config *cfg = Config{} // Zero out the config
state := unknown state := unknown
@ -139,7 +140,7 @@ func (cfg *Config) UnmarshalText(text []byte) error {
case "[Peer]": case "[Peer]":
state = peer state = peer
cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{}) cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{})
peerCfg = &cfg.Peers[len(cfg.Peers) - 1] peerCfg = &cfg.Peers[len(cfg.Peers)-1]
default: default:
parts := strings.Split(ln, "=") parts := strings.Split(ln, "=")
if len(parts) < 2 { if len(parts) < 2 {
@ -172,11 +173,11 @@ func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
if err != nil { if err != nil {
return fmt.Errorf("%v", err) return fmt.Errorf("%v", err)
} }
cfg.Address = append(cfg.Address, &net.IPNet{IP:ip, Mask:cidr.Mask}) cfg.Address = append(cfg.Address, &net.IPNet{IP: ip, Mask: cidr.Mask})
} }
case "DNS": case "DNS":
for _, addr := range strings.Split(rhs, ",") { for _, addr := range strings.Split(rhs, ",") {
ip:= net.ParseIP(strings.TrimSpace(addr)) ip := net.ParseIP(strings.TrimSpace(addr))
if ip == nil { if ip == nil {
return fmt.Errorf("cannot parse IP") return fmt.Errorf("cannot parse IP")
} }
@ -250,7 +251,7 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
if err != nil { if err != nil {
return fmt.Errorf("%v", err) return fmt.Errorf("%v", err)
} }
peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP:ip, Mask:cidr.Mask}) peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask})
} }
case "Endpoint": case "Endpoint":
addr, err := net.ResolveUDPAddr("", rhs) addr, err := net.ResolveUDPAddr("", rhs)
@ -262,4 +263,4 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
return fmt.Errorf("unknown directive %s", lhs) return fmt.Errorf("unknown directive %s", lhs)
} }
return nil return nil
} }

View File

@ -6,7 +6,7 @@ import (
) )
var testConfigs = map[string]string{ var testConfigs = map[string]string{
"simple": `[Interface] "simple": `[Interface]
Address = 10.200.100.8/24 Address = 10.200.100.8/24
DNS = 10.200.100.1 DNS = 10.200.100.1
PrivateKey = oK56DE9Ue9zK76rAc8pBl6opph+1v36lm7cXXsQKrQM= PrivateKey = oK56DE9Ue9zK76rAc8pBl6opph+1v36lm7cXXsQKrQM=
@ -17,7 +17,7 @@ AllowedIPs = 0.0.0.0/0
PresharedKey = /UwcSPg38hW/D9Y3tcS1FOV0K1wuURMbS0sesJEP5ak= PresharedKey = /UwcSPg38hW/D9Y3tcS1FOV0K1wuURMbS0sesJEP5ak=
Endpoint = 123.12.12.1:51820 Endpoint = 123.12.12.1:51820
`, `,
"sample-2": `[Interface] "sample-2": `[Interface]
Address = 10.192.122.1/24 Address = 10.192.122.1/24
Address = 10.10.0.1/16 Address = 10.10.0.1/16
PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
@ -36,7 +36,7 @@ AllowedIPs = 10.192.122.4/32, 192.168.0.0/16
PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=
AllowedIPs = 10.10.10.230/32 AllowedIPs = 10.10.10.230/32
`, `,
"sample-3":`[Interface] "sample-3": `[Interface]
Address = 10.192.122.1/24 Address = 10.192.122.1/24
PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=
ListenPort = 51820 ListenPort = 51820
@ -50,7 +50,7 @@ AllowedIPs = 0.0.0.0/0
`, `,
} }
func TestExampleConfig(t *testing.T) { func TestExampleConfig(t *testing.T) {
c := &Config{} c := &Config{}
for name, cfg := range testConfigs { for name, cfg := range testConfigs {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
@ -62,4 +62,4 @@ func TestExampleConfig(t *testing.T) {
assert.Equal(t, cfg, string(tt)) assert.Equal(t, cfg, string(tt))
}) })
} }
} }

177
wg.go
View File

@ -1,25 +1,156 @@
package wgquick package wgquick
import ( import (
"bytes"
"fmt"
"github.com/mdlayher/wireguardctrl" "github.com/mdlayher/wireguardctrl"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"os"
"os/exec"
"strings"
"syscall" "syscall"
) )
const ( const (
defaultRoutingTable = 254 defaultRoutingTable = 254
) )
// Sync the config to the current setup for given interface // Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`
func (cfg *Config) Sync(iface string, logger logrus.FieldLogger) error { func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface) log := logger.WithField("iface", iface)
_, err := netlink.LinkByName(iface)
if err == nil {
return os.ErrExist
}
if _, ok := err.(netlink.LinkNotFoundError); !ok {
return err
}
for _, dns := range cfg.DNS {
if err := execSh("resolvconf -a tun.%i -m 0 -x", iface, log, fmt.Sprintf("nameserver %s\n", dns)); err != nil {
return err
}
}
if err := execSh(cfg.PreUp, iface, log); err != nil {
return err
}
log.Infoln("applied pre-up command")
if err := Sync(cfg, iface, logger); err != nil {
return err
}
if err := execSh(cfg.PostUp, iface, log); err != nil {
return err
}
log.Infoln("applied post-up command")
return nil
}
// Down destroyes the wg interface. Mostly equivalent to `wg-quick down iface`
func Down(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface)
link, err := netlink.LinkByName(iface)
if err != nil {
return err
}
if len(cfg.DNS) > 1 {
if err := execSh("resolvconf -d tun.%s", iface, log); err != nil {
return err
}
}
if err := execSh(cfg.PreDown, iface, log); err != nil {
return err
}
log.Infoln("applied pre-down command")
if err := netlink.LinkDel(link); err != nil {
return err
}
log.Infoln("link deleted")
if err := execSh(cfg.PostDown, iface, log); err != nil {
return err
}
log.Infoln("applied post-down command")
return nil
}
func execSh(command string, iface string, log logrus.FieldLogger, stdin ...string) error {
cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface))
if len(stdin) > 0 {
log = log.WithField("stdin", strings.Join(stdin, ""))
b := &bytes.Buffer{}
for _, ln := range stdin {
if _, err := fmt.Fprint(b, ln); err != nil {
return err
}
}
cmd.Stdin = b
}
out, err := cmd.CombinedOutput()
if err != nil {
log.WithError(err).Errorf("failed to execute %s:\n%s", cmd.Args, out)
return err
}
log.Infof("executed %s:\n%s", cmd.Args, out)
return nil
}
// Sync the config to the current setup for given interface
func Sync(cfg *Config, iface string, logger logrus.FieldLogger) error {
log := logger.WithField("iface", iface)
link, err := SyncLink(cfg, iface, log)
if err != nil {
log.WithError(err).Errorln("cannot sync wireguard link")
return err
}
log.Info("synced link")
if err := SyncWireguardDevice(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync wireguard link")
return err
}
log.Info("synced link")
if err := SyncAddress(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync addresses")
return err
}
log.Info("synced addresss")
if err := SyncRoutes(cfg, link, log); err != nil {
log.WithError(err).Errorln("cannot sync routes")
return err
}
log.Info("synced routed")
log.Info("Successfully synced device")
return nil
}
// SyncWireguardDevice synces wireguard vpn setting on the given link. It does not set routes/addresses beyond wg internal crypto-key routing
func SyncWireguardDevice(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
cl, err := wireguardctrl.New()
if err != nil {
log.WithError(err).Errorln("cannot setup wireguard device")
return err
}
if err := cl.ConfigureDevice(link.Attrs().Name, cfg.Config); err != nil {
log.WithError(err).Error("cannot configure device")
return err
}
return nil
}
// SyncLink synces link state with the config. It does not sync Wireguard settings, just makes sure the device is up and type wireguard
func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link, error) {
link, err := netlink.LinkByName(iface) link, err := netlink.LinkByName(iface)
if err != nil { if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); !ok { if _, ok := err.(netlink.LinkNotFoundError); !ok {
log.WithError(err).Error("cannot read link") log.WithError(err).Error("cannot read link")
return err return nil, err
} }
log.Info("link not found, creating") log.Info("link not found, creating")
wgLink := &netlink.GenericLink{ wgLink := &netlink.GenericLink{
@ -31,48 +162,25 @@ func (cfg *Config) Sync(iface string, logger logrus.FieldLogger) error {
} }
if err := netlink.LinkAdd(wgLink); err != nil { if err := netlink.LinkAdd(wgLink); err != nil {
log.WithError(err).Error("cannot create link") log.WithError(err).Error("cannot create link")
return err return nil, err
} }
link, err = netlink.LinkByName(iface) link, err = netlink.LinkByName(iface)
if err != nil { if err != nil {
log.WithError(err).Error("cannot read link") log.WithError(err).Error("cannot read link")
return err return nil, err
} }
} }
if err := netlink.LinkSetUp(link); err != nil { if err := netlink.LinkSetUp(link); err != nil {
log.WithError(err).Error("cannot set link up") log.WithError(err).Error("cannot set link up")
return err return nil, err
} }
log.Info("set device up") log.Info("set device up")
return link, nil
cl, err := wireguardctrl.New()
if err != nil {
log.Error(err, "cannot setup wireguard device")
return err
}
if err := cl.ConfigureDevice(iface, cfg.Config); err != nil {
log.WithError(err).Error("cannot configure device")
return err
}
if err := syncAddress(link, cfg, log); err != nil {
log.Error(err, "cannot sync addresses")
return err
}
if err := syncRoutes(link, cfg, log); err != nil {
log.Error(err, "cannot sync routes")
return err
}
log.Info("Successfully setup device")
return nil
} }
func syncAddress(link netlink.Link, cfg *Config, log logrus.FieldLogger) error { // SyncAddress adds/deletes all lind assigned IPV4 addressed as specified in the config
func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
addrs, err := netlink.AddrList(link, syscall.AF_INET) addrs, err := netlink.AddrList(link, syscall.AF_INET)
if err != nil { if err != nil {
log.Error(err, "cannot read link address") log.Error(err, "cannot read link address")
@ -116,7 +224,8 @@ func syncAddress(link netlink.Link, cfg *Config, log logrus.FieldLogger) error {
return nil return nil
} }
func syncRoutes(link netlink.Link, cfg *Config, log logrus.FieldLogger) error { // SyncAddress adds/deletes all route assigned IPV4 addressed as specified in the config
func SyncRoutes(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
routes, err := netlink.RouteList(link, syscall.AF_INET) routes, err := netlink.RouteList(link, syscall.AF_INET)
if err != nil { if err != nil {
log.Error(err, "cannot read existing routes") log.Error(err, "cannot read existing routes")