1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00
wireguard-go/conn/bind_windows.go
Jordan Whited 3bb8fec7e4 conn, device, tun: implement vectorized I/O plumbing
Accept packet vectors for reading and writing in the tun.Device and
conn.Bind interfaces, so that the internal plumbing between these
interfaces now passes a vector of packets. Vectors move untouched
between these interfaces, i.e. if 128 packets are received from
conn.Bind.Read(), 128 packets are passed to tun.Device.Write(). There is
no internal buffering.

Currently, existing implementations are only adjusted to have vectors
of length one. Subsequent patches will improve that.

Also, as a related fixup, use the unix and windows packages rather than
the syscall package when possible.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:13 +01:00

600 lines
15 KiB
Go

/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"encoding/binary"
"io"
"net"
"net/netip"
"strconv"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn/winrio"
)
const (
packetsPerRing = 1024
bytesPerPacket = 2048 - 32
receiveSpins = 15
)
type ringPacket struct {
addr WinRingEndpoint
data [bytesPerPacket]byte
}
type ringBuffer struct {
packets uintptr
head, tail uint32
id winrio.BufferId
iocp windows.Handle
isFull bool
cq winrio.Cq
mu sync.Mutex
overlapped windows.Overlapped
}
func (rb *ringBuffer) Push() *ringPacket {
for rb.isFull {
panic("ring is full")
}
ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
rb.tail += 1
if rb.tail%packetsPerRing == rb.head%packetsPerRing {
rb.isFull = true
}
return ret
}
func (rb *ringBuffer) Return(count uint32) {
if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
return
}
rb.head += count
rb.isFull = false
}
type afWinRingBind struct {
sock windows.Handle
rx, tx ringBuffer
rq winrio.Rq
mu sync.Mutex
blackhole bool
}
// WinRingBind uses Windows registered I/O for fast ring buffered networking.
type WinRingBind struct {
v4, v6 afWinRingBind
mu sync.RWMutex
isOpen atomic.Uint32 // 0, 1, or 2
}
func NewDefaultBind() Bind { return NewWinRingBind() }
func NewWinRingBind() Bind {
if !winrio.Initialize() {
return NewStdNetBind()
}
return new(WinRingBind)
}
type WinRingEndpoint struct {
family uint16
data [30]byte
}
var (
_ Bind = (*WinRingBind)(nil)
_ Endpoint = (*WinRingEndpoint)(nil)
)
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
host, port, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
host16, err := windows.UTF16PtrFromString(host)
if err != nil {
return nil, err
}
port16, err := windows.UTF16PtrFromString(port)
if err != nil {
return nil, err
}
hints := windows.AddrinfoW{
Flags: windows.AI_NUMERICHOST,
Family: windows.AF_UNSPEC,
Socktype: windows.SOCK_DGRAM,
Protocol: windows.IPPROTO_UDP,
}
var addrinfo *windows.AddrinfoW
err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
if err != nil {
return nil, err
}
defer windows.FreeAddrInfoW(addrinfo)
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
return nil, windows.ERROR_INVALID_ADDRESS
}
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
}
func (*WinRingEndpoint) ClearSrc() {}
func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family {
case windows.AF_INET:
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6:
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
}
return netip.Addr{}
}
func (e *WinRingEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
}
func (e *WinRingEndpoint) DstToBytes() []byte {
switch e.family {
case windows.AF_INET:
b := make([]byte, 0, 6)
b = append(b, e.data[2:6]...)
b = append(b, e.data[1], e.data[0])
return b
case windows.AF_INET6:
b := make([]byte, 0, 18)
b = append(b, e.data[6:22]...)
b = append(b, e.data[1], e.data[0])
return b
}
return nil
}
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10)
}
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
}
return ""
}
func (e *WinRingEndpoint) SrcToString() string {
return ""
}
func (ring *ringBuffer) CloseAndZero() {
if ring.cq != 0 {
winrio.CloseCompletionQueue(ring.cq)
ring.cq = 0
}
if ring.iocp != 0 {
windows.CloseHandle(ring.iocp)
ring.iocp = 0
}
if ring.id != 0 {
winrio.DeregisterBuffer(ring.id)
ring.id = 0
}
if ring.packets != 0 {
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
ring.packets = 0
}
ring.head = 0
ring.tail = 0
ring.isFull = false
}
func (bind *afWinRingBind) CloseAndZero() {
bind.rx.CloseAndZero()
bind.tx.CloseAndZero()
if bind.sock != 0 {
windows.CloseHandle(bind.sock)
bind.sock = 0
}
bind.blackhole = false
}
func (bind *WinRingBind) closeAndZero() {
bind.isOpen.Store(0)
bind.v4.CloseAndZero()
bind.v6.CloseAndZero()
}
func (ring *ringBuffer) Open() error {
var err error
packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return err
}
ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
if err != nil {
return err
}
ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
return err
}
ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
if err != nil {
return err
}
return nil
}
func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
var err error
bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
if err != nil {
return nil, err
}
err = bind.rx.Open()
if err != nil {
return nil, err
}
err = bind.tx.Open()
if err != nil {
return nil, err
}
bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
if err != nil {
return nil, err
}
err = windows.Bind(bind.sock, sa)
if err != nil {
return nil, err
}
sa, err = windows.Getsockname(bind.sock)
if err != nil {
return nil, err
}
return sa, nil
}
func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
bind.mu.Lock()
defer bind.mu.Unlock()
defer func() {
if err != nil {
bind.closeAndZero()
}
}()
if bind.isOpen.Load() != 0 {
return nil, 0, ErrBindAlreadyOpen
}
var sa windows.Sockaddr
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
if err != nil {
return nil, 0, err
}
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
if err != nil {
return nil, 0, err
}
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
for i := 0; i < packetsPerRing; i++ {
err = bind.v4.InsertReceiveRequest()
if err != nil {
return nil, 0, err
}
err = bind.v6.InsertReceiveRequest()
if err != nil {
return nil, 0, err
}
}
bind.isOpen.Store(1)
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
}
func (bind *WinRingBind) Close() error {
bind.mu.RLock()
if bind.isOpen.Load() != 1 {
bind.mu.RUnlock()
return nil
}
bind.isOpen.Store(2)
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
bind.mu.RUnlock()
bind.mu.Lock()
defer bind.mu.Unlock()
bind.closeAndZero()
return nil
}
func (bind *WinRingBind) BatchSize() int {
// TODO: implement batching in and out of the ring
return 1
}
func (bind *WinRingBind) SetMark(mark uint32) error {
return nil
}
func (bind *afWinRingBind) InsertReceiveRequest() error {
packet := bind.rx.Push()
dataBuffer := &winrio.Buffer{
Id: bind.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
Length: uint32(len(packet.data)),
}
addressBuffer := &winrio.Buffer{
Id: bind.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
bind.mu.Lock()
defer bind.mu.Unlock()
return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
}
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
bind.rx.mu.Lock()
defer bind.rx.mu.Unlock()
var err error
var count uint32
var results [1]winrio.Result
retry:
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
procyield(1)
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
}
if count == 0 {
err = winrio.Notify(bind.rx.cq)
if err != nil {
return 0, nil, err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return 0, nil, err
}
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
if count == 0 {
return 0, nil, io.ErrNoProgress
}
}
bind.rx.Return(1)
err = bind.InsertReceiveRequest()
if err != nil {
return 0, nil, err
}
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
goto retry
}
if results[0].Status != 0 {
return 0, nil, windows.Errno(results[0].Status)
}
packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
ep := packet.addr
n := copy(buf, packet.data[:results[0].BytesTransferred])
return n, &ep, nil
}
func (bind *WinRingBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v4.Receive(buffs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
}
func (bind *WinRingBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v6.Receive(buffs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
}
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
if isOpen.Load() != 1 {
return net.ErrClosed
}
if len(buf) > bytesPerPacket {
return io.ErrShortBuffer
}
bind.tx.mu.Lock()
defer bind.tx.mu.Unlock()
var results [packetsPerRing]winrio.Result
count := winrio.DequeueCompletion(bind.tx.cq, results[:])
if count == 0 && bind.tx.isFull {
err := winrio.Notify(bind.tx.cq)
if err != nil {
return err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return err
}
if isOpen.Load() != 1 {
return net.ErrClosed
}
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
if count == 0 {
return io.ErrNoProgress
}
}
if count > 0 {
bind.tx.Return(count)
}
packet := bind.tx.Push()
packet.addr = *nend
copy(packet.data[:], buf)
dataBuffer := &winrio.Buffer{
Id: bind.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
Length: uint32(len(buf)),
}
addressBuffer := &winrio.Buffer{
Id: bind.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
bind.mu.Lock()
defer bind.mu.Unlock()
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
func (bind *WinRingBind) Send(buffs [][]byte, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
for _, buf := range buffs {
switch nend.family {
case windows.AF_INET:
if bind.v4.blackhole {
continue
}
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
return err
}
case windows.AF_INET6:
if bind.v6.blackhole {
continue
}
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
return err
}
}
}
return nil
}
func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
s.mu.Lock()
defer s.mu.Unlock()
sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
s.blackhole4 = blackhole
return nil
}
func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
s.mu.Lock()
defer s.mu.Unlock()
sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
s.blackhole6 = blackhole
return nil
}
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
if err != nil {
return err
}
bind.v4.blackhole = blackhole
return nil
}
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
if err != nil {
return err
}
bind.v6.blackhole = blackhole
return nil
}
func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
const IP_UNICAST_IF = 31
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
var bytes [4]byte
binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
if err != nil {
return err
}
return nil
}
func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
const IPV6_UNICAST_IF = 31
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
}