diff --git a/conn/conn_linux.go b/conn/conn_linux.go index ef98100..ef5c0ba 100644 --- a/conn/conn_linux.go +++ b/conn/conn_linux.go @@ -18,10 +18,6 @@ import ( "golang.org/x/sys/unix" ) -const ( - FD_ERR = -1 -) - type IPv4Source struct { Src [4]byte Ifindex int32 @@ -63,6 +59,7 @@ type nativeBind struct { sock4 int sock6 int lastMark uint32 + closing sync.RWMutex } var _ Endpoint = (*NativeEndpoint)(nil) @@ -129,7 +126,7 @@ func createBind(port uint16) (Bind, uint16, error) { port = newPort } - if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR { + if bind.sock4 == -1 && bind.sock6 == -1 { return nil, 0, errors.New("ipv4 and ipv6 not supported") } @@ -141,6 +138,9 @@ func (bind *nativeBind) LastMark() uint32 { } func (bind *nativeBind) SetMark(value uint32) error { + bind.closing.RLock() + defer bind.closing.RUnlock() + if bind.sock6 != -1 { err := unix.SetsockoptInt( bind.sock6, @@ -171,20 +171,26 @@ func (bind *nativeBind) SetMark(value uint32) error { return nil } -func closeUnblock(fd int) error { - // shutdown to unblock readers and writers - unix.Shutdown(fd, unix.SHUT_RDWR) - return unix.Close(fd) -} - func (bind *nativeBind) Close() error { var err1, err2 error + bind.closing.RLock() if bind.sock6 != -1 { - err1 = closeUnblock(bind.sock6) + unix.Shutdown(bind.sock6, unix.SHUT_RDWR) } if bind.sock4 != -1 { - err2 = closeUnblock(bind.sock4) + unix.Shutdown(bind.sock4, unix.SHUT_RDWR) } + bind.closing.RUnlock() + bind.closing.Lock() + if bind.sock6 != -1 { + err1 = unix.Close(bind.sock6) + bind.sock6 = -1 + } + if bind.sock4 != -1 { + err2 = unix.Close(bind.sock4) + bind.sock4 = -1 + } + bind.closing.Unlock() if err1 != nil { return err1 @@ -193,6 +199,9 @@ func (bind *nativeBind) Close() error { } func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + bind.closing.RLock() + defer bind.closing.RUnlock() + var end NativeEndpoint if bind.sock6 == -1 { return 0, nil, syscall.EAFNOSUPPORT @@ -206,6 +215,9 @@ func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { } func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + bind.closing.RLock() + defer bind.closing.RUnlock() + var end NativeEndpoint if bind.sock4 == -1 { return 0, nil, syscall.EAFNOSUPPORT @@ -219,6 +231,9 @@ func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { } func (bind *nativeBind) Send(buff []byte, end Endpoint) error { + bind.closing.RLock() + defer bind.closing.RUnlock() + nend := end.(*NativeEndpoint) if !nend.isV6 { if bind.sock4 == -1 { @@ -316,7 +331,7 @@ func create4(port uint16) (int, uint16, error) { ) if err != nil { - return FD_ERR, 0, err + return -1, 0, err } addr := unix.SockaddrInet4{ @@ -338,7 +353,7 @@ func create4(port uint16) (int, uint16, error) { return unix.Bind(fd, &addr) }(); err != nil { unix.Close(fd) - return FD_ERR, 0, err + return -1, 0, err } sa, err := unix.Getsockname(fd) @@ -360,7 +375,7 @@ func create6(port uint16) (int, uint16, error) { ) if err != nil { - return FD_ERR, 0, err + return -1, 0, err } // set sockopts and bind @@ -392,7 +407,7 @@ func create6(port uint16) (int, uint16, error) { }(); err != nil { unix.Close(fd) - return FD_ERR, 0, err + return -1, 0, err } sa, err := unix.Getsockname(fd)