happy linting
This commit is contained in:
parent
d9764025be
commit
a5695600e2
2
Makefile
2
Makefile
@ -52,4 +52,4 @@ test:
|
|||||||
|
|
||||||
.PHONY: lint
|
.PHONY: lint
|
||||||
lint:
|
lint:
|
||||||
golangci-lint run --enable-all --disable gomnd --disable godox --timeout 5m
|
golangci-lint run --enable-all --disable gomnd --disable godox --disable exhaustivestruct --timeout 5m
|
||||||
|
95
config.go
95
config.go
@ -1,3 +1,4 @@
|
|||||||
|
// nolint: gochecknoglobals, gomnd, goerr113
|
||||||
package wgquick
|
package wgquick
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -14,23 +15,32 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config represents full wg-quick like config structure
|
// Config represents full wg-quick like config structure.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
wgtypes.Config
|
wgtypes.Config
|
||||||
|
|
||||||
// Address list of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned to the interface. May be specified multiple times.
|
// Address list of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned
|
||||||
|
// to the interface. May be specified multiple times.
|
||||||
Address []net.IPNet
|
Address []net.IPNet
|
||||||
|
|
||||||
// list of IP (v4 or v6) addresses to be set as the interface’s DNS servers. May be specified multiple times. Upon bringing the interface up, this runs ‘resolvconf -a tun.INTERFACE -m 0 -x‘ and upon bringing it down, this runs ‘resolvconf -d tun.INTERFACE‘. If these particular invocations of resolvconf(8) are undesirable, the PostUp and PostDown keys below may be used instead.
|
// list of IP (v4 or v6) addresses to be set as the interface’s DNS servers. May be specified multiple times.
|
||||||
|
// Upon bringing the interface up, this runs ‘resolvconf -a tun.INTERFACE -m 0 -x‘ and upon bringing it down,
|
||||||
|
// this runs ‘resolvconf -d tun.INTERFACE‘. If these particular invocations of resolvconf(8) are undesirable,
|
||||||
|
// the PostUp and PostDown keys below may be used instead.
|
||||||
DNS []net.IP
|
DNS []net.IP
|
||||||
|
|
||||||
// MTU is automatically determined from the endpoint addresses or the system default route, which is usually a sane choice. However, to manually specify an MTU to override this automatic discovery, this value may be specified explicitly.
|
// MTU is automatically determined from the endpoint addresses or the system default route,
|
||||||
|
// which is usually a sane choice. However, to manually specify an MTU to override this automatic discovery,
|
||||||
|
// this value may be specified explicitly.
|
||||||
MTU int
|
MTU int
|
||||||
|
|
||||||
// Table — Controls the routing table to which routes are added.
|
// Table — Controls the routing table to which routes are added.
|
||||||
Table int
|
Table int
|
||||||
|
|
||||||
// PreUp, PostUp, PreDown, PostDown — script snippets which will be executed by bash(1) before/after setting up/tearing down the interface, most commonly used to configure custom DNS options or firewall rules. The special string ‘%i’ is expanded to INTERFACE. Each one may be specified multiple times, in which case the commands are executed in order.
|
// PreUp, PostUp, PreDown, PostDown — script snippets which will be executed by bash(1)
|
||||||
|
// before/after setting up/tearing down the interface, most commonly used to configure
|
||||||
|
// custom DNS options or firewall rules. The special string ‘%i’ is expanded to INTERFACE.
|
||||||
|
// Each one may be specified multiple times, in which case the commands are executed in order.
|
||||||
PreUp string
|
PreUp string
|
||||||
PostUp string
|
PostUp string
|
||||||
PreDown string
|
PreDown string
|
||||||
@ -50,14 +60,17 @@ type Config struct {
|
|||||||
SaveConfig bool
|
SaveConfig bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ encoding.TextMarshaler = (*Config)(nil)
|
var (
|
||||||
var _ encoding.TextUnmarshaler = (*Config)(nil)
|
_ encoding.TextMarshaler = (*Config)(nil)
|
||||||
|
_ encoding.TextUnmarshaler = (*Config)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
func (cfg *Config) String() string {
|
func (cfg *Config) String() string {
|
||||||
b, err := cfg.MarshalText()
|
b, err := cfg.MarshalText()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -83,11 +96,13 @@ var cfgTemplate = template.Must(
|
|||||||
func (cfg *Config) MarshalText() (text []byte, err error) {
|
func (cfg *Config) MarshalText() (text []byte, err error) {
|
||||||
buff := &bytes.Buffer{}
|
buff := &bytes.Buffer{}
|
||||||
if err := cfgTemplate.Execute(buff, cfg); err != nil {
|
if err := cfgTemplate.Execute(buff, cfg); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return buff.Bytes(), nil
|
return buff.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint: lll
|
||||||
const wgtypeTemplateSpec = `[Interface]
|
const wgtypeTemplateSpec = `[Interface]
|
||||||
{{- range .Address }}
|
{{- range .Address }}
|
||||||
Address = {{ . }}
|
Address = {{ . }}
|
||||||
@ -115,14 +130,17 @@ AllowedIPs = {{ range $i, $el := .AllowedIPs }}{{if $i}}, {{ end }}{{ $el }}{{ e
|
|||||||
{{- end }}
|
{{- end }}
|
||||||
`
|
`
|
||||||
|
|
||||||
// ParseKey parses the base64 encoded wireguard private key
|
// ParseKey parses the base64 encoded wireguard private key.
|
||||||
func ParseKey(key string) (wgtypes.Key, error) {
|
func ParseKey(key string) (wgtypes.Key, error) {
|
||||||
var pkey wgtypes.Key
|
var pkey wgtypes.Key
|
||||||
|
|
||||||
pkeySlice, err := base64.StdEncoding.DecodeString(key)
|
pkeySlice, err := base64.StdEncoding.DecodeString(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pkey, err
|
return pkey, fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
copy(pkey[:], pkeySlice[:])
|
|
||||||
|
copy(pkey[:], pkeySlice)
|
||||||
|
|
||||||
return pkey, nil
|
return pkey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,17 +155,21 @@ const (
|
|||||||
func (cfg *Config) UnmarshalText(text []byte) error {
|
func (cfg *Config) UnmarshalText(text []byte) error {
|
||||||
*cfg = Config{} // Zero out the config
|
*cfg = Config{} // Zero out the config
|
||||||
state := unknown
|
state := unknown
|
||||||
|
|
||||||
var peerCfg *wgtypes.PeerConfig
|
var peerCfg *wgtypes.PeerConfig
|
||||||
|
|
||||||
for no, line := range strings.Split(string(text), "\n") {
|
for no, line := range strings.Split(string(text), "\n") {
|
||||||
ln := strings.TrimSpace(line)
|
ln := strings.TrimSpace(line)
|
||||||
if len(ln) == 0 || ln[0] == '#' {
|
if len(ln) == 0 || ln[0] == '#' {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
switch ln {
|
switch ln {
|
||||||
case "[Interface]":
|
case "[Interface]":
|
||||||
state = inter
|
state = inter
|
||||||
case "[Peer]":
|
case "[Peer]":
|
||||||
state = peer
|
state = peer
|
||||||
|
|
||||||
cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{})
|
cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{})
|
||||||
peerCfg = &cfg.Peers[len(cfg.Peers)-1]
|
peerCfg = &cfg.Peers[len(cfg.Peers)-1]
|
||||||
default:
|
default:
|
||||||
@ -155,33 +177,40 @@ func (cfg *Config) UnmarshalText(text []byte) error {
|
|||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return fmt.Errorf("cannot parse line %d, missing =", no)
|
return fmt.Errorf("cannot parse line %d, missing =", no)
|
||||||
}
|
}
|
||||||
|
|
||||||
lhs := strings.TrimSpace(parts[0])
|
lhs := strings.TrimSpace(parts[0])
|
||||||
rhs := strings.TrimSpace(strings.Join(parts[1:], "="))
|
rhs := strings.TrimSpace(strings.Join(parts[1:], "="))
|
||||||
|
|
||||||
switch state {
|
switch state {
|
||||||
case inter:
|
case inter:
|
||||||
if err := parseInterfaceLine(cfg, lhs, rhs); err != nil {
|
if err := parseInterfaceLine(cfg, lhs, rhs); err != nil {
|
||||||
return fmt.Errorf("[line %d]: %v", no+1, err)
|
return fmt.Errorf("[line %d]: %w", no+1, err)
|
||||||
}
|
}
|
||||||
case peer:
|
case peer:
|
||||||
if err := parsePeerLine(peerCfg, lhs, rhs); err != nil {
|
if err := parsePeerLine(peerCfg, lhs, rhs); err != nil {
|
||||||
return fmt.Errorf("[line %d]: %v", no+1, err)
|
return fmt.Errorf("[line %d]: %w", no+1, err)
|
||||||
}
|
}
|
||||||
default:
|
case unknown:
|
||||||
return fmt.Errorf("[line %d] cannot parse, unknown state", no+1)
|
return fmt.Errorf("[line %d] cannot parse, unknown state", no+1)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("switch couldnt find state")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint: funlen
|
||||||
func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
|
func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
|
||||||
switch lhs {
|
switch lhs {
|
||||||
case "Address":
|
case "Address":
|
||||||
for _, addr := range strings.Split(rhs, ",") {
|
for _, addr := range strings.Split(rhs, ",") {
|
||||||
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
|
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.Address = append(cfg.Address, net.IPNet{IP: ip, Mask: cidr.Mask})
|
cfg.Address = append(cfg.Address, net.IPNet{IP: ip, Mask: cidr.Mask})
|
||||||
}
|
}
|
||||||
case "DNS":
|
case "DNS":
|
||||||
@ -190,25 +219,29 @@ func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
|
|||||||
if ip == nil {
|
if ip == nil {
|
||||||
return fmt.Errorf("cannot parse IP")
|
return fmt.Errorf("cannot parse IP")
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.DNS = append(cfg.DNS, ip)
|
cfg.DNS = append(cfg.DNS, ip)
|
||||||
}
|
}
|
||||||
case "MTU":
|
case "MTU":
|
||||||
mtu, err := strconv.ParseInt(rhs, 10, 64)
|
mtu, err := strconv.ParseInt(rhs, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.MTU = int(mtu)
|
cfg.MTU = int(mtu)
|
||||||
case "Table":
|
case "Table":
|
||||||
tbl, err := strconv.ParseInt(rhs, 10, 64)
|
tbl, err := strconv.ParseInt(rhs, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.Table = int(tbl)
|
cfg.Table = int(tbl)
|
||||||
case "ListenPort":
|
case "ListenPort":
|
||||||
portI64, err := strconv.ParseInt(rhs, 10, 64)
|
portI64, err := strconv.ParseInt(rhs, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
port := int(portI64)
|
port := int(portI64)
|
||||||
cfg.ListenPort = &port
|
cfg.ListenPort = &port
|
||||||
case "PreUp":
|
case "PreUp":
|
||||||
@ -222,18 +255,21 @@ func parseInterfaceLine(cfg *Config, lhs string, rhs string) error {
|
|||||||
case "SaveConfig":
|
case "SaveConfig":
|
||||||
save, err := strconv.ParseBool(rhs)
|
save, err := strconv.ParseBool(rhs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.SaveConfig = save
|
cfg.SaveConfig = save
|
||||||
case "PrivateKey":
|
case "PrivateKey":
|
||||||
key, err := ParseKey(rhs)
|
key, err := ParseKey(rhs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot decode key %v", err)
|
return fmt.Errorf("cannot decode key %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.PrivateKey = &key
|
cfg.PrivateKey = &key
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown directive %s", lhs)
|
return fmt.Errorf("unknown directive %s", lhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -242,41 +278,48 @@ func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error {
|
|||||||
case "PublicKey":
|
case "PublicKey":
|
||||||
key, err := ParseKey(rhs)
|
key, err := ParseKey(rhs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot decode key %v", err)
|
return fmt.Errorf("cannot decode key %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerCfg.PublicKey = key
|
peerCfg.PublicKey = key
|
||||||
case "PresharedKey":
|
case "PresharedKey":
|
||||||
key, err := ParseKey(rhs)
|
key, err := ParseKey(rhs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot decode key %v", err)
|
return fmt.Errorf("cannot decode key %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if peerCfg.PresharedKey != nil {
|
if peerCfg.PresharedKey != nil {
|
||||||
return fmt.Errorf("preshared key already defined %v", err)
|
return fmt.Errorf("preshared key already defined %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerCfg.PresharedKey = &key
|
peerCfg.PresharedKey = &key
|
||||||
case "AllowedIPs":
|
case "AllowedIPs":
|
||||||
for _, addr := range strings.Split(rhs, ",") {
|
for _, addr := range strings.Split(rhs, ",") {
|
||||||
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
|
ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot parse %s: %v", addr, err)
|
return fmt.Errorf("cannot parse %s: %w", addr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask})
|
peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask})
|
||||||
}
|
}
|
||||||
case "Endpoint":
|
case "Endpoint":
|
||||||
addr, err := net.ResolveUDPAddr("", rhs)
|
addr, err := net.ResolveUDPAddr("", rhs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerCfg.Endpoint = addr
|
peerCfg.Endpoint = addr
|
||||||
case "PersistentKeepalive":
|
case "PersistentKeepalive":
|
||||||
t, err := strconv.ParseInt(rhs, 10, 64)
|
t, err := strconv.ParseInt(rhs, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dur := time.Duration(t * int64(time.Second))
|
dur := time.Duration(t * int64(time.Second))
|
||||||
peerCfg.PersistentKeepaliveInterval = &dur
|
peerCfg.PersistentKeepaliveInterval = &dur
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown directive %s", lhs)
|
return fmt.Errorf("unknown directive %s", lhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
// nolint: scopelint, gochecknoglobals, paralleltest, testpackage
|
||||||
package wgquick
|
package wgquick
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -54,6 +55,7 @@ PersistentKeepalive = 25
|
|||||||
|
|
||||||
func TestExampleConfig(t *testing.T) {
|
func TestExampleConfig(t *testing.T) {
|
||||||
c := &Config{}
|
c := &Config{}
|
||||||
|
|
||||||
for name, cfg := range testConfigs {
|
for name, cfg := range testConfigs {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
err := c.UnmarshalText([]byte(cfg))
|
err := c.UnmarshalText([]byte(cfg))
|
||||||
|
153
wg.go
153
wg.go
@ -2,6 +2,7 @@ package wgquick
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@ -31,20 +32,24 @@ func wgGo(iface string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cmd := wgo.Command(iface)
|
cmd := wgo.Command(iface)
|
||||||
cmd.Start()
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("could not start wireguard-go: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`
|
// Up sets and configures the wg interface. Mostly equivalent to `wg-quick up iface`.
|
||||||
func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
|
func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
|
||||||
log := logger.WithField("iface", iface)
|
log := logger.WithField("iface", iface)
|
||||||
|
|
||||||
_, err := netlink.LinkByName(iface)
|
_, err := netlink.LinkByName(iface)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return os.ErrExist
|
return os.ErrExist
|
||||||
}
|
}
|
||||||
if _, ok := err.(netlink.LinkNotFoundError); !ok {
|
|
||||||
return err
|
if errors.As(err, &netlink.LinkNotFoundError{}) {
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, dns := range cfg.DNS {
|
for _, dns := range cfg.DNS {
|
||||||
@ -57,8 +62,10 @@ func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
|
|||||||
if err := execSh(cfg.PreUp, iface, log); err != nil {
|
if err := execSh(cfg.PreUp, iface, log); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infoln("applied pre-up command")
|
log.Infoln("applied pre-up command")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := Sync(cfg, iface, logger); err != nil {
|
if err := Sync(cfg, iface, logger); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -67,17 +74,20 @@ func Up(cfg *Config, iface string, logger logrus.FieldLogger) error {
|
|||||||
if err := execSh(cfg.PostUp, iface, log); err != nil {
|
if err := execSh(cfg.PostUp, iface, log); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infoln("applied post-up command")
|
log.Infoln("applied post-up command")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Down destroys the wg interface. Mostly equivalent to `wg-quick down iface`
|
// Down destroys the wg interface. Mostly equivalent to `wg-quick down iface`.
|
||||||
func Down(cfg *Config, iface string, logger logrus.FieldLogger) error {
|
func Down(cfg *Config, iface string, logger logrus.FieldLogger) error {
|
||||||
log := logger.WithField("iface", iface)
|
log := logger.WithField("iface", iface)
|
||||||
|
|
||||||
link, err := netlink.LinkByName(iface)
|
link, err := netlink.LinkByName(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cfg.DNS) > 1 {
|
if len(cfg.DNS) > 1 {
|
||||||
@ -90,108 +100,138 @@ func Down(cfg *Config, iface string, logger logrus.FieldLogger) error {
|
|||||||
if err := execSh(cfg.PreDown, iface, log); err != nil {
|
if err := execSh(cfg.PreDown, iface, log); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infoln("applied pre-down command")
|
log.Infoln("applied pre-down command")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.LinkDel(link); err != nil {
|
if err := netlink.LinkDel(link); err != nil {
|
||||||
return err
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infoln("link deleted")
|
log.Infoln("link deleted")
|
||||||
|
|
||||||
if cfg.PostDown != "" {
|
if cfg.PostDown != "" {
|
||||||
if err := execSh(cfg.PostDown, iface, log); err != nil {
|
if err := execSh(cfg.PostDown, iface, log); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infoln("applied post-down command")
|
log.Infoln("applied post-down command")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func execSh(command string, iface string, log logrus.FieldLogger, stdin ...string) error {
|
func execSh(command string, iface string, log logrus.FieldLogger, stdin ...string) error {
|
||||||
cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface))
|
cmd := exec.Command("sh", "-ce", strings.ReplaceAll(command, "%i", iface)) // nolint: gosec
|
||||||
|
|
||||||
if len(stdin) > 0 {
|
if len(stdin) > 0 {
|
||||||
log = log.WithField("stdin", strings.Join(stdin, ""))
|
log = log.WithField("stdin", strings.Join(stdin, ""))
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
|
|
||||||
for _, ln := range stdin {
|
for _, ln := range stdin {
|
||||||
if _, err := fmt.Fprint(b, ln); err != nil {
|
if _, err := fmt.Fprint(b, ln); err != nil {
|
||||||
return err
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Stdin = b
|
cmd.Stdin = b
|
||||||
}
|
}
|
||||||
|
|
||||||
out, err := cmd.CombinedOutput()
|
out, err := cmd.CombinedOutput()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Errorf("failed to execute %s:\n%s", cmd.Args, out)
|
log.WithError(err).Errorf("failed to execute %s:\n%s", cmd.Args, out)
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("executed %s:\n%s", cmd.Args, out)
|
log.Infof("executed %s:\n%s", cmd.Args, out)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync the config to the current setup for given interface
|
// Sync the config to the current setup for given interface.
|
||||||
// It perform 4 operations:
|
// It perform 4 operations:
|
||||||
// * SyncLink --> makes sure link is up and type wireguard
|
// * SyncLink --> makes sure link is up and type wireguard.
|
||||||
// * SyncWireguardDevice --> configures allowedIP & other wireguard specific settings
|
// * SyncWireguardDevice --> configures allowedIP & other wireguard specific settings.
|
||||||
// * SyncAddress --> synces linux addresses bounded to this interface
|
// * SyncAddress --> synces linux addresses bounded to this interface.
|
||||||
// * SyncRoutes --> synces all allowedIP routes to route to this interface
|
// * SyncRoutes --> synces all allowedIP routes to route to this interface.
|
||||||
func Sync(cfg *Config, iface string, logger logrus.FieldLogger) error {
|
func Sync(cfg *Config, iface string, logger logrus.FieldLogger) error {
|
||||||
log := logger.WithField("iface", iface)
|
log := logger.WithField("iface", iface)
|
||||||
|
|
||||||
link, err := SyncLink(cfg, iface, log)
|
link, err := SyncLink(cfg, iface, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Errorln("cannot sync wireguard link")
|
log.WithError(err).Errorln("cannot sync wireguard link")
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("synced link")
|
log.Info("synced link")
|
||||||
|
|
||||||
if err := SyncWireguardDevice(cfg, link, log); err != nil {
|
if err := SyncWireguardDevice(cfg, link, log); err != nil {
|
||||||
log.WithError(err).Errorln("cannot sync wireguard link")
|
log.WithError(err).Errorln("cannot sync wireguard link")
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("synced link")
|
log.Info("synced link")
|
||||||
|
|
||||||
if err := SyncAddress(cfg, link, log); err != nil {
|
if err := SyncAddress(cfg, link, log); err != nil {
|
||||||
log.WithError(err).Errorln("cannot sync addresses")
|
log.WithError(err).Errorln("cannot sync addresses")
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("synced addresss")
|
log.Info("synced addresss")
|
||||||
|
|
||||||
var managedRoutes []net.IPNet
|
managedRoutes := make([]net.IPNet, 0)
|
||||||
|
|
||||||
for _, peer := range cfg.Peers {
|
for _, peer := range cfg.Peers {
|
||||||
for _, rt := range peer.AllowedIPs {
|
managedRoutes = append(managedRoutes, peer.AllowedIPs...)
|
||||||
managedRoutes = append(managedRoutes, rt)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := SyncRoutes(cfg, link, managedRoutes, log); err != nil {
|
if err := SyncRoutes(cfg, link, managedRoutes, log); err != nil {
|
||||||
log.WithError(err).Errorln("cannot sync routes")
|
log.WithError(err).Errorln("cannot sync routes")
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("synced routed")
|
log.Info("synced routed")
|
||||||
log.Info("Successfully synced device")
|
log.Info("Successfully synced device")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncWireguardDevice synces wireguard vpn setting on the given link. It does not set routes/addresses beyond wg internal crypto-key routing, only handles wireguard specific settings
|
// SyncWireguardDevice synces wireguard vpn setting on the given link.
|
||||||
|
// It does not set routes/addresses beyond wg internal crypto-key routing, only handles wireguard specific settings.
|
||||||
func SyncWireguardDevice(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
func SyncWireguardDevice(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
||||||
cl, err := wgctrl.New()
|
cl, err := wgctrl.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Errorln("cannot setup wireguard device")
|
log.WithError(err).Errorln("cannot setup wireguard device")
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cl.ConfigureDevice(link.Attrs().Name, cfg.Config); err != nil {
|
if err := cl.ConfigureDevice(link.Attrs().Name, cfg.Config); err != nil {
|
||||||
log.WithError(err).Error("cannot configure device")
|
log.WithError(err).Error("cannot configure device")
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncLink synces link state with the config. It does not sync Wireguard settings, just makes sure the device is up and type wireguard
|
// SyncLink synces link state with the config.
|
||||||
|
// It does not sync Wireguard settings, just makes sure the device is up and type wireguard.
|
||||||
func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link, error) {
|
func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link, error) {
|
||||||
link, err := netlink.LinkByName(iface)
|
link, err := netlink.LinkByName(iface)
|
||||||
|
// nolint: nestif
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(netlink.LinkNotFoundError); !ok {
|
if errors.As(err, &netlink.LinkNotFoundError{}) {
|
||||||
log.WithError(err).Error("cannot read link")
|
log.WithError(err).Error("cannot read link")
|
||||||
return nil, err
|
|
||||||
|
return nil, fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("link not found, creating")
|
log.Info("link not found, creating")
|
||||||
|
|
||||||
wgLink := &netlink.GenericLink{
|
wgLink := &netlink.GenericLink{
|
||||||
@ -204,9 +244,11 @@ func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link,
|
|||||||
if err := netlink.LinkAdd(wgLink); err != nil {
|
if err := netlink.LinkAdd(wgLink); err != nil {
|
||||||
log.WithError(err).Errorf("cannot create link: %s", err.Error())
|
log.WithError(err).Errorf("cannot create link: %s", err.Error())
|
||||||
log.Info("trying to use embedded wireguard-go...")
|
log.Info("trying to use embedded wireguard-go...")
|
||||||
|
|
||||||
if err := wgGo(iface); err != nil {
|
if err := wgGo(iface); err != nil {
|
||||||
log.WithError(err).Errorf("cannot create link through wireguard-go: %s", err.Error())
|
log.WithError(err).Errorf("cannot create link through wireguard-go: %s", err.Error())
|
||||||
return nil, fmt.Errorf("cannot create link")
|
|
||||||
|
return nil, fmt.Errorf("cannot create link: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,32 +258,41 @@ func SyncLink(cfg *Config, iface string, log logrus.FieldLogger) (netlink.Link,
|
|||||||
link, err = netlink.LinkByName(iface)
|
link, err = netlink.LinkByName(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("cannot read link")
|
log.WithError(err).Error("cannot read link")
|
||||||
return nil, err
|
|
||||||
|
return nil, fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.LinkSetUp(link); err != nil {
|
if err := netlink.LinkSetUp(link); err != nil {
|
||||||
log.WithError(err).Error("cannot set link up")
|
log.WithError(err).Error("cannot set link up")
|
||||||
return nil, err
|
|
||||||
|
return nil, fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("set device up")
|
log.Info("set device up")
|
||||||
|
|
||||||
return link, nil
|
return link, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncAddress adds/deletes all lind assigned IPV4 addressed as specified in the config
|
// SyncAddress adds/deletes all lind assigned IPV4 addressed as specified in the config
|
||||||
|
// nolint: funlen, gosec, scopelint
|
||||||
func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
||||||
addrs, err := netlink.AddrList(link, syscall.AF_INET)
|
addrs, err := netlink.AddrList(link, syscall.AF_INET)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err, "cannot read link address")
|
log.Error(err, "cannot read link address")
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// nil addr means I've used it
|
// nil addr means I've used it
|
||||||
presentAddresses := make(map[string]netlink.Addr, 0)
|
presentAddresses := make(map[string]netlink.Addr)
|
||||||
|
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
log.WithFields(map[string]interface{}{
|
log.WithFields(map[string]interface{}{
|
||||||
"addr": fmt.Sprint(addr.IPNet),
|
"addr": fmt.Sprint(addr.IPNet),
|
||||||
"label": addr.Label,
|
"label": addr.Label,
|
||||||
}).Debugf("found existing address: %v", addr)
|
}).Debugf("found existing address: %v", addr)
|
||||||
|
|
||||||
presentAddresses[addr.IPNet.String()] = addr
|
presentAddresses[addr.IPNet.String()] = addr
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -249,19 +300,24 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
|||||||
log := log.WithField("addr", addr.String())
|
log := log.WithField("addr", addr.String())
|
||||||
_, present := presentAddresses[addr.String()]
|
_, present := presentAddresses[addr.String()]
|
||||||
presentAddresses[addr.String()] = netlink.Addr{} // mark as present
|
presentAddresses[addr.String()] = netlink.Addr{} // mark as present
|
||||||
|
|
||||||
if present {
|
if present {
|
||||||
log.Info("address present")
|
log.Info("address present")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.AddrAdd(link, &netlink.Addr{
|
if err := netlink.AddrAdd(link, &netlink.Addr{
|
||||||
IPNet: &addr,
|
IPNet: &addr,
|
||||||
Label: cfg.AddressLabel,
|
Label: cfg.AddressLabel,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
if err != syscall.EEXIST {
|
if errors.Is(err, syscall.EEXIST) {
|
||||||
log.WithError(err).Error("cannot add addr")
|
log.WithError(err).Error("cannot add addr")
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("address added")
|
log.Info("address added")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -269,16 +325,21 @@ func SyncAddress(cfg *Config, link netlink.Link, log logrus.FieldLogger) error {
|
|||||||
if addr.IPNet == nil {
|
if addr.IPNet == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log := log.WithFields(map[string]interface{}{
|
log := log.WithFields(map[string]interface{}{
|
||||||
"addr": addr.IPNet.String(),
|
"addr": addr.IPNet.String(),
|
||||||
"label": addr.Label,
|
"label": addr.Label,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := netlink.AddrDel(link, &addr); err != nil {
|
if err := netlink.AddrDel(link, &addr); err != nil {
|
||||||
log.WithError(err).Error("cannot delete addr")
|
log.WithError(err).Error("cannot delete addr")
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("addr deleted")
|
log.Info("addr deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -298,13 +359,17 @@ func fillRouteDefaults(rt *netlink.Route) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config
|
// SyncRoutes adds/deletes all route assigned IPV4 addressed as specified in the config
|
||||||
|
// nolint: funlen, gosec, scopelint
|
||||||
func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error {
|
func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log logrus.FieldLogger) error {
|
||||||
wantedRoutes := make(map[string][]netlink.Route, len(managedRoutes))
|
wantedRoutes := make(map[string][]netlink.Route, len(managedRoutes))
|
||||||
|
|
||||||
presentRoutes, err := netlink.RouteList(link, syscall.AF_INET)
|
presentRoutes, err := netlink.RouteList(link, syscall.AF_INET)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err, "cannot read existing routes")
|
log.Error(err, "cannot read existing routes")
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rt := range managedRoutes {
|
for _, rt := range managedRoutes {
|
||||||
rt := rt // make copy
|
rt := rt // make copy
|
||||||
log.WithField("dst", rt.String()).Debug("managing route")
|
log.WithField("dst", rt.String()).Debug("managing route")
|
||||||
@ -330,10 +395,13 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l
|
|||||||
"type": rt.Type,
|
"type": rt.Type,
|
||||||
"metric": rt.Priority,
|
"metric": rt.Priority,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := netlink.RouteReplace(&rt); err != nil {
|
if err := netlink.RouteReplace(&rt); err != nil {
|
||||||
log.WithError(err).Errorln("cannot add/replace route")
|
log.WithError(err).Errorln("cannot add/replace route")
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infoln("route added/replaced")
|
log.Infoln("route added/replaced")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -344,6 +412,7 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -355,25 +424,31 @@ func SyncRoutes(cfg *Config, link netlink.Link, managedRoutes []net.IPNet, log l
|
|||||||
"type": rt.Type,
|
"type": rt.Type,
|
||||||
"metric": rt.Priority,
|
"metric": rt.Priority,
|
||||||
})
|
})
|
||||||
|
|
||||||
if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == unix.RT_CLASS_MAIN)) {
|
if !(rt.Table == cfg.Table || (cfg.Table == 0 && rt.Table == unix.RT_CLASS_MAIN)) {
|
||||||
log.Debug("wrong table for route, skipping")
|
log.Debug("wrong table for route, skipping")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(rt.Protocol == cfg.RouteProtocol) {
|
if !(rt.Protocol == cfg.RouteProtocol) {
|
||||||
log.Infof("skipping route deletion, not owned by this daemon")
|
log.Infof("skipping route deletion, not owned by this daemon")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if checkWanted(rt) {
|
if checkWanted(rt) {
|
||||||
log.Debug("route wanted, skipping deleting")
|
log.Debug("route wanted, skipping deleting")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.RouteDel(&rt); err != nil {
|
if err := netlink.RouteDel(&rt); err != nil {
|
||||||
log.WithError(err).Error("cannot delete route")
|
log.WithError(err).Error("cannot delete route")
|
||||||
return err
|
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("route deleted")
|
log.Info("route deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user