mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
conn: fix StdNetBind fallback on Windows
If RIO is unavailable, NewWinRingBind() falls back to StdNetBind. StdNetBind uses x/net/ipv{4,6}.PacketConn for sending and receiving datagrams, specifically via the {Read,Write}Batch methods. These methods are unimplemented on Windows and will return runtime errors as a result. Additionally, only Linux benefits from these x/net types for reading and writing, so we update StdNetBind to fall back to the standard library net package for all platforms other than Linux. Reviewed-by: James Tucker <james@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
dbd949307e
commit
2fcdaf9799
192
conn/bind_std.go
192
conn/bind_std.go
@ -10,6 +10,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
@ -22,16 +23,21 @@ var (
|
|||||||
_ Bind = (*StdNetBind)(nil)
|
_ Bind = (*StdNetBind)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
// StdNetBind implements Bind for all platforms except Windows.
|
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
|
||||||
|
// (see bind_windows.go), it may fall back to StdNetBind.
|
||||||
|
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
|
||||||
|
// methods for sending and receiving multiple datagrams per-syscall. See the
|
||||||
|
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
||||||
type StdNetBind struct {
|
type StdNetBind struct {
|
||||||
mu sync.Mutex // protects following fields
|
mu sync.Mutex // protects following fields
|
||||||
ipv4 *net.UDPConn
|
ipv4 *net.UDPConn
|
||||||
ipv6 *net.UDPConn
|
ipv6 *net.UDPConn
|
||||||
blackhole4 bool
|
blackhole4 bool
|
||||||
blackhole6 bool
|
blackhole6 bool
|
||||||
ipv4PC *ipv4.PacketConn
|
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
||||||
ipv6PC *ipv6.PacketConn
|
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
||||||
udpAddrPool sync.Pool
|
|
||||||
|
udpAddrPool sync.Pool // following fields are not guarded by mu
|
||||||
ipv4MsgsPool sync.Pool
|
ipv4MsgsPool sync.Pool
|
||||||
ipv6MsgsPool sync.Pool
|
ipv6MsgsPool sync.Pool
|
||||||
}
|
}
|
||||||
@ -154,6 +160,8 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
|||||||
again:
|
again:
|
||||||
port := int(uport)
|
port := int(uport)
|
||||||
var v4conn, v6conn *net.UDPConn
|
var v4conn, v6conn *net.UDPConn
|
||||||
|
var v4pc *ipv4.PacketConn
|
||||||
|
var v6pc *ipv6.PacketConn
|
||||||
|
|
||||||
v4conn, port, err = listenNet("udp4", port)
|
v4conn, port, err = listenNet("udp4", port)
|
||||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
@ -173,63 +181,92 @@ again:
|
|||||||
}
|
}
|
||||||
var fns []ReceiveFunc
|
var fns []ReceiveFunc
|
||||||
if v4conn != nil {
|
if v4conn != nil {
|
||||||
fns = append(fns, s.receiveIPv4)
|
if runtime.GOOS == "linux" {
|
||||||
|
v4pc = ipv4.NewPacketConn(v4conn)
|
||||||
|
s.ipv4PC = v4pc
|
||||||
|
}
|
||||||
|
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
|
||||||
s.ipv4 = v4conn
|
s.ipv4 = v4conn
|
||||||
}
|
}
|
||||||
if v6conn != nil {
|
if v6conn != nil {
|
||||||
fns = append(fns, s.receiveIPv6)
|
if runtime.GOOS == "linux" {
|
||||||
|
v6pc = ipv6.NewPacketConn(v6conn)
|
||||||
|
s.ipv6PC = v6pc
|
||||||
|
}
|
||||||
|
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
|
||||||
s.ipv6 = v6conn
|
s.ipv6 = v6conn
|
||||||
}
|
}
|
||||||
if len(fns) == 0 {
|
if len(fns) == 0 {
|
||||||
return nil, 0, syscall.EAFNOSUPPORT
|
return nil, 0, syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
|
|
||||||
s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
|
|
||||||
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
|
|
||||||
|
|
||||||
return fns, uint16(port), nil
|
return fns, uint16(port), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
||||||
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||||
defer s.ipv4MsgsPool.Put(msgs)
|
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||||
for i := range buffs {
|
defer s.ipv4MsgsPool.Put(msgs)
|
||||||
(*msgs)[i].Buffers[0] = buffs[i]
|
for i := range buffs {
|
||||||
|
(*msgs)[i].Buffers[0] = buffs[i]
|
||||||
|
}
|
||||||
|
var numMsgs int
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
msg := &(*msgs)[0]
|
||||||
|
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < numMsgs; i++ {
|
||||||
|
msg := &(*msgs)[i]
|
||||||
|
sizes[i] = msg.N
|
||||||
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
ep := asEndpoint(addrPort)
|
||||||
|
getSrcFromControl(msg.OOB, ep)
|
||||||
|
eps[i] = ep
|
||||||
|
}
|
||||||
|
return numMsgs, nil
|
||||||
}
|
}
|
||||||
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for i := 0; i < numMsgs; i++ {
|
|
||||||
msg := &(*msgs)[i]
|
|
||||||
sizes[i] = msg.N
|
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
||||||
ep := asEndpoint(addrPort)
|
|
||||||
getSrcFromControl(msg.OOB, ep)
|
|
||||||
eps[i] = ep
|
|
||||||
}
|
|
||||||
return numMsgs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
||||||
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||||
defer s.ipv6MsgsPool.Put(msgs)
|
msgs := s.ipv4MsgsPool.Get().(*[]ipv6.Message)
|
||||||
for i := range buffs {
|
defer s.ipv4MsgsPool.Put(msgs)
|
||||||
(*msgs)[i].Buffers[0] = buffs[i]
|
for i := range buffs {
|
||||||
|
(*msgs)[i].Buffers[0] = buffs[i]
|
||||||
|
}
|
||||||
|
var numMsgs int
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
msg := &(*msgs)[0]
|
||||||
|
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < numMsgs; i++ {
|
||||||
|
msg := &(*msgs)[i]
|
||||||
|
sizes[i] = msg.N
|
||||||
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
ep := asEndpoint(addrPort)
|
||||||
|
getSrcFromControl(msg.OOB, ep)
|
||||||
|
eps[i] = ep
|
||||||
|
}
|
||||||
|
return numMsgs, nil
|
||||||
}
|
}
|
||||||
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for i := 0; i < numMsgs; i++ {
|
|
||||||
msg := &(*msgs)[i]
|
|
||||||
sizes[i] = msg.N
|
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
||||||
ep := asEndpoint(addrPort)
|
|
||||||
getSrcFromControl(msg.OOB, ep)
|
|
||||||
eps[i] = ep
|
|
||||||
}
|
|
||||||
return numMsgs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||||
@ -246,10 +283,12 @@ func (s *StdNetBind) Close() error {
|
|||||||
if s.ipv4 != nil {
|
if s.ipv4 != nil {
|
||||||
err1 = s.ipv4.Close()
|
err1 = s.ipv4.Close()
|
||||||
s.ipv4 = nil
|
s.ipv4 = nil
|
||||||
|
s.ipv4PC = nil
|
||||||
}
|
}
|
||||||
if s.ipv6 != nil {
|
if s.ipv6 != nil {
|
||||||
err2 = s.ipv6.Close()
|
err2 = s.ipv6.Close()
|
||||||
s.ipv6 = nil
|
s.ipv6 = nil
|
||||||
|
s.ipv6PC = nil
|
||||||
}
|
}
|
||||||
s.blackhole4 = false
|
s.blackhole4 = false
|
||||||
s.blackhole6 = false
|
s.blackhole6 = false
|
||||||
@ -263,11 +302,18 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
blackhole := s.blackhole4
|
blackhole := s.blackhole4
|
||||||
conn := s.ipv4
|
conn := s.ipv4
|
||||||
|
var (
|
||||||
|
pc4 *ipv4.PacketConn
|
||||||
|
pc6 *ipv6.PacketConn
|
||||||
|
)
|
||||||
is6 := false
|
is6 := false
|
||||||
if endpoint.DstIP().Is6() {
|
if endpoint.DstIP().Is6() {
|
||||||
blackhole = s.blackhole6
|
blackhole = s.blackhole6
|
||||||
conn = s.ipv6
|
conn = s.ipv6
|
||||||
|
pc6 = s.ipv6PC
|
||||||
is6 = true
|
is6 = true
|
||||||
|
} else {
|
||||||
|
pc4 = s.ipv4PC
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
@ -278,13 +324,13 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
|
|||||||
return syscall.EAFNOSUPPORT
|
return syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
if is6 {
|
if is6 {
|
||||||
return s.send6(s.ipv6PC, endpoint, buffs)
|
return s.send6(conn, pc6, endpoint, buffs)
|
||||||
} else {
|
} else {
|
||||||
return s.send4(s.ipv4PC, endpoint, buffs)
|
return s.send4(conn, pc4, endpoint, buffs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
|
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
|
||||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||||
as4 := ep.DstIP().As4()
|
as4 := ep.DstIP().As4()
|
||||||
copy(ua.IP, as4[:])
|
copy(ua.IP, as4[:])
|
||||||
@ -301,19 +347,28 @@ func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) e
|
|||||||
err error
|
err error
|
||||||
start int
|
start int
|
||||||
)
|
)
|
||||||
for {
|
if runtime.GOOS == "linux" {
|
||||||
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
|
for {
|
||||||
if err != nil || n == len((*msgs)[start:len(buffs)]) {
|
n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
|
||||||
break
|
if err != nil || n == len((*msgs)[start:len(buffs)]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
start += n
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i, buff := range buffs {
|
||||||
|
_, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
start += n
|
|
||||||
}
|
}
|
||||||
s.udpAddrPool.Put(ua)
|
s.udpAddrPool.Put(ua)
|
||||||
s.ipv4MsgsPool.Put(msgs)
|
s.ipv4MsgsPool.Put(msgs)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
|
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
|
||||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||||
as16 := ep.DstIP().As16()
|
as16 := ep.DstIP().As16()
|
||||||
copy(ua.IP, as16[:])
|
copy(ua.IP, as16[:])
|
||||||
@ -330,12 +385,21 @@ func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) e
|
|||||||
err error
|
err error
|
||||||
start int
|
start int
|
||||||
)
|
)
|
||||||
for {
|
if runtime.GOOS == "linux" {
|
||||||
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
|
for {
|
||||||
if err != nil || n == len((*msgs)[start:len(buffs)]) {
|
n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
|
||||||
break
|
if err != nil || n == len((*msgs)[start:len(buffs)]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
start += n
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i, buff := range buffs {
|
||||||
|
_, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
start += n
|
|
||||||
}
|
}
|
||||||
s.udpAddrPool.Put(ua)
|
s.udpAddrPool.Put(ua)
|
||||||
s.ipv6MsgsPool.Put(msgs)
|
s.ipv6MsgsPool.Put(msgs)
|
||||||
|
22
conn/bind_std_test.go
Normal file
22
conn/bind_std_test.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package conn
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
|
||||||
|
bind := NewStdNetBind().(*StdNetBind)
|
||||||
|
fns, _, err := bind.Open(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
bind.Close()
|
||||||
|
buffs := make([][]byte, 1)
|
||||||
|
buffs[0] = make([]byte, 1)
|
||||||
|
sizes := make([]int, 1)
|
||||||
|
eps := make([]Endpoint, 1)
|
||||||
|
for _, fn := range fns {
|
||||||
|
// The ReceiveFuncs must not access conn-related fields on StdNetBind
|
||||||
|
// unguarded. Close() nils the conn-related fields resulting in a panic
|
||||||
|
// if they violate the mutex.
|
||||||
|
fn(buffs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user