diff --git a/cmd/root.go b/cmd/root.go index 7fdecab..fe1d216 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,6 +5,7 @@ import ( "fmt" "io/ioutil" "os" + "strings" "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -18,6 +19,7 @@ var ( ) var ( + cfgDir string iface string verbose bool protocol int @@ -112,7 +114,7 @@ func loadConfig(cfg string) (*wgquick.Config, logrus.FieldLogger) { log = logrus.WithField("iface", iface) } - cfg = "/etc/wireguard/" + cfg + ".conf" + cfg = fmt.Sprintf("%s/%s.conf", cfgDir, cfg) _, err = os.Stat(cfg) if err != nil { @@ -140,6 +142,9 @@ func loadConfig(cfg string) (*wgquick.Config, logrus.FieldLogger) { } func init() { + cobra.OnInitialize(initConfig) + + rootCmd.PersistentFlags().StringVar(&cfgDir, "config-dir", "", "config directory (default is /etc/wireguard)") rootCmd.PersistentFlags().StringVarP(&iface, "iface", "i", "", "if interface name should differ from config") rootCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "verbose") rootCmd.PersistentFlags().IntVarP(&protocol, "route-protocol", "p", 0, "route protocol to use for our routes") @@ -157,6 +162,14 @@ func init() { rootCmd.AddCommand(showCmd) } +func initConfig() { + if cfgDir != "" { + cfgDir = "/etc/wireguard" + } + + cfgDir = strings.TrimSuffix(cfgDir, "/") +} + func Execute() { if err := rootCmd.Execute(); err != nil { logrus.Fatal(err)