// nolint: gochecknoglobals, gomnd, goerr113 package wgquick import ( "bytes" "encoding" "encoding/base64" "fmt" "net" "strconv" "strings" "text/template" "time" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) // Config represents full wg-quick like config structure. type Config struct { wgtypes.Config // Address list of IP (v4 or v6) addresses (optionally with CIDR masks) to be assigned // to the interface. May be specified multiple times. Address []net.IPNet // list of IP (v4 or v6) addresses to be set as the interface’s DNS servers. May be specified multiple times. // Upon bringing the interface up, this runs ‘resolvconf -a tun.INTERFACE -m 0 -x‘ and upon bringing it down, // this runs ‘resolvconf -d tun.INTERFACE‘. If these particular invocations of resolvconf(8) are undesirable, // the PostUp and PostDown keys below may be used instead. DNS []net.IP // MTU is automatically determined from the endpoint addresses or the system default route, // which is usually a sane choice. However, to manually specify an MTU to override this automatic discovery, // this value may be specified explicitly. MTU int // Table — Controls the routing table to which routes are added. Table int // PreUp, PostUp, PreDown, PostDown — script snippets which will be executed by bash(1) // before/after setting up/tearing down the interface, most commonly used to configure // custom DNS options or firewall rules. The special string ‘%i’ is expanded to INTERFACE. // Each one may be specified multiple times, in which case the commands are executed in order. PreUp string PostUp string PreDown string PostDown string // RouteProtocol to set on the route. See linux/rtnetlink.h Use value > 4 or default 0 RouteProtocol int // RouteMetric sets this metric on all managed routes. Lower number means pick this one RouteMetric int // Address label to set on the link AddressLabel string // SaveConfig — if set to ‘true’, the configuration is saved from the current state of the interface upon shutdown. // Currently unsupported SaveConfig bool } var ( _ encoding.TextMarshaler = (*Config)(nil) _ encoding.TextUnmarshaler = (*Config)(nil) ) func (cfg *Config) String() string { b, err := cfg.MarshalText() if err != nil { panic(err) } return string(b) } func serializeKey(key *wgtypes.Key) string { return base64.StdEncoding.EncodeToString(key[:]) } func toSeconds(duration time.Duration) int { return int(duration / time.Second) } var funcMap = template.FuncMap(map[string]interface{}{ "wgKey": serializeKey, "toSeconds": toSeconds, }) var cfgTemplate = template.Must( template. New("wg-cfg"). Funcs(funcMap). Parse(wgtypeTemplateSpec)) func (cfg *Config) MarshalText() (text []byte, err error) { buff := &bytes.Buffer{} if err := cfgTemplate.Execute(buff, cfg); err != nil { return nil, fmt.Errorf("%w", err) } return buff.Bytes(), nil } // nolint: lll const wgtypeTemplateSpec = `[Interface] {{- range .Address }} Address = {{ . }} {{- end }} {{- range .DNS }} DNS = {{ . }} {{- end }} PrivateKey = {{ .PrivateKey | wgKey }} {{- if .ListenPort }}{{ "\n" }}ListenPort = {{ .ListenPort }}{{ end }} {{- if .MTU }}{{ "\n" }}MTU = {{ .MTU }}{{ end }} {{- if .Table }}{{ "\n" }}Table = {{ .Table }}{{ end }} {{- if .PreUp }}{{ "\n" }}PreUp = {{ .PreUp }}{{ end }} {{- if .PostUp }}{{ "\n" }}PostUp = {{ .PostUp }}{{ end }} {{- if .PreDown }}{{ "\n" }}PreDown = {{ .PreDown }}{{ end }} {{- if .PostDown }}{{ "\n" }}PostDown = {{ .PostDown }}{{ end }} {{- if .SaveConfig }}{{ "\n" }}SaveConfig = {{ .SaveConfig }}{{ end }} {{- range .Peers }} {{- "\n" }} [Peer] PublicKey = {{ .PublicKey | wgKey }} AllowedIPs = {{ range $i, $el := .AllowedIPs }}{{if $i}}, {{ end }}{{ $el }}{{ end }} {{- if .PresharedKey }}{{ "\n" }}PresharedKey = {{ .PresharedKey }}{{ end }} {{- if .PersistentKeepaliveInterval }}{{ "\n" }}PersistentKeepalive = {{ .PersistentKeepaliveInterval | toSeconds }}{{ end }} {{- if .Endpoint }}{{ "\n" }}Endpoint = {{ .Endpoint }}{{ end }} {{- end }} ` // ParseKey parses the base64 encoded wireguard private key. func ParseKey(key string) (wgtypes.Key, error) { var pkey wgtypes.Key pkeySlice, err := base64.StdEncoding.DecodeString(key) if err != nil { return pkey, fmt.Errorf("%w", err) } copy(pkey[:], pkeySlice) return pkey, nil } type parseState int const ( unknown parseState = iota inter = iota peer = iota ) func (cfg *Config) UnmarshalText(text []byte) error { *cfg = Config{} // Zero out the config state := unknown var peerCfg *wgtypes.PeerConfig for no, line := range strings.Split(string(text), "\n") { ln := strings.TrimSpace(line) if len(ln) == 0 || ln[0] == '#' { continue } switch ln { case "[Interface]": state = inter case "[Peer]": state = peer cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{}) peerCfg = &cfg.Peers[len(cfg.Peers)-1] default: parts := strings.Split(ln, "=") if len(parts) < 2 { return fmt.Errorf("cannot parse line %d, missing =", no) } lhs := strings.TrimSpace(parts[0]) rhs := strings.TrimSpace(strings.Join(parts[1:], "=")) switch state { case inter: if err := parseInterfaceLine(cfg, lhs, rhs); err != nil { return fmt.Errorf("[line %d]: %w", no+1, err) } case peer: if err := parsePeerLine(peerCfg, lhs, rhs); err != nil { return fmt.Errorf("[line %d]: %w", no+1, err) } case unknown: return fmt.Errorf("[line %d] cannot parse, unknown state", no+1) default: return fmt.Errorf("switch couldnt find state") } } } return nil } // nolint: funlen func parseInterfaceLine(cfg *Config, lhs string, rhs string) error { switch lhs { case "Address": for _, addr := range strings.Split(rhs, ",") { ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr)) if err != nil { return fmt.Errorf("%w", err) } cfg.Address = append(cfg.Address, net.IPNet{IP: ip, Mask: cidr.Mask}) } case "DNS": for _, addr := range strings.Split(rhs, ",") { ip := net.ParseIP(strings.TrimSpace(addr)) if ip == nil { return fmt.Errorf("cannot parse IP") } cfg.DNS = append(cfg.DNS, ip) } case "MTU": mtu, err := strconv.ParseInt(rhs, 10, 64) if err != nil { return fmt.Errorf("%w", err) } cfg.MTU = int(mtu) case "Table": tbl, err := strconv.ParseInt(rhs, 10, 64) if err != nil { return fmt.Errorf("%w", err) } cfg.Table = int(tbl) case "ListenPort": portI64, err := strconv.ParseInt(rhs, 10, 64) if err != nil { return fmt.Errorf("%w", err) } port := int(portI64) cfg.ListenPort = &port case "PreUp": cfg.PreUp = rhs case "PostUp": cfg.PostUp = rhs case "PreDown": cfg.PreDown = rhs case "PostDown": cfg.PostDown = rhs case "SaveConfig": save, err := strconv.ParseBool(rhs) if err != nil { return fmt.Errorf("%w", err) } cfg.SaveConfig = save case "PrivateKey": key, err := ParseKey(rhs) if err != nil { return fmt.Errorf("cannot decode key %w", err) } cfg.PrivateKey = &key default: return fmt.Errorf("unknown directive %s", lhs) } return nil } func parsePeerLine(peerCfg *wgtypes.PeerConfig, lhs string, rhs string) error { switch lhs { case "PublicKey": key, err := ParseKey(rhs) if err != nil { return fmt.Errorf("cannot decode key %w", err) } peerCfg.PublicKey = key case "PresharedKey": key, err := ParseKey(rhs) if err != nil { return fmt.Errorf("cannot decode key %w", err) } if peerCfg.PresharedKey != nil { return fmt.Errorf("preshared key already defined %w", err) } peerCfg.PresharedKey = &key case "AllowedIPs": for _, addr := range strings.Split(rhs, ",") { ip, cidr, err := net.ParseCIDR(strings.TrimSpace(addr)) if err != nil { return fmt.Errorf("cannot parse %s: %w", addr, err) } peerCfg.AllowedIPs = append(peerCfg.AllowedIPs, net.IPNet{IP: ip, Mask: cidr.Mask}) } case "Endpoint": addr, err := net.ResolveUDPAddr("", rhs) if err != nil { return fmt.Errorf("%w", err) } peerCfg.Endpoint = addr case "PersistentKeepalive": t, err := strconv.ParseInt(rhs, 10, 64) if err != nil { return fmt.Errorf("%w", err) } dur := time.Duration(t * int64(time.Second)) peerCfg.PersistentKeepaliveInterval = &dur default: return fmt.Errorf("unknown directive %s", lhs) } return nil }