mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
Allows passing UAPI fd to service
This commit is contained in:
parent
88801529fd
commit
e1227d3af4
59
src/main.go
59
src/main.go
@ -9,7 +9,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EnvWGTunFD = "WG_TUN_FD"
|
ENV_WG_TUN_FD = "WG_TUN_FD"
|
||||||
|
ENV_WG_UAPI_FD = "WG_UAPI_FD"
|
||||||
)
|
)
|
||||||
|
|
||||||
func printUsage() {
|
func printUsage() {
|
||||||
@ -65,46 +66,69 @@ func main() {
|
|||||||
logLevel,
|
logLevel,
|
||||||
fmt.Sprintf("(%s) ", interfaceName),
|
fmt.Sprintf("(%s) ", interfaceName),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.Debug.Println("Debug log enabled")
|
logger.Debug.Println("Debug log enabled")
|
||||||
|
|
||||||
// open TUN device
|
// open TUN device (or use supplied fd)
|
||||||
|
|
||||||
tun, err := func() (TUNDevice, error) {
|
tun, err := func() (TUNDevice, error) {
|
||||||
tunFdStr := os.Getenv(EnvWGTunFD)
|
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||||
if tunFdStr == "" {
|
if tunFdStr == "" {
|
||||||
return CreateTUN(interfaceName)
|
return CreateTUN(interfaceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// construct tun device from supplied FD
|
// construct tun device from supplied fd
|
||||||
|
|
||||||
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
file := os.NewFile(uintptr(fd), "")
|
||||||
return CreateTUNFromFile(interfaceName, file)
|
return CreateTUNFromFile(interfaceName, file)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Println("Failed to create TUN device:", err)
|
logger.Error.Println("Failed to create TUN device:", err)
|
||||||
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// open UAPI file (or use supplied fd)
|
||||||
|
|
||||||
|
fileUAPI, err := func() (*os.File, error) {
|
||||||
|
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
|
||||||
|
if uapiFdStr == "" {
|
||||||
|
return UAPIOpen(interfaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// use supplied fd
|
||||||
|
|
||||||
|
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.NewFile(uintptr(fd), ""), nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error.Println("UAPI listen error:", err)
|
||||||
|
os.Exit(ExitSetupFailed)
|
||||||
|
return
|
||||||
|
}
|
||||||
// daemonize the process
|
// daemonize the process
|
||||||
|
|
||||||
if !foreground {
|
if !foreground {
|
||||||
env := os.Environ()
|
env := os.Environ()
|
||||||
_, ok := os.LookupEnv(EnvWGTunFD)
|
env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
|
||||||
if !ok {
|
env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
|
||||||
kvp := fmt.Sprintf("%s=3", EnvWGTunFD)
|
|
||||||
env = append(env, kvp)
|
|
||||||
}
|
|
||||||
attr := &os.ProcAttr{
|
attr := &os.ProcAttr{
|
||||||
Files: []*os.File{
|
Files: []*os.File{
|
||||||
nil, // stdin
|
nil, // stdin
|
||||||
nil, // stdout
|
nil, // stdout
|
||||||
nil, // stderr
|
nil, // stderr
|
||||||
tun.File(),
|
tun.File(),
|
||||||
|
fileUAPI,
|
||||||
},
|
},
|
||||||
Dir: ".",
|
Dir: ".",
|
||||||
Env: env,
|
Env: env,
|
||||||
@ -112,6 +136,7 @@ func main() {
|
|||||||
err = Daemonize(attr)
|
err = Daemonize(attr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Println("Failed to daemonize:", err)
|
logger.Error.Println("Failed to daemonize:", err)
|
||||||
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -123,20 +148,17 @@ func main() {
|
|||||||
// create wireguard device
|
// create wireguard device
|
||||||
|
|
||||||
device := NewDevice(tun, logger)
|
device := NewDevice(tun, logger)
|
||||||
|
|
||||||
logger.Info.Println("Device started")
|
logger.Info.Println("Device started")
|
||||||
|
|
||||||
// start configuration lister
|
// start uapi listener
|
||||||
|
|
||||||
uapi, err := NewUAPIListener(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error.Println("UAPI listen error:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
errs := make(chan error)
|
errs := make(chan error)
|
||||||
term := make(chan os.Signal)
|
term := make(chan os.Signal)
|
||||||
wait := device.WaitChannel()
|
wait := device.WaitChannel()
|
||||||
|
|
||||||
|
uapi, err := UAPIListen(interfaceName, fileUAPI)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
conn, err := uapi.Accept()
|
conn, err := uapi.Accept()
|
||||||
@ -161,9 +183,10 @@ func main() {
|
|||||||
case <-errs:
|
case <-errs:
|
||||||
}
|
}
|
||||||
|
|
||||||
// clean up UAPI bind
|
// clean up
|
||||||
|
|
||||||
uapi.Close()
|
uapi.Close()
|
||||||
|
device.Close()
|
||||||
|
|
||||||
logger.Info.Println("Shutting down")
|
logger.Info.Println("Shutting down")
|
||||||
}
|
}
|
||||||
|
@ -227,7 +227,7 @@ func (tun *NativeTun) MTU() (int, error) {
|
|||||||
|
|
||||||
val := binary.LittleEndian.Uint32(ifr[16:20])
|
val := binary.LittleEndian.Uint32(ifr[16:20])
|
||||||
if val >= (1 << 31) {
|
if val >= (1 << 31) {
|
||||||
return int(val-(1<<31)) - (1 << 31), nil
|
return int(toInt32(val)), nil
|
||||||
}
|
}
|
||||||
return int(val), nil
|
return int(val), nil
|
||||||
}
|
}
|
||||||
|
@ -10,12 +10,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ipcErrorIO = -int64(unix.EIO)
|
ipcErrorIO = -int64(unix.EIO)
|
||||||
ipcErrorProtocol = -int64(unix.EPROTO)
|
ipcErrorProtocol = -int64(unix.EPROTO)
|
||||||
ipcErrorInvalid = -int64(unix.EINVAL)
|
ipcErrorInvalid = -int64(unix.EINVAL)
|
||||||
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
|
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
|
||||||
socketDirectory = "/var/run/wireguard"
|
socketDirectory = "/var/run/wireguard"
|
||||||
socketName = "%s.sock"
|
socketName = "%s.sock"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UAPIListener struct {
|
type UAPIListener struct {
|
||||||
@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func connectUnixSocket(path string) (net.Listener, error) {
|
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||||
|
|
||||||
// attempt inital connection
|
// wrap file in listener
|
||||||
|
|
||||||
listener, err := net.Listen("unix", path)
|
listener, err := net.FileListener(file)
|
||||||
if err == nil {
|
|
||||||
return listener, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if active
|
|
||||||
|
|
||||||
_, err = net.Dial("unix", path)
|
|
||||||
if err == nil {
|
|
||||||
return nil, errors.New("Unix socket in use")
|
|
||||||
}
|
|
||||||
|
|
||||||
// attempt cleanup
|
|
||||||
|
|
||||||
err = os.Remove(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return net.Listen("unix", path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUAPIListener(name string) (net.Listener, error) {
|
|
||||||
|
|
||||||
// check if path exist
|
|
||||||
|
|
||||||
err := os.MkdirAll(socketDirectory, 077)
|
|
||||||
if err != nil && !os.IsExist(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open UNIX socket
|
|
||||||
|
|
||||||
socketPath := path.Join(
|
|
||||||
socketDirectory,
|
|
||||||
fmt.Sprintf(socketName, name),
|
|
||||||
)
|
|
||||||
|
|
||||||
listener, err := connectUnixSocket(socketPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
|
|||||||
|
|
||||||
// watch for deletion of socket
|
// watch for deletion of socket
|
||||||
|
|
||||||
|
socketPath := path.Join(
|
||||||
|
socketDirectory,
|
||||||
|
fmt.Sprintf(socketName, name),
|
||||||
|
)
|
||||||
|
|
||||||
uapi.inotifyFd, err = unix.InotifyInit()
|
uapi.inotifyFd, err = unix.InotifyInit()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
|
|||||||
go func(l *UAPIListener) {
|
go func(l *UAPIListener) {
|
||||||
var buff [4096]byte
|
var buff [4096]byte
|
||||||
for {
|
for {
|
||||||
unix.Read(uapi.inotifyFd, buff[:])
|
// start with lstat to avoid race condition
|
||||||
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
|
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
|
||||||
l.connErr <- err
|
l.connErr <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
unix.Read(uapi.inotifyFd, buff[:])
|
||||||
}
|
}
|
||||||
}(uapi)
|
}(uapi)
|
||||||
|
|
||||||
@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
|
|||||||
|
|
||||||
return uapi, nil
|
return uapi, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UAPIOpen(name string) (*os.File, error) {
|
||||||
|
|
||||||
|
// check if path exist
|
||||||
|
|
||||||
|
err := os.MkdirAll(socketDirectory, 0600)
|
||||||
|
if err != nil && !os.IsExist(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// open UNIX socket
|
||||||
|
|
||||||
|
socketPath := path.Join(
|
||||||
|
socketDirectory,
|
||||||
|
fmt.Sprintf(socketName, name),
|
||||||
|
)
|
||||||
|
|
||||||
|
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := func() (*net.UnixListener, error) {
|
||||||
|
|
||||||
|
// initial connection attempt
|
||||||
|
|
||||||
|
listener, err := net.ListenUnix("unix", addr)
|
||||||
|
if err == nil {
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if socket already active
|
||||||
|
|
||||||
|
_, err = net.Dial("unix", socketPath)
|
||||||
|
if err == nil {
|
||||||
|
return nil, errors.New("unix socket in use")
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup & attempt again
|
||||||
|
|
||||||
|
err = os.Remove(socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return net.ListenUnix("unix", addr)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return listener.File()
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user