1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

uapi: linux: put sock files in netns-specific subdir

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2020-05-01 23:39:43 -06:00
parent eb897a7dd8
commit 9e1d4865cb

View File

@ -13,6 +13,7 @@ import (
"path" "path"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/rwcancel"
) )
@ -23,7 +24,8 @@ const (
IpcErrorProtocol = -int64(unix.EPROTO) IpcErrorProtocol = -int64(unix.EPROTO)
IpcErrorInvalid = -int64(unix.EINVAL) IpcErrorInvalid = -int64(unix.EINVAL)
IpcErrorPortInUse = -int64(unix.EADDRINUSE) IpcErrorPortInUse = -int64(unix.EADDRINUSE)
socketName = "%s.sock" socketNameFmt = "%s.sock"
netnsFmt = "netns-%d"
) )
type UAPIListener struct { type UAPIListener struct {
@ -63,6 +65,29 @@ func (l *UAPIListener) Addr() net.Addr {
return l.listener.Addr() return l.listener.Addr()
} }
func currentNetns() (netns uint32, err error) {
link, err := os.Readlink("/proc/self/ns/net")
if err != nil {
return
}
_, err = fmt.Sscanf(link, "net:[%d]", &netns)
return
}
func ifaceSocketPath(iface string) string {
if netns, err := currentNetns(); err == nil {
return path.Join(
socketDirectory,
fmt.Sprintf(netnsFmt, netns),
fmt.Sprintf(socketNameFmt, iface),
)
}
return path.Join(
socketDirectory,
fmt.Sprintf(socketNameFmt, iface),
)
}
func UAPIListen(name string, file *os.File) (net.Listener, error) { func UAPIListen(name string, file *os.File) (net.Listener, error) {
// wrap file in listener // wrap file in listener
@ -82,12 +107,9 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
connErr: make(chan error, 1), connErr: make(chan error, 1),
} }
// watch for deletion of socket socketPath := ifaceSocketPath(name)
socketPath := path.Join( // watch for deletion of socket
socketDirectory,
fmt.Sprintf(socketName, name),
)
uapi.inotifyFd, err = unix.InotifyInit() uapi.inotifyFd, err = unix.InotifyInit()
if err != nil { if err != nil {
@ -145,21 +167,17 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
} }
func UAPIOpen(name string) (*os.File, error) { func UAPIOpen(name string) (*os.File, error) {
socketPath := ifaceSocketPath(name)
// check if path exist // check if path exist
err := os.MkdirAll(socketDirectory, 0755) err := os.MkdirAll(path.Dir(socketPath), 0755)
if err != nil && !os.IsExist(err) { if err != nil && !os.IsExist(err) {
return nil, err return nil, err
} }
// open UNIX socket // open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
addr, err := net.ResolveUnixAddr("unix", socketPath) addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil { if err != nil {
return nil, err return nil, err