1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

conn: linux: do not allow ReceiveIPvX to race with Close

If Close is called after ReceiveIPvX, then ReceiveIPvX will block on an
invalid or potentially reused fd.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-01-07 17:00:21 +01:00
parent 29b0477585
commit 3b3de758ec

View File

@ -18,10 +18,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
const (
FD_ERR = -1
)
type IPv4Source struct { type IPv4Source struct {
Src [4]byte Src [4]byte
Ifindex int32 Ifindex int32
@ -63,6 +59,7 @@ type nativeBind struct {
sock4 int sock4 int
sock6 int sock6 int
lastMark uint32 lastMark uint32
closing sync.RWMutex
} }
var _ Endpoint = (*NativeEndpoint)(nil) var _ Endpoint = (*NativeEndpoint)(nil)
@ -129,7 +126,7 @@ func createBind(port uint16) (Bind, uint16, error) {
port = newPort port = newPort
} }
if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR { if bind.sock4 == -1 && bind.sock6 == -1 {
return nil, 0, errors.New("ipv4 and ipv6 not supported") return nil, 0, errors.New("ipv4 and ipv6 not supported")
} }
@ -141,6 +138,9 @@ func (bind *nativeBind) LastMark() uint32 {
} }
func (bind *nativeBind) SetMark(value uint32) error { func (bind *nativeBind) SetMark(value uint32) error {
bind.closing.RLock()
defer bind.closing.RUnlock()
if bind.sock6 != -1 { if bind.sock6 != -1 {
err := unix.SetsockoptInt( err := unix.SetsockoptInt(
bind.sock6, bind.sock6,
@ -171,20 +171,26 @@ func (bind *nativeBind) SetMark(value uint32) error {
return nil return nil
} }
func closeUnblock(fd int) error {
// shutdown to unblock readers and writers
unix.Shutdown(fd, unix.SHUT_RDWR)
return unix.Close(fd)
}
func (bind *nativeBind) Close() error { func (bind *nativeBind) Close() error {
var err1, err2 error var err1, err2 error
bind.closing.RLock()
if bind.sock6 != -1 { if bind.sock6 != -1 {
err1 = closeUnblock(bind.sock6) unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
} }
if bind.sock4 != -1 { if bind.sock4 != -1 {
err2 = closeUnblock(bind.sock4) unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
} }
bind.closing.RUnlock()
bind.closing.Lock()
if bind.sock6 != -1 {
err1 = unix.Close(bind.sock6)
bind.sock6 = -1
}
if bind.sock4 != -1 {
err2 = unix.Close(bind.sock4)
bind.sock4 = -1
}
bind.closing.Unlock()
if err1 != nil { if err1 != nil {
return err1 return err1
@ -193,6 +199,9 @@ func (bind *nativeBind) Close() error {
} }
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
bind.closing.RLock()
defer bind.closing.RUnlock()
var end NativeEndpoint var end NativeEndpoint
if bind.sock6 == -1 { if bind.sock6 == -1 {
return 0, nil, syscall.EAFNOSUPPORT return 0, nil, syscall.EAFNOSUPPORT
@ -206,6 +215,9 @@ func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
} }
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
bind.closing.RLock()
defer bind.closing.RUnlock()
var end NativeEndpoint var end NativeEndpoint
if bind.sock4 == -1 { if bind.sock4 == -1 {
return 0, nil, syscall.EAFNOSUPPORT return 0, nil, syscall.EAFNOSUPPORT
@ -219,6 +231,9 @@ func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
} }
func (bind *nativeBind) Send(buff []byte, end Endpoint) error { func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
bind.closing.RLock()
defer bind.closing.RUnlock()
nend := end.(*NativeEndpoint) nend := end.(*NativeEndpoint)
if !nend.isV6 { if !nend.isV6 {
if bind.sock4 == -1 { if bind.sock4 == -1 {
@ -316,7 +331,7 @@ func create4(port uint16) (int, uint16, error) {
) )
if err != nil { if err != nil {
return FD_ERR, 0, err return -1, 0, err
} }
addr := unix.SockaddrInet4{ addr := unix.SockaddrInet4{
@ -338,7 +353,7 @@ func create4(port uint16) (int, uint16, error) {
return unix.Bind(fd, &addr) return unix.Bind(fd, &addr)
}(); err != nil { }(); err != nil {
unix.Close(fd) unix.Close(fd)
return FD_ERR, 0, err return -1, 0, err
} }
sa, err := unix.Getsockname(fd) sa, err := unix.Getsockname(fd)
@ -360,7 +375,7 @@ func create6(port uint16) (int, uint16, error) {
) )
if err != nil { if err != nil {
return FD_ERR, 0, err return -1, 0, err
} }
// set sockopts and bind // set sockopts and bind
@ -392,7 +407,7 @@ func create6(port uint16) (int, uint16, error) {
}(); err != nil { }(); err != nil {
unix.Close(fd) unix.Close(fd)
return FD_ERR, 0, err return -1, 0, err
} }
sa, err := unix.Getsockname(fd) sa, err := unix.Getsockname(fd)