/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ package ipc import ( "errors" "fmt" "net" "os" "path" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/rwcancel" ) var socketDirectory = "/var/run/wireguard" const ( IpcErrorIO = -int64(unix.EIO) IpcErrorProtocol = -int64(unix.EPROTO) IpcErrorInvalid = -int64(unix.EINVAL) IpcErrorPortInUse = -int64(unix.EADDRINUSE) socketNameFmt = "%s.sock" netnsFmt = "netns-%d" ) type UAPIListener struct { listener net.Listener // unix socket listener connNew chan net.Conn connErr chan error inotifyFd int inotifyRWCancel *rwcancel.RWCancel } func (l *UAPIListener) Accept() (net.Conn, error) { for { select { case conn := <-l.connNew: return conn, nil case err := <-l.connErr: return nil, err } } } func (l *UAPIListener) Close() error { err1 := unix.Close(l.inotifyFd) err2 := l.inotifyRWCancel.Cancel() err3 := l.listener.Close() if err1 != nil { return err1 } if err2 != nil { return err2 } return err3 } func (l *UAPIListener) Addr() net.Addr { return l.listener.Addr() } func currentNetns() (netns uint32, err error) { link, err := os.Readlink("/proc/self/ns/net") if err != nil { return } _, err = fmt.Sscanf(link, "net:[%d]", &netns) return } func ifaceSocketPath(iface string) string { if netns, err := currentNetns(); err == nil { return path.Join( socketDirectory, fmt.Sprintf(netnsFmt, netns), fmt.Sprintf(socketNameFmt, iface), ) } return path.Join( socketDirectory, fmt.Sprintf(socketNameFmt, iface), ) } func UAPIListen(name string, file *os.File) (net.Listener, error) { // wrap file in listener listener, err := net.FileListener(file) if err != nil { return nil, err } if unixListener, ok := listener.(*net.UnixListener); ok { unixListener.SetUnlinkOnClose(true) } uapi := &UAPIListener{ listener: listener, connNew: make(chan net.Conn, 1), connErr: make(chan error, 1), } socketPath := ifaceSocketPath(name) // watch for deletion of socket uapi.inotifyFd, err = unix.InotifyInit() if err != nil { return nil, err } _, err = unix.InotifyAddWatch( uapi.inotifyFd, socketPath, unix.IN_ATTRIB| unix.IN_DELETE| unix.IN_DELETE_SELF, ) if err != nil { return nil, err } uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd) if err != nil { unix.Close(uapi.inotifyFd) return nil, err } go func(l *UAPIListener) { var buff [0]byte for { // start with lstat to avoid race condition if _, err := os.Lstat(socketPath); os.IsNotExist(err) { l.connErr <- err return } _, err := uapi.inotifyRWCancel.Read(buff[:]) if err != nil { l.connErr <- err return } } }(uapi) // watch for new connections go func(l *UAPIListener) { for { conn, err := l.listener.Accept() if err != nil { l.connErr <- err break } l.connNew <- conn } }(uapi) return uapi, nil } func UAPIOpen(name string) (*os.File, error) { socketPath := ifaceSocketPath(name) // check if path exist err := os.MkdirAll(path.Dir(socketPath), 0755) if err != nil && !os.IsExist(err) { return nil, err } // open UNIX socket addr, err := net.ResolveUnixAddr("unix", socketPath) if err != nil { return nil, err } oldUmask := unix.Umask(0077) listener, err := func() (*net.UnixListener, error) { // initial connection attempt listener, err := net.ListenUnix("unix", addr) if err == nil { return listener, nil } // check if socket already active _, err = net.Dial("unix", socketPath) if err == nil { return nil, errors.New("unix socket in use") } // cleanup & attempt again err = os.Remove(socketPath) if err != nil { return nil, err } return net.ListenUnix("unix", addr) }() unix.Umask(oldUmask) if err != nil { return nil, err } return listener.File() }