1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Ensure go routines can exit

This commit is contained in:
Jason A. Donenfeld 2018-05-14 02:14:33 +02:00
parent 29b0453cf1
commit 659106bd6d
2 changed files with 52 additions and 49 deletions

View File

@ -293,7 +293,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
// prepare signals
device.signals.stop = make(chan struct{}, 1)
device.signals.stop = make(chan struct{}, 0)
// prepare net

View File

@ -38,6 +38,8 @@ type NativeTun struct {
events chan TUNEvent // device related events
nopi bool // the device was pased IFF_NO_PI
rwcancel *rwcancel.RWCancel
netlinkSock int
shutdownHackListener chan struct{}
}
func (tun *NativeTun) File() *os.File {
@ -45,10 +47,6 @@ func (tun *NativeTun) File() *os.File {
}
func (tun *NativeTun) RoutineHackListener() {
// TODO: This function never actually exits in response to anything,
// a go routine that goes forever. We'll want to fix that if this is
// to ever be used as any sort of library.
/* This is needed for the detection to work across network namespaces
* If you are reading this and know a better method, please get in touch.
*/
@ -61,47 +59,38 @@ func (tun *NativeTun) RoutineHackListener() {
case unix.EIO:
tun.events <- TUNEventDown
default:
}
time.Sleep(time.Second / 10)
}
}
func toRTMGRP(sc uint) uint {
return 1 << (sc - 1)
}
func (tun *NativeTun) RoutineNetlinkListener() {
groups := toRTMGRP(unix.RTNLGRP_LINK)
groups |= toRTMGRP(unix.RTNLGRP_IPV4_IFADDR)
groups |= toRTMGRP(unix.RTNLGRP_IPV6_IFADDR)
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil {
tun.errors <- errors.New("Failed to create netlink event listener socket")
return
}
defer unix.Close(sock)
select {
case <-time.After(time.Second / 10):
case <-tun.shutdownHackListener:
return
}
}
}
func createNetlinkSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32(groups),
Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))),
}
err = unix.Bind(sock, saddr)
if err != nil {
tun.errors <- errors.New("Failed to bind netlink event listener socket")
return
return -1, err
}
return sock, nil
}
// TODO: This function never actually exits in response to anything,
// a go routine that goes forever. We'll want to fix that if this is
// to ever be used as any sort of library. See what we've done with
// calling shutdown() on the netlink socket in conn_linux.go, and
// change this to be more like that.
func (tun *NativeTun) RoutineNetlinkListener() {
for msg := make([]byte, 1<<16); ; {
msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0)
msgn, _, _, _, err := unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
if err != nil {
tun.errors <- fmt.Errorf("Failed to receive netlink message: %s", err.Error())
tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error())
return
}
@ -339,13 +328,16 @@ func (tun *NativeTun) Events() chan TUNEvent {
}
func (tun *NativeTun) Close() error {
err := tun.fd.Close()
if err != nil {
return err
}
err1 := tun.fd.Close()
err2 := closeUnblock(tun.netlinkSock)
tun.rwcancel.Cancel()
close(tun.events)
return nil
close(tun.shutdownHackListener)
if err1 != nil {
return err1
}
return err2
}
func CreateTUN(name string) (TUNDevice, error) {
@ -375,7 +367,7 @@ func CreateTUN(name string) (TUNDevice, error) {
var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
nameBytes := []byte(name)
if len(nameBytes) >= unix.IFNAMSIZ {
return nil, errors.New("Interface name too long")
return nil, errors.New("interface name too long")
}
copy(ifr[:], nameBytes)
binary.LittleEndian.PutUint16(ifr[16:], flags)
@ -398,6 +390,7 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
fd: fd,
events: make(chan TUNEvent, 5),
errors: make(chan error, 5),
shutdownHackListener: make(chan struct{}, 0),
nopi: false,
}
var err error
@ -419,10 +412,20 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
return nil, err
}
// set default MTU
err = device.setMTU(DefaultMTU)
if err != nil {
return nil, err
}
device.netlinkSock, err = createNetlinkSocket()
if err != nil {
return nil, err
}
go device.RoutineNetlinkListener()
go device.RoutineHackListener() // cross namespace
// set default MTU
return device, device.setMTU(DefaultMTU)
return device, nil
}