1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

conn: fix StdNetBind fallback on Windows

If RIO is unavailable, NewWinRingBind() falls back to StdNetBind.
StdNetBind uses x/net/ipv{4,6}.PacketConn for sending and receiving
datagrams, specifically via the {Read,Write}Batch methods.
These methods are unimplemented on Windows and will return runtime
errors as a result. Additionally, only Linux benefits from these
x/net types for reading and writing, so we update StdNetBind to fall
back to the standard library net package for all platforms other than
Linux.

Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jordan Whited 2023-03-06 15:58:32 -08:00 committed by Jason A. Donenfeld
parent dbd949307e
commit 2fcdaf9799
2 changed files with 150 additions and 64 deletions

View File

@ -10,6 +10,7 @@ import (
"errors" "errors"
"net" "net"
"net/netip" "net/netip"
"runtime"
"strconv" "strconv"
"sync" "sync"
"syscall" "syscall"
@ -22,16 +23,21 @@ var (
_ Bind = (*StdNetBind)(nil) _ Bind = (*StdNetBind)(nil)
) )
// StdNetBind implements Bind for all platforms except Windows. // StdNetBind implements Bind for all platforms. While Windows has its own Bind
// (see bind_windows.go), it may fall back to StdNetBind.
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct { type StdNetBind struct {
mu sync.Mutex // protects following fields mu sync.Mutex // protects following fields
ipv4 *net.UDPConn ipv4 *net.UDPConn
ipv6 *net.UDPConn ipv6 *net.UDPConn
blackhole4 bool blackhole4 bool
blackhole6 bool blackhole6 bool
ipv4PC *ipv4.PacketConn ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn ipv6PC *ipv6.PacketConn // will be nil on non-Linux
udpAddrPool sync.Pool
udpAddrPool sync.Pool // following fields are not guarded by mu
ipv4MsgsPool sync.Pool ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool ipv6MsgsPool sync.Pool
} }
@ -154,6 +160,8 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
again: again:
port := int(uport) port := int(uport)
var v4conn, v6conn *net.UDPConn var v4conn, v6conn *net.UDPConn
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn
v4conn, port, err = listenNet("udp4", port) v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
@ -173,63 +181,92 @@ again:
} }
var fns []ReceiveFunc var fns []ReceiveFunc
if v4conn != nil { if v4conn != nil {
fns = append(fns, s.receiveIPv4) if runtime.GOOS == "linux" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
s.ipv4 = v4conn s.ipv4 = v4conn
} }
if v6conn != nil { if v6conn != nil {
fns = append(fns, s.receiveIPv6) if runtime.GOOS == "linux" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
s.ipv6 = v6conn s.ipv6 = v6conn
} }
if len(fns) == 0 { if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT return nil, 0, syscall.EAFNOSUPPORT
} }
s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
return fns, uint16(port), nil return fns, uint16(port), nil
} }
func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
defer s.ipv4MsgsPool.Put(msgs) msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
for i := range buffs { defer s.ipv4MsgsPool.Put(msgs)
(*msgs)[i].Buffers[0] = buffs[i] for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
} }
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
} }
func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
defer s.ipv6MsgsPool.Put(msgs) msgs := s.ipv4MsgsPool.Get().(*[]ipv6.Message)
for i := range buffs { defer s.ipv4MsgsPool.Put(msgs)
(*msgs)[i].Buffers[0] = buffs[i] for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
} }
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
} }
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
@ -246,10 +283,12 @@ func (s *StdNetBind) Close() error {
if s.ipv4 != nil { if s.ipv4 != nil {
err1 = s.ipv4.Close() err1 = s.ipv4.Close()
s.ipv4 = nil s.ipv4 = nil
s.ipv4PC = nil
} }
if s.ipv6 != nil { if s.ipv6 != nil {
err2 = s.ipv6.Close() err2 = s.ipv6.Close()
s.ipv6 = nil s.ipv6 = nil
s.ipv6PC = nil
} }
s.blackhole4 = false s.blackhole4 = false
s.blackhole6 = false s.blackhole6 = false
@ -263,11 +302,18 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
s.mu.Lock() s.mu.Lock()
blackhole := s.blackhole4 blackhole := s.blackhole4
conn := s.ipv4 conn := s.ipv4
var (
pc4 *ipv4.PacketConn
pc6 *ipv6.PacketConn
)
is6 := false is6 := false
if endpoint.DstIP().Is6() { if endpoint.DstIP().Is6() {
blackhole = s.blackhole6 blackhole = s.blackhole6
conn = s.ipv6 conn = s.ipv6
pc6 = s.ipv6PC
is6 = true is6 = true
} else {
pc4 = s.ipv4PC
} }
s.mu.Unlock() s.mu.Unlock()
@ -278,13 +324,13 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
if is6 { if is6 {
return s.send6(s.ipv6PC, endpoint, buffs) return s.send6(conn, pc6, endpoint, buffs)
} else { } else {
return s.send4(s.ipv4PC, endpoint, buffs) return s.send4(conn, pc4, endpoint, buffs)
} }
} }
func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error { func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr) ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4() as4 := ep.DstIP().As4()
copy(ua.IP, as4[:]) copy(ua.IP, as4[:])
@ -301,19 +347,28 @@ func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) e
err error err error
start int start int
) )
for { if runtime.GOOS == "linux" {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) for {
if err != nil || n == len((*msgs)[start:len(buffs)]) { n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
break if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
} else {
for i, buff := range buffs {
_, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
if err != nil {
break
}
} }
start += n
} }
s.udpAddrPool.Put(ua) s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs) s.ipv4MsgsPool.Put(msgs)
return err return err
} }
func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error { func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr) ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16() as16 := ep.DstIP().As16()
copy(ua.IP, as16[:]) copy(ua.IP, as16[:])
@ -330,12 +385,21 @@ func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) e
err error err error
start int start int
) )
for { if runtime.GOOS == "linux" {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) for {
if err != nil || n == len((*msgs)[start:len(buffs)]) { n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
break if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
} else {
for i, buff := range buffs {
_, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
if err != nil {
break
}
} }
start += n
} }
s.udpAddrPool.Put(ua) s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs) s.ipv6MsgsPool.Put(msgs)

22
conn/bind_std_test.go Normal file
View File

@ -0,0 +1,22 @@
package conn
import "testing"
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
bind := NewStdNetBind().(*StdNetBind)
fns, _, err := bind.Open(0)
if err != nil {
t.Fatal(err)
}
bind.Close()
buffs := make([][]byte, 1)
buffs[0] = make([]byte, 1)
sizes := make([]int, 1)
eps := make([]Endpoint, 1)
for _, fn := range fns {
// The ReceiveFuncs must not access conn-related fields on StdNetBind
// unguarded. Close() nils the conn-related fields resulting in a panic
// if they violate the mutex.
fn(buffs, sizes, eps)
}
}