mirror of
https://git.zx2c4.com/wireguard-go
synced 2025-09-18 20:57:50 +02:00
Compare commits
64 Commits
0.0.202302
...
master
Author | SHA1 | Date | |
---|---|---|---|
|
f333402bd9 | ||
|
c92064f1ce | ||
|
264889f0bb | ||
|
256bcbd70d | ||
|
1571e0fbae | ||
|
842888ac5c | ||
|
9e7529c3d2 | ||
|
436f7fdc16 | ||
|
0e4482a086 | ||
|
77b6c824a8 | ||
|
bc30fee374 | ||
|
b82c016264 | ||
|
45916071ba | ||
|
e3c1354d27 | ||
|
32546a15a8 | ||
|
9eb3221f1d | ||
|
867a4c4a3f | ||
|
113c8f1340 | ||
|
12269c2761 | ||
|
542e565baa | ||
|
7c20311b3d | ||
|
4ffa9c2032 | ||
|
d0bc03c707 | ||
|
1cf89f5339 | ||
|
2e0774f246 | ||
|
b3df23dcd4 | ||
|
f502ec3fad | ||
|
5d37bd24e1 | ||
|
24ea13351e | ||
|
177caa7e44 | ||
|
42ec952ead | ||
|
ec8f6f82c2 | ||
|
1ec454f253 | ||
|
8a015f7c76 | ||
|
895d6c23cd | ||
|
4201e08f1d | ||
|
6a84778f2c | ||
|
469159ecf7 | ||
|
6e755e132a | ||
|
1f25eac395 | ||
|
25eb973e00 | ||
|
b7cd547315 | ||
|
052af4a807 | ||
|
aad7fca9c5 | ||
|
6f895be10d | ||
|
6a07b2a355 | ||
|
334b605e72 | ||
|
3a9e75374f | ||
|
cc20c08c96 | ||
|
1417a47c8f | ||
|
7f511c3bb1 | ||
|
07a1e55270 | ||
|
fff53afca7 | ||
|
0ad14a89f5 | ||
|
7d327ed35a | ||
|
f41f474466 | ||
|
5819c6af28 | ||
|
6901984f6a | ||
|
2fcdaf9799 | ||
|
dbd949307e | ||
|
f26efb65f2 | ||
|
f67c862a2a | ||
|
9e2f386022 | ||
|
3bb8fec7e4 |
@ -46,7 +46,7 @@ This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapp
|
||||
|
||||
## Building
|
||||
|
||||
This requires an installation of [go](https://golang.org) ≥ 1.18.
|
||||
This requires an installation of the latest version of [Go](https://go.dev/).
|
||||
|
||||
```
|
||||
$ git clone https://git.zx2c4.com/wireguard-go
|
||||
@ -56,7 +56,7 @@ $ make
|
||||
|
||||
## License
|
||||
|
||||
Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
|
@ -1,562 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type ipv4Source struct {
|
||||
Src [4]byte
|
||||
Ifindex int32
|
||||
}
|
||||
|
||||
type ipv6Source struct {
|
||||
src [16]byte
|
||||
// ifindex belongs in dst.ZoneId
|
||||
}
|
||||
|
||||
type LinuxSocketEndpoint struct {
|
||||
mu sync.Mutex
|
||||
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
||||
src [unsafe.Sizeof(ipv6Source{})]byte
|
||||
isV6 bool
|
||||
}
|
||||
|
||||
func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
|
||||
func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
|
||||
func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
|
||||
|
||||
func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
|
||||
return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
|
||||
}
|
||||
|
||||
func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
|
||||
return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
|
||||
}
|
||||
|
||||
func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
|
||||
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
|
||||
}
|
||||
|
||||
func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
|
||||
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
|
||||
}
|
||||
|
||||
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
|
||||
type LinuxSocketBind struct {
|
||||
// mu guards sock4 and sock6 and the associated fds.
|
||||
// As long as someone holds mu (read or write), the associated fds are valid.
|
||||
mu sync.RWMutex
|
||||
sock4 int
|
||||
sock6 int
|
||||
}
|
||||
|
||||
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
|
||||
func NewDefaultBind() Bind { return NewLinuxSocketBind() }
|
||||
|
||||
var (
|
||||
_ Endpoint = (*LinuxSocketEndpoint)(nil)
|
||||
_ Bind = (*LinuxSocketBind)(nil)
|
||||
)
|
||||
|
||||
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||
var end LinuxSocketEndpoint
|
||||
e, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if e.Addr().Is4() {
|
||||
dst := end.dst4()
|
||||
end.isV6 = false
|
||||
dst.Port = int(e.Port())
|
||||
dst.Addr = e.Addr().As4()
|
||||
end.ClearSrc()
|
||||
return &end, nil
|
||||
}
|
||||
|
||||
if e.Addr().Is6() {
|
||||
zone, err := zoneToUint32(e.Addr().Zone())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dst := end.dst6()
|
||||
end.isV6 = true
|
||||
dst.Port = int(e.Port())
|
||||
dst.ZoneId = zone
|
||||
dst.Addr = e.Addr().As16()
|
||||
end.ClearSrc()
|
||||
return &end, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("invalid IP address")
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var newPort uint16
|
||||
var tries int
|
||||
|
||||
if bind.sock4 != -1 || bind.sock6 != -1 {
|
||||
return nil, 0, ErrBindAlreadyOpen
|
||||
}
|
||||
|
||||
originalPort := port
|
||||
|
||||
again:
|
||||
port = originalPort
|
||||
var sock4, sock6 int
|
||||
// Attempt ipv6 bind, update port if successful.
|
||||
sock6, newPort, err = create6(port)
|
||||
if err != nil {
|
||||
if !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, 0, err
|
||||
}
|
||||
} else {
|
||||
port = newPort
|
||||
}
|
||||
|
||||
// Attempt ipv4 bind, update port if successful.
|
||||
sock4, newPort, err = create4(port)
|
||||
if err != nil {
|
||||
if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||
unix.Close(sock6)
|
||||
tries++
|
||||
goto again
|
||||
}
|
||||
if !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
unix.Close(sock6)
|
||||
return nil, 0, err
|
||||
}
|
||||
} else {
|
||||
port = newPort
|
||||
}
|
||||
|
||||
var fns []ReceiveFunc
|
||||
if sock4 != -1 {
|
||||
bind.sock4 = sock4
|
||||
fns = append(fns, bind.receiveIPv4)
|
||||
}
|
||||
if sock6 != -1 {
|
||||
bind.sock6 = sock6
|
||||
fns = append(fns, bind.receiveIPv6)
|
||||
}
|
||||
if len(fns) == 0 {
|
||||
return nil, 0, syscall.EAFNOSUPPORT
|
||||
}
|
||||
return fns, port, nil
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) SetMark(value uint32) error {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
|
||||
if bind.sock6 != -1 {
|
||||
err := unix.SetsockoptInt(
|
||||
bind.sock6,
|
||||
unix.SOL_SOCKET,
|
||||
unix.SO_MARK,
|
||||
int(value),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if bind.sock4 != -1 {
|
||||
err := unix.SetsockoptInt(
|
||||
bind.sock4,
|
||||
unix.SOL_SOCKET,
|
||||
unix.SO_MARK,
|
||||
int(value),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) Close() error {
|
||||
// Take a readlock to shut down the sockets...
|
||||
bind.mu.RLock()
|
||||
if bind.sock6 != -1 {
|
||||
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
|
||||
}
|
||||
if bind.sock4 != -1 {
|
||||
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
|
||||
}
|
||||
bind.mu.RUnlock()
|
||||
// ...and a write lock to close the fd.
|
||||
// This ensures that no one else is using the fd.
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
var err1, err2 error
|
||||
if bind.sock6 != -1 {
|
||||
err1 = unix.Close(bind.sock6)
|
||||
bind.sock6 = -1
|
||||
}
|
||||
if bind.sock4 != -1 {
|
||||
err2 = unix.Close(bind.sock4)
|
||||
bind.sock4 = -1
|
||||
}
|
||||
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if bind.sock4 == -1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
var end LinuxSocketEndpoint
|
||||
n, err := receive4(bind.sock4, buf, &end)
|
||||
return n, &end, err
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if bind.sock6 == -1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
var end LinuxSocketEndpoint
|
||||
n, err := receive6(bind.sock6, buf, &end)
|
||||
return n, &end, err
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
|
||||
nend, ok := end.(*LinuxSocketEndpoint)
|
||||
if !ok {
|
||||
return ErrWrongEndpointType
|
||||
}
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if !nend.isV6 {
|
||||
if bind.sock4 == -1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return send4(bind.sock4, nend, buff)
|
||||
} else {
|
||||
if bind.sock6 == -1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return send6(bind.sock6, nend, buff)
|
||||
}
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
|
||||
if !end.isV6 {
|
||||
return netip.AddrFrom4(end.src4().Src)
|
||||
} else {
|
||||
return netip.AddrFrom16(end.src6().src)
|
||||
}
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
|
||||
if !end.isV6 {
|
||||
return netip.AddrFrom4(end.dst4().Addr)
|
||||
} else {
|
||||
return netip.AddrFrom16(end.dst6().Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) DstToBytes() []byte {
|
||||
if !end.isV6 {
|
||||
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
|
||||
} else {
|
||||
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
|
||||
}
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) SrcToString() string {
|
||||
return end.SrcIP().String()
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) DstToString() string {
|
||||
var port int
|
||||
if !end.isV6 {
|
||||
port = end.dst4().Port
|
||||
} else {
|
||||
port = end.dst6().Port
|
||||
}
|
||||
return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) ClearDst() {
|
||||
for i := range end.dst {
|
||||
end.dst[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) ClearSrc() {
|
||||
for i := range end.src {
|
||||
end.src[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
func zoneToUint32(zone string) (uint32, error) {
|
||||
if zone == "" {
|
||||
return 0, nil
|
||||
}
|
||||
if intr, err := net.InterfaceByName(zone); err == nil {
|
||||
return uint32(intr.Index), nil
|
||||
}
|
||||
n, err := strconv.ParseUint(zone, 10, 32)
|
||||
return uint32(n), err
|
||||
}
|
||||
|
||||
func create4(port uint16) (int, uint16, error) {
|
||||
// create socket
|
||||
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
addr := unix.SockaddrInet4{
|
||||
Port: int(port),
|
||||
}
|
||||
|
||||
// set sockopts and bind
|
||||
|
||||
if err := func() error {
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.IPPROTO_IP,
|
||||
unix.IP_PKTINFO,
|
||||
1,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return unix.Bind(fd, &addr)
|
||||
}(); err != nil {
|
||||
unix.Close(fd)
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
sa, err := unix.Getsockname(fd)
|
||||
if err == nil {
|
||||
addr.Port = sa.(*unix.SockaddrInet4).Port
|
||||
}
|
||||
|
||||
return fd, uint16(addr.Port), err
|
||||
}
|
||||
|
||||
func create6(port uint16) (int, uint16, error) {
|
||||
// create socket
|
||||
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET6,
|
||||
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
// set sockopts and bind
|
||||
|
||||
addr := unix.SockaddrInet6{
|
||||
Port: int(port),
|
||||
}
|
||||
|
||||
if err := func() error {
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.IPPROTO_IPV6,
|
||||
unix.IPV6_RECVPKTINFO,
|
||||
1,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.IPPROTO_IPV6,
|
||||
unix.IPV6_V6ONLY,
|
||||
1,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return unix.Bind(fd, &addr)
|
||||
}(); err != nil {
|
||||
unix.Close(fd)
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
sa, err := unix.Getsockname(fd)
|
||||
if err == nil {
|
||||
addr.Port = sa.(*unix.SockaddrInet6).Port
|
||||
}
|
||||
|
||||
return fd, uint16(addr.Port), err
|
||||
}
|
||||
|
||||
func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
|
||||
// construct message header
|
||||
|
||||
cmsg := struct {
|
||||
cmsghdr unix.Cmsghdr
|
||||
pktinfo unix.Inet4Pktinfo
|
||||
}{
|
||||
unix.Cmsghdr{
|
||||
Level: unix.IPPROTO_IP,
|
||||
Type: unix.IP_PKTINFO,
|
||||
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
||||
},
|
||||
unix.Inet4Pktinfo{
|
||||
Spec_dst: end.src4().Src,
|
||||
Ifindex: end.src4().Ifindex,
|
||||
},
|
||||
}
|
||||
|
||||
end.mu.Lock()
|
||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||
end.mu.Unlock()
|
||||
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// clear src and retry
|
||||
|
||||
if err == unix.EINVAL {
|
||||
end.ClearSrc()
|
||||
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
||||
end.mu.Lock()
|
||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||
end.mu.Unlock()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
|
||||
// construct message header
|
||||
|
||||
cmsg := struct {
|
||||
cmsghdr unix.Cmsghdr
|
||||
pktinfo unix.Inet6Pktinfo
|
||||
}{
|
||||
unix.Cmsghdr{
|
||||
Level: unix.IPPROTO_IPV6,
|
||||
Type: unix.IPV6_PKTINFO,
|
||||
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
|
||||
},
|
||||
unix.Inet6Pktinfo{
|
||||
Addr: end.src6().src,
|
||||
Ifindex: end.dst6().ZoneId,
|
||||
},
|
||||
}
|
||||
|
||||
if cmsg.pktinfo.Addr == [16]byte{} {
|
||||
cmsg.pktinfo.Ifindex = 0
|
||||
}
|
||||
|
||||
end.mu.Lock()
|
||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||
end.mu.Unlock()
|
||||
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// clear src and retry
|
||||
|
||||
if err == unix.EINVAL {
|
||||
end.ClearSrc()
|
||||
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
||||
end.mu.Lock()
|
||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||
end.mu.Unlock()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
|
||||
// construct message header
|
||||
|
||||
var cmsg struct {
|
||||
cmsghdr unix.Cmsghdr
|
||||
pktinfo unix.Inet4Pktinfo
|
||||
}
|
||||
|
||||
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
end.isV6 = false
|
||||
|
||||
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
|
||||
*end.dst4() = *newDst4
|
||||
}
|
||||
|
||||
// update source cache
|
||||
|
||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
||||
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
||||
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
||||
end.src4().Src = cmsg.pktinfo.Spec_dst
|
||||
end.src4().Ifindex = cmsg.pktinfo.Ifindex
|
||||
}
|
||||
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
|
||||
// construct message header
|
||||
|
||||
var cmsg struct {
|
||||
cmsghdr unix.Cmsghdr
|
||||
pktinfo unix.Inet6Pktinfo
|
||||
}
|
||||
|
||||
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
end.isV6 = true
|
||||
|
||||
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
|
||||
*end.dst6() = *newDst6
|
||||
}
|
||||
|
||||
// update source cache
|
||||
|
||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
|
||||
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
|
||||
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
|
||||
end.src6().src = cmsg.pktinfo.Addr
|
||||
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
|
||||
}
|
||||
|
||||
return size, nil
|
||||
}
|
522
conn/bind_std.go
522
conn/bind_std.go
@ -1,69 +1,126 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
// StdNetBind is meant to be a temporary solution on platforms for which
|
||||
// the sticky socket / source caching behavior has not yet been implemented.
|
||||
// It uses the Go's net package to implement networking.
|
||||
// See LinuxSocketBind for a proper implementation on the Linux platform.
|
||||
var (
|
||||
_ Bind = (*StdNetBind)(nil)
|
||||
)
|
||||
|
||||
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
|
||||
// (see bind_windows.go), it may fall back to StdNetBind.
|
||||
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
|
||||
// methods for sending and receiving multiple datagrams per-syscall. See the
|
||||
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
||||
type StdNetBind struct {
|
||||
mu sync.Mutex // protects following fields
|
||||
ipv4 *net.UDPConn
|
||||
ipv6 *net.UDPConn
|
||||
mu sync.Mutex // protects all fields except as specified
|
||||
ipv4 *net.UDPConn
|
||||
ipv6 *net.UDPConn
|
||||
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
||||
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
||||
ipv4TxOffload bool
|
||||
ipv4RxOffload bool
|
||||
ipv6TxOffload bool
|
||||
ipv6RxOffload bool
|
||||
|
||||
// these two fields are not guarded by mu
|
||||
udpAddrPool sync.Pool
|
||||
msgsPool sync.Pool
|
||||
|
||||
blackhole4 bool
|
||||
blackhole6 bool
|
||||
}
|
||||
|
||||
func NewStdNetBind() Bind { return &StdNetBind{} }
|
||||
func NewStdNetBind() Bind {
|
||||
return &StdNetBind{
|
||||
udpAddrPool: sync.Pool{
|
||||
New: func() any {
|
||||
return &net.UDPAddr{
|
||||
IP: make([]byte, 16),
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
type StdNetEndpoint netip.AddrPort
|
||||
msgsPool: sync.Pool{
|
||||
New: func() any {
|
||||
// ipv6.Message and ipv4.Message are interchangeable as they are
|
||||
// both aliases for x/net/internal/socket.Message.
|
||||
msgs := make([]ipv6.Message, IdealBatchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type StdNetEndpoint struct {
|
||||
// AddrPort is the endpoint destination.
|
||||
netip.AddrPort
|
||||
// src is the current sticky source address and interface index, if
|
||||
// supported. Typically this is a PKTINFO structure from/for control
|
||||
// messages, see unix.PKTINFO for an example.
|
||||
src []byte
|
||||
}
|
||||
|
||||
var (
|
||||
_ Bind = (*StdNetBind)(nil)
|
||||
_ Endpoint = StdNetEndpoint{}
|
||||
_ Endpoint = &StdNetEndpoint{}
|
||||
)
|
||||
|
||||
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||
e, err := netip.ParseAddrPort(s)
|
||||
return asEndpoint(e), err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StdNetEndpoint{
|
||||
AddrPort: e,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (StdNetEndpoint) ClearSrc() {}
|
||||
|
||||
func (e StdNetEndpoint) DstIP() netip.Addr {
|
||||
return (netip.AddrPort)(e).Addr()
|
||||
func (e *StdNetEndpoint) ClearSrc() {
|
||||
if e.src != nil {
|
||||
// Truncate src, no need to reallocate.
|
||||
e.src = e.src[:0]
|
||||
}
|
||||
}
|
||||
|
||||
func (e StdNetEndpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{} // not supported
|
||||
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||
return e.AddrPort.Addr()
|
||||
}
|
||||
|
||||
func (e StdNetEndpoint) DstToBytes() []byte {
|
||||
b, _ := (netip.AddrPort)(e).MarshalBinary()
|
||||
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
|
||||
|
||||
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||
b, _ := e.AddrPort.MarshalBinary()
|
||||
return b
|
||||
}
|
||||
|
||||
func (e StdNetEndpoint) DstToString() string {
|
||||
return (netip.AddrPort)(e).String()
|
||||
}
|
||||
|
||||
func (e StdNetEndpoint) SrcToString() string {
|
||||
return ""
|
||||
func (e *StdNetEndpoint) DstToString() string {
|
||||
return e.AddrPort.String()
|
||||
}
|
||||
|
||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@ -77,17 +134,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return conn, uaddr.Port, nil
|
||||
return conn.(*net.UDPConn), uaddr.Port, nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var tries int
|
||||
|
||||
if bind.ipv4 != nil || bind.ipv6 != nil {
|
||||
if s.ipv4 != nil || s.ipv6 != nil {
|
||||
return nil, 0, ErrBindAlreadyOpen
|
||||
}
|
||||
|
||||
@ -95,90 +152,207 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||
// If uport is 0, we can retry on failure.
|
||||
again:
|
||||
port := int(uport)
|
||||
var ipv4, ipv6 *net.UDPConn
|
||||
var v4conn, v6conn *net.UDPConn
|
||||
var v4pc *ipv4.PacketConn
|
||||
var v6pc *ipv6.PacketConn
|
||||
|
||||
ipv4, port, err = listenNet("udp4", port)
|
||||
v4conn, port, err = listenNet("udp4", port)
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Listen on the same port as we're using for ipv4.
|
||||
ipv6, port, err = listenNet("udp6", port)
|
||||
v6conn, port, err = listenNet("udp6", port)
|
||||
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||
ipv4.Close()
|
||||
v4conn.Close()
|
||||
tries++
|
||||
goto again
|
||||
}
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
ipv4.Close()
|
||||
v4conn.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
var fns []ReceiveFunc
|
||||
if ipv4 != nil {
|
||||
fns = append(fns, bind.makeReceiveIPv4(ipv4))
|
||||
bind.ipv4 = ipv4
|
||||
if v4conn != nil {
|
||||
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
v4pc = ipv4.NewPacketConn(v4conn)
|
||||
s.ipv4PC = v4pc
|
||||
}
|
||||
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
|
||||
s.ipv4 = v4conn
|
||||
}
|
||||
if ipv6 != nil {
|
||||
fns = append(fns, bind.makeReceiveIPv6(ipv6))
|
||||
bind.ipv6 = ipv6
|
||||
if v6conn != nil {
|
||||
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
v6pc = ipv6.NewPacketConn(v6conn)
|
||||
s.ipv6PC = v6pc
|
||||
}
|
||||
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
|
||||
s.ipv6 = v6conn
|
||||
}
|
||||
if len(fns) == 0 {
|
||||
return nil, 0, syscall.EAFNOSUPPORT
|
||||
}
|
||||
|
||||
return fns, uint16(port), nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) Close() error {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
|
||||
for i := range *msgs {
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||
}
|
||||
s.msgsPool.Put(msgs)
|
||||
}
|
||||
|
||||
func (s *StdNetBind) getMessages() *[]ipv6.Message {
|
||||
return s.msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
var (
|
||||
// If compilation fails here these are no longer the same underlying type.
|
||||
_ ipv6.Message = ipv4.Message{}
|
||||
)
|
||||
|
||||
type batchReader interface {
|
||||
ReadBatch([]ipv6.Message, int) (int, error)
|
||||
}
|
||||
|
||||
type batchWriter interface {
|
||||
WriteBatch([]ipv6.Message, int) (int, error)
|
||||
}
|
||||
|
||||
func (s *StdNetBind) receiveIP(
|
||||
br batchReader,
|
||||
conn *net.UDPConn,
|
||||
rxOffload bool,
|
||||
bufs [][]byte,
|
||||
sizes []int,
|
||||
eps []Endpoint,
|
||||
) (n int, err error) {
|
||||
msgs := s.getMessages()
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||
}
|
||||
defer s.putMessages(msgs)
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
if rxOffload {
|
||||
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
|
||||
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
numMsgs, err = br.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
sizes[i] = msg.N
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
|
||||
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||
// rename the IdealBatchSize constant to BatchSize.
|
||||
func (s *StdNetBind) BatchSize() int {
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
return IdealBatchSize
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *StdNetBind) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var err1, err2 error
|
||||
if bind.ipv4 != nil {
|
||||
err1 = bind.ipv4.Close()
|
||||
bind.ipv4 = nil
|
||||
if s.ipv4 != nil {
|
||||
err1 = s.ipv4.Close()
|
||||
s.ipv4 = nil
|
||||
s.ipv4PC = nil
|
||||
}
|
||||
if bind.ipv6 != nil {
|
||||
err2 = bind.ipv6.Close()
|
||||
bind.ipv6 = nil
|
||||
if s.ipv6 != nil {
|
||||
err2 = s.ipv6.Close()
|
||||
s.ipv6 = nil
|
||||
s.ipv6PC = nil
|
||||
}
|
||||
bind.blackhole4 = false
|
||||
bind.blackhole6 = false
|
||||
s.blackhole4 = false
|
||||
s.blackhole6 = false
|
||||
s.ipv4TxOffload = false
|
||||
s.ipv4RxOffload = false
|
||||
s.ipv6TxOffload = false
|
||||
s.ipv6RxOffload = false
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
|
||||
return func(buff []byte) (int, Endpoint, error) {
|
||||
n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
|
||||
return n, asEndpoint(endpoint), err
|
||||
}
|
||||
type ErrUDPGSODisabled struct {
|
||||
onLaddr string
|
||||
RetryErr error
|
||||
}
|
||||
|
||||
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
|
||||
return func(buff []byte) (int, Endpoint, error) {
|
||||
n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
|
||||
return n, asEndpoint(endpoint), err
|
||||
}
|
||||
func (e ErrUDPGSODisabled) Error() string {
|
||||
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
||||
var err error
|
||||
nend, ok := endpoint.(StdNetEndpoint)
|
||||
if !ok {
|
||||
return ErrWrongEndpointType
|
||||
}
|
||||
addrPort := netip.AddrPort(nend)
|
||||
func (e ErrUDPGSODisabled) Unwrap() error {
|
||||
return e.RetryErr
|
||||
}
|
||||
|
||||
bind.mu.Lock()
|
||||
blackhole := bind.blackhole4
|
||||
conn := bind.ipv4
|
||||
if addrPort.Addr().Is6() {
|
||||
blackhole = bind.blackhole6
|
||||
conn = bind.ipv6
|
||||
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||
s.mu.Lock()
|
||||
blackhole := s.blackhole4
|
||||
conn := s.ipv4
|
||||
offload := s.ipv4TxOffload
|
||||
br := batchWriter(s.ipv4PC)
|
||||
is6 := false
|
||||
if endpoint.DstIP().Is6() {
|
||||
blackhole = s.blackhole6
|
||||
conn = s.ipv6
|
||||
br = s.ipv6PC
|
||||
is6 = true
|
||||
offload = s.ipv6TxOffload
|
||||
}
|
||||
bind.mu.Unlock()
|
||||
s.mu.Unlock()
|
||||
|
||||
if blackhole {
|
||||
return nil
|
||||
@ -186,27 +360,185 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
||||
if conn == nil {
|
||||
return syscall.EAFNOSUPPORT
|
||||
}
|
||||
_, err = conn.WriteToUDPAddrPort(buff, addrPort)
|
||||
|
||||
msgs := s.getMessages()
|
||||
defer s.putMessages(msgs)
|
||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||
defer s.udpAddrPool.Put(ua)
|
||||
if is6 {
|
||||
as16 := endpoint.DstIP().As16()
|
||||
copy(ua.IP, as16[:])
|
||||
ua.IP = ua.IP[:16]
|
||||
} else {
|
||||
as4 := endpoint.DstIP().As4()
|
||||
copy(ua.IP, as4[:])
|
||||
ua.IP = ua.IP[:4]
|
||||
}
|
||||
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
|
||||
var (
|
||||
retried bool
|
||||
err error
|
||||
)
|
||||
retry:
|
||||
if offload {
|
||||
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
|
||||
err = s.send(conn, br, (*msgs)[:n])
|
||||
if err != nil && offload && errShouldDisableUDPGSO(err) {
|
||||
offload = false
|
||||
s.mu.Lock()
|
||||
if is6 {
|
||||
s.ipv6TxOffload = false
|
||||
} else {
|
||||
s.ipv4TxOffload = false
|
||||
}
|
||||
s.mu.Unlock()
|
||||
retried = true
|
||||
goto retry
|
||||
}
|
||||
} else {
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Addr = ua
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
|
||||
}
|
||||
err = s.send(conn, br, (*msgs)[:len(bufs)])
|
||||
}
|
||||
if retried {
|
||||
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
|
||||
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
|
||||
// but Endpoints are immutable, so we can re-use them.
|
||||
var endpointPool = sync.Pool{
|
||||
New: func() any {
|
||||
return make(map[netip.AddrPort]Endpoint)
|
||||
},
|
||||
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
start int
|
||||
)
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
for {
|
||||
n, err = pc.WriteBatch(msgs[start:], 0)
|
||||
if err != nil || n == len(msgs[start:]) {
|
||||
break
|
||||
}
|
||||
start += n
|
||||
}
|
||||
} else {
|
||||
for _, msg := range msgs {
|
||||
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// asEndpoint returns an Endpoint containing ap.
|
||||
func asEndpoint(ap netip.AddrPort) Endpoint {
|
||||
m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
|
||||
defer endpointPool.Put(m)
|
||||
e, ok := m[ap]
|
||||
if !ok {
|
||||
e = Endpoint(StdNetEndpoint(ap))
|
||||
m[ap] = e
|
||||
const (
|
||||
// Exceeding these values results in EMSGSIZE. They account for layer3 and
|
||||
// layer4 headers. IPv6 does not need to account for itself as the payload
|
||||
// length field is self excluding.
|
||||
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
|
||||
maxIPv6PayloadLen = 1<<16 - 1 - 8
|
||||
|
||||
// This is a hard limit imposed by the kernel.
|
||||
udpSegmentMaxDatagrams = 64
|
||||
)
|
||||
|
||||
type setGSOFunc func(control *[]byte, gsoSize uint16)
|
||||
|
||||
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
|
||||
var (
|
||||
base = -1 // index of msg we are currently coalescing into
|
||||
gsoSize int // segmentation size of msgs[base]
|
||||
dgramCnt int // number of dgrams coalesced into msgs[base]
|
||||
endBatch bool // tracking flag to start a new batch on next iteration of bufs
|
||||
)
|
||||
maxPayloadLen := maxIPv4PayloadLen
|
||||
if ep.DstIP().Is6() {
|
||||
maxPayloadLen = maxIPv6PayloadLen
|
||||
}
|
||||
return e
|
||||
for i, buf := range bufs {
|
||||
if i > 0 {
|
||||
msgLen := len(buf)
|
||||
baseLenBefore := len(msgs[base].Buffers[0])
|
||||
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
|
||||
if msgLen+baseLenBefore <= maxPayloadLen &&
|
||||
msgLen <= gsoSize &&
|
||||
msgLen <= freeBaseCap &&
|
||||
dgramCnt < udpSegmentMaxDatagrams &&
|
||||
!endBatch {
|
||||
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
|
||||
if i == len(bufs)-1 {
|
||||
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||
}
|
||||
dgramCnt++
|
||||
if msgLen < gsoSize {
|
||||
// A smaller than gsoSize packet on the tail is legal, but
|
||||
// it must end the batch.
|
||||
endBatch = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if dgramCnt > 1 {
|
||||
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||
}
|
||||
// Reset prior to incrementing base since we are preparing to start a
|
||||
// new potential batch.
|
||||
endBatch = false
|
||||
base++
|
||||
gsoSize = len(buf)
|
||||
setSrcControl(&msgs[base].OOB, ep)
|
||||
msgs[base].Buffers[0] = buf
|
||||
msgs[base].Addr = addr
|
||||
dgramCnt = 1
|
||||
}
|
||||
return base + 1
|
||||
}
|
||||
|
||||
type getGSOFunc func(control []byte) (int, error)
|
||||
|
||||
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
|
||||
for i := firstMsgAt; i < len(msgs); i++ {
|
||||
msg := &msgs[i]
|
||||
if msg.N == 0 {
|
||||
return n, err
|
||||
}
|
||||
var (
|
||||
gsoSize int
|
||||
start int
|
||||
end = msg.N
|
||||
numToSplit = 1
|
||||
)
|
||||
gsoSize, err = getGSO(msg.OOB[:msg.NN])
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if gsoSize > 0 {
|
||||
numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
||||
end = gsoSize
|
||||
}
|
||||
for j := 0; j < numToSplit; j++ {
|
||||
if n > i {
|
||||
return n, errors.New("splitting coalesced packet resulted in overflow")
|
||||
}
|
||||
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
||||
msgs[n].N = copied
|
||||
msgs[n].Addr = msg.Addr
|
||||
start = end
|
||||
end += gsoSize
|
||||
if end > msg.N {
|
||||
end = msg.N
|
||||
}
|
||||
n++
|
||||
}
|
||||
if i != n-1 {
|
||||
// It is legal for bytes to move within msg.Buffers[0] as a result
|
||||
// of splitting, so we only zero the source msg len when it is not
|
||||
// the destination of the last split operation above.
|
||||
msg.N = 0
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
250
conn/bind_std_test.go
Normal file
250
conn/bind_std_test.go
Normal file
@ -0,0 +1,250 @@
|
||||
package conn
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
|
||||
bind := NewStdNetBind().(*StdNetBind)
|
||||
fns, _, err := bind.Open(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
bind.Close()
|
||||
bufs := make([][]byte, 1)
|
||||
bufs[0] = make([]byte, 1)
|
||||
sizes := make([]int, 1)
|
||||
eps := make([]Endpoint, 1)
|
||||
for _, fn := range fns {
|
||||
// The ReceiveFuncs must not access conn-related fields on StdNetBind
|
||||
// unguarded. Close() nils the conn-related fields resulting in a panic
|
||||
// if they violate the mutex.
|
||||
fn(bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
|
||||
*control = (*control)[:cap(*control)]
|
||||
binary.LittleEndian.PutUint16(*control, gsoSize)
|
||||
}
|
||||
|
||||
func Test_coalesceMessages(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
buffs [][]byte
|
||||
wantLens []int
|
||||
wantGSO []int
|
||||
}{
|
||||
{
|
||||
name: "one message no coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{1},
|
||||
wantGSO: []int{0},
|
||||
},
|
||||
{
|
||||
name: "two messages equal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 1, 2),
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{2},
|
||||
wantGSO: []int{1},
|
||||
},
|
||||
{
|
||||
name: "two messages unequal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 3),
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{3},
|
||||
wantGSO: []int{2},
|
||||
},
|
||||
{
|
||||
name: "three messages second unequal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 3),
|
||||
make([]byte, 1, 1),
|
||||
make([]byte, 2, 2),
|
||||
},
|
||||
wantLens: []int{3, 2},
|
||||
wantGSO: []int{2, 0},
|
||||
},
|
||||
{
|
||||
name: "three messages limited cap coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 4),
|
||||
make([]byte, 2, 2),
|
||||
make([]byte, 2, 2),
|
||||
},
|
||||
wantLens: []int{4, 2},
|
||||
wantGSO: []int{2, 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1").To4(),
|
||||
Port: 1,
|
||||
}
|
||||
msgs := make([]ipv6.Message, len(tt.buffs))
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make([][]byte, 1)
|
||||
msgs[i].OOB = make([]byte, 0, 2)
|
||||
}
|
||||
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
|
||||
if got != len(tt.wantLens) {
|
||||
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
|
||||
}
|
||||
for i := 0; i < got; i++ {
|
||||
if msgs[i].Addr != addr {
|
||||
t.Errorf("msgs[%d].Addr != passed addr", i)
|
||||
}
|
||||
gotLen := len(msgs[i].Buffers[0])
|
||||
if gotLen != tt.wantLens[i] {
|
||||
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
|
||||
}
|
||||
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
|
||||
if err != nil {
|
||||
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
|
||||
}
|
||||
if gotGSO != tt.wantGSO[i] {
|
||||
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mockGetGSOSize(control []byte) (int, error) {
|
||||
if len(control) < 2 {
|
||||
return 0, nil
|
||||
}
|
||||
return int(binary.LittleEndian.Uint16(control)), nil
|
||||
}
|
||||
|
||||
func Test_splitCoalescedMessages(t *testing.T) {
|
||||
newMsg := func(n, gso int) ipv6.Message {
|
||||
msg := ipv6.Message{
|
||||
Buffers: [][]byte{make([]byte, 1<<16-1)},
|
||||
N: n,
|
||||
OOB: make([]byte, 2),
|
||||
}
|
||||
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
|
||||
if gso > 0 {
|
||||
msg.NN = 2
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
msgs []ipv6.Message
|
||||
firstMsgAt int
|
||||
wantNumEval int
|
||||
wantMsgLens []int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "second last split last empty",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(3, 1),
|
||||
newMsg(0, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 3,
|
||||
wantMsgLens: []int{1, 1, 1, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last empty",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(0, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 1,
|
||||
wantMsgLens: []int{1, 0, 0, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last no split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(1, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 2,
|
||||
wantMsgLens: []int{1, 1, 0, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(3, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last split last split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(2, 1),
|
||||
newMsg(2, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last split overflow",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(4, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
|
||||
if err != nil && !tt.wantErr {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if got != tt.wantNumEval {
|
||||
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
|
||||
}
|
||||
for i, msg := range tt.msgs {
|
||||
if msg.N != tt.wantMsgLens[i] {
|
||||
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
@ -164,7 +164,7 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
|
||||
func (e *WinRingEndpoint) DstToString() string {
|
||||
switch e.family {
|
||||
case windows.AF_INET:
|
||||
netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||
return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||
case windows.AF_INET6:
|
||||
var zone string
|
||||
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
|
||||
@ -321,6 +321,13 @@ func (bind *WinRingBind) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||
// rename the IdealBatchSize constant to BatchSize.
|
||||
func (bind *WinRingBind) BatchSize() int {
|
||||
// TODO: implement batching in and out of the ring
|
||||
return 1
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
@ -409,16 +416,22 @@ retry:
|
||||
return n, &ep, nil
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
|
||||
func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
return bind.v4.Receive(buf, &bind.isOpen)
|
||||
n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
|
||||
sizes[0] = n
|
||||
eps[0] = ep
|
||||
return 1, err
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
|
||||
func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
return bind.v6.Receive(buf, &bind.isOpen)
|
||||
n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
|
||||
sizes[0] = n
|
||||
eps[0] = ep
|
||||
return 1, err
|
||||
}
|
||||
|
||||
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
|
||||
@ -473,32 +486,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi
|
||||
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
|
||||
func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||
nend, ok := endpoint.(*WinRingEndpoint)
|
||||
if !ok {
|
||||
return ErrWrongEndpointType
|
||||
}
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
switch nend.family {
|
||||
case windows.AF_INET:
|
||||
if bind.v4.blackhole {
|
||||
return nil
|
||||
for _, buf := range bufs {
|
||||
switch nend.family {
|
||||
case windows.AF_INET:
|
||||
if bind.v4.blackhole {
|
||||
continue
|
||||
}
|
||||
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
|
||||
return err
|
||||
}
|
||||
case windows.AF_INET6:
|
||||
if bind.v6.blackhole {
|
||||
continue
|
||||
}
|
||||
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return bind.v4.Send(buf, nend, &bind.isOpen)
|
||||
case windows.AF_INET6:
|
||||
if bind.v6.blackhole {
|
||||
return nil
|
||||
}
|
||||
return bind.v6.Send(buf, nend, &bind.isOpen)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
sysconn, err := bind.ipv4.SyscallConn()
|
||||
func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sysconn, err := s.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -511,14 +530,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bind.blackhole4 = blackhole
|
||||
s.blackhole4 = blackhole
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
sysconn, err := bind.ipv6.SyscallConn()
|
||||
func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sysconn, err := s.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -531,7 +550,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bind.blackhole6 = blackhole
|
||||
s.blackhole6 = blackhole
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bindtest
|
||||
@ -89,32 +89,39 @@ func (c *ChannelBind) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChannelBind) BatchSize() int { return 1 }
|
||||
|
||||
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
||||
|
||||
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
|
||||
return func(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||
return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
return 0, nil, net.ErrClosed
|
||||
return 0, net.ErrClosed
|
||||
case rx := <-ch:
|
||||
return copy(b, rx), c.target6, nil
|
||||
copied := copy(bufs[0], rx)
|
||||
sizes[0] = copied
|
||||
eps[0] = c.target6
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
bc := make([]byte, len(b))
|
||||
copy(bc, b)
|
||||
if ep.(ChannelEndpoint) == c.target4 {
|
||||
*c.tx4 <- bc
|
||||
} else if ep.(ChannelEndpoint) == c.target6 {
|
||||
*c.tx6 <- bc
|
||||
} else {
|
||||
return os.ErrInvalid
|
||||
func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||
for _, b := range bufs {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
bc := make([]byte, len(b))
|
||||
copy(bc, b)
|
||||
if ep.(ChannelEndpoint) == c.target4 {
|
||||
*c.tx4 <- bc
|
||||
} else if ep.(ChannelEndpoint) == c.target6 {
|
||||
*c.tx6 <- bc
|
||||
} else {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -1,12 +1,12 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
||||
sysconn, err := bind.ipv4.SyscallConn()
|
||||
func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
||||
sysconn, err := s.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
@ -19,8 +19,8 @@ func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
|
||||
sysconn, err := bind.ipv6.SyscallConn()
|
||||
func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
|
||||
sysconn, err := s.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
26
conn/conn.go
26
conn/conn.go
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package conn implements WireGuard's network connections.
|
||||
@ -15,10 +15,17 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A ReceiveFunc receives a single inbound packet from the network.
|
||||
// It writes the data into b. n is the length of the packet.
|
||||
// ep is the remote endpoint.
|
||||
type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
|
||||
const (
|
||||
IdealBatchSize = 128 // maximum number of packets handled per read and write
|
||||
)
|
||||
|
||||
// A ReceiveFunc receives at least one packet from the network and writes them
|
||||
// into packets. On a successful read it returns the number of elements of
|
||||
// sizes, packets, and endpoints that should be evaluated. Some elements of
|
||||
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
|
||||
// and eps slice with a length greater than or equal to the length of packets.
|
||||
// These lengths must not exceed the length of the associated Bind.BatchSize().
|
||||
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
|
||||
|
||||
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||
//
|
||||
@ -38,11 +45,16 @@ type Bind interface {
|
||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||
SetMark(mark uint32) error
|
||||
|
||||
// Send writes a packet b to address ep.
|
||||
Send(b []byte, ep Endpoint) error
|
||||
// Send writes one or more packets in bufs to address ep. The length of
|
||||
// bufs must not exceed BatchSize().
|
||||
Send(bufs [][]byte, ep Endpoint) error
|
||||
|
||||
// ParseEndpoint creates a new endpoint from a string.
|
||||
ParseEndpoint(s string) (Endpoint, error)
|
||||
|
||||
// BatchSize is the number of buffers expected to be passed to
|
||||
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
||||
BatchSize() int
|
||||
}
|
||||
|
||||
// BindSocketToInterface is implemented by Bind objects that support being
|
||||
|
24
conn/conn_test.go
Normal file
24
conn/conn_test.go
Normal file
@ -0,0 +1,24 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPrettyName(t *testing.T) {
|
||||
var (
|
||||
recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
|
||||
)
|
||||
|
||||
const want = "TestPrettyName"
|
||||
|
||||
t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
|
||||
if got := recvFunc.PrettyName(); got != want {
|
||||
t.Errorf("PrettyName() = %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
43
conn/controlfns.go
Normal file
43
conn/controlfns.go
Normal file
@ -0,0 +1,43 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
|
||||
// the max supported by a default configuration of macOS. Some platforms will
|
||||
// silently clamp the value to other maximums, such as linux clamping to
|
||||
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
|
||||
// around this limitation)
|
||||
const socketBufferSize = 7 << 20
|
||||
|
||||
// controlFn is the callback function signature from net.ListenConfig.Control.
|
||||
// It is used to apply platform specific configuration to the socket prior to
|
||||
// bind.
|
||||
type controlFn func(network, address string, c syscall.RawConn) error
|
||||
|
||||
// controlFns is a list of functions that are called from the listen config
|
||||
// that can apply socket options.
|
||||
var controlFns = []controlFn{}
|
||||
|
||||
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
||||
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
||||
// information OOB configuration for sticky sockets.
|
||||
func listenConfig() *net.ListenConfig {
|
||||
return &net.ListenConfig{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
for _, fn := range controlFns {
|
||||
if err := fn(network, address, c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
109
conn/controlfns_linux.go
Normal file
109
conn/controlfns_linux.go
Normal file
@ -0,0 +1,109 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Taken from go/src/internal/syscall/unix/kernel_version_linux.go
|
||||
func kernelVersion() (major, minor int) {
|
||||
var uname unix.Utsname
|
||||
if err := unix.Uname(&uname); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
values [2]int
|
||||
value, vi int
|
||||
)
|
||||
for _, c := range uname.Release {
|
||||
if '0' <= c && c <= '9' {
|
||||
value = (value * 10) + int(c-'0')
|
||||
} else {
|
||||
// Note that we're assuming N.N.N here.
|
||||
// If we see anything else, we are likely to mis-parse it.
|
||||
values[vi] = value
|
||||
vi++
|
||||
if vi >= len(values) {
|
||||
break
|
||||
}
|
||||
value = 0
|
||||
}
|
||||
}
|
||||
|
||||
return values[0], values[1]
|
||||
}
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
|
||||
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
|
||||
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
|
||||
// fail silently - the result of failure is lower performance on very fast
|
||||
// links or high latency links.
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
// Set up to *mem_max
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||
// Set beyond *mem_max if CAP_NET_ADMIN
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
|
||||
})
|
||||
},
|
||||
|
||||
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
||||
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
var err error
|
||||
switch network {
|
||||
case "udp4":
|
||||
if runtime.GOOS != "android" {
|
||||
c.Control(func(fd uintptr) {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
||||
})
|
||||
}
|
||||
case "udp6":
|
||||
c.Control(func(fd uintptr) {
|
||||
if runtime.GOOS != "android" {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||
})
|
||||
default:
|
||||
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
||||
}
|
||||
return err
|
||||
},
|
||||
|
||||
// Attempt to enable UDP_GRO
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
// Kernels below 5.12 are missing 98184612aca0 ("net:
|
||||
// udp: Add support for getsockopt(..., ..., UDP_GRO,
|
||||
// ..., ...);"), which means we can't read this back
|
||||
// later. We could pipe the return value through to
|
||||
// the rest of the code, but UDP_GRO is kind of buggy
|
||||
// anyway, so just gate this here.
|
||||
major, minor := kernelVersion()
|
||||
if major < 5 || (major == 5 && minor < 12) {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.Control(func(fd uintptr) {
|
||||
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
||||
})
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
35
conn/controlfns_unix.go
Normal file
35
conn/controlfns_unix.go
Normal file
@ -0,0 +1,35 @@
|
||||
//go:build !windows && !linux && !wasm
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||
})
|
||||
},
|
||||
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
var err error
|
||||
if network == "udp6" {
|
||||
c.Control(func(fd uintptr) {
|
||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||
})
|
||||
}
|
||||
return err
|
||||
},
|
||||
)
|
||||
}
|
23
conn/controlfns_windows.go
Normal file
23
conn/controlfns_windows.go
Normal file
@ -0,0 +1,23 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func init() {
|
||||
controlFns = append(controlFns,
|
||||
func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
|
||||
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
|
||||
})
|
||||
},
|
||||
)
|
||||
}
|
@ -1,8 +1,8 @@
|
||||
//go:build !linux && !windows
|
||||
//go:build !windows
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
12
conn/errors_default.go
Normal file
12
conn/errors_default.go
Normal file
@ -0,0 +1,12 @@
|
||||
//go:build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func errShouldDisableUDPGSO(_ error) bool {
|
||||
return false
|
||||
}
|
26
conn/errors_linux.go
Normal file
26
conn/errors_linux.go
Normal file
@ -0,0 +1,26 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func errShouldDisableUDPGSO(err error) bool {
|
||||
var serr *os.SyscallError
|
||||
if errors.As(err, &serr) {
|
||||
// EIO is returned by udp_send_skb() if the device driver does not have
|
||||
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
|
||||
// See:
|
||||
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
||||
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
||||
return serr.Err == unix.EIO
|
||||
}
|
||||
return false
|
||||
}
|
15
conn/features_default.go
Normal file
15
conn/features_default.go
Normal file
@ -0,0 +1,15 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import "net"
|
||||
|
||||
func supportsUDPOffload(_ *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
return
|
||||
}
|
29
conn/features_linux.go
Normal file
29
conn/features_linux.go
Normal file
@ -0,0 +1,29 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||
rc, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = rc.Control(func(fd uintptr) {
|
||||
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
||||
txOffload = errSyscall == nil
|
||||
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
|
||||
rxOffload = errSyscall == nil && opt == 1
|
||||
})
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
return txOffload, rxOffload
|
||||
}
|
21
conn/gso_default.go
Normal file
21
conn/gso_default.go
Normal file
@ -0,0 +1,21 @@
|
||||
//go:build !linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||
func getGSOSize(control []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
|
||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||
}
|
||||
|
||||
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
|
||||
// offloading control data.
|
||||
const gsoControlSize = 0
|
65
conn/gso_linux.go
Normal file
65
conn/gso_linux.go
Normal file
@ -0,0 +1,65 @@
|
||||
//go:build linux
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
sizeOfGSOData = 2
|
||||
)
|
||||
|
||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||
func getGSOSize(control []byte) (int, error) {
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
||||
}
|
||||
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
|
||||
var gso uint16
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
|
||||
return int(gso), nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
|
||||
// data in control untouched.
|
||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||
existingLen := len(*control)
|
||||
avail := cap(*control) - existingLen
|
||||
space := unix.CmsgSpace(sizeOfGSOData)
|
||||
if avail < space {
|
||||
return
|
||||
}
|
||||
*control = (*control)[:cap(*control)]
|
||||
gsoControl := (*control)[existingLen:]
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
|
||||
hdr.Level = unix.SOL_UDP
|
||||
hdr.Type = unix.UDP_SEGMENT
|
||||
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
|
||||
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
|
||||
*control = (*control)[:existingLen+space]
|
||||
}
|
||||
|
||||
// gsoControlSize returns the recommended buffer size for pooling UDP
|
||||
// offloading control data.
|
||||
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
|
@ -2,11 +2,11 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func (bind *StdNetBind) SetMark(mark uint32) error {
|
||||
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
@ -26,13 +26,13 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) SetMark(mark uint32) error {
|
||||
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||
var operr error
|
||||
if fwmarkIoctl == 0 {
|
||||
return nil
|
||||
}
|
||||
if bind.ipv4 != nil {
|
||||
fd, err := bind.ipv4.SyscallConn()
|
||||
if s.ipv4 != nil {
|
||||
fd, err := s.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -46,8 +46,8 @@ func (bind *StdNetBind) SetMark(mark uint32) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if bind.ipv6 != nil {
|
||||
fd, err := bind.ipv6.SyscallConn()
|
||||
if s.ipv6 != nil {
|
||||
fd, err := s.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
42
conn/sticky_default.go
Normal file
42
conn/sticky_default.go
Normal file
@ -0,0 +1,42 @@
|
||||
//go:build !linux || android
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import "net/netip"
|
||||
|
||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
|
||||
// {get,set}srcControl feature set, but use alternatively named flags and need
|
||||
// ports and require testing.
|
||||
|
||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||
}
|
||||
|
||||
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||
}
|
||||
|
||||
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||
// offloading control data.
|
||||
const stickyControlSize = 0
|
||||
|
||||
const StdNetSupportsStickySockets = false
|
112
conn/sticky_linux.go
Normal file
112
conn/sticky_linux.go
Normal file
@ -0,0 +1,112 @@
|
||||
//go:build linux && !android
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||
switch len(e.src) {
|
||||
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
return netip.AddrFrom4(info.Spec_dst)
|
||||
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
// TODO: set zone. in order to do so we need to check if the address is
|
||||
// link local, and if it is perform a syscall to turn the ifindex into a
|
||||
// zone string because netip uses string zones.
|
||||
return netip.AddrFrom16(info.Addr)
|
||||
}
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||
switch len(e.src) {
|
||||
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
return info.Ifindex
|
||||
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||
return int32(info.Ifindex)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *StdNetEndpoint) SrcToString() string {
|
||||
return e.SrcIP().String()
|
||||
}
|
||||
|
||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||
// the source information found.
|
||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||
ep.ClearSrc()
|
||||
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem []byte = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IP &&
|
||||
hdr.Type == unix.IP_PKTINFO {
|
||||
|
||||
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
|
||||
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||
}
|
||||
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
|
||||
|
||||
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||
copy(ep.src, hdrBuf)
|
||||
copy(ep.src[unix.CmsgLen(0):], data)
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.Level == unix.IPPROTO_IPV6 &&
|
||||
hdr.Type == unix.IPV6_PKTINFO {
|
||||
|
||||
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
|
||||
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
||||
}
|
||||
|
||||
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
|
||||
|
||||
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||
copy(ep.src, hdrBuf)
|
||||
copy(ep.src[unix.CmsgLen(0):], data)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
|
||||
// and source ifindex found in ep. control's len will be set to 0 in the event
|
||||
// that ep is a default value.
|
||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||
if cap(*control) < len(ep.src) {
|
||||
return
|
||||
}
|
||||
*control = (*control)[:0]
|
||||
*control = append(*control, ep.src...)
|
||||
}
|
||||
|
||||
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||
// offloading control data.
|
||||
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
|
||||
|
||||
const StdNetSupportsStickySockets = true
|
266
conn/sticky_linux_test.go
Normal file
266
conn/sticky_linux_test.go
Normal file
@ -0,0 +1,266 @@
|
||||
//go:build linux && !android
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
|
||||
var buf []byte
|
||||
if addr.Is4() {
|
||||
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||
hdr := unix.Cmsghdr{
|
||||
Level: unix.IPPROTO_IP,
|
||||
Type: unix.IP_PKTINFO,
|
||||
}
|
||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
|
||||
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
|
||||
|
||||
info := unix.Inet4Pktinfo{
|
||||
Ifindex: ifidx,
|
||||
Spec_dst: addr.As4(),
|
||||
}
|
||||
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
|
||||
} else {
|
||||
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
||||
hdr := unix.Cmsghdr{
|
||||
Level: unix.IPPROTO_IPV6,
|
||||
Type: unix.IPV6_PKTINFO,
|
||||
}
|
||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
|
||||
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
|
||||
|
||||
info := unix.Inet6Pktinfo{
|
||||
Ifindex: uint32(ifidx),
|
||||
Addr: addr.As16(),
|
||||
}
|
||||
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
|
||||
}
|
||||
|
||||
ep.src = buf
|
||||
}
|
||||
|
||||
func Test_setSrcControl(t *testing.T) {
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
ep := &StdNetEndpoint{
|
||||
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
|
||||
}
|
||||
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
|
||||
|
||||
control := make([]byte, stickyControlSize)
|
||||
|
||||
setSrcControl(&control, ep)
|
||||
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
if hdr.Level != unix.IPPROTO_IP {
|
||||
t.Errorf("unexpected level: %d", hdr.Level)
|
||||
}
|
||||
if hdr.Type != unix.IP_PKTINFO {
|
||||
t.Errorf("unexpected type: %d", hdr.Type)
|
||||
}
|
||||
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
|
||||
t.Errorf("unexpected length: %d", hdr.Len)
|
||||
}
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
|
||||
t.Errorf("unexpected address: %v", info.Spec_dst)
|
||||
}
|
||||
if info.Ifindex != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
ep := &StdNetEndpoint{
|
||||
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
|
||||
}
|
||||
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
||||
|
||||
control := make([]byte, stickyControlSize)
|
||||
|
||||
setSrcControl(&control, ep)
|
||||
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
if hdr.Level != unix.IPPROTO_IPV6 {
|
||||
t.Errorf("unexpected level: %d", hdr.Level)
|
||||
}
|
||||
if hdr.Type != unix.IPV6_PKTINFO {
|
||||
t.Errorf("unexpected type: %d", hdr.Type)
|
||||
}
|
||||
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
|
||||
t.Errorf("unexpected length: %d", hdr.Len)
|
||||
}
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
if info.Addr != ep.SrcIP().As16() {
|
||||
t.Errorf("unexpected address: %v", info.Addr)
|
||||
}
|
||||
if info.Ifindex != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ClearOnNoSrc", func(t *testing.T) {
|
||||
control := make([]byte, stickyControlSize)
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = 1
|
||||
hdr.Type = 2
|
||||
hdr.Len = 3
|
||||
|
||||
setSrcControl(&control, &StdNetEndpoint{})
|
||||
|
||||
if len(control) != 0 {
|
||||
t.Errorf("unexpected control: %v", control)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_getSrcFromControl(t *testing.T) {
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
control := make([]byte, stickyControlSize)
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = unix.IPPROTO_IP
|
||||
hdr.Type = unix.IP_PKTINFO
|
||||
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
||||
info.Ifindex = 5
|
||||
|
||||
ep := &StdNetEndpoint{}
|
||||
getSrcFromControl(control, ep)
|
||||
|
||||
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
|
||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||
}
|
||||
if ep.SrcIfidx() != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||
}
|
||||
})
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
control := make([]byte, stickyControlSize)
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = unix.IPPROTO_IPV6
|
||||
hdr.Type = unix.IPV6_PKTINFO
|
||||
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
|
||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
info.Ifindex = 5
|
||||
|
||||
ep := &StdNetEndpoint{}
|
||||
getSrcFromControl(control, ep)
|
||||
|
||||
if ep.SrcIP() != netip.MustParseAddr("::1") {
|
||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||
}
|
||||
if ep.SrcIfidx() != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||
}
|
||||
})
|
||||
t.Run("ClearOnEmpty", func(t *testing.T) {
|
||||
var control []byte
|
||||
ep := &StdNetEndpoint{}
|
||||
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
||||
|
||||
getSrcFromControl(control, ep)
|
||||
if ep.SrcIP().IsValid() {
|
||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||
}
|
||||
if ep.SrcIfidx() != 0 {
|
||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||
}
|
||||
})
|
||||
t.Run("Multiple", func(t *testing.T) {
|
||||
zeroControl := make([]byte, unix.CmsgSpace(0))
|
||||
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
|
||||
zeroHdr.SetLen(unix.CmsgLen(0))
|
||||
|
||||
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = unix.IPPROTO_IP
|
||||
hdr.Type = unix.IP_PKTINFO
|
||||
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
||||
info.Ifindex = 5
|
||||
|
||||
combined := make([]byte, 0)
|
||||
combined = append(combined, zeroControl...)
|
||||
combined = append(combined, control...)
|
||||
|
||||
ep := &StdNetEndpoint{}
|
||||
getSrcFromControl(combined, ep)
|
||||
|
||||
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
|
||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||
}
|
||||
if ep.SrcIfidx() != 5 {
|
||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_listenConfig(t *testing.T) {
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
var i int
|
||||
sc.Control(func(fd uintptr) {
|
||||
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 1 {
|
||||
t.Error("IP_PKTINFO not set!")
|
||||
}
|
||||
} else {
|
||||
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||
}
|
||||
})
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
var i int
|
||||
sc.Control(func(fd uintptr) {
|
||||
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 1 {
|
||||
t.Error("IPV6_PKTINFO not set!")
|
||||
}
|
||||
} else {
|
||||
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||
}
|
||||
})
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package winrio
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -223,6 +223,60 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) remove() {
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = nil
|
||||
if node.child[0] != nil && node.child[1] != nil {
|
||||
return
|
||||
}
|
||||
bit := 0
|
||||
if node.child[0] == nil {
|
||||
bit = 1
|
||||
}
|
||||
child := node.child[bit]
|
||||
if child != nil {
|
||||
child.parent = node.parent
|
||||
}
|
||||
*node.parent.parentBit = child
|
||||
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
||||
node.zeroizePointers()
|
||||
return
|
||||
}
|
||||
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
||||
if parent.peer != nil {
|
||||
node.zeroizePointers()
|
||||
return
|
||||
}
|
||||
child = parent.child[node.parent.parentBitType^1]
|
||||
if child != nil {
|
||||
child.parent = parent.parent
|
||||
}
|
||||
*parent.parent.parentBit = child
|
||||
node.zeroizePointers()
|
||||
parent.zeroizePointers()
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
var node *trieEntry
|
||||
var exact bool
|
||||
|
||||
if prefix.Addr().Is6() {
|
||||
ip := prefix.Addr().As16()
|
||||
node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
|
||||
} else if prefix.Addr().Is4() {
|
||||
ip := prefix.Addr().As4()
|
||||
node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
|
||||
} else {
|
||||
panic(errors.New("removing unknown address type"))
|
||||
}
|
||||
if !exact || node == nil || peer != node.peer {
|
||||
return
|
||||
}
|
||||
node.remove()
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
@ -230,38 +284,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
var next *list.Element
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||
next = elem.Next()
|
||||
node := elem.Value.(*trieEntry)
|
||||
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = nil
|
||||
if node.child[0] != nil && node.child[1] != nil {
|
||||
continue
|
||||
}
|
||||
bit := 0
|
||||
if node.child[0] == nil {
|
||||
bit = 1
|
||||
}
|
||||
child := node.child[bit]
|
||||
if child != nil {
|
||||
child.parent = node.parent
|
||||
}
|
||||
*node.parent.parentBit = child
|
||||
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
||||
node.zeroizePointers()
|
||||
continue
|
||||
}
|
||||
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
||||
if parent.peer != nil {
|
||||
node.zeroizePointers()
|
||||
continue
|
||||
}
|
||||
child = parent.child[node.parent.parentBitType^1]
|
||||
if child != nil {
|
||||
child.parent = parent.parent
|
||||
}
|
||||
*parent.parent.parentBit = child
|
||||
node.zeroizePointers()
|
||||
parent.zeroizePointers()
|
||||
elem.Value.(*trieEntry).remove()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -83,7 +83,7 @@ func TestTrieRandom(t *testing.T) {
|
||||
var peers []*Peer
|
||||
var allowedIPs AllowedIPs
|
||||
|
||||
rand.Seed(1)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
|
||||
for n := 0; n < NumberOfPeers; n++ {
|
||||
peers = append(peers, &Peer{})
|
||||
@ -91,14 +91,14 @@ func TestTrieRandom(t *testing.T) {
|
||||
|
||||
for n := 0; n < NumberOfAddresses; n++ {
|
||||
var addr4 [4]byte
|
||||
rand.Read(addr4[:])
|
||||
rng.Read(addr4[:])
|
||||
cidr := uint8(rand.Intn(32) + 1)
|
||||
index := rand.Intn(NumberOfPeers)
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
|
||||
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
|
||||
|
||||
var addr6 [16]byte
|
||||
rand.Read(addr6[:])
|
||||
rng.Read(addr6[:])
|
||||
cidr = uint8(rand.Intn(128) + 1)
|
||||
index = rand.Intn(NumberOfPeers)
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
|
||||
@ -109,7 +109,7 @@ func TestTrieRandom(t *testing.T) {
|
||||
for p = 0; ; p++ {
|
||||
for n := 0; n < NumberOfTests; n++ {
|
||||
var addr4 [4]byte
|
||||
rand.Read(addr4[:])
|
||||
rng.Read(addr4[:])
|
||||
peer1 := slow4.Lookup(addr4[:])
|
||||
peer2 := allowedIPs.Lookup(addr4[:])
|
||||
if peer1 != peer2 {
|
||||
@ -117,7 +117,7 @@ func TestTrieRandom(t *testing.T) {
|
||||
}
|
||||
|
||||
var addr6 [16]byte
|
||||
rand.Read(addr6[:])
|
||||
rng.Read(addr6[:])
|
||||
peer1 = slow6.Lookup(addr6[:])
|
||||
peer2 = allowedIPs.Lookup(addr6[:])
|
||||
if peer1 != peer2 {
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -39,12 +39,12 @@ func TestCommonBits(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
|
||||
func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) {
|
||||
var trie *trieEntry
|
||||
var peers []*Peer
|
||||
root := parentIndirection{&trie, 2}
|
||||
|
||||
rand.Seed(1)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
|
||||
const AddressLength = 4
|
||||
|
||||
@ -54,15 +54,15 @@ func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
|
||||
|
||||
for n := 0; n < addressNumber; n++ {
|
||||
var addr [AddressLength]byte
|
||||
rand.Read(addr[:])
|
||||
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||
index := rand.Int() % peerNumber
|
||||
rng.Read(addr[:])
|
||||
cidr := uint8(rng.Uint32() % (AddressLength * 8))
|
||||
index := rng.Int() % peerNumber
|
||||
root.insert(addr[:], cidr, peers[index])
|
||||
}
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var addr [AddressLength]byte
|
||||
rand.Read(addr[:])
|
||||
rng.Read(addr[:])
|
||||
trie.lookup(addr[:])
|
||||
}
|
||||
}
|
||||
@ -101,6 +101,10 @@ func TestTrieIPv4(t *testing.T) {
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||
}
|
||||
|
||||
remove := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
||||
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||
}
|
||||
|
||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
||||
if p != peer {
|
||||
@ -176,6 +180,21 @@ func TestTrieIPv4(t *testing.T) {
|
||||
allowedIPs.RemoveByPeer(a)
|
||||
|
||||
assertNEQ(a, 192, 168, 0, 1)
|
||||
|
||||
insert(a, 1, 0, 0, 0, 32)
|
||||
insert(a, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 1, 0, 0, 0)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 192, 0, 0, 0, 32)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(nil, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(b, 192, 0, 0, 0, 24)
|
||||
assertEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 192, 0, 0, 0, 24)
|
||||
assertNEQ(a, 192, 0, 0, 1)
|
||||
remove(a, 1, 0, 0, 0, 32)
|
||||
assertNEQ(a, 1, 0, 0, 0)
|
||||
}
|
||||
|
||||
/* Test ported from kernel implementation:
|
||||
@ -211,6 +230,15 @@ func TestTrieIPv6(t *testing.T) {
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||
}
|
||||
|
||||
remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
addr = append(addr, expand(b)...)
|
||||
addr = append(addr, expand(c)...)
|
||||
addr = append(addr, expand(d)...)
|
||||
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||
}
|
||||
|
||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
@ -223,6 +251,18 @@ func TestTrieIPv6(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
assertNEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||
var addr []byte
|
||||
addr = append(addr, expand(a)...)
|
||||
addr = append(addr, expand(b)...)
|
||||
addr = append(addr, expand(c)...)
|
||||
addr = append(addr, expand(d)...)
|
||||
p := allowedIPs.Lookup(addr)
|
||||
if p == peer {
|
||||
t.Error("Assert NEQ failed")
|
||||
}
|
||||
}
|
||||
|
||||
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
|
||||
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
|
||||
insert(e, 0, 0, 0, 0, 0)
|
||||
@ -244,4 +284,21 @@ func TestTrieIPv6(t *testing.T) {
|
||||
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
|
||||
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
|
||||
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
|
||||
|
||||
insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||
assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||
remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||
assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -26,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
|
||||
func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
|
||||
datagram, ok := <-b.in6
|
||||
if !ok {
|
||||
return 0, nil, errors.New("closed")
|
||||
}
|
||||
copy(buff, datagram.msg)
|
||||
copy(buf, datagram.msg)
|
||||
return len(datagram.msg), datagram.endpoint, nil
|
||||
}
|
||||
|
||||
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
|
||||
func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
|
||||
datagram, ok := <-b.in4
|
||||
if !ok {
|
||||
return 0, nil, errors.New("closed")
|
||||
}
|
||||
copy(buff, datagram.msg)
|
||||
copy(buf, datagram.msg)
|
||||
return len(datagram.msg), datagram.endpoint, nil
|
||||
}
|
||||
|
||||
@ -51,6 +51,6 @@ func (b *DummyBind) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
|
||||
func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -19,13 +19,13 @@ import (
|
||||
// call wg.Done to remove the initial reference.
|
||||
// When the refcount hits 0, the queue's channel is closed.
|
||||
type outboundQueue struct {
|
||||
c chan *QueueOutboundElement
|
||||
c chan *QueueOutboundElementsContainer
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newOutboundQueue() *outboundQueue {
|
||||
q := &outboundQueue{
|
||||
c: make(chan *QueueOutboundElement, QueueOutboundSize),
|
||||
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||
}
|
||||
q.wg.Add(1)
|
||||
go func() {
|
||||
@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
|
||||
|
||||
// A inboundQueue is similar to an outboundQueue; see those docs.
|
||||
type inboundQueue struct {
|
||||
c chan *QueueInboundElement
|
||||
c chan *QueueInboundElementsContainer
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newInboundQueue() *inboundQueue {
|
||||
q := &inboundQueue{
|
||||
c: make(chan *QueueInboundElement, QueueInboundSize),
|
||||
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||
}
|
||||
q.wg.Add(1)
|
||||
go func() {
|
||||
@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
|
||||
}
|
||||
|
||||
type autodrainingInboundQueue struct {
|
||||
c chan *QueueInboundElement
|
||||
c chan *QueueInboundElementsContainer
|
||||
}
|
||||
|
||||
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
|
||||
@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
|
||||
// some other means, such as sending a sentinel nil values.
|
||||
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||
q := &autodrainingInboundQueue{
|
||||
c: make(chan *QueueInboundElement, QueueInboundSize),
|
||||
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||
}
|
||||
runtime.SetFinalizer(q, device.flushInboundQueue)
|
||||
return q
|
||||
@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||
for {
|
||||
select {
|
||||
case elem := <-q.c:
|
||||
elem.Lock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
case elemsContainer := <-q.c:
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||
}
|
||||
|
||||
type autodrainingOutboundQueue struct {
|
||||
c chan *QueueOutboundElement
|
||||
c chan *QueueOutboundElementsContainer
|
||||
}
|
||||
|
||||
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
|
||||
@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct {
|
||||
// All sends to the channel must be best-effort, because there may be no receivers.
|
||||
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||
q := &autodrainingOutboundQueue{
|
||||
c: make(chan *QueueOutboundElement, QueueOutboundSize),
|
||||
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||
}
|
||||
runtime.SetFinalizer(q, device.flushOutboundQueue)
|
||||
return q
|
||||
@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
|
||||
for {
|
||||
select {
|
||||
case elem := <-q.c:
|
||||
elem.Lock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
case elemsContainer := <-q.c:
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -68,9 +68,11 @@ type Device struct {
|
||||
cookieChecker CookieChecker
|
||||
|
||||
pool struct {
|
||||
messageBuffers *WaitPool
|
||||
inboundElements *WaitPool
|
||||
outboundElements *WaitPool
|
||||
inboundElementsContainer *WaitPool
|
||||
outboundElementsContainer *WaitPool
|
||||
messageBuffers *WaitPool
|
||||
inboundElements *WaitPool
|
||||
outboundElements *WaitPool
|
||||
}
|
||||
|
||||
queue struct {
|
||||
@ -295,6 +297,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||
device.rate.limiter.Init()
|
||||
device.indexTable.Init()
|
||||
|
||||
device.PopulatePools()
|
||||
|
||||
// create queues
|
||||
@ -322,6 +325,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||
return device
|
||||
}
|
||||
|
||||
// BatchSize returns the BatchSize for the device as a whole which is the max of
|
||||
// the bind batch size and the tun batch size. The batch size reported by device
|
||||
// is the size used to construct memory pools, and is the allowed batch size for
|
||||
// the lifetime of the device.
|
||||
func (device *Device) BatchSize() int {
|
||||
size := device.net.bind.BatchSize()
|
||||
dSize := device.tun.device.BatchSize()
|
||||
if size < dSize {
|
||||
size = dSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
||||
device.peers.RLock()
|
||||
defer device.peers.RUnlock()
|
||||
@ -354,6 +370,8 @@ func (device *Device) RemoveAllPeers() {
|
||||
func (device *Device) Close() {
|
||||
device.state.Lock()
|
||||
defer device.state.Unlock()
|
||||
device.ipcMutex.Lock()
|
||||
defer device.ipcMutex.Unlock()
|
||||
if device.isClosed() {
|
||||
return
|
||||
}
|
||||
@ -443,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error {
|
||||
// clear cached source addresses
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
peer.markEndpointSrcForClearing()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
@ -472,11 +486,13 @@ func (device *Device) BindUpdate() error {
|
||||
var err error
|
||||
var recvFns []conn.ReceiveFunc
|
||||
netc := &device.net
|
||||
|
||||
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
||||
if err != nil {
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
|
||||
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
||||
if err != nil {
|
||||
netc.bind.Close()
|
||||
@ -495,11 +511,7 @@ func (device *Device) BindUpdate() error {
|
||||
// clear cached source addresses
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
peer.markEndpointSrcForClearing()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
@ -507,8 +519,9 @@ func (device *Device) BindUpdate() error {
|
||||
device.net.stopping.Add(len(recvFns))
|
||||
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||
batchSize := netc.bind.BatchSize()
|
||||
for _, fn := range recvFns {
|
||||
go device.RoutineReceiveIncoming(fn)
|
||||
go device.RoutineReceiveIncoming(batchSize, fn)
|
||||
}
|
||||
|
||||
device.log.Verbosef("UDP bind has been updated")
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
@ -21,6 +22,7 @@ import (
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/conn/bindtest"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||
)
|
||||
|
||||
@ -307,6 +309,17 @@ func TestConcurrencySafety(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Perform bind updates and keepalive sends concurrently with tunnel use.
|
||||
t.Run("bindUpdate and keepalive", func(t *testing.T) {
|
||||
const iters = 10
|
||||
for i := 0; i < iters; i++ {
|
||||
for _, peer := range pair {
|
||||
peer.dev.BindUpdate()
|
||||
peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
close(done)
|
||||
}
|
||||
|
||||
@ -405,3 +418,59 @@ func goroutineLeakCheck(t *testing.T) {
|
||||
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
|
||||
})
|
||||
}
|
||||
|
||||
type fakeBindSized struct {
|
||||
size int
|
||||
}
|
||||
|
||||
func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
func (b *fakeBindSized) Close() error { return nil }
|
||||
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
|
||||
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
|
||||
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
|
||||
func (b *fakeBindSized) BatchSize() int { return b.size }
|
||||
|
||||
type fakeTUNDeviceSized struct {
|
||||
size int
|
||||
}
|
||||
|
||||
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
|
||||
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
|
||||
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
|
||||
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
|
||||
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
|
||||
func (t *fakeTUNDeviceSized) Close() error { return nil }
|
||||
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
|
||||
|
||||
func TestBatchSize(t *testing.T) {
|
||||
d := Device{}
|
||||
|
||||
d.net.bind = &fakeBindSized{1}
|
||||
d.tun.device = &fakeTUNDeviceSized{1}
|
||||
if want, got := 1, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
|
||||
d.net.bind = &fakeBindSized{1}
|
||||
d.tun.device = &fakeTUNDeviceSized{128}
|
||||
if want, got := 128, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
|
||||
d.net.bind = &fakeBindSized{128}
|
||||
d.tun.device = &fakeTUNDeviceSized{1}
|
||||
if want, got := 128, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
|
||||
d.net.bind = &fakeBindSized{128}
|
||||
d.tun.device = &fakeTUNDeviceSized{128}
|
||||
if want, got := 128, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
||||
device.net.brokenRoaming = true
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Lock()
|
||||
peer.disableRoaming = peer.endpoint != nil
|
||||
peer.Unlock()
|
||||
peer.endpoint.Lock()
|
||||
peer.endpoint.disableRoaming = peer.endpoint.val != nil
|
||||
peer.endpoint.Unlock()
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,11 +1,12 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
@ -115,6 +116,98 @@ type MessageCookieReply struct {
|
||||
Cookie [blake2s.Size128 + poly1305.TagSize]byte
|
||||
}
|
||||
|
||||
var errMessageLengthMismatch = errors.New("message length mismatch")
|
||||
|
||||
func (msg *MessageInitiation) unmarshal(b []byte) error {
|
||||
if len(b) != MessageInitiationSize {
|
||||
return errMessageLengthMismatch
|
||||
}
|
||||
|
||||
msg.Type = binary.LittleEndian.Uint32(b)
|
||||
msg.Sender = binary.LittleEndian.Uint32(b[4:])
|
||||
copy(msg.Ephemeral[:], b[8:])
|
||||
copy(msg.Static[:], b[8+len(msg.Ephemeral):])
|
||||
copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):])
|
||||
copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):])
|
||||
copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (msg *MessageInitiation) marshal(b []byte) error {
|
||||
if len(b) != MessageInitiationSize {
|
||||
return errMessageLengthMismatch
|
||||
}
|
||||
|
||||
binary.LittleEndian.PutUint32(b, msg.Type)
|
||||
binary.LittleEndian.PutUint32(b[4:], msg.Sender)
|
||||
copy(b[8:], msg.Ephemeral[:])
|
||||
copy(b[8+len(msg.Ephemeral):], msg.Static[:])
|
||||
copy(b[8+len(msg.Ephemeral)+len(msg.Static):], msg.Timestamp[:])
|
||||
copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):], msg.MAC1[:])
|
||||
copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):], msg.MAC2[:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (msg *MessageResponse) unmarshal(b []byte) error {
|
||||
if len(b) != MessageResponseSize {
|
||||
return errMessageLengthMismatch
|
||||
}
|
||||
|
||||
msg.Type = binary.LittleEndian.Uint32(b)
|
||||
msg.Sender = binary.LittleEndian.Uint32(b[4:])
|
||||
msg.Receiver = binary.LittleEndian.Uint32(b[8:])
|
||||
copy(msg.Ephemeral[:], b[12:])
|
||||
copy(msg.Empty[:], b[12+len(msg.Ephemeral):])
|
||||
copy(msg.MAC1[:], b[12+len(msg.Ephemeral)+len(msg.Empty):])
|
||||
copy(msg.MAC2[:], b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (msg *MessageResponse) marshal(b []byte) error {
|
||||
if len(b) != MessageResponseSize {
|
||||
return errMessageLengthMismatch
|
||||
}
|
||||
|
||||
binary.LittleEndian.PutUint32(b, msg.Type)
|
||||
binary.LittleEndian.PutUint32(b[4:], msg.Sender)
|
||||
binary.LittleEndian.PutUint32(b[8:], msg.Receiver)
|
||||
copy(b[12:], msg.Ephemeral[:])
|
||||
copy(b[12+len(msg.Ephemeral):], msg.Empty[:])
|
||||
copy(b[12+len(msg.Ephemeral)+len(msg.Empty):], msg.MAC1[:])
|
||||
copy(b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):], msg.MAC2[:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (msg *MessageCookieReply) unmarshal(b []byte) error {
|
||||
if len(b) != MessageCookieReplySize {
|
||||
return errMessageLengthMismatch
|
||||
}
|
||||
|
||||
msg.Type = binary.LittleEndian.Uint32(b)
|
||||
msg.Receiver = binary.LittleEndian.Uint32(b[4:])
|
||||
copy(msg.Nonce[:], b[8:])
|
||||
copy(msg.Cookie[:], b[8+len(msg.Nonce):])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (msg *MessageCookieReply) marshal(b []byte) error {
|
||||
if len(b) != MessageCookieReplySize {
|
||||
return errMessageLengthMismatch
|
||||
}
|
||||
|
||||
binary.LittleEndian.PutUint32(b, msg.Type)
|
||||
binary.LittleEndian.PutUint32(b[4:], msg.Receiver)
|
||||
copy(b[8:], msg.Nonce[:])
|
||||
copy(b[8+len(msg.Nonce):], msg.Cookie[:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type Handshake struct {
|
||||
state handshakeState
|
||||
mutex sync.RWMutex
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -17,17 +17,20 @@ import (
|
||||
|
||||
type Peer struct {
|
||||
isRunning atomic.Bool
|
||||
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
|
||||
keypairs Keypairs
|
||||
handshake Handshake
|
||||
device *Device
|
||||
endpoint conn.Endpoint
|
||||
stopping sync.WaitGroup // routines pending stop
|
||||
txBytes atomic.Uint64 // bytes send to peer (endpoint)
|
||||
rxBytes atomic.Uint64 // bytes received from peer
|
||||
lastHandshakeNano atomic.Int64 // nano seconds since epoch
|
||||
|
||||
disableRoaming bool
|
||||
endpoint struct {
|
||||
sync.Mutex
|
||||
val conn.Endpoint
|
||||
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
|
||||
disableRoaming bool
|
||||
}
|
||||
|
||||
timers struct {
|
||||
retransmitHandshake *Timer
|
||||
@ -45,9 +48,9 @@ type Peer struct {
|
||||
}
|
||||
|
||||
queue struct {
|
||||
staged chan *QueueOutboundElement // staged packets before a handshake is available
|
||||
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
||||
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
||||
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
|
||||
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
||||
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
||||
}
|
||||
|
||||
cookieGenerator CookieGenerator
|
||||
@ -74,14 +77,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
|
||||
// create peer
|
||||
peer := new(Peer)
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
|
||||
peer.cookieGenerator.Init(pk)
|
||||
peer.device = device
|
||||
peer.queue.outbound = newAutodrainingOutboundQueue(device)
|
||||
peer.queue.inbound = newAutodrainingInboundQueue(device)
|
||||
peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
|
||||
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
|
||||
|
||||
// map public key
|
||||
_, ok := device.peers.keyMap[pk]
|
||||
@ -97,7 +98,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
// reset endpoint
|
||||
peer.endpoint = nil
|
||||
peer.endpoint.Lock()
|
||||
peer.endpoint.val = nil
|
||||
peer.endpoint.disableRoaming = false
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
// init timers
|
||||
peer.timersInit()
|
||||
@ -108,7 +113,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffer(buffer []byte) error {
|
||||
func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||
peer.device.net.RLock()
|
||||
defer peer.device.net.RUnlock()
|
||||
|
||||
@ -116,16 +121,25 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
peer.RLock()
|
||||
defer peer.RUnlock()
|
||||
|
||||
if peer.endpoint == nil {
|
||||
peer.endpoint.Lock()
|
||||
endpoint := peer.endpoint.val
|
||||
if endpoint == nil {
|
||||
peer.endpoint.Unlock()
|
||||
return errors.New("no known endpoint for peer")
|
||||
}
|
||||
if peer.endpoint.clearSrcOnTx {
|
||||
endpoint.ClearSrc()
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
}
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
err := peer.device.net.bind.Send(buffer, peer.endpoint)
|
||||
err := peer.device.net.bind.Send(buffers, endpoint)
|
||||
if err == nil {
|
||||
peer.txBytes.Add(uint64(len(buffer)))
|
||||
var totalLen uint64
|
||||
for _, b := range buffers {
|
||||
totalLen += uint64(len(b))
|
||||
}
|
||||
peer.txBytes.Add(totalLen)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -187,8 +201,12 @@ func (peer *Peer) Start() {
|
||||
|
||||
device.flushInboundQueue(peer.queue.inbound)
|
||||
device.flushOutboundQueue(peer.queue.outbound)
|
||||
go peer.RoutineSequentialSender()
|
||||
go peer.RoutineSequentialReceiver()
|
||||
|
||||
// Use the device batch size, not the bind batch size, as the device size is
|
||||
// the size of the batch pools.
|
||||
batchSize := peer.device.BatchSize()
|
||||
go peer.RoutineSequentialSender(batchSize)
|
||||
go peer.RoutineSequentialReceiver(batchSize)
|
||||
|
||||
peer.isRunning.Store(true)
|
||||
}
|
||||
@ -259,10 +277,20 @@ func (peer *Peer) Stop() {
|
||||
}
|
||||
|
||||
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
||||
if peer.disableRoaming {
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
if peer.endpoint.disableRoaming {
|
||||
return
|
||||
}
|
||||
peer.Lock()
|
||||
peer.endpoint = endpoint
|
||||
peer.Unlock()
|
||||
peer.endpoint.clearSrcOnTx = false
|
||||
peer.endpoint.val = endpoint
|
||||
}
|
||||
|
||||
func (peer *Peer) markEndpointSrcForClearing() {
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
if peer.endpoint.val == nil {
|
||||
return
|
||||
}
|
||||
peer.endpoint.clearSrcOnTx = true
|
||||
}
|
||||
|
@ -1,20 +1,19 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type WaitPool struct {
|
||||
pool sync.Pool
|
||||
cond sync.Cond
|
||||
lock sync.Mutex
|
||||
count atomic.Uint32
|
||||
count uint32 // Get calls not yet Put back
|
||||
max uint32
|
||||
}
|
||||
|
||||
@ -27,10 +26,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
|
||||
func (p *WaitPool) Get() any {
|
||||
if p.max != 0 {
|
||||
p.lock.Lock()
|
||||
for p.count.Load() >= p.max {
|
||||
for p.count >= p.max {
|
||||
p.cond.Wait()
|
||||
}
|
||||
p.count.Add(1)
|
||||
p.count++
|
||||
p.lock.Unlock()
|
||||
}
|
||||
return p.pool.Get()
|
||||
@ -41,11 +40,21 @@ func (p *WaitPool) Put(x any) {
|
||||
if p.max == 0 {
|
||||
return
|
||||
}
|
||||
p.count.Add(^uint32(0))
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
p.count--
|
||||
p.cond.Signal()
|
||||
}
|
||||
|
||||
func (device *Device) PopulatePools() {
|
||||
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
s := make([]*QueueInboundElement, 0, device.BatchSize())
|
||||
return &QueueInboundElementsContainer{elems: s}
|
||||
})
|
||||
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
s := make([]*QueueOutboundElement, 0, device.BatchSize())
|
||||
return &QueueOutboundElementsContainer{elems: s}
|
||||
})
|
||||
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
return new([MaxMessageSize]byte)
|
||||
})
|
||||
@ -57,6 +66,34 @@ func (device *Device) PopulatePools() {
|
||||
})
|
||||
}
|
||||
|
||||
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
|
||||
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
|
||||
c.Mutex = sync.Mutex{}
|
||||
return c
|
||||
}
|
||||
|
||||
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
|
||||
for i := range c.elems {
|
||||
c.elems[i] = nil
|
||||
}
|
||||
c.elems = c.elems[:0]
|
||||
device.pool.inboundElementsContainer.Put(c)
|
||||
}
|
||||
|
||||
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
|
||||
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
|
||||
c.Mutex = sync.Mutex{}
|
||||
return c
|
||||
}
|
||||
|
||||
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
|
||||
for i := range c.elems {
|
||||
c.elems[i] = nil
|
||||
}
|
||||
c.elems = c.elems[:0]
|
||||
device.pool.outboundElementsContainer.Put(c)
|
||||
}
|
||||
|
||||
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
||||
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -32,7 +32,9 @@ func TestWaitPool(t *testing.T) {
|
||||
wg.Add(workers)
|
||||
var max atomic.Uint32
|
||||
updateMax := func() {
|
||||
count := p.count.Load()
|
||||
p.lock.Lock()
|
||||
count := p.count
|
||||
p.lock.Unlock()
|
||||
if count > p.max {
|
||||
t.Errorf("count (%d) > max (%d)", count, p.max)
|
||||
}
|
||||
@ -89,3 +91,51 @@ func BenchmarkWaitPool(b *testing.B) {
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkWaitPoolEmpty(b *testing.B) {
|
||||
var wg sync.WaitGroup
|
||||
var trials atomic.Int32
|
||||
trials.Store(int32(b.N))
|
||||
workers := runtime.NumCPU() + 2
|
||||
if workers-4 <= 0 {
|
||||
b.Skip("Not enough cores")
|
||||
}
|
||||
p := NewWaitPool(0, func() any { return make([]byte, 16) })
|
||||
wg.Add(workers)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < workers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for trials.Add(-1) > 0 {
|
||||
x := p.Get()
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||
p.Put(x)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkSyncPool(b *testing.B) {
|
||||
var wg sync.WaitGroup
|
||||
var trials atomic.Int32
|
||||
trials.Store(int32(b.N))
|
||||
workers := runtime.NumCPU() + 2
|
||||
if workers-4 <= 0 {
|
||||
b.Skip("Not enough cores")
|
||||
}
|
||||
p := sync.Pool{New: func() any { return make([]byte, 16) }}
|
||||
wg.Add(workers)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < workers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for trials.Add(-1) > 0 {
|
||||
x := p.Get()
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||
p.Put(x)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
@ -1,17 +1,19 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
/* Reduce memory consumption for Android */
|
||||
|
||||
const (
|
||||
QueueStagedSize = 128
|
||||
QueueStagedSize = conn.IdealBatchSize
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
MaxSegmentSize = 2200
|
||||
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
|
||||
PreallocatedBuffersPerPool = 4096
|
||||
)
|
||||
|
@ -2,13 +2,15 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
const (
|
||||
QueueStagedSize = 128
|
||||
QueueStagedSize = conn.IdealBatchSize
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,12 +1,11 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
@ -27,7 +26,6 @@ type QueueHandshakeElement struct {
|
||||
}
|
||||
|
||||
type QueueInboundElement struct {
|
||||
sync.Mutex
|
||||
buffer *[MaxMessageSize]byte
|
||||
packet []byte
|
||||
counter uint64
|
||||
@ -35,6 +33,11 @@ type QueueInboundElement struct {
|
||||
endpoint conn.Endpoint
|
||||
}
|
||||
|
||||
type QueueInboundElementsContainer struct {
|
||||
sync.Mutex
|
||||
elems []*QueueInboundElement
|
||||
}
|
||||
|
||||
// clearPointers clears elem fields that contain pointers.
|
||||
// This makes the garbage collector's life easier and
|
||||
// avoids accidentally keeping other objects around unnecessarily.
|
||||
@ -66,7 +69,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
||||
* Every time the bind is updated a new routine is started for
|
||||
* IPv4 and IPv6 (separately)
|
||||
*/
|
||||
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
||||
func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
|
||||
recvName := recv.PrettyName()
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
|
||||
@ -79,20 +82,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
||||
|
||||
// receive datagrams until conn is closed
|
||||
|
||||
buffer := device.GetMessageBuffer()
|
||||
|
||||
var (
|
||||
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
|
||||
bufs = make([][]byte, maxBatchSize)
|
||||
err error
|
||||
size int
|
||||
endpoint conn.Endpoint
|
||||
sizes = make([]int, maxBatchSize)
|
||||
count int
|
||||
endpoints = make([]conn.Endpoint, maxBatchSize)
|
||||
deathSpiral int
|
||||
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
|
||||
)
|
||||
|
||||
for {
|
||||
size, endpoint, err = recv(buffer[:])
|
||||
for i := range bufsArrs {
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for i := 0; i < maxBatchSize; i++ {
|
||||
if bufsArrs[i] != nil {
|
||||
device.PutMessageBuffer(bufsArrs[i])
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
count, err = recv(bufs, sizes, endpoints)
|
||||
if err != nil {
|
||||
device.PutMessageBuffer(buffer)
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
@ -103,101 +119,119 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
||||
if deathSpiral < 10 {
|
||||
deathSpiral++
|
||||
time.Sleep(time.Second / 3)
|
||||
buffer = device.GetMessageBuffer()
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
deathSpiral = 0
|
||||
|
||||
if size < MinMessageSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// check size of packet
|
||||
|
||||
packet := buffer[:size]
|
||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
||||
|
||||
var okay bool
|
||||
|
||||
switch msgType {
|
||||
|
||||
// check if transport
|
||||
|
||||
case MessageTransportType:
|
||||
|
||||
// check size
|
||||
|
||||
if len(packet) < MessageTransportSize {
|
||||
// handle each packet in the batch
|
||||
for i, size := range sizes[:count] {
|
||||
if size < MinMessageSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// lookup key pair
|
||||
// check size of packet
|
||||
|
||||
receiver := binary.LittleEndian.Uint32(
|
||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
||||
)
|
||||
value := device.indexTable.Lookup(receiver)
|
||||
keypair := value.keypair
|
||||
if keypair == nil {
|
||||
packet := bufsArrs[i][:size]
|
||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
||||
|
||||
switch msgType {
|
||||
|
||||
// check if transport
|
||||
|
||||
case MessageTransportType:
|
||||
|
||||
// check size
|
||||
|
||||
if len(packet) < MessageTransportSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// lookup key pair
|
||||
|
||||
receiver := binary.LittleEndian.Uint32(
|
||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
||||
)
|
||||
value := device.indexTable.Lookup(receiver)
|
||||
keypair := value.keypair
|
||||
if keypair == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// check keypair expiry
|
||||
|
||||
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// create work element
|
||||
peer := value.peer
|
||||
elem := device.GetInboundElement()
|
||||
elem.packet = packet
|
||||
elem.buffer = bufsArrs[i]
|
||||
elem.keypair = keypair
|
||||
elem.endpoint = endpoints[i]
|
||||
elem.counter = 0
|
||||
|
||||
elemsForPeer, ok := elemsByPeer[peer]
|
||||
if !ok {
|
||||
elemsForPeer = device.GetInboundElementsContainer()
|
||||
elemsForPeer.Lock()
|
||||
elemsByPeer[peer] = elemsForPeer
|
||||
}
|
||||
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
continue
|
||||
|
||||
// otherwise it is a fixed size & handshake related packet
|
||||
|
||||
case MessageInitiationType:
|
||||
if len(packet) != MessageInitiationSize {
|
||||
continue
|
||||
}
|
||||
|
||||
case MessageResponseType:
|
||||
if len(packet) != MessageResponseSize {
|
||||
continue
|
||||
}
|
||||
|
||||
case MessageCookieReplyType:
|
||||
if len(packet) != MessageCookieReplySize {
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received message with unknown type")
|
||||
continue
|
||||
}
|
||||
|
||||
// check keypair expiry
|
||||
|
||||
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// create work element
|
||||
peer := value.peer
|
||||
elem := device.GetInboundElement()
|
||||
elem.packet = packet
|
||||
elem.buffer = buffer
|
||||
elem.keypair = keypair
|
||||
elem.endpoint = endpoint
|
||||
elem.counter = 0
|
||||
elem.Mutex = sync.Mutex{}
|
||||
elem.Lock()
|
||||
|
||||
// add to decryption queues
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elem
|
||||
device.queue.decryption.c <- elem
|
||||
buffer = device.GetMessageBuffer()
|
||||
} else {
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
continue
|
||||
|
||||
// otherwise it is a fixed size & handshake related packet
|
||||
|
||||
case MessageInitiationType:
|
||||
okay = len(packet) == MessageInitiationSize
|
||||
|
||||
case MessageResponseType:
|
||||
okay = len(packet) == MessageResponseSize
|
||||
|
||||
case MessageCookieReplyType:
|
||||
okay = len(packet) == MessageCookieReplySize
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received message with unknown type")
|
||||
}
|
||||
|
||||
if okay {
|
||||
select {
|
||||
case device.queue.handshake.c <- QueueHandshakeElement{
|
||||
msgType: msgType,
|
||||
buffer: buffer,
|
||||
buffer: bufsArrs[i],
|
||||
packet: packet,
|
||||
endpoint: endpoint,
|
||||
endpoint: endpoints[i],
|
||||
}:
|
||||
buffer = device.GetMessageBuffer()
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
default:
|
||||
}
|
||||
}
|
||||
for peer, elemsContainer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elemsContainer
|
||||
device.queue.decryption.c <- elemsContainer
|
||||
} else {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
}
|
||||
delete(elemsByPeer, peer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -207,26 +241,28 @@ func (device *Device) RoutineDecryption(id int) {
|
||||
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
|
||||
device.log.Verbosef("Routine: decryption worker %d - started", id)
|
||||
|
||||
for elem := range device.queue.decryption.c {
|
||||
// split message into fields
|
||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
||||
content := elem.packet[MessageTransportOffsetContent:]
|
||||
for elemsContainer := range device.queue.decryption.c {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
// split message into fields
|
||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
||||
content := elem.packet[MessageTransportOffsetContent:]
|
||||
|
||||
// decrypt and release to consumer
|
||||
var err error
|
||||
elem.counter = binary.LittleEndian.Uint64(counter)
|
||||
// copy counter to nonce
|
||||
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
||||
elem.packet, err = elem.keypair.receive.Open(
|
||||
content[:0],
|
||||
nonce[:],
|
||||
content,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
elem.packet = nil
|
||||
// decrypt and release to consumer
|
||||
var err error
|
||||
elem.counter = binary.LittleEndian.Uint64(counter)
|
||||
// copy counter to nonce
|
||||
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
||||
elem.packet, err = elem.keypair.receive.Open(
|
||||
content[:0],
|
||||
nonce[:],
|
||||
content,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
elem.packet = nil
|
||||
}
|
||||
}
|
||||
elem.Unlock()
|
||||
elemsContainer.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@ -250,8 +286,7 @@ func (device *Device) RoutineHandshake(id int) {
|
||||
// unmarshal packet
|
||||
|
||||
var reply MessageCookieReply
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
||||
err := reply.unmarshal(elem.packet)
|
||||
if err != nil {
|
||||
device.log.Verbosef("Failed to decode cookie reply")
|
||||
goto skip
|
||||
@ -316,8 +351,7 @@ func (device *Device) RoutineHandshake(id int) {
|
||||
// unmarshal
|
||||
|
||||
var msg MessageInitiation
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||
err := msg.unmarshal(elem.packet)
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to decode initiation message")
|
||||
goto skip
|
||||
@ -349,8 +383,7 @@ func (device *Device) RoutineHandshake(id int) {
|
||||
// unmarshal
|
||||
|
||||
var msg MessageResponse
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||
err := msg.unmarshal(elem.packet)
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to decode response message")
|
||||
goto skip
|
||||
@ -393,7 +426,7 @@ func (device *Device) RoutineHandshake(id int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) RoutineSequentialReceiver() {
|
||||
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
||||
device := peer.device
|
||||
defer func() {
|
||||
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
|
||||
@ -401,89 +434,103 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
}()
|
||||
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
|
||||
|
||||
for elem := range peer.queue.inbound.c {
|
||||
if elem == nil {
|
||||
bufs := make([][]byte, 0, maxBatchSize)
|
||||
|
||||
for elemsContainer := range peer.queue.inbound.c {
|
||||
if elemsContainer == nil {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
elem.Lock()
|
||||
if elem.packet == nil {
|
||||
// decryption failed
|
||||
goto skip
|
||||
}
|
||||
|
||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
||||
goto skip
|
||||
}
|
||||
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||
peer.timersHandshakeComplete()
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||
goto skip
|
||||
}
|
||||
peer.timersDataReceived()
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case ipv4.Version:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
goto skip
|
||||
}
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||
goto skip
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
||||
goto skip
|
||||
elemsContainer.Lock()
|
||||
validTailPacket := -1
|
||||
dataPacketReceived := false
|
||||
rxBytesLen := uint64(0)
|
||||
for i, elem := range elemsContainer.elems {
|
||||
if elem.packet == nil {
|
||||
// decryption failed
|
||||
continue
|
||||
}
|
||||
|
||||
case ipv6.Version:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
goto skip
|
||||
}
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += ipv6.HeaderLen
|
||||
if int(length) > len(elem.packet) {
|
||||
goto skip
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
||||
goto skip
|
||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Packet with invalid IP version from %v", peer)
|
||||
goto skip
|
||||
validTailPacket = i
|
||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
peer.timersHandshakeComplete()
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||
continue
|
||||
}
|
||||
dataPacketReceived = true
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
case 6:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += ipv6.HeaderLen
|
||||
if int(length) > len(elem.packet) {
|
||||
continue
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Packet with invalid IP version from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
|
||||
}
|
||||
|
||||
_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
|
||||
if err != nil && !device.isClosed() {
|
||||
device.log.Errorf("Failed to write packet to TUN device: %v", err)
|
||||
peer.rxBytes.Add(rxBytesLen)
|
||||
if validTailPacket >= 0 {
|
||||
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
}
|
||||
if len(peer.queue.inbound.c) == 0 {
|
||||
err = device.tun.device.Flush()
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("Unable to flush packets: %v", err)
|
||||
if dataPacketReceived {
|
||||
peer.timersDataReceived()
|
||||
}
|
||||
if len(bufs) > 0 {
|
||||
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
|
||||
if err != nil && !device.isClosed() {
|
||||
device.log.Errorf("Failed to write packets to TUN device: %v", err)
|
||||
}
|
||||
}
|
||||
skip:
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
bufs = bufs[:0]
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
}
|
||||
}
|
||||
|
359
device/send.go
359
device/send.go
@ -1,12 +1,11 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
@ -17,6 +16,8 @@ import (
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
/* Outbound flow
|
||||
@ -44,7 +45,6 @@ import (
|
||||
*/
|
||||
|
||||
type QueueOutboundElement struct {
|
||||
sync.Mutex
|
||||
buffer *[MaxMessageSize]byte // slice holding the packet data
|
||||
packet []byte // slice of "buffer" (always!)
|
||||
nonce uint64 // nonce for encryption
|
||||
@ -52,10 +52,14 @@ type QueueOutboundElement struct {
|
||||
peer *Peer // related peer
|
||||
}
|
||||
|
||||
type QueueOutboundElementsContainer struct {
|
||||
sync.Mutex
|
||||
elems []*QueueOutboundElement
|
||||
}
|
||||
|
||||
func (device *Device) NewOutboundElement() *QueueOutboundElement {
|
||||
elem := device.GetOutboundElement()
|
||||
elem.buffer = device.GetMessageBuffer()
|
||||
elem.Mutex = sync.Mutex{}
|
||||
elem.nonce = 0
|
||||
// keypair and peer were cleared (if necessary) by clearPointers.
|
||||
return elem
|
||||
@ -77,12 +81,15 @@ func (elem *QueueOutboundElement) clearPointers() {
|
||||
func (peer *Peer) SendKeepalive() {
|
||||
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
|
||||
elem := peer.device.NewOutboundElement()
|
||||
elemsContainer := peer.device.GetOutboundElementsContainer()
|
||||
elemsContainer.elems = append(elemsContainer.elems, elem)
|
||||
select {
|
||||
case peer.queue.staged <- elem:
|
||||
case peer.queue.staged <- elemsContainer:
|
||||
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
|
||||
default:
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
}
|
||||
}
|
||||
peer.SendStagedPackets()
|
||||
@ -116,16 +123,14 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var buff [MessageInitiationSize]byte
|
||||
writer := bytes.NewBuffer(buff[:0])
|
||||
binary.Write(writer, binary.LittleEndian, msg)
|
||||
packet := writer.Bytes()
|
||||
packet := make([]byte, MessageInitiationSize)
|
||||
_ = msg.marshal(packet)
|
||||
peer.cookieGenerator.AddMacs(packet)
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err = peer.SendBuffer(packet)
|
||||
err = peer.SendBuffers([][]byte{packet})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||
}
|
||||
@ -147,10 +152,8 @@ func (peer *Peer) SendHandshakeResponse() error {
|
||||
return err
|
||||
}
|
||||
|
||||
var buff [MessageResponseSize]byte
|
||||
writer := bytes.NewBuffer(buff[:0])
|
||||
binary.Write(writer, binary.LittleEndian, response)
|
||||
packet := writer.Bytes()
|
||||
packet := make([]byte, MessageResponseSize)
|
||||
_ = response.marshal(packet)
|
||||
peer.cookieGenerator.AddMacs(packet)
|
||||
|
||||
err = peer.BeginSymmetricSession()
|
||||
@ -163,7 +166,8 @@ func (peer *Peer) SendHandshakeResponse() error {
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err = peer.SendBuffer(packet)
|
||||
// TODO: allocation could be avoided
|
||||
err = peer.SendBuffers([][]byte{packet})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||
}
|
||||
@ -180,10 +184,11 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
|
||||
return err
|
||||
}
|
||||
|
||||
var buff [MessageCookieReplySize]byte
|
||||
writer := bytes.NewBuffer(buff[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
|
||||
packet := make([]byte, MessageCookieReplySize)
|
||||
_ = reply.marshal(packet)
|
||||
// TODO: allocation could be avoided
|
||||
device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -198,11 +203,6 @@ func (peer *Peer) keepKeyFreshSending() {
|
||||
}
|
||||
}
|
||||
|
||||
/* Reads packets from the TUN and inserts
|
||||
* into staged queue for peer
|
||||
*
|
||||
* Obs. Single instance per TUN device
|
||||
*/
|
||||
func (device *Device) RoutineReadFromTUN() {
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: TUN reader - stopped")
|
||||
@ -212,81 +212,123 @@ func (device *Device) RoutineReadFromTUN() {
|
||||
|
||||
device.log.Verbosef("Routine: TUN reader - started")
|
||||
|
||||
var elem *QueueOutboundElement
|
||||
var (
|
||||
batchSize = device.BatchSize()
|
||||
readErr error
|
||||
elems = make([]*QueueOutboundElement, batchSize)
|
||||
bufs = make([][]byte, batchSize)
|
||||
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
|
||||
count = 0
|
||||
sizes = make([]int, batchSize)
|
||||
offset = MessageTransportHeaderSize
|
||||
)
|
||||
|
||||
for i := range elems {
|
||||
elems[i] = device.NewOutboundElement()
|
||||
bufs[i] = elems[i].buffer[:]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, elem := range elems {
|
||||
if elem != nil {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if elem != nil {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
// read packets
|
||||
count, readErr = device.tun.device.Read(bufs, sizes, offset)
|
||||
for i := 0; i < count; i++ {
|
||||
if sizes[i] < 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
elem := elems[i]
|
||||
elem.packet = bufs[i][offset : offset+sizes[i]]
|
||||
|
||||
// lookup peer
|
||||
var peer *Peer
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
case 6:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received packet with unknown IP version")
|
||||
}
|
||||
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
elemsForPeer, ok := elemsByPeer[peer]
|
||||
if !ok {
|
||||
elemsForPeer = device.GetOutboundElementsContainer()
|
||||
elemsByPeer[peer] = elemsForPeer
|
||||
}
|
||||
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||
elems[i] = device.NewOutboundElement()
|
||||
bufs[i] = elems[i].buffer[:]
|
||||
}
|
||||
elem = device.NewOutboundElement()
|
||||
|
||||
// read packet
|
||||
for peer, elemsForPeer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.StagePackets(elemsForPeer)
|
||||
peer.SendStagedPackets()
|
||||
} else {
|
||||
for _, elem := range elemsForPeer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsForPeer)
|
||||
}
|
||||
delete(elemsByPeer, peer)
|
||||
}
|
||||
|
||||
offset := MessageTransportHeaderSize
|
||||
size, err := device.tun.device.Read(elem.buffer[:], offset)
|
||||
if err != nil {
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, tun.ErrTooManySegments) {
|
||||
// TODO: record stat for this
|
||||
// This will happen if MSS is surprisingly small (< 576)
|
||||
// coincident with reasonably high throughput.
|
||||
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
|
||||
continue
|
||||
}
|
||||
if !device.isClosed() {
|
||||
if !errors.Is(err, os.ErrClosed) {
|
||||
device.log.Errorf("Failed to read packet from TUN device: %v", err)
|
||||
if !errors.Is(readErr, os.ErrClosed) {
|
||||
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
|
||||
}
|
||||
go device.Close()
|
||||
}
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
return
|
||||
}
|
||||
|
||||
if size == 0 || size > MaxContentSize {
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = elem.buffer[offset : offset+size]
|
||||
|
||||
// lookup peer
|
||||
|
||||
var peer *Peer
|
||||
switch elem.packet[0] >> 4 {
|
||||
case ipv4.Version:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
case ipv6.Version:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received packet with unknown IP version")
|
||||
}
|
||||
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
if peer.isRunning.Load() {
|
||||
peer.StagePacket(elem)
|
||||
elem = nil
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
|
||||
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
|
||||
for {
|
||||
select {
|
||||
case peer.queue.staged <- elem:
|
||||
case peer.queue.staged <- elems:
|
||||
return
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case tooOld := <-peer.queue.staged:
|
||||
peer.device.PutMessageBuffer(tooOld.buffer)
|
||||
peer.device.PutOutboundElement(tooOld)
|
||||
for _, elem := range tooOld.elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(tooOld)
|
||||
default:
|
||||
}
|
||||
}
|
||||
@ -305,26 +347,53 @@ top:
|
||||
}
|
||||
|
||||
for {
|
||||
var elemsContainerOOO *QueueOutboundElementsContainer
|
||||
select {
|
||||
case elem := <-peer.queue.staged:
|
||||
elem.peer = peer
|
||||
elem.nonce = keypair.sendNonce.Add(1) - 1
|
||||
if elem.nonce >= RejectAfterMessages {
|
||||
keypair.sendNonce.Store(RejectAfterMessages)
|
||||
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
|
||||
case elemsContainer := <-peer.queue.staged:
|
||||
i := 0
|
||||
for _, elem := range elemsContainer.elems {
|
||||
elem.peer = peer
|
||||
elem.nonce = keypair.sendNonce.Add(1) - 1
|
||||
if elem.nonce >= RejectAfterMessages {
|
||||
keypair.sendNonce.Store(RejectAfterMessages)
|
||||
if elemsContainerOOO == nil {
|
||||
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
|
||||
}
|
||||
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
|
||||
continue
|
||||
} else {
|
||||
elemsContainer.elems[i] = elem
|
||||
i++
|
||||
}
|
||||
|
||||
elem.keypair = keypair
|
||||
}
|
||||
elemsContainer.Lock()
|
||||
elemsContainer.elems = elemsContainer.elems[:i]
|
||||
|
||||
if elemsContainerOOO != nil {
|
||||
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
|
||||
}
|
||||
|
||||
if len(elemsContainer.elems) == 0 {
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
goto top
|
||||
}
|
||||
|
||||
elem.keypair = keypair
|
||||
elem.Lock()
|
||||
|
||||
// add to parallel and sequential queue
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.outbound.c <- elem
|
||||
peer.device.queue.encryption.c <- elem
|
||||
peer.queue.outbound.c <- elemsContainer
|
||||
peer.device.queue.encryption.c <- elemsContainer
|
||||
} else {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
for _, elem := range elemsContainer.elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
}
|
||||
|
||||
if elemsContainerOOO != nil {
|
||||
goto top
|
||||
}
|
||||
default:
|
||||
return
|
||||
@ -335,9 +404,12 @@ top:
|
||||
func (peer *Peer) FlushStagedPackets() {
|
||||
for {
|
||||
select {
|
||||
case elem := <-peer.queue.staged:
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
case elemsContainer := <-peer.queue.staged:
|
||||
for _, elem := range elemsContainer.elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
@ -371,41 +443,38 @@ func (device *Device) RoutineEncryption(id int) {
|
||||
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
|
||||
device.log.Verbosef("Routine: encryption worker %d - started", id)
|
||||
|
||||
for elem := range device.queue.encryption.c {
|
||||
// populate header fields
|
||||
header := elem.buffer[:MessageTransportHeaderSize]
|
||||
for elemsContainer := range device.queue.encryption.c {
|
||||
for _, elem := range elemsContainer.elems {
|
||||
// populate header fields
|
||||
header := elem.buffer[:MessageTransportHeaderSize]
|
||||
|
||||
fieldType := header[0:4]
|
||||
fieldReceiver := header[4:8]
|
||||
fieldNonce := header[8:16]
|
||||
fieldType := header[0:4]
|
||||
fieldReceiver := header[4:8]
|
||||
fieldNonce := header[8:16]
|
||||
|
||||
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
||||
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
|
||||
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
||||
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
||||
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
|
||||
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
||||
|
||||
// pad content to multiple of 16
|
||||
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
|
||||
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
||||
// pad content to multiple of 16
|
||||
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
|
||||
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
||||
|
||||
// encrypt content and release to consumer
|
||||
// encrypt content and release to consumer
|
||||
|
||||
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
|
||||
elem.packet = elem.keypair.send.Seal(
|
||||
header,
|
||||
nonce[:],
|
||||
elem.packet,
|
||||
nil,
|
||||
)
|
||||
elem.Unlock()
|
||||
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
|
||||
elem.packet = elem.keypair.send.Seal(
|
||||
header,
|
||||
nonce[:],
|
||||
elem.packet,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
elemsContainer.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
/* Sequentially reads packets from queue and sends to endpoint
|
||||
*
|
||||
* Obs. Single instance per peer.
|
||||
* The routine terminates then the outbound queue is closed.
|
||||
*/
|
||||
func (peer *Peer) RoutineSequentialSender() {
|
||||
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||
device := peer.device
|
||||
defer func() {
|
||||
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
|
||||
@ -413,36 +482,58 @@ func (peer *Peer) RoutineSequentialSender() {
|
||||
}()
|
||||
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
|
||||
|
||||
for elem := range peer.queue.outbound.c {
|
||||
if elem == nil {
|
||||
bufs := make([][]byte, 0, maxBatchSize)
|
||||
|
||||
for elemsContainer := range peer.queue.outbound.c {
|
||||
bufs = bufs[:0]
|
||||
if elemsContainer == nil {
|
||||
return
|
||||
}
|
||||
elem.Lock()
|
||||
if !peer.isRunning.Load() {
|
||||
// peer has been stopped; return re-usable elems to the shared pool.
|
||||
// This is an optimization only. It is possible for the peer to be stopped
|
||||
// immediately after this check, in which case, elem will get processed.
|
||||
// The timers and SendBuffer code are resilient to a few stragglers.
|
||||
// The timers and SendBuffers code are resilient to a few stragglers.
|
||||
// TODO: rework peer shutdown order to ensure
|
||||
// that we never accidentally keep timers alive longer than necessary.
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
continue
|
||||
}
|
||||
dataSent := false
|
||||
elemsContainer.Lock()
|
||||
for _, elem := range elemsContainer.elems {
|
||||
if len(elem.packet) != MessageKeepaliveSize {
|
||||
dataSent = true
|
||||
}
|
||||
bufs = append(bufs, elem.packet)
|
||||
}
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
// send message and return buffer to pool
|
||||
|
||||
err := peer.SendBuffer(elem.packet)
|
||||
if len(elem.packet) != MessageKeepaliveSize {
|
||||
err := peer.SendBuffers(bufs)
|
||||
if dataSent {
|
||||
peer.timersDataSent()
|
||||
}
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
for _, elem := range elemsContainer.elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsContainer(elemsContainer)
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to send data packet: %v", peer, err)
|
||||
var errGSO conn.ErrUDPGSODisabled
|
||||
if errors.As(err, &errGSO) {
|
||||
device.log.Verbosef(err.Error())
|
||||
err = errGSO.RetryErr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -7,6 +7,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
func (device *Device) startRouteListener(_ conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This implements userspace semantics of "sticky sockets", modeled after
|
||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||
@ -9,7 +9,7 @@
|
||||
*
|
||||
* Currently there is no way to achieve this within the net package:
|
||||
* See e.g. https://github.com/golang/go/issues/17930
|
||||
* So this code is remains platform dependent.
|
||||
* So this code remains platform dependent.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -25,7 +25,10 @@ import (
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
if _, ok := bind.(*conn.LinuxSocketBind); !ok {
|
||||
if !conn.StdNetSupportsStickySockets {
|
||||
return nil, nil
|
||||
}
|
||||
if _, ok := bind.(*conn.StdNetBind); !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@ -44,7 +47,7 @@ func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, er
|
||||
return netlinkCancel, nil
|
||||
}
|
||||
|
||||
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||
func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||
type peerEndpointPtr struct {
|
||||
peer *Peer
|
||||
endpoint *conn.Endpoint
|
||||
@ -107,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
pePtr.peer.Lock()
|
||||
if &pePtr.peer.endpoint != pePtr.endpoint {
|
||||
pePtr.peer.Unlock()
|
||||
pePtr.peer.endpoint.Lock()
|
||||
if &pePtr.peer.endpoint.val != pePtr.endpoint {
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
|
||||
pePtr.peer.Unlock()
|
||||
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
|
||||
pePtr.peer.Unlock()
|
||||
pePtr.peer.endpoint.clearSrcOnTx = true
|
||||
pePtr.peer.endpoint.Unlock()
|
||||
}
|
||||
attr = attr[attrhdr.Len:]
|
||||
}
|
||||
@ -131,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||
device.peers.RLock()
|
||||
i := uint32(1)
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.RLock()
|
||||
if peer.endpoint == nil {
|
||||
peer.RUnlock()
|
||||
peer.endpoint.Lock()
|
||||
if peer.endpoint.val == nil {
|
||||
peer.endpoint.Unlock()
|
||||
continue
|
||||
}
|
||||
nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
|
||||
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
|
||||
if nativeEP == nil {
|
||||
peer.RUnlock()
|
||||
peer.endpoint.Unlock()
|
||||
continue
|
||||
}
|
||||
if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
|
||||
peer.RUnlock()
|
||||
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
|
||||
peer.endpoint.Unlock()
|
||||
break
|
||||
}
|
||||
nlmsg := struct {
|
||||
@ -169,12 +172,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||
Len: 8,
|
||||
Type: unix.RTA_DST,
|
||||
},
|
||||
nativeEP.Dst4().Addr,
|
||||
nativeEP.DstIP().As4(),
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_SRC,
|
||||
},
|
||||
nativeEP.Src4().Src,
|
||||
nativeEP.SrcIP().As4(),
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_MARK,
|
||||
@ -185,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||
reqPeerLock.Lock()
|
||||
reqPeer[i] = peerEndpointPtr{
|
||||
peer: peer,
|
||||
endpoint: &peer.endpoint,
|
||||
endpoint: &peer.endpoint.val,
|
||||
}
|
||||
reqPeerLock.Unlock()
|
||||
peer.RUnlock()
|
||||
peer.endpoint.Unlock()
|
||||
i++
|
||||
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||
if err != nil {
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This is based heavily on timers.c from the kernel implementation.
|
||||
*/
|
||||
@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) {
|
||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
|
||||
|
||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||
peer.Lock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
peer.Unlock()
|
||||
peer.markEndpointSrcForClearing()
|
||||
|
||||
peer.SendHandshakeInitiation(true)
|
||||
}
|
||||
@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) {
|
||||
func expiredNewHandshake(peer *Peer) {
|
||||
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
|
||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||
peer.Lock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
peer.Unlock()
|
||||
peer.markEndpointSrcForClearing()
|
||||
peer.SendHandshakeInitiation(false)
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@ -99,33 +99,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||
|
||||
for _, peer := range device.peers.keyMap {
|
||||
// Serialize peer state.
|
||||
// Do the work in an anonymous function so that we can use defer.
|
||||
func() {
|
||||
peer.RLock()
|
||||
defer peer.RUnlock()
|
||||
peer.handshake.mutex.RLock()
|
||||
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
||||
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
||||
peer.handshake.mutex.RUnlock()
|
||||
sendf("protocol_version=1")
|
||||
peer.endpoint.Lock()
|
||||
if peer.endpoint.val != nil {
|
||||
sendf("endpoint=%s", peer.endpoint.val.DstToString())
|
||||
}
|
||||
peer.endpoint.Unlock()
|
||||
|
||||
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
||||
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
||||
sendf("protocol_version=1")
|
||||
if peer.endpoint != nil {
|
||||
sendf("endpoint=%s", peer.endpoint.DstToString())
|
||||
}
|
||||
nano := peer.lastHandshakeNano.Load()
|
||||
secs := nano / time.Second.Nanoseconds()
|
||||
nano %= time.Second.Nanoseconds()
|
||||
|
||||
nano := peer.lastHandshakeNano.Load()
|
||||
secs := nano / time.Second.Nanoseconds()
|
||||
nano %= time.Second.Nanoseconds()
|
||||
sendf("last_handshake_time_sec=%d", secs)
|
||||
sendf("last_handshake_time_nsec=%d", nano)
|
||||
sendf("tx_bytes=%d", peer.txBytes.Load())
|
||||
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
||||
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
||||
|
||||
sendf("last_handshake_time_sec=%d", secs)
|
||||
sendf("last_handshake_time_nsec=%d", nano)
|
||||
sendf("tx_bytes=%d", peer.txBytes.Load())
|
||||
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
||||
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
||||
|
||||
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||
sendf("allowed_ip=%s", prefix.String())
|
||||
return true
|
||||
})
|
||||
}()
|
||||
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||
sendf("allowed_ip=%s", prefix.String())
|
||||
return true
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
@ -262,7 +260,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
|
||||
return
|
||||
}
|
||||
if peer.created {
|
||||
peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
|
||||
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
|
||||
}
|
||||
if peer.device.isUp() {
|
||||
peer.Start()
|
||||
@ -345,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
||||
}
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
peer.endpoint = endpoint
|
||||
peer.endpoint.Lock()
|
||||
defer peer.endpoint.Unlock()
|
||||
peer.endpoint.val = endpoint
|
||||
|
||||
case "persistent_keepalive_interval":
|
||||
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
||||
@ -373,7 +371,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||
device.allowedips.RemoveByPeer(peer.Peer)
|
||||
|
||||
case "allowed_ip":
|
||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
||||
add := true
|
||||
verb := "Adding"
|
||||
if len(value) > 0 && value[0] == '-' {
|
||||
add = false
|
||||
verb = "Removing"
|
||||
value = value[1:]
|
||||
}
|
||||
device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb)
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||
@ -381,7 +386,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||
if peer.dummy {
|
||||
return nil
|
||||
}
|
||||
device.allowedips.Insert(prefix, peer.Peer)
|
||||
if add {
|
||||
device.allowedips.Insert(prefix, peer.Peer)
|
||||
} else {
|
||||
device.allowedips.Remove(prefix, peer.Peer)
|
||||
}
|
||||
|
||||
case "protocol_version":
|
||||
if value != "1" {
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
package main
|
||||
|
||||
|
16
go.mod
16
go.mod
@ -1,16 +1,16 @@
|
||||
module golang.zx2c4.com/wireguard
|
||||
|
||||
go 1.19
|
||||
go 1.23.1
|
||||
|
||||
require (
|
||||
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
|
||||
golang.org/x/net v0.0.0-20220225172249-27dd8689420f
|
||||
golang.org/x/sys v0.2.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
|
||||
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0
|
||||
golang.org/x/crypto v0.37.0
|
||||
golang.org/x/net v0.39.0
|
||||
golang.org/x/sys v0.32.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
golang.org/x/time v0.7.0 // indirect
|
||||
)
|
||||
|
28
go.sum
28
go.sum
@ -1,14 +1,14 @@
|
||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38=
|
||||
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc=
|
||||
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY=
|
||||
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA=
|
||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
||||
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
|
||||
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
|
@ -4,7 +4,6 @@
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package namedpipe
|
||||
|
||||
|
@ -4,7 +4,6 @@
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
|
||||
package namedpipe
|
||||
|
@ -4,7 +4,6 @@
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package namedpipe_test
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
@ -96,7 +96,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||
}
|
||||
|
||||
go func(l *UAPIListener) {
|
||||
var buff [0]byte
|
||||
var buf [0]byte
|
||||
for {
|
||||
defer uapi.inotifyRWCancel.Close()
|
||||
// start with lstat to avoid race condition
|
||||
@ -104,7 +104,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||
l.connErr <- err
|
||||
return
|
||||
}
|
||||
_, err := uapi.inotifyRWCancel.Read(buff[:])
|
||||
_, err := uapi.inotifyRWCancel.Read(buf[:])
|
||||
if err != nil {
|
||||
l.connErr <- err
|
||||
return
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
@ -1,11 +1,11 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
// Made up sentinel error codes for the js/wasm platform.
|
||||
// Made up sentinel error codes for {js,wasip1}/wasm.
|
||||
const (
|
||||
IpcErrorIO = 1
|
||||
IpcErrorInvalid = 2
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
16
main.go
16
main.go
@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
@ -13,8 +13,8 @@ import (
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
@ -111,7 +111,7 @@ func main() {
|
||||
|
||||
// open TUN device (or use supplied fd)
|
||||
|
||||
tun, err := func() (tun.Device, error) {
|
||||
tdev, err := func() (tun.Device, error) {
|
||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||
if tunFdStr == "" {
|
||||
return tun.CreateTUN(interfaceName, device.DefaultMTU)
|
||||
@ -124,7 +124,7 @@ func main() {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = syscall.SetNonblock(int(fd), true)
|
||||
err = unix.SetNonblock(int(fd), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -134,7 +134,7 @@ func main() {
|
||||
}()
|
||||
|
||||
if err == nil {
|
||||
realInterfaceName, err2 := tun.Name()
|
||||
realInterfaceName, err2 := tdev.Name()
|
||||
if err2 == nil {
|
||||
interfaceName = realInterfaceName
|
||||
}
|
||||
@ -196,7 +196,7 @@ func main() {
|
||||
files[0], // stdin
|
||||
files[1], // stdout
|
||||
files[2], // stderr
|
||||
tun.File(),
|
||||
tdev.File(),
|
||||
fileUAPI,
|
||||
},
|
||||
Dir: ".",
|
||||
@ -222,7 +222,7 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
|
||||
device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
|
||||
|
||||
logger.Verbosef("Device started")
|
||||
|
||||
@ -250,7 +250,7 @@ func main() {
|
||||
|
||||
// wait for program to terminate
|
||||
|
||||
signal.Notify(term, syscall.SIGTERM)
|
||||
signal.Notify(term, unix.SIGTERM)
|
||||
signal.Notify(term, os.Interrupt)
|
||||
|
||||
select {
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
@ -9,7 +9,8 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
@ -81,7 +82,7 @@ func main() {
|
||||
|
||||
signal.Notify(term, os.Interrupt)
|
||||
signal.Notify(term, os.Kill)
|
||||
signal.Notify(term, syscall.SIGTERM)
|
||||
signal.Notify(term, windows.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-term:
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ratelimiter
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ratelimiter
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package replay
|
||||
|
@ -1,8 +1,8 @@
|
||||
//go:build !windows && !js
|
||||
//go:build !windows && !wasm
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package rwcancel implements cancelable read/write operations on
|
||||
@ -64,7 +64,7 @@ func (rw *RWCancel) ReadyRead() bool {
|
||||
|
||||
func (rw *RWCancel) ReadyWrite() bool {
|
||||
closeFd := int32(rw.closingReader.Fd())
|
||||
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
|
||||
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLIN}}
|
||||
var err error
|
||||
for {
|
||||
_, err = unix.Poll(pollFds, -1)
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build windows || js
|
||||
//go:build windows || wasm
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tai64n
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tai64n
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
102
tun/checksum.go
Normal file
102
tun/checksum.go
Normal file
@ -0,0 +1,102 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// TODO: Explore SIMD and/or other assembly optimizations.
|
||||
func checksumNoFold(b []byte, initial uint64) uint64 {
|
||||
tmp := make([]byte, 8)
|
||||
binary.NativeEndian.PutUint64(tmp, initial)
|
||||
ac := binary.BigEndian.Uint64(tmp)
|
||||
var carry uint64
|
||||
|
||||
for len(b) >= 128 {
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry)
|
||||
ac += carry
|
||||
b = b[128:]
|
||||
}
|
||||
if len(b) >= 64 {
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
|
||||
ac += carry
|
||||
b = b[64:]
|
||||
}
|
||||
if len(b) >= 32 {
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
|
||||
ac += carry
|
||||
b = b[32:]
|
||||
}
|
||||
if len(b) >= 16 {
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
|
||||
ac += carry
|
||||
b = b[16:]
|
||||
}
|
||||
if len(b) >= 8 {
|
||||
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
|
||||
ac += carry
|
||||
b = b[8:]
|
||||
}
|
||||
if len(b) >= 4 {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0)
|
||||
ac += carry
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0)
|
||||
ac += carry
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) == 1 {
|
||||
tmp := binary.NativeEndian.Uint16([]byte{b[0], 0})
|
||||
ac, carry = bits.Add64(ac, uint64(tmp), 0)
|
||||
ac += carry
|
||||
}
|
||||
|
||||
binary.NativeEndian.PutUint64(tmp, ac)
|
||||
return binary.BigEndian.Uint64(tmp)
|
||||
}
|
||||
|
||||
func checksum(b []byte, initial uint64) uint16 {
|
||||
ac := checksumNoFold(b, initial)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
return uint16(ac)
|
||||
}
|
||||
|
||||
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
|
||||
sum := checksumNoFold(srcAddr, 0)
|
||||
sum = checksumNoFold(dstAddr, sum)
|
||||
sum = checksumNoFold([]byte{0, protocol}, sum)
|
||||
tmp := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(tmp, totalLen)
|
||||
return checksumNoFold(tmp, sum)
|
||||
}
|
98
tun/checksum_test.go
Normal file
98
tun/checksum_test.go
Normal file
@ -0,0 +1,98 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func checksumRef(b []byte, initial uint16) uint16 {
|
||||
ac := uint64(initial)
|
||||
|
||||
for len(b) >= 2 {
|
||||
ac += uint64(binary.BigEndian.Uint16(b))
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) == 1 {
|
||||
ac += uint64(b[0]) << 8
|
||||
}
|
||||
|
||||
for (ac >> 16) > 0 {
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
}
|
||||
return uint16(ac)
|
||||
}
|
||||
|
||||
func pseudoHeaderChecksumRefNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
|
||||
sum := checksumRef(srcAddr, 0)
|
||||
sum = checksumRef(dstAddr, sum)
|
||||
sum = checksumRef([]byte{0, protocol}, sum)
|
||||
tmp := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(tmp, totalLen)
|
||||
return checksumRef(tmp, sum)
|
||||
}
|
||||
|
||||
func TestChecksum(t *testing.T) {
|
||||
for length := 0; length <= 9001; length++ {
|
||||
buf := make([]byte, length)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
rng.Read(buf)
|
||||
csum := checksum(buf, 0x1234)
|
||||
csumRef := checksumRef(buf, 0x1234)
|
||||
if csum != csumRef {
|
||||
t.Error("Expected checksum", csumRef, "got", csum)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPseudoHeaderChecksum(t *testing.T) {
|
||||
for _, addrLen := range []int{4, 16} {
|
||||
for length := 0; length <= 9001; length++ {
|
||||
srcAddr := make([]byte, addrLen)
|
||||
dstAddr := make([]byte, addrLen)
|
||||
buf := make([]byte, length)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
rng.Read(srcAddr)
|
||||
rng.Read(dstAddr)
|
||||
rng.Read(buf)
|
||||
phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
|
||||
csum := checksum(buf, phSum)
|
||||
phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
|
||||
csumRef := checksumRef(buf, phSumRef)
|
||||
if csum != csumRef {
|
||||
t.Error("Expected checksumRef", csumRef, "got", csum)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkChecksum(b *testing.B) {
|
||||
lengths := []int{
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1500,
|
||||
2048,
|
||||
4096,
|
||||
8192,
|
||||
9000,
|
||||
9001,
|
||||
}
|
||||
|
||||
for _, length := range lengths {
|
||||
b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
|
||||
buf := make([]byte, length)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
rng.Read(buf)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
checksum(buf, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
12
tun/errors.go
Normal file
12
tun/errors.go
Normal file
@ -0,0 +1,12 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrTooManySegments is returned by Device.Read() when segmentation
|
||||
// overflows the length of supplied buffers. This error should not cause
|
||||
// reads to cease.
|
||||
ErrTooManySegments = errors.New("too many segments")
|
||||
)
|
@ -1,9 +1,8 @@
|
||||
//go:build ignore
|
||||
// +build ignore
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
@ -1,9 +1,8 @@
|
||||
//go:build ignore
|
||||
// +build ignore
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
@ -1,9 +1,8 @@
|
||||
//go:build ignore
|
||||
// +build ignore
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package netstack
|
||||
@ -19,12 +19,13 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"gvisor.dev/gvisor/pkg/bufferv2"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
@ -42,7 +43,8 @@ type netTun struct {
|
||||
ep *channel.Endpoint
|
||||
stack *stack.Stack
|
||||
events chan tun.Event
|
||||
incomingPacket chan *bufferv2.View
|
||||
notifyHandle *channel.NotificationHandle
|
||||
incomingPacket chan *buffer.View
|
||||
mtu int
|
||||
dnsServers []netip.Addr
|
||||
hasV4, hasV6 bool
|
||||
@ -60,12 +62,17 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
|
||||
ep: channel.New(1024, uint32(mtu), ""),
|
||||
stack: stack.New(opts),
|
||||
events: make(chan tun.Event, 10),
|
||||
incomingPacket: make(chan *bufferv2.View),
|
||||
incomingPacket: make(chan *buffer.View),
|
||||
dnsServers: dnsServers,
|
||||
mtu: mtu,
|
||||
}
|
||||
dev.ep.AddNotify(dev)
|
||||
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
|
||||
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
|
||||
tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
|
||||
if tcpipErr != nil {
|
||||
return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
|
||||
}
|
||||
dev.notifyHandle = dev.ep.AddNotify(dev)
|
||||
tcpipErr = dev.stack.CreateNIC(1, dev.ep)
|
||||
if tcpipErr != nil {
|
||||
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
||||
}
|
||||
@ -78,7 +85,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
|
||||
}
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: protoNumber,
|
||||
AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
|
||||
AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
|
||||
}
|
||||
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
||||
if tcpipErr != nil {
|
||||
@ -113,35 +120,43 @@ func (tun *netTun) Events() <-chan tun.Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *netTun) Read(buf []byte, offset int) (int, error) {
|
||||
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
|
||||
view, ok := <-tun.incomingPacket
|
||||
if !ok {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
return view.Read(buf[offset:])
|
||||
n, err := view.Read(buf[0][offset:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (tun *netTun) Write(buf []byte, offset int) (int, error) {
|
||||
packet := buf[offset:]
|
||||
if len(packet) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
|
||||
for _, buf := range buf {
|
||||
packet := buf[offset:]
|
||||
if len(packet) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
|
||||
switch packet[0] >> 4 {
|
||||
case 4:
|
||||
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||
case 6:
|
||||
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
|
||||
switch packet[0] >> 4 {
|
||||
case 4:
|
||||
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||
case 6:
|
||||
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
|
||||
default:
|
||||
return 0, syscall.EAFNOSUPPORT
|
||||
}
|
||||
}
|
||||
|
||||
return len(buf), nil
|
||||
}
|
||||
|
||||
func (tun *netTun) WriteNotify() {
|
||||
pkt := tun.ep.Read()
|
||||
if pkt.IsNil() {
|
||||
if pkt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -151,19 +166,16 @@ func (tun *netTun) WriteNotify() {
|
||||
tun.incomingPacket <- view
|
||||
}
|
||||
|
||||
func (tun *netTun) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *netTun) Close() error {
|
||||
tun.stack.RemoveNIC(1)
|
||||
tun.stack.Close()
|
||||
tun.ep.RemoveNotify(tun.notifyHandle)
|
||||
tun.ep.Close()
|
||||
|
||||
if tun.events != nil {
|
||||
close(tun.events)
|
||||
}
|
||||
|
||||
tun.ep.Close()
|
||||
|
||||
if tun.incomingPacket != nil {
|
||||
close(tun.incomingPacket)
|
||||
}
|
||||
@ -175,6 +187,10 @@ func (tun *netTun) MTU() (int, error) {
|
||||
return tun.mtu, nil
|
||||
}
|
||||
|
||||
func (tun *netTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
||||
var protoNumber tcpip.NetworkProtocolNumber
|
||||
if endpoint.Addr().Is4() {
|
||||
@ -184,7 +200,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ
|
||||
}
|
||||
return tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.Address(endpoint.Addr().AsSlice()),
|
||||
Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
|
||||
Port: endpoint.Port(),
|
||||
}, protoNumber
|
||||
}
|
||||
@ -439,7 +455,7 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
|
||||
}
|
||||
|
||||
remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))
|
||||
remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
|
||||
return res.Count, &PingAddr{remoteAddr}, nil
|
||||
}
|
||||
|
||||
@ -898,7 +914,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
|
||||
}
|
||||
}
|
||||
}
|
||||
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
|
||||
// We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
|
||||
var addrs []netip.Addr
|
||||
if tnet.hasV6 {
|
||||
addrs = append(addrsV6, addrsV4...)
|
||||
|
993
tun/offload_linux.go
Normal file
993
tun/offload_linux.go
Normal file
@ -0,0 +1,993 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
const tcpFlagsOffset = 13
|
||||
|
||||
const (
|
||||
tcpFlagFIN uint8 = 0x01
|
||||
tcpFlagPSH uint8 = 0x08
|
||||
tcpFlagACK uint8 = 0x10
|
||||
)
|
||||
|
||||
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
|
||||
// kernel symbol is virtio_net_hdr.
|
||||
type virtioNetHdr struct {
|
||||
flags uint8
|
||||
gsoType uint8
|
||||
hdrLen uint16
|
||||
gsoSize uint16
|
||||
csumStart uint16
|
||||
csumOffset uint16
|
||||
}
|
||||
|
||||
func (v *virtioNetHdr) decode(b []byte) error {
|
||||
if len(b) < virtioNetHdrLen {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *virtioNetHdr) encode(b []byte) error {
|
||||
if len(b) < virtioNetHdrLen {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
|
||||
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
|
||||
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
|
||||
)
|
||||
|
||||
// tcpFlowKey represents the key for a TCP flow.
|
||||
type tcpFlowKey struct {
|
||||
srcAddr, dstAddr [16]byte
|
||||
srcPort, dstPort uint16
|
||||
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
|
||||
isV6 bool
|
||||
}
|
||||
|
||||
// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
|
||||
type tcpGROTable struct {
|
||||
itemsByFlow map[tcpFlowKey][]tcpGROItem
|
||||
itemsPool [][]tcpGROItem
|
||||
}
|
||||
|
||||
func newTCPGROTable() *tcpGROTable {
|
||||
t := &tcpGROTable{
|
||||
itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
|
||||
itemsPool: make([][]tcpGROItem, conn.IdealBatchSize),
|
||||
}
|
||||
for i := range t.itemsPool {
|
||||
t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
|
||||
key := tcpFlowKey{}
|
||||
addrSize := dstAddrOffset - srcAddrOffset
|
||||
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
|
||||
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
|
||||
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
|
||||
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
|
||||
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
|
||||
key.isV6 = addrSize == 16
|
||||
return key
|
||||
}
|
||||
|
||||
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
||||
// returning the packets found for the flow, or inserting a new one if none
|
||||
// is found.
|
||||
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
|
||||
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||
items, ok := t.itemsByFlow[key]
|
||||
if ok {
|
||||
return items, ok
|
||||
}
|
||||
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
||||
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// insert an item in the table for the provided packet and packet metadata.
|
||||
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
|
||||
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||
item := tcpGROItem{
|
||||
key: key,
|
||||
bufsIndex: uint16(bufsIndex),
|
||||
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
|
||||
iphLen: uint8(tcphOffset),
|
||||
tcphLen: uint8(tcphLen),
|
||||
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
|
||||
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
|
||||
}
|
||||
items, ok := t.itemsByFlow[key]
|
||||
if !ok {
|
||||
items = t.newItems()
|
||||
}
|
||||
items = append(items, item)
|
||||
t.itemsByFlow[key] = items
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
|
||||
items, _ := t.itemsByFlow[item.key]
|
||||
items[i] = item
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
|
||||
items, _ := t.itemsByFlow[key]
|
||||
items = append(items[:i], items[i+1:]...)
|
||||
t.itemsByFlow[key] = items
|
||||
}
|
||||
|
||||
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
|
||||
// of a GRO evaluation across a vector of packets.
|
||||
type tcpGROItem struct {
|
||||
key tcpFlowKey
|
||||
sentSeq uint32 // the sequence number
|
||||
bufsIndex uint16 // the index into the original bufs slice
|
||||
numMerged uint16 // the number of packets merged into this item
|
||||
gsoSize uint16 // payload size
|
||||
iphLen uint8 // ip header len
|
||||
tcphLen uint8 // tcp header len
|
||||
pshSet bool // psh flag is set
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) newItems() []tcpGROItem {
|
||||
var items []tcpGROItem
|
||||
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
|
||||
return items
|
||||
}
|
||||
|
||||
func (t *tcpGROTable) reset() {
|
||||
for k, items := range t.itemsByFlow {
|
||||
items = items[:0]
|
||||
t.itemsPool = append(t.itemsPool, items)
|
||||
delete(t.itemsByFlow, k)
|
||||
}
|
||||
}
|
||||
|
||||
// udpFlowKey represents the key for a UDP flow.
|
||||
type udpFlowKey struct {
|
||||
srcAddr, dstAddr [16]byte
|
||||
srcPort, dstPort uint16
|
||||
isV6 bool
|
||||
}
|
||||
|
||||
// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
|
||||
type udpGROTable struct {
|
||||
itemsByFlow map[udpFlowKey][]udpGROItem
|
||||
itemsPool [][]udpGROItem
|
||||
}
|
||||
|
||||
func newUDPGROTable() *udpGROTable {
|
||||
u := &udpGROTable{
|
||||
itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
|
||||
itemsPool: make([][]udpGROItem, conn.IdealBatchSize),
|
||||
}
|
||||
for i := range u.itemsPool {
|
||||
u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
|
||||
key := udpFlowKey{}
|
||||
addrSize := dstAddrOffset - srcAddrOffset
|
||||
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
|
||||
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
|
||||
key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
|
||||
key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
|
||||
key.isV6 = addrSize == 16
|
||||
return key
|
||||
}
|
||||
|
||||
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
||||
// returning the packets found for the flow, or inserting a new one if none
|
||||
// is found.
|
||||
func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
|
||||
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
|
||||
items, ok := u.itemsByFlow[key]
|
||||
if ok {
|
||||
return items, ok
|
||||
}
|
||||
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
||||
u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// insert an item in the table for the provided packet and packet metadata.
|
||||
func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
|
||||
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
|
||||
item := udpGROItem{
|
||||
key: key,
|
||||
bufsIndex: uint16(bufsIndex),
|
||||
gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
|
||||
iphLen: uint8(udphOffset),
|
||||
cSumKnownInvalid: cSumKnownInvalid,
|
||||
}
|
||||
items, ok := u.itemsByFlow[key]
|
||||
if !ok {
|
||||
items = u.newItems()
|
||||
}
|
||||
items = append(items, item)
|
||||
u.itemsByFlow[key] = items
|
||||
}
|
||||
|
||||
func (u *udpGROTable) updateAt(item udpGROItem, i int) {
|
||||
items, _ := u.itemsByFlow[item.key]
|
||||
items[i] = item
|
||||
}
|
||||
|
||||
// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
|
||||
// of a GRO evaluation across a vector of packets.
|
||||
type udpGROItem struct {
|
||||
key udpFlowKey
|
||||
bufsIndex uint16 // the index into the original bufs slice
|
||||
numMerged uint16 // the number of packets merged into this item
|
||||
gsoSize uint16 // payload size
|
||||
iphLen uint8 // ip header len
|
||||
cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
|
||||
}
|
||||
|
||||
func (u *udpGROTable) newItems() []udpGROItem {
|
||||
var items []udpGROItem
|
||||
items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
|
||||
return items
|
||||
}
|
||||
|
||||
func (u *udpGROTable) reset() {
|
||||
for k, items := range u.itemsByFlow {
|
||||
items = items[:0]
|
||||
u.itemsPool = append(u.itemsPool, items)
|
||||
delete(u.itemsByFlow, k)
|
||||
}
|
||||
}
|
||||
|
||||
// canCoalesce represents the outcome of checking if two TCP packets are
|
||||
// candidates for coalescing.
|
||||
type canCoalesce int
|
||||
|
||||
const (
|
||||
coalescePrepend canCoalesce = -1
|
||||
coalesceUnavailable canCoalesce = 0
|
||||
coalesceAppend canCoalesce = 1
|
||||
)
|
||||
|
||||
// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
|
||||
// meet all requirements to be merged as part of a GRO operation, otherwise it
|
||||
// returns false.
|
||||
func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
|
||||
if len(pktA) < 9 || len(pktB) < 9 {
|
||||
return false
|
||||
}
|
||||
if pktA[0]>>4 == 6 {
|
||||
if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
|
||||
// cannot coalesce with unequal Traffic class values
|
||||
return false
|
||||
}
|
||||
if pktA[7] != pktB[7] {
|
||||
// cannot coalesce with unequal Hop limit values
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if pktA[1] != pktB[1] {
|
||||
// cannot coalesce with unequal ToS values
|
||||
return false
|
||||
}
|
||||
if pktA[6]>>5 != pktB[6]>>5 {
|
||||
// cannot coalesce with unequal DF or reserved bits. MF is checked
|
||||
// further up the stack.
|
||||
return false
|
||||
}
|
||||
if pktA[8] != pktB[8] {
|
||||
// cannot coalesce with unequal TTL values
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
||||
// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
|
||||
// packets involved in the current GRO evaluation. bufsOffset is the offset at
|
||||
// which packet data begins within bufs.
|
||||
func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
||||
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
||||
if !ipHeadersCanCoalesce(pkt, pktTarget) {
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
|
||||
// A smaller than gsoSize packet has been appended previously.
|
||||
// Nothing can come after a smaller packet on the end.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize > item.gsoSize {
|
||||
// We cannot have a larger packet following a smaller one.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
return coalesceAppend
|
||||
}
|
||||
|
||||
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
||||
// described by item. This function makes considerations that match the kernel's
|
||||
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
|
||||
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
||||
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
||||
if tcphLen != item.tcphLen {
|
||||
// cannot coalesce with unequal tcp options len
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if tcphLen > 20 {
|
||||
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
|
||||
// cannot coalesce with unequal tcp options
|
||||
return coalesceUnavailable
|
||||
}
|
||||
}
|
||||
if !ipHeadersCanCoalesce(pkt, pktTarget) {
|
||||
return coalesceUnavailable
|
||||
}
|
||||
// seq adjacency
|
||||
lhsLen := item.gsoSize
|
||||
lhsLen += item.numMerged * item.gsoSize
|
||||
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
|
||||
if item.pshSet {
|
||||
// We cannot append to a segment that has the PSH flag set, PSH
|
||||
// can only be set on the final segment in a reassembled group.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
|
||||
// A smaller than gsoSize packet has been appended previously.
|
||||
// Nothing can come after a smaller packet on the end.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize > item.gsoSize {
|
||||
// We cannot have a larger packet following a smaller one.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
return coalesceAppend
|
||||
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
|
||||
if pshSet {
|
||||
// We cannot prepend with a segment that has the PSH flag set, PSH
|
||||
// can only be set on the final segment in a reassembled group.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize < item.gsoSize {
|
||||
// We cannot have a larger packet following a smaller one.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
if gsoSize > item.gsoSize && item.numMerged > 0 {
|
||||
// There's at least one previous merge, and we're larger than all
|
||||
// previous. This would put multiple smaller packets on the end.
|
||||
return coalesceUnavailable
|
||||
}
|
||||
return coalescePrepend
|
||||
}
|
||||
return coalesceUnavailable
|
||||
}
|
||||
|
||||
func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
|
||||
srcAddrAt := ipv4SrcAddrOffset
|
||||
addrSize := 4
|
||||
if isV6 {
|
||||
srcAddrAt = ipv6SrcAddrOffset
|
||||
addrSize = 16
|
||||
}
|
||||
lenForPseudo := uint16(len(pkt) - int(iphLen))
|
||||
cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
|
||||
return ^checksum(pkt[iphLen:], cSum) == 0
|
||||
}
|
||||
|
||||
// coalesceResult represents the result of attempting to coalesce two TCP
|
||||
// packets.
|
||||
type coalesceResult int
|
||||
|
||||
const (
|
||||
coalesceInsufficientCap coalesceResult = iota
|
||||
coalescePSHEnding
|
||||
coalesceItemInvalidCSum
|
||||
coalescePktInvalidCSum
|
||||
coalesceSuccess
|
||||
)
|
||||
|
||||
// coalesceUDPPackets attempts to coalesce pkt with the packet described by
|
||||
// item, and returns the outcome.
|
||||
func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
||||
pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
|
||||
headersLen := item.iphLen + udphLen
|
||||
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
||||
|
||||
if cap(pktHead)-bufsOffset < coalescedLen {
|
||||
// We don't want to allocate a new underlying array if capacity is
|
||||
// too small.
|
||||
return coalesceInsufficientCap
|
||||
}
|
||||
if item.numMerged == 0 {
|
||||
if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
|
||||
return coalesceItemInvalidCSum
|
||||
}
|
||||
}
|
||||
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
|
||||
return coalescePktInvalidCSum
|
||||
}
|
||||
extendBy := len(pkt) - int(headersLen)
|
||||
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
||||
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
||||
|
||||
item.numMerged++
|
||||
return coalesceSuccess
|
||||
}
|
||||
|
||||
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
|
||||
// item, and returns the outcome. This function may swap bufs elements in the
|
||||
// event of a prepend as item's bufs index is already being tracked for writing
|
||||
// to a Device.
|
||||
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
||||
var pktHead []byte // the packet that will end up at the front
|
||||
headersLen := item.iphLen + item.tcphLen
|
||||
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
||||
|
||||
// Copy data
|
||||
if mode == coalescePrepend {
|
||||
pktHead = pkt
|
||||
if cap(pkt)-bufsOffset < coalescedLen {
|
||||
// We don't want to allocate a new underlying array if capacity is
|
||||
// too small.
|
||||
return coalesceInsufficientCap
|
||||
}
|
||||
if pshSet {
|
||||
return coalescePSHEnding
|
||||
}
|
||||
if item.numMerged == 0 {
|
||||
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||
return coalesceItemInvalidCSum
|
||||
}
|
||||
}
|
||||
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||
return coalescePktInvalidCSum
|
||||
}
|
||||
item.sentSeq = seq
|
||||
extendBy := coalescedLen - len(pktHead)
|
||||
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
|
||||
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
|
||||
// Flip the slice headers in bufs as part of prepend. The index of item
|
||||
// is already being tracked for writing.
|
||||
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
|
||||
} else {
|
||||
pktHead = bufs[item.bufsIndex][bufsOffset:]
|
||||
if cap(pktHead)-bufsOffset < coalescedLen {
|
||||
// We don't want to allocate a new underlying array if capacity is
|
||||
// too small.
|
||||
return coalesceInsufficientCap
|
||||
}
|
||||
if item.numMerged == 0 {
|
||||
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||
return coalesceItemInvalidCSum
|
||||
}
|
||||
}
|
||||
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||
return coalescePktInvalidCSum
|
||||
}
|
||||
if pshSet {
|
||||
// We are appending a segment with PSH set.
|
||||
item.pshSet = pshSet
|
||||
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
|
||||
}
|
||||
extendBy := len(pkt) - int(headersLen)
|
||||
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
||||
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
||||
}
|
||||
|
||||
if gsoSize > item.gsoSize {
|
||||
item.gsoSize = gsoSize
|
||||
}
|
||||
|
||||
item.numMerged++
|
||||
return coalesceSuccess
|
||||
}
|
||||
|
||||
const (
|
||||
ipv4FlagMoreFragments uint8 = 0x20
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4SrcAddrOffset = 12
|
||||
ipv6SrcAddrOffset = 8
|
||||
maxUint16 = 1<<16 - 1
|
||||
)
|
||||
|
||||
type groResult int
|
||||
|
||||
const (
|
||||
groResultNoop groResult = iota
|
||||
groResultTableInsert
|
||||
groResultCoalesced
|
||||
)
|
||||
|
||||
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
|
||||
// existing packets tracked in table. It returns a groResultNoop when no
|
||||
// action was taken, groResultTableInsert when the evaluated packet was
|
||||
// inserted into table, and groResultCoalesced when the evaluated packet was
|
||||
// coalesced with another packet in table.
|
||||
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
|
||||
pkt := bufs[pktI][offset:]
|
||||
if len(pkt) > maxUint16 {
|
||||
// A valid IPv4 or IPv6 packet will never exceed this.
|
||||
return groResultNoop
|
||||
}
|
||||
iphLen := int((pkt[0] & 0x0F) * 4)
|
||||
if isV6 {
|
||||
iphLen = 40
|
||||
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
||||
if ipv6HPayloadLen != len(pkt)-iphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
} else {
|
||||
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
||||
if totalLen != len(pkt) {
|
||||
return groResultNoop
|
||||
}
|
||||
}
|
||||
if len(pkt) < iphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
|
||||
if tcphLen < 20 || tcphLen > 60 {
|
||||
return groResultNoop
|
||||
}
|
||||
if len(pkt) < iphLen+tcphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
if !isV6 {
|
||||
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
||||
// no GRO support for fragmented segments for now
|
||||
return groResultNoop
|
||||
}
|
||||
}
|
||||
tcpFlags := pkt[iphLen+tcpFlagsOffset]
|
||||
var pshSet bool
|
||||
// not a candidate if any non-ACK flags (except PSH+ACK) are set
|
||||
if tcpFlags != tcpFlagACK {
|
||||
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
|
||||
return groResultNoop
|
||||
}
|
||||
pshSet = true
|
||||
}
|
||||
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
|
||||
// not a candidate if payload len is 0
|
||||
if gsoSize < 1 {
|
||||
return groResultNoop
|
||||
}
|
||||
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
|
||||
srcAddrOffset := ipv4SrcAddrOffset
|
||||
addrLen := 4
|
||||
if isV6 {
|
||||
srcAddrOffset = ipv6SrcAddrOffset
|
||||
addrLen = 16
|
||||
}
|
||||
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||
if !existing {
|
||||
return groResultTableInsert
|
||||
}
|
||||
for i := len(items) - 1; i >= 0; i-- {
|
||||
// In the best case of packets arriving in order iterating in reverse is
|
||||
// more efficient if there are multiple items for a given flow. This
|
||||
// also enables a natural table.deleteAt() in the
|
||||
// coalesceItemInvalidCSum case without the need for index tracking.
|
||||
// This algorithm makes a best effort to coalesce in the event of
|
||||
// unordered packets, where pkt may land anywhere in items from a
|
||||
// sequence number perspective, however once an item is inserted into
|
||||
// the table it is never compared across other items later.
|
||||
item := items[i]
|
||||
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
|
||||
if can != coalesceUnavailable {
|
||||
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
|
||||
switch result {
|
||||
case coalesceSuccess:
|
||||
table.updateAt(item, i)
|
||||
return groResultCoalesced
|
||||
case coalesceItemInvalidCSum:
|
||||
// delete the item with an invalid csum
|
||||
table.deleteAt(item.key, i)
|
||||
case coalescePktInvalidCSum:
|
||||
// no point in inserting an item that we can't coalesce
|
||||
return groResultNoop
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
// failed to coalesce with any other packets; store the item in the flow
|
||||
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||
return groResultTableInsert
|
||||
}
|
||||
|
||||
// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
|
||||
// metadata found in table.
|
||||
func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
|
||||
for _, items := range table.itemsByFlow {
|
||||
for _, item := range items {
|
||||
if item.numMerged > 0 {
|
||||
hdr := virtioNetHdr{
|
||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
||||
hdrLen: uint16(item.iphLen + item.tcphLen),
|
||||
gsoSize: item.gsoSize,
|
||||
csumStart: uint16(item.iphLen),
|
||||
csumOffset: 16,
|
||||
}
|
||||
pkt := bufs[item.bufsIndex][offset:]
|
||||
|
||||
// Recalculate the total len (IPv4) or payload len (IPv6).
|
||||
// Recalculate the (IPv4) header checksum.
|
||||
if item.key.isV6 {
|
||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
||||
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
|
||||
} else {
|
||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||
pkt[10], pkt[11] = 0, 0
|
||||
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
|
||||
iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
|
||||
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
|
||||
}
|
||||
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Calculate the pseudo header checksum and place it at the TCP
|
||||
// checksum offset. Downstream checksum offloading will combine
|
||||
// this with computation of the tcp header and payload checksum.
|
||||
addrLen := 4
|
||||
addrOffset := ipv4SrcAddrOffset
|
||||
if item.key.isV6 {
|
||||
addrLen = 16
|
||||
addrOffset = ipv6SrcAddrOffset
|
||||
}
|
||||
srcAddrAt := offset + addrOffset
|
||||
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
||||
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
||||
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
|
||||
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
|
||||
} else {
|
||||
hdr := virtioNetHdr{}
|
||||
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
|
||||
// metadata found in table.
|
||||
func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
|
||||
for _, items := range table.itemsByFlow {
|
||||
for _, item := range items {
|
||||
if item.numMerged > 0 {
|
||||
hdr := virtioNetHdr{
|
||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
||||
hdrLen: uint16(item.iphLen + udphLen),
|
||||
gsoSize: item.gsoSize,
|
||||
csumStart: uint16(item.iphLen),
|
||||
csumOffset: 6,
|
||||
}
|
||||
pkt := bufs[item.bufsIndex][offset:]
|
||||
|
||||
// Recalculate the total len (IPv4) or payload len (IPv6).
|
||||
// Recalculate the (IPv4) header checksum.
|
||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
|
||||
if item.key.isV6 {
|
||||
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
|
||||
} else {
|
||||
pkt[10], pkt[11] = 0, 0
|
||||
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
|
||||
iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
|
||||
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
|
||||
}
|
||||
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Recalculate the UDP len field value
|
||||
binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
|
||||
|
||||
// Calculate the pseudo header checksum and place it at the UDP
|
||||
// checksum offset. Downstream checksum offloading will combine
|
||||
// this with computation of the udp header and payload checksum.
|
||||
addrLen := 4
|
||||
addrOffset := ipv4SrcAddrOffset
|
||||
if item.key.isV6 {
|
||||
addrLen = 16
|
||||
addrOffset = ipv6SrcAddrOffset
|
||||
}
|
||||
srcAddrAt := offset + addrOffset
|
||||
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
||||
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
||||
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
|
||||
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
|
||||
} else {
|
||||
hdr := virtioNetHdr{}
|
||||
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type groCandidateType uint8
|
||||
|
||||
const (
|
||||
notGROCandidate groCandidateType = iota
|
||||
tcp4GROCandidate
|
||||
tcp6GROCandidate
|
||||
udp4GROCandidate
|
||||
udp6GROCandidate
|
||||
)
|
||||
|
||||
func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType {
|
||||
if len(b) < 28 {
|
||||
return notGROCandidate
|
||||
}
|
||||
if b[0]>>4 == 4 {
|
||||
if b[0]&0x0F != 5 {
|
||||
// IPv4 packets w/IP options do not coalesce
|
||||
return notGROCandidate
|
||||
}
|
||||
if b[9] == unix.IPPROTO_TCP && len(b) >= 40 {
|
||||
return tcp4GROCandidate
|
||||
}
|
||||
if b[9] == unix.IPPROTO_UDP && canUDPGRO {
|
||||
return udp4GROCandidate
|
||||
}
|
||||
} else if b[0]>>4 == 6 {
|
||||
if b[6] == unix.IPPROTO_TCP && len(b) >= 60 {
|
||||
return tcp6GROCandidate
|
||||
}
|
||||
if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO {
|
||||
return udp6GROCandidate
|
||||
}
|
||||
}
|
||||
return notGROCandidate
|
||||
}
|
||||
|
||||
const (
|
||||
udphLen = 8
|
||||
)
|
||||
|
||||
// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
|
||||
// existing packets tracked in table. It returns a groResultNoop when no
|
||||
// action was taken, groResultTableInsert when the evaluated packet was
|
||||
// inserted into table, and groResultCoalesced when the evaluated packet was
|
||||
// coalesced with another packet in table.
|
||||
func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
|
||||
pkt := bufs[pktI][offset:]
|
||||
if len(pkt) > maxUint16 {
|
||||
// A valid IPv4 or IPv6 packet will never exceed this.
|
||||
return groResultNoop
|
||||
}
|
||||
iphLen := int((pkt[0] & 0x0F) * 4)
|
||||
if isV6 {
|
||||
iphLen = 40
|
||||
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
||||
if ipv6HPayloadLen != len(pkt)-iphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
} else {
|
||||
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
||||
if totalLen != len(pkt) {
|
||||
return groResultNoop
|
||||
}
|
||||
}
|
||||
if len(pkt) < iphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
if len(pkt) < iphLen+udphLen {
|
||||
return groResultNoop
|
||||
}
|
||||
if !isV6 {
|
||||
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
||||
// no GRO support for fragmented segments for now
|
||||
return groResultNoop
|
||||
}
|
||||
}
|
||||
gsoSize := uint16(len(pkt) - udphLen - iphLen)
|
||||
// not a candidate if payload len is 0
|
||||
if gsoSize < 1 {
|
||||
return groResultNoop
|
||||
}
|
||||
srcAddrOffset := ipv4SrcAddrOffset
|
||||
addrLen := 4
|
||||
if isV6 {
|
||||
srcAddrOffset = ipv6SrcAddrOffset
|
||||
addrLen = 16
|
||||
}
|
||||
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
|
||||
if !existing {
|
||||
return groResultTableInsert
|
||||
}
|
||||
// With UDP we only check the last item, otherwise we could reorder packets
|
||||
// for a given flow. We must also always insert a new item, or successfully
|
||||
// coalesce with an existing item, for the same reason.
|
||||
item := items[len(items)-1]
|
||||
can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
|
||||
var pktCSumKnownInvalid bool
|
||||
if can == coalesceAppend {
|
||||
result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
|
||||
switch result {
|
||||
case coalesceSuccess:
|
||||
table.updateAt(item, len(items)-1)
|
||||
return groResultCoalesced
|
||||
case coalesceItemInvalidCSum:
|
||||
// If the existing item has an invalid csum we take no action. A new
|
||||
// item will be stored after it, and the existing item will never be
|
||||
// revisited as part of future coalescing candidacy checks.
|
||||
case coalescePktInvalidCSum:
|
||||
// We must insert a new item, but we also mark it as invalid csum
|
||||
// to prevent a repeat checksum validation.
|
||||
pktCSumKnownInvalid = true
|
||||
default:
|
||||
}
|
||||
}
|
||||
// failed to coalesce with any other packets; store the item in the flow
|
||||
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
|
||||
return groResultTableInsert
|
||||
}
|
||||
|
||||
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
|
||||
// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
|
||||
// empty (but non-nil), and are passed in to save allocs as the caller may reset
|
||||
// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is
|
||||
// supported.
|
||||
func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error {
|
||||
for i := range bufs {
|
||||
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
|
||||
return errors.New("invalid offset")
|
||||
}
|
||||
var result groResult
|
||||
switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) {
|
||||
case tcp4GROCandidate:
|
||||
result = tcpGRO(bufs, offset, i, tcpTable, false)
|
||||
case tcp6GROCandidate:
|
||||
result = tcpGRO(bufs, offset, i, tcpTable, true)
|
||||
case udp4GROCandidate:
|
||||
result = udpGRO(bufs, offset, i, udpTable, false)
|
||||
case udp6GROCandidate:
|
||||
result = udpGRO(bufs, offset, i, udpTable, true)
|
||||
}
|
||||
switch result {
|
||||
case groResultNoop:
|
||||
hdr := virtioNetHdr{}
|
||||
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fallthrough
|
||||
case groResultTableInsert:
|
||||
*toWrite = append(*toWrite, i)
|
||||
}
|
||||
}
|
||||
errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
|
||||
errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
|
||||
return errors.Join(errTCP, errUDP)
|
||||
}
|
||||
|
||||
// gsoSplit splits packets from in into outBuffs, writing the size of each
|
||||
// element into sizes. It returns the number of buffers populated, and/or an
|
||||
// error.
|
||||
func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) {
|
||||
iphLen := int(hdr.csumStart)
|
||||
srcAddrOffset := ipv6SrcAddrOffset
|
||||
addrLen := 16
|
||||
if !isV6 {
|
||||
in[10], in[11] = 0, 0 // clear ipv4 header checksum
|
||||
srcAddrOffset = ipv4SrcAddrOffset
|
||||
addrLen = 4
|
||||
}
|
||||
transportCsumAt := int(hdr.csumStart + hdr.csumOffset)
|
||||
in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
|
||||
var firstTCPSeqNum uint32
|
||||
var protocol uint8
|
||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
||||
protocol = unix.IPPROTO_TCP
|
||||
firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:])
|
||||
} else {
|
||||
protocol = unix.IPPROTO_UDP
|
||||
}
|
||||
nextSegmentDataAt := int(hdr.hdrLen)
|
||||
i := 0
|
||||
for ; nextSegmentDataAt < len(in); i++ {
|
||||
if i == len(outBuffs) {
|
||||
return i - 1, ErrTooManySegments
|
||||
}
|
||||
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
|
||||
if nextSegmentEnd > len(in) {
|
||||
nextSegmentEnd = len(in)
|
||||
}
|
||||
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
|
||||
totalLen := int(hdr.hdrLen) + segmentDataLen
|
||||
sizes[i] = totalLen
|
||||
out := outBuffs[i][outOffset:]
|
||||
|
||||
copy(out, in[:iphLen])
|
||||
if !isV6 {
|
||||
// For IPv4 we are responsible for incrementing the ID field,
|
||||
// updating the total len field, and recalculating the header
|
||||
// checksum.
|
||||
if i > 0 {
|
||||
id := binary.BigEndian.Uint16(out[4:])
|
||||
id += uint16(i)
|
||||
binary.BigEndian.PutUint16(out[4:], id)
|
||||
}
|
||||
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
|
||||
ipv4CSum := ^checksum(out[:iphLen], 0)
|
||||
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
|
||||
} else {
|
||||
// For IPv6 we are responsible for updating the payload length field.
|
||||
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
|
||||
}
|
||||
|
||||
// copy transport header
|
||||
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
|
||||
|
||||
if protocol == unix.IPPROTO_TCP {
|
||||
// set TCP seq and adjust TCP flags
|
||||
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
|
||||
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
|
||||
if nextSegmentEnd != len(in) {
|
||||
// FIN and PSH should only be set on last segment
|
||||
clearFlags := tcpFlagFIN | tcpFlagPSH
|
||||
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
|
||||
}
|
||||
} else {
|
||||
// set UDP header len
|
||||
binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart))
|
||||
}
|
||||
|
||||
// payload
|
||||
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
|
||||
|
||||
// transport checksum
|
||||
transportHeaderLen := int(hdr.hdrLen - hdr.csumStart)
|
||||
lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
|
||||
transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
|
||||
transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold)
|
||||
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum)
|
||||
|
||||
nextSegmentDataAt += int(hdr.gsoSize)
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
|
||||
cSumAt := cSumStart + cSumOffset
|
||||
// The initial value at the checksum offset should be summed with the
|
||||
// checksum we compute. This is typically the pseudo-header checksum.
|
||||
initial := binary.BigEndian.Uint16(in[cSumAt:])
|
||||
in[cSumAt], in[cSumAt+1] = 0, 0
|
||||
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
|
||||
return nil
|
||||
}
|
752
tun/offload_linux_test.go
Normal file
752
tun/offload_linux_test.go
Normal file
@ -0,0 +1,752 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
)
|
||||
|
||||
const (
|
||||
offset = virtioNetHdrLen
|
||||
)
|
||||
|
||||
var (
|
||||
ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
|
||||
ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
|
||||
ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
|
||||
ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
|
||||
ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
|
||||
ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
|
||||
)
|
||||
|
||||
func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
|
||||
totalLen := 28 + payloadLen
|
||||
b := make([]byte, offset+int(totalLen), 65535)
|
||||
ipv4H := header.IPv4(b[offset:])
|
||||
srcAs4 := srcIPPort.Addr().As4()
|
||||
dstAs4 := dstIPPort.Addr().As4()
|
||||
ipFields := &header.IPv4Fields{
|
||||
SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
|
||||
DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
|
||||
Protocol: unix.IPPROTO_UDP,
|
||||
TTL: 64,
|
||||
TotalLength: uint16(totalLen),
|
||||
}
|
||||
if ipFn != nil {
|
||||
ipFn(ipFields)
|
||||
}
|
||||
ipv4H.Encode(ipFields)
|
||||
udpH := header.UDP(b[offset+20:])
|
||||
udpH.Encode(&header.UDPFields{
|
||||
SrcPort: srcIPPort.Port(),
|
||||
DstPort: dstIPPort.Port(),
|
||||
Length: uint16(payloadLen + udphLen),
|
||||
})
|
||||
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
|
||||
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
|
||||
udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
|
||||
return b
|
||||
}
|
||||
|
||||
func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
|
||||
return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
|
||||
}
|
||||
|
||||
func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
|
||||
totalLen := 48 + payloadLen
|
||||
b := make([]byte, offset+int(totalLen), 65535)
|
||||
ipv6H := header.IPv6(b[offset:])
|
||||
srcAs16 := srcIPPort.Addr().As16()
|
||||
dstAs16 := dstIPPort.Addr().As16()
|
||||
ipFields := &header.IPv6Fields{
|
||||
SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
|
||||
DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
|
||||
TransportProtocol: unix.IPPROTO_UDP,
|
||||
HopLimit: 64,
|
||||
PayloadLength: uint16(payloadLen + udphLen),
|
||||
}
|
||||
if ipFn != nil {
|
||||
ipFn(ipFields)
|
||||
}
|
||||
ipv6H.Encode(ipFields)
|
||||
udpH := header.UDP(b[offset+40:])
|
||||
udpH.Encode(&header.UDPFields{
|
||||
SrcPort: srcIPPort.Port(),
|
||||
DstPort: dstIPPort.Port(),
|
||||
Length: uint16(payloadLen + udphLen),
|
||||
})
|
||||
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
|
||||
udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
|
||||
return b
|
||||
}
|
||||
|
||||
func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
|
||||
return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
|
||||
}
|
||||
|
||||
func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
|
||||
totalLen := 40 + segmentSize
|
||||
b := make([]byte, offset+int(totalLen), 65535)
|
||||
ipv4H := header.IPv4(b[offset:])
|
||||
srcAs4 := srcIPPort.Addr().As4()
|
||||
dstAs4 := dstIPPort.Addr().As4()
|
||||
ipFields := &header.IPv4Fields{
|
||||
SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
|
||||
DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
|
||||
Protocol: unix.IPPROTO_TCP,
|
||||
TTL: 64,
|
||||
TotalLength: uint16(totalLen),
|
||||
}
|
||||
if ipFn != nil {
|
||||
ipFn(ipFields)
|
||||
}
|
||||
ipv4H.Encode(ipFields)
|
||||
tcpH := header.TCP(b[offset+20:])
|
||||
tcpH.Encode(&header.TCPFields{
|
||||
SrcPort: srcIPPort.Port(),
|
||||
DstPort: dstIPPort.Port(),
|
||||
SeqNum: seq,
|
||||
AckNum: 1,
|
||||
DataOffset: 20,
|
||||
Flags: flags,
|
||||
WindowSize: 3000,
|
||||
})
|
||||
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
|
||||
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
|
||||
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
||||
return b
|
||||
}
|
||||
|
||||
func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
|
||||
return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
|
||||
}
|
||||
|
||||
func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
|
||||
totalLen := 60 + segmentSize
|
||||
b := make([]byte, offset+int(totalLen), 65535)
|
||||
ipv6H := header.IPv6(b[offset:])
|
||||
srcAs16 := srcIPPort.Addr().As16()
|
||||
dstAs16 := dstIPPort.Addr().As16()
|
||||
ipFields := &header.IPv6Fields{
|
||||
SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
|
||||
DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
|
||||
TransportProtocol: unix.IPPROTO_TCP,
|
||||
HopLimit: 64,
|
||||
PayloadLength: uint16(segmentSize + 20),
|
||||
}
|
||||
if ipFn != nil {
|
||||
ipFn(ipFields)
|
||||
}
|
||||
ipv6H.Encode(ipFields)
|
||||
tcpH := header.TCP(b[offset+40:])
|
||||
tcpH.Encode(&header.TCPFields{
|
||||
SrcPort: srcIPPort.Port(),
|
||||
DstPort: dstIPPort.Port(),
|
||||
SeqNum: seq,
|
||||
AckNum: 1,
|
||||
DataOffset: 20,
|
||||
Flags: flags,
|
||||
WindowSize: 3000,
|
||||
})
|
||||
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
|
||||
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
||||
return b
|
||||
}
|
||||
|
||||
func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
|
||||
return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
|
||||
}
|
||||
|
||||
func Test_handleVirtioRead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hdr virtioNetHdr
|
||||
pktIn []byte
|
||||
wantLens []int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"tcp4",
|
||||
virtioNetHdr{
|
||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
||||
gsoSize: 100,
|
||||
hdrLen: 40,
|
||||
csumStart: 20,
|
||||
csumOffset: 16,
|
||||
},
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
|
||||
[]int{140, 140},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"tcp6",
|
||||
virtioNetHdr{
|
||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
|
||||
gsoSize: 100,
|
||||
hdrLen: 60,
|
||||
csumStart: 40,
|
||||
csumOffset: 16,
|
||||
},
|
||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
|
||||
[]int{160, 160},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"udp4",
|
||||
virtioNetHdr{
|
||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
|
||||
gsoSize: 100,
|
||||
hdrLen: 28,
|
||||
csumStart: 20,
|
||||
csumOffset: 6,
|
||||
},
|
||||
udp4Packet(ip4PortA, ip4PortB, 200),
|
||||
[]int{128, 128},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"udp6",
|
||||
virtioNetHdr{
|
||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||
gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
|
||||
gsoSize: 100,
|
||||
hdrLen: 48,
|
||||
csumStart: 40,
|
||||
csumOffset: 6,
|
||||
},
|
||||
udp6Packet(ip6PortA, ip6PortB, 200),
|
||||
[]int{148, 148},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
out := make([][]byte, conn.IdealBatchSize)
|
||||
sizes := make([]int, conn.IdealBatchSize)
|
||||
for i := range out {
|
||||
out[i] = make([]byte, 65535)
|
||||
}
|
||||
tt.hdr.encode(tt.pktIn)
|
||||
n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
|
||||
if err != nil {
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
t.Fatalf("got err: %v", err)
|
||||
}
|
||||
if n != len(tt.wantLens) {
|
||||
t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
|
||||
}
|
||||
for i := range tt.wantLens {
|
||||
if tt.wantLens[i] != sizes[i] {
|
||||
t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func flipTCP4Checksum(b []byte) []byte {
|
||||
at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
|
||||
b[at] ^= 0xFF
|
||||
b[at+1] ^= 0xFF
|
||||
return b
|
||||
}
|
||||
|
||||
func flipUDP4Checksum(b []byte) []byte {
|
||||
at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
|
||||
b[at] ^= 0xFF
|
||||
b[at+1] ^= 0xFF
|
||||
return b
|
||||
}
|
||||
|
||||
func Fuzz_handleGRO(f *testing.F) {
|
||||
pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
|
||||
pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
|
||||
pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
|
||||
pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
|
||||
pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
|
||||
pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
|
||||
pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
|
||||
pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
|
||||
pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
|
||||
pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
|
||||
pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
|
||||
pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
|
||||
f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
|
||||
f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
|
||||
pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
|
||||
toWrite := make([]int, 0, len(pkts))
|
||||
handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
|
||||
if len(toWrite) > len(pkts) {
|
||||
t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
|
||||
}
|
||||
seenWriteI := make(map[int]bool)
|
||||
for _, writeI := range toWrite {
|
||||
if writeI < 0 || writeI > len(pkts)-1 {
|
||||
t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
|
||||
}
|
||||
if seenWriteI[writeI] {
|
||||
t.Errorf("duplicate toWrite value: %d", writeI)
|
||||
}
|
||||
seenWriteI[writeI] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_handleGRO(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pktsIn [][]byte
|
||||
canUDPGRO bool
|
||||
wantToWrite []int
|
||||
wantLens []int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"multiple protocols and flows",
|
||||
[][]byte{
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
|
||||
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||||
udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
|
||||
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
|
||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
|
||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
|
||||
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
|
||||
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||||
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||||
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||||
},
|
||||
true,
|
||||
[]int{0, 1, 2, 4, 5, 7, 9},
|
||||
[]int{240, 228, 128, 140, 260, 160, 248},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"multiple protocols and flows no UDP GRO",
|
||||
[][]byte{
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
|
||||
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||||
udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
|
||||
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
|
||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
|
||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
|
||||
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
|
||||
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||||
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||||
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||||
},
|
||||
false,
|
||||
[]int{0, 1, 2, 4, 5, 7, 8, 9, 10},
|
||||
[]int{240, 128, 128, 140, 260, 160, 128, 148, 148},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"PSH interleaved",
|
||||
[][]byte{
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
|
||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
|
||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
|
||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
|
||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
|
||||
},
|
||||
true,
|
||||
[]int{0, 2, 4, 6},
|
||||
[]int{240, 240, 260, 260},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"coalesceItemInvalidCSum",
|
||||
[][]byte{
|
||||
flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
|
||||
flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
|
||||
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||
},
|
||||
true,
|
||||
[]int{0, 1, 3, 4},
|
||||
[]int{140, 240, 128, 228},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"out of order",
|
||||
[][]byte{
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
|
||||
},
|
||||
true,
|
||||
[]int{0},
|
||||
[]int{340},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"unequal TTL",
|
||||
[][]byte{
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||||
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||||
fields.TTL++
|
||||
}),
|
||||
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||||
fields.TTL++
|
||||
}),
|
||||
},
|
||||
true,
|
||||
[]int{0, 1, 2, 3},
|
||||
[]int{140, 140, 128, 128},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"unequal ToS",
|
||||
[][]byte{
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||||
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||||
fields.TOS++
|
||||
}),
|
||||
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||||
fields.TOS++
|
||||
}),
|
||||
},
|
||||
true,
|
||||
[]int{0, 1, 2, 3},
|
||||
[]int{140, 140, 128, 128},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"unequal flags more fragments set",
|
||||
[][]byte{
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||||
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||||
fields.Flags = 1
|
||||
}),
|
||||
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||||
fields.Flags = 1
|
||||
}),
|
||||
},
|
||||
true,
|
||||
[]int{0, 1, 2, 3},
|
||||
[]int{140, 140, 128, 128},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"unequal flags DF set",
|
||||
[][]byte{
|
||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||||
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||||
fields.Flags = 2
|
||||
}),
|
||||
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||||
fields.Flags = 2
|
||||
}),
|
||||
},
|
||||
true,
|
||||
[]int{0, 1, 2, 3},
|
||||
[]int{140, 140, 128, 128},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ipv6 unequal hop limit",
|
||||
[][]byte{
|
||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
|
||||
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
|
||||
fields.HopLimit++
|
||||
}),
|
||||
udp6Packet(ip6PortA, ip6PortB, 100),
|
||||
udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
|
||||
fields.HopLimit++
|
||||
}),
|
||||
},
|
||||
true,
|
||||
[]int{0, 1, 2, 3},
|
||||
[]int{160, 160, 148, 148},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"ipv6 unequal traffic class",
|
||||
[][]byte{
|
||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
|
||||
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
|
||||
fields.TrafficClass++
|
||||
}),
|
||||
udp6Packet(ip6PortA, ip6PortB, 100),
|
||||
udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
|
||||
fields.TrafficClass++
|
||||
}),
|
||||
},
|
||||
true,
|
||||
[]int{0, 1, 2, 3},
|
||||
[]int{160, 160, 148, 148},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
toWrite := make([]int, 0, len(tt.pktsIn))
|
||||
err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
|
||||
if err != nil {
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
t.Fatalf("got err: %v", err)
|
||||
}
|
||||
if len(toWrite) != len(tt.wantToWrite) {
|
||||
t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
|
||||
}
|
||||
for i, pktI := range tt.wantToWrite {
|
||||
if tt.wantToWrite[i] != toWrite[i] {
|
||||
t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
|
||||
}
|
||||
if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
|
||||
t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_packetIsGROCandidate(t *testing.T) {
|
||||
tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
|
||||
tcp4TooShort := tcp4[:39]
|
||||
ip4InvalidHeaderLen := make([]byte, len(tcp4))
|
||||
copy(ip4InvalidHeaderLen, tcp4)
|
||||
ip4InvalidHeaderLen[0] = 0x46
|
||||
ip4InvalidProtocol := make([]byte, len(tcp4))
|
||||
copy(ip4InvalidProtocol, tcp4)
|
||||
ip4InvalidProtocol[9] = unix.IPPROTO_GRE
|
||||
|
||||
tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
|
||||
tcp6TooShort := tcp6[:59]
|
||||
ip6InvalidProtocol := make([]byte, len(tcp6))
|
||||
copy(ip6InvalidProtocol, tcp6)
|
||||
ip6InvalidProtocol[6] = unix.IPPROTO_GRE
|
||||
|
||||
udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
|
||||
udp4TooShort := udp4[:27]
|
||||
|
||||
udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
|
||||
udp6TooShort := udp6[:47]
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
b []byte
|
||||
canUDPGRO bool
|
||||
want groCandidateType
|
||||
}{
|
||||
{
|
||||
"tcp4",
|
||||
tcp4,
|
||||
true,
|
||||
tcp4GROCandidate,
|
||||
},
|
||||
{
|
||||
"tcp6",
|
||||
tcp6,
|
||||
true,
|
||||
tcp6GROCandidate,
|
||||
},
|
||||
{
|
||||
"udp4",
|
||||
udp4,
|
||||
true,
|
||||
udp4GROCandidate,
|
||||
},
|
||||
{
|
||||
"udp4 no support",
|
||||
udp4,
|
||||
false,
|
||||
notGROCandidate,
|
||||
},
|
||||
{
|
||||
"udp6",
|
||||
udp6,
|
||||
true,
|
||||
udp6GROCandidate,
|
||||
},
|
||||
{
|
||||
"udp6 no support",
|
||||
udp6,
|
||||
false,
|
||||
notGROCandidate,
|
||||
},
|
||||
{
|
||||
"udp4 too short",
|
||||
udp4TooShort,
|
||||
true,
|
||||
notGROCandidate,
|
||||
},
|
||||
{
|
||||
"udp6 too short",
|
||||
udp6TooShort,
|
||||
true,
|
||||
notGROCandidate,
|
||||
},
|
||||
{
|
||||
"tcp4 too short",
|
||||
tcp4TooShort,
|
||||
true,
|
||||
notGROCandidate,
|
||||
},
|
||||
{
|
||||
"tcp6 too short",
|
||||
tcp6TooShort,
|
||||
true,
|
||||
notGROCandidate,
|
||||
},
|
||||
{
|
||||
"invalid IP version",
|
||||
[]byte{0x00},
|
||||
true,
|
||||
notGROCandidate,
|
||||
},
|
||||
{
|
||||
"invalid IP header len",
|
||||
ip4InvalidHeaderLen,
|
||||
true,
|
||||
notGROCandidate,
|
||||
},
|
||||
{
|
||||
"ip4 invalid protocol",
|
||||
ip4InvalidProtocol,
|
||||
true,
|
||||
notGROCandidate,
|
||||
},
|
||||
{
|
||||
"ip6 invalid protocol",
|
||||
ip6InvalidProtocol,
|
||||
true,
|
||||
notGROCandidate,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
|
||||
t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_udpPacketsCanCoalesce(t *testing.T) {
|
||||
udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
|
||||
udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
|
||||
udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
|
||||
|
||||
type args struct {
|
||||
pkt []byte
|
||||
iphLen uint8
|
||||
gsoSize uint16
|
||||
item udpGROItem
|
||||
bufs [][]byte
|
||||
bufsOffset int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want canCoalesce
|
||||
}{
|
||||
{
|
||||
"coalesceAppend equal gso",
|
||||
args{
|
||||
pkt: udp4a[offset:],
|
||||
iphLen: 20,
|
||||
gsoSize: 100,
|
||||
item: udpGROItem{
|
||||
gsoSize: 100,
|
||||
iphLen: 20,
|
||||
},
|
||||
bufs: [][]byte{
|
||||
udp4a,
|
||||
udp4b,
|
||||
},
|
||||
bufsOffset: offset,
|
||||
},
|
||||
coalesceAppend,
|
||||
},
|
||||
{
|
||||
"coalesceAppend smaller gso",
|
||||
args{
|
||||
pkt: udp4a[offset : len(udp4a)-90],
|
||||
iphLen: 20,
|
||||
gsoSize: 10,
|
||||
item: udpGROItem{
|
||||
gsoSize: 100,
|
||||
iphLen: 20,
|
||||
},
|
||||
bufs: [][]byte{
|
||||
udp4a,
|
||||
udp4b,
|
||||
},
|
||||
bufsOffset: offset,
|
||||
},
|
||||
coalesceAppend,
|
||||
},
|
||||
{
|
||||
"coalesceUnavailable smaller gso previously appended",
|
||||
args{
|
||||
pkt: udp4a[offset:],
|
||||
iphLen: 20,
|
||||
gsoSize: 100,
|
||||
item: udpGROItem{
|
||||
gsoSize: 100,
|
||||
iphLen: 20,
|
||||
},
|
||||
bufs: [][]byte{
|
||||
udp4c,
|
||||
udp4b,
|
||||
},
|
||||
bufsOffset: offset,
|
||||
},
|
||||
coalesceUnavailable,
|
||||
},
|
||||
{
|
||||
"coalesceUnavailable larger following smaller",
|
||||
args{
|
||||
pkt: udp4c[offset:],
|
||||
iphLen: 20,
|
||||
gsoSize: 110,
|
||||
item: udpGROItem{
|
||||
gsoSize: 100,
|
||||
iphLen: 20,
|
||||
},
|
||||
bufs: [][]byte{
|
||||
udp4a,
|
||||
udp4c,
|
||||
},
|
||||
bufsOffset: offset,
|
||||
},
|
||||
coalesceUnavailable,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
|
||||
t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
42
tun/tun.go
42
tun/tun.go
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
@ -18,12 +18,36 @@ const (
|
||||
)
|
||||
|
||||
type Device interface {
|
||||
File() *os.File // returns the file descriptor of the device
|
||||
Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
|
||||
Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
|
||||
Flush() error // flush all previous writes to the device
|
||||
MTU() (int, error) // returns the MTU of the device
|
||||
Name() (string, error) // fetches and returns the current name
|
||||
Events() <-chan Event // returns a constant channel of events related to the device
|
||||
Close() error // stops the device and closes the event channel
|
||||
// File returns the file descriptor of the device.
|
||||
File() *os.File
|
||||
|
||||
// Read one or more packets from the Device (without any additional headers).
|
||||
// On a successful read it returns the number of packets read, and sets
|
||||
// packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
|
||||
// A nonzero offset can be used to instruct the Device on where to begin
|
||||
// reading into each element of the bufs slice.
|
||||
Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
|
||||
|
||||
// Write one or more packets to the device (without any additional headers).
|
||||
// On a successful write it returns the number of packets written. A nonzero
|
||||
// offset can be used to instruct the Device on where to begin writing from
|
||||
// each packet contained within the bufs slice.
|
||||
Write(bufs [][]byte, offset int) (int, error)
|
||||
|
||||
// MTU returns the MTU of the Device.
|
||||
MTU() (int, error)
|
||||
|
||||
// Name returns the current name of the Device.
|
||||
Name() (string, error)
|
||||
|
||||
// Events returns a channel of type Event, which is fed Device events.
|
||||
Events() <-chan Event
|
||||
|
||||
// Close stops the Device and closes the Event channel.
|
||||
Close() error
|
||||
|
||||
// BatchSize returns the preferred/max number of packets that can be read or
|
||||
// written in a single read/write call. BatchSize must not change over the
|
||||
// lifetime of a Device.
|
||||
BatchSize() int
|
||||
}
|
||||
|
@ -1,21 +1,19 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
@ -30,18 +28,6 @@ type NativeTun struct {
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
|
||||
for i := 0; i < 20; i++ {
|
||||
iface, err = net.InterfaceByIndex(index)
|
||||
if err != nil && errors.Is(err, syscall.ENOMEM) {
|
||||
time.Sleep(time.Duration(i) * time.Second / 3)
|
||||
continue
|
||||
}
|
||||
return iface, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
||||
var (
|
||||
statusUp bool
|
||||
@ -55,33 +41,29 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
||||
retry:
|
||||
n, err := unix.Read(tun.routeSocket, data)
|
||||
if err != nil {
|
||||
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
|
||||
if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR {
|
||||
goto retry
|
||||
}
|
||||
tun.errors <- err
|
||||
return
|
||||
}
|
||||
|
||||
if n < 14 {
|
||||
if n < 28 {
|
||||
continue
|
||||
}
|
||||
|
||||
if data[3 /* type */] != unix.RTM_IFINFO {
|
||||
if data[3 /* ifm_type */] != unix.RTM_IFINFO {
|
||||
continue
|
||||
}
|
||||
ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */])))
|
||||
ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifm_index */])))
|
||||
if ifindex != tunIfindex {
|
||||
continue
|
||||
}
|
||||
|
||||
iface, err := retryInterfaceByIndex(ifindex)
|
||||
if err != nil {
|
||||
tun.errors <- err
|
||||
return
|
||||
}
|
||||
flags := int(*(*uint32)(unsafe.Pointer(&data[8 /* ifm_flags */])))
|
||||
|
||||
// Up / Down event
|
||||
up := (iface.Flags & net.FlagUp) != 0
|
||||
up := (flags & syscall.IFF_UP) != 0
|
||||
if up != statusUp && up {
|
||||
tun.events <- EventUp
|
||||
}
|
||||
@ -90,11 +72,13 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
||||
}
|
||||
statusUp = up
|
||||
|
||||
mtu := int(*(*uint32)(unsafe.Pointer(&data[24 /* ifm_data.ifi_mtu */])))
|
||||
|
||||
// MTU changes
|
||||
if iface.MTU != statusMTU {
|
||||
if mtu != statusMTU {
|
||||
tun.events <- EventMTUUpdate
|
||||
}
|
||||
statusMTU = iface.MTU
|
||||
statusMTU = mtu
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,45 +201,46 @@ func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
// TODO: the BSDs look very similar in Read() and Write(). They should be
|
||||
// collapsed, with platform-specific files containing the varying parts of
|
||||
// their implementations.
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
buff := buff[offset-4:]
|
||||
n, err := tun.tunFile.Read(buff[:])
|
||||
buf := bufs[0][offset-4:]
|
||||
n, err := tun.tunFile.Read(buf[:])
|
||||
if n < 4 {
|
||||
return 0, err
|
||||
}
|
||||
return n - 4, err
|
||||
sizes[0] = n - 4
|
||||
return 1, err
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
// reserve space for header
|
||||
|
||||
buff = buff[offset-4:]
|
||||
|
||||
// add packet information header
|
||||
|
||||
buff[0] = 0x00
|
||||
buff[1] = 0x00
|
||||
buff[2] = 0x00
|
||||
|
||||
if buff[4]>>4 == ipv6.Version {
|
||||
buff[3] = unix.AF_INET6
|
||||
} else {
|
||||
buff[3] = unix.AF_INET
|
||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||
if offset < 4 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
// write
|
||||
|
||||
return tun.tunFile.Write(buff)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
for i, buf := range bufs {
|
||||
buf = buf[offset-4:]
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
buf[2] = 0x00
|
||||
switch buf[4] >> 4 {
|
||||
case 4:
|
||||
buf[3] = unix.AF_INET
|
||||
case 6:
|
||||
buf[3] = unix.AF_INET6
|
||||
default:
|
||||
return i, unix.EAFNOSUPPORT
|
||||
}
|
||||
if _, err := tun.tunFile.Write(buf); err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
@ -318,6 +303,10 @@ func (tun *NativeTun) MTU() (int, error) {
|
||||
return int(ifr.MTU), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func socketCloexec(family, sotype, proto int) (fd int, err error) {
|
||||
// See go/src/net/sys_cloexec.go for background.
|
||||
syscall.ForkLock.RLock()
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
@ -333,45 +333,46 @@ func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
buff := buff[offset-4:]
|
||||
n, err := tun.tunFile.Read(buff[:])
|
||||
buf := bufs[0][offset-4:]
|
||||
n, err := tun.tunFile.Read(buf[:])
|
||||
if n < 4 {
|
||||
return 0, err
|
||||
}
|
||||
return n - 4, err
|
||||
sizes[0] = n - 4
|
||||
return 1, err
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
|
||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||
if offset < 4 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
buf = buf[offset-4:]
|
||||
if len(buf) < 5 {
|
||||
return 0, io.ErrShortBuffer
|
||||
for i, buf := range bufs {
|
||||
buf = buf[offset-4:]
|
||||
if len(buf) < 5 {
|
||||
return i, io.ErrShortBuffer
|
||||
}
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
buf[2] = 0x00
|
||||
switch buf[4] >> 4 {
|
||||
case 4:
|
||||
buf[3] = unix.AF_INET
|
||||
case 6:
|
||||
buf[3] = unix.AF_INET6
|
||||
default:
|
||||
return i, unix.EAFNOSUPPORT
|
||||
}
|
||||
if _, err := tun.tunFile.Write(buf); err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
buf[2] = 0x00
|
||||
switch buf[4] >> 4 {
|
||||
case 4:
|
||||
buf[3] = unix.AF_INET
|
||||
case 6:
|
||||
buf[3] = unix.AF_INET6
|
||||
default:
|
||||
return 0, unix.EAFNOSUPPORT
|
||||
}
|
||||
return tun.tunFile.Write(buf)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
@ -428,3 +429,7 @@ func (tun *NativeTun) MTU() (int, error) {
|
||||
}
|
||||
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
293
tun/tun_linux.go
293
tun/tun_linux.go
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
@ -17,9 +17,8 @@ import (
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
@ -33,17 +32,27 @@ type NativeTun struct {
|
||||
index int32 // if index
|
||||
errors chan error // async error handling
|
||||
events chan Event // device related events
|
||||
nopi bool // the device was passed IFF_NO_PI
|
||||
netlinkSock int
|
||||
netlinkCancel *rwcancel.RWCancel
|
||||
hackListenerClosed sync.Mutex
|
||||
statusListenersShutdown chan struct{}
|
||||
batchSize int
|
||||
vnetHdr bool
|
||||
udpGSO bool
|
||||
|
||||
closeOnce sync.Once
|
||||
|
||||
nameOnce sync.Once // guards calling initNameCache, which sets following fields
|
||||
nameCache string // name of interface
|
||||
nameErr error
|
||||
|
||||
readOpMu sync.Mutex // readOpMu guards readBuff
|
||||
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
|
||||
|
||||
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable
|
||||
toWrite []int
|
||||
tcpGROTable *tcpGROTable
|
||||
udpGROTable *udpGROTable
|
||||
}
|
||||
|
||||
func (tun *NativeTun) File() *os.File {
|
||||
@ -323,57 +332,147 @@ func (tun *NativeTun) nameSlow() (string, error) {
|
||||
return unix.ByteSliceToString(ifr[:]), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
|
||||
if tun.nopi {
|
||||
buf = buf[offset:]
|
||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||
tun.writeOpMu.Lock()
|
||||
defer func() {
|
||||
tun.tcpGROTable.reset()
|
||||
tun.udpGROTable.reset()
|
||||
tun.writeOpMu.Unlock()
|
||||
}()
|
||||
var (
|
||||
errs error
|
||||
total int
|
||||
)
|
||||
tun.toWrite = tun.toWrite[:0]
|
||||
if tun.vnetHdr {
|
||||
err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
offset -= virtioNetHdrLen
|
||||
} else {
|
||||
// reserve space for header
|
||||
buf = buf[offset-4:]
|
||||
|
||||
// add packet information header
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
if buf[4]>>4 == ipv6.Version {
|
||||
buf[2] = 0x86
|
||||
buf[3] = 0xdd
|
||||
} else {
|
||||
buf[2] = 0x08
|
||||
buf[3] = 0x00
|
||||
for i := range bufs {
|
||||
tun.toWrite = append(tun.toWrite, i)
|
||||
}
|
||||
}
|
||||
|
||||
n, err := tun.tunFile.Write(buf)
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
for _, bufsI := range tun.toWrite {
|
||||
n, err := tun.tunFile.Write(bufs[bufsI][offset:])
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
return total, os.ErrClosed
|
||||
}
|
||||
if err != nil {
|
||||
errs = errors.Join(errs, err)
|
||||
} else {
|
||||
total += n
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
return total, errs
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
}
|
||||
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
|
||||
// each buffer. It mutates sizes to reflect the size of each element of bufs,
|
||||
// and returns the number of packets read.
|
||||
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
var hdr virtioNetHdr
|
||||
err := hdr.decode(in)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
in = in[virtioNetHdrLen:]
|
||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
|
||||
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
|
||||
// This means CHECKSUM_PARTIAL in skb context. We are responsible
|
||||
// for computing the checksum starting at hdr.csumStart and placing
|
||||
// at hdr.csumOffset.
|
||||
err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if len(in) > len(bufs[0][offset:]) {
|
||||
return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
|
||||
}
|
||||
n := copy(bufs[0][offset:], in)
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
|
||||
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) {
|
||||
select {
|
||||
case err = <-tun.errors:
|
||||
ipVersion := in[0] >> 4
|
||||
switch ipVersion {
|
||||
case 4:
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
|
||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||
}
|
||||
case 6:
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
|
||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||
}
|
||||
default:
|
||||
if tun.nopi {
|
||||
n, err = tun.tunFile.Read(buf[offset:])
|
||||
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
|
||||
}
|
||||
|
||||
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
|
||||
// of the entire first packet when the kernel is handling it as part of a
|
||||
// FORWARD path. Instead, parse the transport header length and add it onto
|
||||
// csumStart, which is synonymous for IP header length.
|
||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
|
||||
hdr.hdrLen = hdr.csumStart + 8
|
||||
} else {
|
||||
if len(in) <= int(hdr.csumStart+12) {
|
||||
return 0, errors.New("packet is too short")
|
||||
}
|
||||
|
||||
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
|
||||
if tcpHLen < 20 || tcpHLen > 60 {
|
||||
// A TCP header must be between 20 and 60 bytes in length.
|
||||
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
||||
}
|
||||
hdr.hdrLen = hdr.csumStart + tcpHLen
|
||||
}
|
||||
|
||||
if len(in) < int(hdr.hdrLen) {
|
||||
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
|
||||
}
|
||||
|
||||
if hdr.hdrLen < hdr.csumStart {
|
||||
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
|
||||
}
|
||||
cSumAt := int(hdr.csumStart + hdr.csumOffset)
|
||||
if cSumAt+1 >= len(in) {
|
||||
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
|
||||
}
|
||||
|
||||
return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
tun.readOpMu.Lock()
|
||||
defer tun.readOpMu.Unlock()
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
readInto := bufs[0][offset:]
|
||||
if tun.vnetHdr {
|
||||
readInto = tun.readBuff[:]
|
||||
}
|
||||
n, err := tun.tunFile.Read(readInto)
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tun.vnetHdr {
|
||||
return handleVirtioRead(readInto[:n], bufs, sizes, offset)
|
||||
} else {
|
||||
buff := buf[offset-4:]
|
||||
n, err = tun.tunFile.Read(buff[:])
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
if n < 4 {
|
||||
n = 0
|
||||
} else {
|
||||
n -= 4
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Events() <-chan Event {
|
||||
@ -399,6 +498,56 @@ func (tun *NativeTun) Close() error {
|
||||
return err2
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return tun.batchSize
|
||||
}
|
||||
|
||||
const (
|
||||
// TODO: support TSO with ECN bits
|
||||
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
|
||||
)
|
||||
|
||||
func (tun *NativeTun) initFromFlags(name string) error {
|
||||
sc, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if e := sc.Control(func(fd uintptr) {
|
||||
var (
|
||||
ifr *unix.Ifreq
|
||||
)
|
||||
ifr, err = unix.NewIfreq(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
got := ifr.Uint16()
|
||||
if got&unix.IFF_VNET_HDR != 0 {
|
||||
// tunTCPOffloads were added in Linux v2.6. We require their support
|
||||
// if IFF_VNET_HDR is set.
|
||||
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tun.vnetHdr = true
|
||||
tun.batchSize = conn.IdealBatchSize
|
||||
// tunUDPOffloads were added in Linux v6.2. We do not return an
|
||||
// error if they are unsupported at runtime.
|
||||
tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil
|
||||
} else {
|
||||
tun.batchSize = 1
|
||||
}
|
||||
}); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateTUN creates a Device with the provided name and MTU.
|
||||
func CreateTUN(name string, mtu int) (Device, error) {
|
||||
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
@ -408,25 +557,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
|
||||
nameBytes := []byte(name)
|
||||
if len(nameBytes) >= unix.IFNAMSIZ {
|
||||
unix.Close(nfd)
|
||||
return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
|
||||
ifr, err := unix.NewIfreq(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(ifr[:], nameBytes)
|
||||
*(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(nfd),
|
||||
uintptr(unix.TUNSETIFF),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
if errno != 0 {
|
||||
unix.Close(nfd)
|
||||
return nil, errno
|
||||
// IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
|
||||
// where a null write will return EINVAL indicating the TUN is up.
|
||||
ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
|
||||
err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(nfd, true)
|
||||
@ -441,13 +581,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
||||
return CreateTUNFromFile(fd, mtu)
|
||||
}
|
||||
|
||||
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
|
||||
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
tun := &NativeTun{
|
||||
tunFile: file,
|
||||
events: make(chan Event, 5),
|
||||
errors: make(chan error, 5),
|
||||
statusListenersShutdown: make(chan struct{}),
|
||||
nopi: false,
|
||||
tcpGROTable: newTCPGROTable(),
|
||||
udpGROTable: newUDPGROTable(),
|
||||
toWrite: make([]int, 0, conn.IdealBatchSize),
|
||||
}
|
||||
|
||||
name, err := tun.Name()
|
||||
@ -455,8 +598,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// start event listener
|
||||
err = tun.initFromFlags(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// start event listener
|
||||
tun.index, err = getIFIndex(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -485,6 +632,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
// CreateUnmonitoredTUNFromFD creates a Device from the provided file
|
||||
// descriptor.
|
||||
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
|
||||
err := unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
@ -492,14 +641,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
|
||||
}
|
||||
file := os.NewFile(uintptr(fd), "/dev/tun")
|
||||
tun := &NativeTun{
|
||||
tunFile: file,
|
||||
events: make(chan Event, 5),
|
||||
errors: make(chan error, 5),
|
||||
nopi: true,
|
||||
tunFile: file,
|
||||
events: make(chan Event, 5),
|
||||
errors: make(chan error, 5),
|
||||
tcpGROTable: newTCPGROTable(),
|
||||
udpGROTable: newUDPGROTable(),
|
||||
toWrite: make([]int, 0, conn.IdealBatchSize),
|
||||
}
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return tun, name, nil
|
||||
err = tun.initFromFlags(name)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return tun, name, err
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
@ -8,13 +8,13 @@ package tun
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
@ -204,45 +204,43 @@ func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
buff := buff[offset-4:]
|
||||
n, err := tun.tunFile.Read(buff[:])
|
||||
buf := bufs[0][offset-4:]
|
||||
n, err := tun.tunFile.Read(buf[:])
|
||||
if n < 4 {
|
||||
return 0, err
|
||||
}
|
||||
return n - 4, err
|
||||
sizes[0] = n - 4
|
||||
return 1, err
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
// reserve space for header
|
||||
|
||||
buff = buff[offset-4:]
|
||||
|
||||
// add packet information header
|
||||
|
||||
buff[0] = 0x00
|
||||
buff[1] = 0x00
|
||||
buff[2] = 0x00
|
||||
|
||||
if buff[4]>>4 == ipv6.Version {
|
||||
buff[3] = unix.AF_INET6
|
||||
} else {
|
||||
buff[3] = unix.AF_INET
|
||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||
if offset < 4 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
// write
|
||||
|
||||
return tun.tunFile.Write(buff)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
for i, buf := range bufs {
|
||||
buf = buf[offset-4:]
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
buf[2] = 0x00
|
||||
switch buf[4] >> 4 {
|
||||
case 4:
|
||||
buf[3] = unix.AF_INET
|
||||
case 6:
|
||||
buf[3] = unix.AF_INET6
|
||||
default:
|
||||
return i, unix.EAFNOSUPPORT
|
||||
}
|
||||
if _, err := tun.tunFile.Write(buf); err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
@ -329,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) {
|
||||
|
||||
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user