261 lines
7.8 KiB
Go
261 lines
7.8 KiB
Go
package wgctl
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/base64"
|
||
"github.com/mdlayher/wireguardctrl"
|
||
"github.com/mdlayher/wireguardctrl/wgtypes"
|
||
log "github.com/sirupsen/logrus"
|
||
"github.com/vishvananda/netlink"
|
||
"net"
|
||
"os/exec"
|
||
"syscall"
|
||
"text/template"
|
||
)
|
||
|
||
type Config struct {
|
||
wgtypes.Config
|
||
|
||
// Address — a comma-separated list of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned to the interface. May be specified multiple times.
|
||
Address []*net.IPNet
|
||
|
||
// — a comma-separated list of IP (v4 or v6) addresses to be set as the interface’s DNS servers. May be specified multiple times. Upon bringing the interface up, this runs ‘resolvconf -a tun.INTERFACE -m 0 -x‘ and upon bringing it down, this runs ‘resolvconf -d tun.INTERFACE‘. If these particular invocations of resolvconf(8) are undesirable, the PostUp and PostDown keys below may be used instead.
|
||
DNS []net.IP
|
||
// — if not specified, the MTU is automatically determined from the endpoint addresses or the system default route, which is usually a sane choice. However, to manually specify an MTU to override this automatic discovery, this value may be specified explicitly.
|
||
MTU int
|
||
|
||
// Table — Controls the routing table to which routes are added. There are two special values: ‘off’ disables the creation of routes altogether, and ‘auto’ (the default) adds routes to the default table and enables special handling of default routes.
|
||
Table int
|
||
|
||
// PreUp, PostUp, PreDown, PostDown — script snippets which will be executed by bash(1) before/after setting up/tearing down the interface, most commonly used to configure custom DNS options or firewall rules. The special string ‘%i’ is expanded to INTERFACE. Each one may be specified multiple times, in which case the commands are executed in order.
|
||
|
||
PreUp string
|
||
PostUp string
|
||
PreDown string
|
||
PostDown string
|
||
|
||
// SaveConfig — if set to ‘true’, the configuration is saved from the current state of the interface upon shutdown.
|
||
SaveConfig bool
|
||
}
|
||
|
||
func (cfg *Config) String() string {
|
||
b, err := cfg.MarshalText()
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
return string(b)
|
||
}
|
||
|
||
func (cfg *Config) MarshalText() (text []byte, err error) {
|
||
buff := &bytes.Buffer{}
|
||
if err := cfgTemplate.Execute(buff, cfg); err != nil {
|
||
return nil, err
|
||
}
|
||
return buff.Bytes(), nil
|
||
}
|
||
|
||
const wgtypeTemplateSpec = `[Interface]
|
||
{{- range := .Address }}
|
||
Address = {{ . }}
|
||
{{ end }}
|
||
{{- range := .DNS }}
|
||
DNS = {{ . }}
|
||
{{ end }}
|
||
PrivateKey = {{ .PrivateKey | wgKey }}
|
||
{{- if .ListenPort }}{{ "\n" }}ListenPort = {{ .ListenPort }}{{ end }}
|
||
{{- if .MTU }}{{ "\n" }}MTU = {{ .MTU }}{{ end }}
|
||
{{- if .Table }}{{ "\n" }}Table = {{ .Table }}{{ end }}
|
||
{{- if .PreUp }}{{ "\n" }}PreUp = {{ .PreUp }}{{ end }}
|
||
{{- if .PostUp }}{{ "\n" }}Table = {{ .Table }}{{ end }}
|
||
{{- if .PreDown }}{{ "\n" }}PreDown = {{ .PreDown }}{{ end }}
|
||
{{- if .PostDown }}{{ "\n" }}PostDown = {{ .PostDown }}{{ end }}
|
||
{{- if .SaveConfig }}{{ "\n" }}SaveConfig = {{ .SaveConfig }}{{ end }}
|
||
|
||
{{- range .Peers }}
|
||
[Peer]
|
||
PublicKey = {{ .PublicKey | wgKey }}
|
||
AllowedIps = {{ range $i, $el := .AllowedIPs }}{{if $i}}, {{ end }}{{ $el }}{{ end }}
|
||
{{- if .Endpoint }}{{ "\n" }}Endpoint = {{ .Endpoint }}{{ end }}
|
||
{{- end }}
|
||
`
|
||
|
||
func serializeKey(key *wgtypes.Key) string {
|
||
return base64.StdEncoding.EncodeToString(key[:])
|
||
}
|
||
|
||
func ParseKey(key string) (wgtypes.Key, error) {
|
||
var pkey wgtypes.Key
|
||
pkeySlice, err := base64.StdEncoding.DecodeString(key)
|
||
if err != nil {
|
||
return pkey, err
|
||
}
|
||
copy(pkey[:], pkeySlice[:])
|
||
return pkey, nil
|
||
}
|
||
|
||
var cfgTemplate = template.Must(
|
||
template.
|
||
New("wg-cfg").
|
||
Funcs(template.FuncMap(map[string]interface{}{"wgKey": serializeKey})).
|
||
Parse(wgtypeTemplateSpec))
|
||
|
||
func (cfg *Config) Up(iface string) error {
|
||
|
||
link, err := netlink.LinkByName(iface)
|
||
if err != nil {
|
||
if _, ok := err.(netlink.LinkNotFoundError); !ok {
|
||
log.Error(err, "cannot read link, probably doesn't exist")
|
||
return err
|
||
}
|
||
log.Info("link not found, creating")
|
||
wgLink := &netlink.GenericLink{
|
||
LinkAttrs: netlink.LinkAttrs{
|
||
Name: iface,
|
||
},
|
||
LinkType: "wireguard",
|
||
}
|
||
if err := netlink.LinkAdd(wgLink); err != nil {
|
||
log.Error(err, "cannot create link", "iface", iface)
|
||
return err
|
||
}
|
||
if err := exec.Command("ip", "link", "add", "dev", iface, "type", "wireguard").Run(); err != nil {
|
||
}
|
||
|
||
link, err = netlink.LinkByName(iface)
|
||
if err != nil {
|
||
log.Error(err, "cannot read link")
|
||
return err
|
||
}
|
||
}
|
||
log.Info("link", "type", link.Type(), "attrs", link.Attrs())
|
||
if err := netlink.LinkSetUp(link); err != nil {
|
||
log.Error(err, "cannot set link up", "type", link.Type(), "attrs", link.Attrs())
|
||
return err
|
||
}
|
||
log.Info("set device up", "iface", iface)
|
||
|
||
cl, err := wireguardctrl.New()
|
||
if err != nil {
|
||
log.Error(err, "cannot setup wireguard device")
|
||
return err
|
||
}
|
||
|
||
if err := cl.ConfigureDevice(iface, cfg.Config); err != nil {
|
||
log.Error(err, "cannot configure device", "iface", iface)
|
||
return err
|
||
}
|
||
|
||
if err := syncAddress(link, cfg); err != nil {
|
||
log.Error(err, "cannot sync addresses")
|
||
return err
|
||
}
|
||
|
||
if err := syncRoutes(link, cfg); err != nil {
|
||
log.Error(err, "cannot sync routes")
|
||
return err
|
||
}
|
||
|
||
log.Info("Successfully setup device", "iface", iface)
|
||
return nil
|
||
|
||
}
|
||
|
||
func syncAddress(link netlink.Link, cfg *Config) error {
|
||
addrs, err := netlink.AddrList(link, syscall.AF_INET)
|
||
if err != nil {
|
||
log.Error(err, "cannot read link address")
|
||
return err
|
||
}
|
||
|
||
presentAddresses := make(map[string]int, 0)
|
||
for _, addr := range addrs {
|
||
presentAddresses[addr.IPNet.String()] = 1
|
||
}
|
||
|
||
for _, addr := range cfg.Address {
|
||
_, present := presentAddresses[addr.String()]
|
||
presentAddresses[addr.String()] = 2
|
||
if present {
|
||
log.Info("address present", "addr", addr, "iface", link.Attrs().Name)
|
||
continue
|
||
}
|
||
|
||
if err := netlink.AddrAdd(link, &netlink.Addr{
|
||
IPNet: addr,
|
||
}); err != nil {
|
||
log.Error(err, "cannot add addr", "iface", link.Attrs().Name)
|
||
return err
|
||
}
|
||
log.Info("address added", "addr", addr, "iface", link.Attrs().Name)
|
||
}
|
||
|
||
for addr, p := range presentAddresses {
|
||
if p < 2 {
|
||
nlAddr, err := netlink.ParseAddr(addr)
|
||
if err != nil {
|
||
log.Error(err, "cannot parse del addr", "iface", link.Attrs().Name, "addr", addr)
|
||
return err
|
||
}
|
||
if err := netlink.AddrAdd(link, nlAddr); err != nil {
|
||
log.Error(err, "cannot delete addr", "iface", link.Attrs().Name, "addr", addr)
|
||
return err
|
||
}
|
||
log.Info("address deleted", "addr", addr, "iface", link.Attrs().Name)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func syncRoutes(link netlink.Link, cfg *Config) error {
|
||
routes, err := netlink.RouteList(link, syscall.AF_INET)
|
||
if err != nil {
|
||
log.Error(err, "cannot read existing routes")
|
||
return err
|
||
}
|
||
|
||
presentRoutes := make(map[string]int, 0)
|
||
for _, r := range routes {
|
||
presentRoutes[r.Dst.String()] = 1
|
||
}
|
||
|
||
for _, peer := range cfg.Peers {
|
||
for _, rt := range peer.AllowedIPs {
|
||
_, present := presentRoutes[rt.String()]
|
||
presentRoutes[rt.String()] = 2
|
||
if present {
|
||
log.Info("route present", "iface", link.Attrs().Name, "route", rt.String())
|
||
continue
|
||
}
|
||
if err := netlink.RouteAdd(&netlink.Route{
|
||
LinkIndex: link.Attrs().Index,
|
||
Dst: &rt,
|
||
}); err != nil {
|
||
log.Error(err, "cannot setup route", "iface", link.Attrs().Name, "route", rt.String())
|
||
return err
|
||
}
|
||
log.Info("route added", "iface", link.Attrs().Name, "route", rt.String())
|
||
}
|
||
}
|
||
|
||
// Clean extra routes
|
||
for rtStr, p := range presentRoutes {
|
||
_, rt, err := net.ParseCIDR(rtStr)
|
||
if err != nil {
|
||
log.Info("cannot parse route", "iface", link.Attrs().Name, "route", rtStr)
|
||
return err
|
||
}
|
||
if p < 2 {
|
||
log.Info("extra manual route found", "iface", link.Attrs().Name, "route", rt.String())
|
||
if err := netlink.RouteDel(&netlink.Route{
|
||
LinkIndex: link.Attrs().Index,
|
||
Dst: rt,
|
||
}); err != nil {
|
||
log.Error(err, "cannot setup route", "iface", link.Attrs().Name, "route", rt.String())
|
||
return err
|
||
}
|
||
log.Info("route deleted", "iface", link.Attrs().Name, "route", rt)
|
||
}
|
||
}
|
||
return nil
|
||
}
|