1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Ensure go routines can exit

This commit is contained in:
Jason A. Donenfeld 2018-05-14 02:14:33 +02:00
parent 29b0453cf1
commit 659106bd6d
2 changed files with 52 additions and 49 deletions

View File

@ -293,7 +293,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
// prepare signals // prepare signals
device.signals.stop = make(chan struct{}, 1) device.signals.stop = make(chan struct{}, 0)
// prepare net // prepare net

View File

@ -31,13 +31,15 @@ const (
) )
type NativeTun struct { type NativeTun struct {
fd *os.File fd *os.File
index int32 // if index index int32 // if index
name string // name of interface name string // name of interface
errors chan error // async error handling errors chan error // async error handling
events chan TUNEvent // device related events events chan TUNEvent // device related events
nopi bool // the device was pased IFF_NO_PI nopi bool // the device was pased IFF_NO_PI
rwcancel *rwcancel.RWCancel rwcancel *rwcancel.RWCancel
netlinkSock int
shutdownHackListener chan struct{}
} }
func (tun *NativeTun) File() *os.File { func (tun *NativeTun) File() *os.File {
@ -45,10 +47,6 @@ func (tun *NativeTun) File() *os.File {
} }
func (tun *NativeTun) RoutineHackListener() { func (tun *NativeTun) RoutineHackListener() {
// TODO: This function never actually exits in response to anything,
// a go routine that goes forever. We'll want to fix that if this is
// to ever be used as any sort of library.
/* This is needed for the detection to work across network namespaces /* This is needed for the detection to work across network namespaces
* If you are reading this and know a better method, please get in touch. * If you are reading this and know a better method, please get in touch.
*/ */
@ -61,47 +59,38 @@ func (tun *NativeTun) RoutineHackListener() {
case unix.EIO: case unix.EIO:
tun.events <- TUNEventDown tun.events <- TUNEventDown
default: default:
return
}
select {
case <-time.After(time.Second / 10):
case <-tun.shutdownHackListener:
return
} }
time.Sleep(time.Second / 10)
} }
} }
func toRTMGRP(sc uint) uint { func createNetlinkSocket() (int, error) {
return 1 << (sc - 1)
}
func (tun *NativeTun) RoutineNetlinkListener() {
groups := toRTMGRP(unix.RTNLGRP_LINK)
groups |= toRTMGRP(unix.RTNLGRP_IPV4_IFADDR)
groups |= toRTMGRP(unix.RTNLGRP_IPV6_IFADDR)
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil { if err != nil {
tun.errors <- errors.New("Failed to create netlink event listener socket") return -1, err
return
} }
defer unix.Close(sock)
saddr := &unix.SockaddrNetlink{ saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK, Family: unix.AF_NETLINK,
Groups: uint32(groups), Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))),
} }
err = unix.Bind(sock, saddr) err = unix.Bind(sock, saddr)
if err != nil { if err != nil {
tun.errors <- errors.New("Failed to bind netlink event listener socket") return -1, err
return
} }
return sock, nil
}
// TODO: This function never actually exits in response to anything, func (tun *NativeTun) RoutineNetlinkListener() {
// a go routine that goes forever. We'll want to fix that if this is
// to ever be used as any sort of library. See what we've done with
// calling shutdown() on the netlink socket in conn_linux.go, and
// change this to be more like that.
for msg := make([]byte, 1<<16); ; { for msg := make([]byte, 1<<16); ; {
msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0) msgn, _, _, _, err := unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
if err != nil { if err != nil {
tun.errors <- fmt.Errorf("Failed to receive netlink message: %s", err.Error()) tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error())
return return
} }
@ -339,13 +328,16 @@ func (tun *NativeTun) Events() chan TUNEvent {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
err := tun.fd.Close() err1 := tun.fd.Close()
if err != nil { err2 := closeUnblock(tun.netlinkSock)
return err
}
tun.rwcancel.Cancel() tun.rwcancel.Cancel()
close(tun.events) close(tun.events)
return nil close(tun.shutdownHackListener)
if err1 != nil {
return err1
}
return err2
} }
func CreateTUN(name string) (TUNDevice, error) { func CreateTUN(name string) (TUNDevice, error) {
@ -375,7 +367,7 @@ func CreateTUN(name string) (TUNDevice, error) {
var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
nameBytes := []byte(name) nameBytes := []byte(name)
if len(nameBytes) >= unix.IFNAMSIZ { if len(nameBytes) >= unix.IFNAMSIZ {
return nil, errors.New("Interface name too long") return nil, errors.New("interface name too long")
} }
copy(ifr[:], nameBytes) copy(ifr[:], nameBytes)
binary.LittleEndian.PutUint16(ifr[16:], flags) binary.LittleEndian.PutUint16(ifr[16:], flags)
@ -395,10 +387,11 @@ func CreateTUN(name string) (TUNDevice, error) {
func CreateTUNFromFile(fd *os.File) (TUNDevice, error) { func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
device := &NativeTun{ device := &NativeTun{
fd: fd, fd: fd,
events: make(chan TUNEvent, 5), events: make(chan TUNEvent, 5),
errors: make(chan error, 5), errors: make(chan error, 5),
nopi: false, shutdownHackListener: make(chan struct{}, 0),
nopi: false,
} }
var err error var err error
@ -419,10 +412,20 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
return nil, err return nil, err
} }
// set default MTU
err = device.setMTU(DefaultMTU)
if err != nil {
return nil, err
}
device.netlinkSock, err = createNetlinkSocket()
if err != nil {
return nil, err
}
go device.RoutineNetlinkListener() go device.RoutineNetlinkListener()
go device.RoutineHackListener() // cross namespace go device.RoutineHackListener() // cross namespace
// set default MTU return device, nil
return device, device.setMTU(DefaultMTU)
} }