diff --git a/conn_linux.go b/conn_linux.go index 8d076ac..e30631f 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -15,6 +15,7 @@ package main import ( + "./rwcancel" "errors" "golang.org/x/sys/unix" "net" @@ -55,10 +56,11 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { } type NativeBind struct { - sock4 int - sock6 int - netlinkSock int - lastMark uint32 + sock4 int + sock6 int + netlinkSock int + netlinkCancel *rwcancel.RWCancel + lastMark uint32 } var _ Endpoint = (*NativeEndpoint)(nil) @@ -125,18 +127,23 @@ func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) { if err != nil { return nil, 0, err } + bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock) + if err != nil { + unix.Close(bind.netlinkSock) + return nil, 0, err + } go bind.routineRouteListener(device) bind.sock6, port, err = create6(port) if err != nil { - unix.Close(bind.netlinkSock) + bind.netlinkCancel.Cancel() return nil, port, err } bind.sock4, port, err = create4(port) if err != nil { - unix.Close(bind.netlinkSock) + bind.netlinkCancel.Cancel() unix.Close(bind.sock6) } return &bind, port, err @@ -178,7 +185,8 @@ func closeUnblock(fd int) error { func (bind *NativeBind) Close() error { err1 := closeUnblock(bind.sock6) err2 := closeUnblock(bind.sock4) - err3 := closeUnblock(bind.netlinkSock) + err3 := bind.netlinkCancel.Cancel() + if err1 != nil { return err1 } @@ -539,8 +547,20 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { func (bind *NativeBind) routineRouteListener(device *Device) { var reqPeer map[uint32]*Peer + defer unix.Close(bind.netlinkSock) + for msg := make([]byte, 1<<16); ; { - msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) + var err error + var msgn int + for { + msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) + if err == nil || !rwcancel.ErrorIsEAGAIN(err) { + break + } + if !bind.netlinkCancel.ReadyRead() { + return + } + } if err != nil { return } diff --git a/main.go b/main.go index c9ef343..6e876df 100644 --- a/main.go +++ b/main.go @@ -221,14 +221,10 @@ func main() { return } - // create wireguard device - device := NewDevice(tun, logger) logger.Info.Println("Device started") - // start uapi listener - errs := make(chan error) term := make(chan os.Signal) diff --git a/tun_darwin.go b/tun_darwin.go index ac8bffd..8f9a5d5 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -122,11 +122,13 @@ func CreateTUNFromFile(file *os.File) (TUNDevice, error) { _, err := tun.Name() if err != nil { + tun.fd.Close() return nil, err } tun.rwcancel, err = rwcancel.NewRWCancel(int(file.Fd())) if err != nil { + tun.fd.Close() return nil, err } diff --git a/tun_linux.go b/tun_linux.go index 8e42d44..32bd95d 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -31,14 +31,16 @@ const ( ) type NativeTun struct { - fd *os.File - index int32 // if index - name string // name of interface - errors chan error // async error handling - events chan TUNEvent // device related events - nopi bool // the device was pased IFF_NO_PI - rwcancel *rwcancel.RWCancel - netlinkSock int + fd *os.File + fdCancel *rwcancel.RWCancel + index int32 // if index + name string // name of interface + errors chan error // async error handling + events chan TUNEvent // device related events + nopi bool // the device was pased IFF_NO_PI + netlinkSock int + netlinkCancel *rwcancel.RWCancel + statusListenersShutdown chan struct{} } @@ -86,9 +88,22 @@ func createNetlinkSocket() (int, error) { } func (tun *NativeTun) RoutineNetlinkListener() { + defer unix.Close(tun.netlinkSock) + for msg := make([]byte, 1<<16); ; { - msgn, _, _, _, err := unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0) + var err error + var msgn int + for { + msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0) + if err == nil || !rwcancel.ErrorIsEAGAIN(err) { + break + } + if !tun.netlinkCancel.ReadyRead() { + tun.errors <- fmt.Errorf("netlink socket closed: %s", err.Error()) + return + } + } if err != nil { tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error()) return @@ -323,7 +338,7 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { if err == nil || !rwcancel.ErrorIsEAGAIN(err) { return n, err } - if !tun.rwcancel.ReadyRead() { + if !tun.fdCancel.ReadyRead() { return 0, errors.New("tun device closed") } } @@ -334,10 +349,13 @@ func (tun *NativeTun) Events() chan TUNEvent { } func (tun *NativeTun) Close() error { + var err1 error close(tun.statusListenersShutdown) - err1 := closeUnblock(tun.netlinkSock) + if tun.netlinkCancel != nil { + err1 = tun.netlinkCancel.Cancel() + } err2 := tun.fd.Close() - err3 := tun.rwcancel.Cancel() + err3 := tun.fdCancel.Cancel() close(tun.events) if err1 != nil { @@ -404,13 +422,15 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) { } var err error - tun.rwcancel, err = rwcancel.NewRWCancel(int(fd.Fd())) + tun.fdCancel, err = rwcancel.NewRWCancel(int(fd.Fd())) if err != nil { + tun.fd.Close() return nil, err } _, err = tun.Name() if err != nil { + tun.fd.Close() return nil, err } @@ -423,6 +443,12 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) { tun.netlinkSock, err = createNetlinkSocket() if err != nil { + tun.fd.Close() + return nil, err + } + tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock) + if err != nil { + tun.fd.Close() return nil, err }