mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-14 16:55:15 +01:00
conn, device, tun: implement vectorized I/O plumbing
Accept packet vectors for reading and writing in the tun.Device and conn.Bind interfaces, so that the internal plumbing between these interfaces now passes a vector of packets. Vectors move untouched between these interfaces, i.e. if 128 packets are received from conn.Bind.Read(), 128 packets are passed to tun.Device.Write(). There is no internal buffering. Currently, existing implementations are only adjusted to have vectors of length one. Subsequent patches will improve that. Also, as a related fixup, use the unix and windows packages rather than the syscall package when possible. Co-authored-by: James Tucker <james@tailscale.com> Signed-off-by: James Tucker <james@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
21636207a6
commit
3bb8fec7e4
@ -193,6 +193,10 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) Close() error {
|
||||
// Take a readlock to shut down the sockets...
|
||||
bind.mu.RLock()
|
||||
@ -223,29 +227,39 @@ func (bind *LinuxSocketBind) Close() error {
|
||||
return err2
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
|
||||
func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if bind.sock4 == -1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
var end LinuxSocketEndpoint
|
||||
n, err := receive4(bind.sock4, buf, &end)
|
||||
return n, &end, err
|
||||
n, err := receive4(bind.sock4, buffs[0], &end)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
eps[0] = &end
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
|
||||
func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
if bind.sock6 == -1 {
|
||||
return 0, nil, net.ErrClosed
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
var end LinuxSocketEndpoint
|
||||
n, err := receive6(bind.sock6, buf, &end)
|
||||
return n, &end, err
|
||||
n, err := receive6(bind.sock6, buffs[0], &end)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
eps[0] = &end
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
|
||||
func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error {
|
||||
nend, ok := end.(*LinuxSocketEndpoint)
|
||||
if !ok {
|
||||
return ErrWrongEndpointType
|
||||
@ -256,13 +270,24 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
|
||||
if bind.sock4 == -1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return send4(bind.sock4, nend, buff)
|
||||
for _, buff := range buffs {
|
||||
err := send4(bind.sock4, nend, buff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if bind.sock6 == -1 {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return send6(bind.sock6, nend, buff)
|
||||
for _, buff := range buffs {
|
||||
err := send6(bind.sock6, nend, buff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
|
||||
|
@ -128,6 +128,10 @@ again:
|
||||
return fns, uint16(port), nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) Close() error {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
@ -150,20 +154,30 @@ func (bind *StdNetBind) Close() error {
|
||||
}
|
||||
|
||||
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
|
||||
return func(buff []byte) (int, Endpoint, error) {
|
||||
n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
|
||||
return n, asEndpoint(endpoint), err
|
||||
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
|
||||
if err == nil {
|
||||
sizes[0] = size
|
||||
eps[0] = asEndpoint(endpoint)
|
||||
return 1, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
|
||||
return func(buff []byte) (int, Endpoint, error) {
|
||||
n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
|
||||
return n, asEndpoint(endpoint), err
|
||||
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||
size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
|
||||
if err == nil {
|
||||
sizes[0] = size
|
||||
eps[0] = asEndpoint(endpoint)
|
||||
return 1, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
||||
func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
|
||||
var err error
|
||||
nend, ok := endpoint.(StdNetEndpoint)
|
||||
if !ok {
|
||||
@ -186,8 +200,13 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
||||
if conn == nil {
|
||||
return syscall.EAFNOSUPPORT
|
||||
}
|
||||
_, err = conn.WriteToUDPAddrPort(buff, addrPort)
|
||||
return err
|
||||
for _, buff := range buffs {
|
||||
_, err = conn.WriteToUDPAddrPort(buff, addrPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
|
||||
|
@ -321,6 +321,11 @@ func (bind *WinRingBind) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) BatchSize() int {
|
||||
// TODO: implement batching in and out of the ring
|
||||
return 1
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
||||
@ -409,16 +414,22 @@ retry:
|
||||
return n, &ep, nil
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
|
||||
func (bind *WinRingBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
return bind.v4.Receive(buf, &bind.isOpen)
|
||||
n, ep, err := bind.v4.Receive(buffs[0], &bind.isOpen)
|
||||
sizes[0] = n
|
||||
eps[0] = ep
|
||||
return 1, err
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
|
||||
func (bind *WinRingBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
return bind.v6.Receive(buf, &bind.isOpen)
|
||||
n, ep, err := bind.v6.Receive(buffs[0], &bind.isOpen)
|
||||
sizes[0] = n
|
||||
eps[0] = ep
|
||||
return 1, err
|
||||
}
|
||||
|
||||
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
|
||||
@ -473,32 +484,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi
|
||||
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||
}
|
||||
|
||||
func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
|
||||
func (bind *WinRingBind) Send(buffs [][]byte, endpoint Endpoint) error {
|
||||
nend, ok := endpoint.(*WinRingEndpoint)
|
||||
if !ok {
|
||||
return ErrWrongEndpointType
|
||||
}
|
||||
bind.mu.RLock()
|
||||
defer bind.mu.RUnlock()
|
||||
switch nend.family {
|
||||
case windows.AF_INET:
|
||||
if bind.v4.blackhole {
|
||||
return nil
|
||||
for _, buf := range buffs {
|
||||
switch nend.family {
|
||||
case windows.AF_INET:
|
||||
if bind.v4.blackhole {
|
||||
continue
|
||||
}
|
||||
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
|
||||
return err
|
||||
}
|
||||
case windows.AF_INET6:
|
||||
if bind.v6.blackhole {
|
||||
continue
|
||||
}
|
||||
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return bind.v4.Send(buf, nend, &bind.isOpen)
|
||||
case windows.AF_INET6:
|
||||
if bind.v6.blackhole {
|
||||
return nil
|
||||
}
|
||||
return bind.v6.Send(buf, nend, &bind.isOpen)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
sysconn, err := bind.ipv4.SyscallConn()
|
||||
func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sysconn, err := s.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -511,14 +528,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bind.blackhole4 = blackhole
|
||||
s.blackhole4 = blackhole
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||
bind.mu.Lock()
|
||||
defer bind.mu.Unlock()
|
||||
sysconn, err := bind.ipv6.SyscallConn()
|
||||
func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sysconn, err := s.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -531,7 +548,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bind.blackhole6 = blackhole
|
||||
s.blackhole6 = blackhole
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -89,32 +89,39 @@ func (c *ChannelBind) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChannelBind) BatchSize() int { return 1 }
|
||||
|
||||
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
||||
|
||||
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
|
||||
return func(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||
return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
return 0, nil, net.ErrClosed
|
||||
return 0, net.ErrClosed
|
||||
case rx := <-ch:
|
||||
return copy(b, rx), c.target6, nil
|
||||
copied := copy(buffs[0], rx)
|
||||
sizes[0] = copied
|
||||
eps[0] = c.target6
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
bc := make([]byte, len(b))
|
||||
copy(bc, b)
|
||||
if ep.(ChannelEndpoint) == c.target4 {
|
||||
*c.tx4 <- bc
|
||||
} else if ep.(ChannelEndpoint) == c.target6 {
|
||||
*c.tx6 <- bc
|
||||
} else {
|
||||
return os.ErrInvalid
|
||||
func (c *ChannelBind) Send(buffs [][]byte, ep conn.Endpoint) error {
|
||||
for _, b := range buffs {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
bc := make([]byte, len(b))
|
||||
copy(bc, b)
|
||||
if ep.(ChannelEndpoint) == c.target4 {
|
||||
*c.tx4 <- bc
|
||||
} else if ep.(ChannelEndpoint) == c.target6 {
|
||||
*c.tx6 <- bc
|
||||
} else {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
24
conn/conn.go
24
conn/conn.go
@ -15,10 +15,17 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A ReceiveFunc receives a single inbound packet from the network.
|
||||
// It writes the data into b. n is the length of the packet.
|
||||
// ep is the remote endpoint.
|
||||
type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
|
||||
const (
|
||||
DefaultBatchSize = 1 // maximum number of packets handled per read and write
|
||||
)
|
||||
|
||||
// A ReceiveFunc receives at least one packet from the network and writes them
|
||||
// into packets. On a successful read it returns the number of elements of
|
||||
// sizes, packets, and endpoints that should be evaluated. Some elements of
|
||||
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
|
||||
// and eps slice with a length greater than or equal to the length of packets.
|
||||
// These lengths must not exceed the length of the associated Bind.BatchSize().
|
||||
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
|
||||
|
||||
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||
//
|
||||
@ -38,11 +45,16 @@ type Bind interface {
|
||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||
SetMark(mark uint32) error
|
||||
|
||||
// Send writes a packet b to address ep.
|
||||
Send(b []byte, ep Endpoint) error
|
||||
// Send writes one or more packets in buffs to address ep. The length of
|
||||
// buffs must not exceed BatchSize().
|
||||
Send(buffs [][]byte, ep Endpoint) error
|
||||
|
||||
// ParseEndpoint creates a new endpoint from a string.
|
||||
ParseEndpoint(s string) (Endpoint, error)
|
||||
|
||||
// BatchSize is the number of buffers expected to be passed to
|
||||
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
||||
BatchSize() int
|
||||
}
|
||||
|
||||
// BindSocketToInterface is implemented by Bind objects that support being
|
||||
|
24
conn/conn_test.go
Normal file
24
conn/conn_test.go
Normal file
@ -0,0 +1,24 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPrettyName(t *testing.T) {
|
||||
var (
|
||||
recvFunc ReceiveFunc = func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
|
||||
)
|
||||
|
||||
const want = "TestPrettyName"
|
||||
|
||||
t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
|
||||
if got := recvFunc.PrettyName(); got != want {
|
||||
t.Errorf("PrettyName() = %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
|
||||
}
|
||||
|
||||
type autodrainingInboundQueue struct {
|
||||
c chan *QueueInboundElement
|
||||
c chan *[]*QueueInboundElement
|
||||
}
|
||||
|
||||
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
|
||||
@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
|
||||
// some other means, such as sending a sentinel nil values.
|
||||
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||
q := &autodrainingInboundQueue{
|
||||
c: make(chan *QueueInboundElement, QueueInboundSize),
|
||||
c: make(chan *[]*QueueInboundElement, QueueInboundSize),
|
||||
}
|
||||
runtime.SetFinalizer(q, device.flushInboundQueue)
|
||||
return q
|
||||
@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||
for {
|
||||
select {
|
||||
case elem := <-q.c:
|
||||
elem.Lock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
case elems := <-q.c:
|
||||
for _, elem := range *elems {
|
||||
elem.Lock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
device.PutInboundElementsSlice(elems)
|
||||
default:
|
||||
return
|
||||
}
|
||||
@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||
}
|
||||
|
||||
type autodrainingOutboundQueue struct {
|
||||
c chan *QueueOutboundElement
|
||||
c chan *[]*QueueOutboundElement
|
||||
}
|
||||
|
||||
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
|
||||
@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct {
|
||||
// All sends to the channel must be best-effort, because there may be no receivers.
|
||||
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||
q := &autodrainingOutboundQueue{
|
||||
c: make(chan *QueueOutboundElement, QueueOutboundSize),
|
||||
c: make(chan *[]*QueueOutboundElement, QueueOutboundSize),
|
||||
}
|
||||
runtime.SetFinalizer(q, device.flushOutboundQueue)
|
||||
return q
|
||||
@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
|
||||
for {
|
||||
select {
|
||||
case elem := <-q.c:
|
||||
elem.Lock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
case elems := <-q.c:
|
||||
for _, elem := range *elems {
|
||||
elem.Lock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsSlice(elems)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
@ -68,9 +68,11 @@ type Device struct {
|
||||
cookieChecker CookieChecker
|
||||
|
||||
pool struct {
|
||||
messageBuffers *WaitPool
|
||||
inboundElements *WaitPool
|
||||
outboundElements *WaitPool
|
||||
outboundElementsSlice *WaitPool
|
||||
inboundElementsSlice *WaitPool
|
||||
messageBuffers *WaitPool
|
||||
inboundElements *WaitPool
|
||||
outboundElements *WaitPool
|
||||
}
|
||||
|
||||
queue struct {
|
||||
@ -295,6 +297,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||
device.rate.limiter.Init()
|
||||
device.indexTable.Init()
|
||||
|
||||
device.PopulatePools()
|
||||
|
||||
// create queues
|
||||
@ -322,6 +325,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||
return device
|
||||
}
|
||||
|
||||
// BatchSize returns the BatchSize for the device as a whole which is the max of
|
||||
// the bind batch size and the tun batch size. The batch size reported by device
|
||||
// is the size used to construct memory pools, and is the allowed batch size for
|
||||
// the lifetime of the device.
|
||||
func (device *Device) BatchSize() int {
|
||||
size := device.net.bind.BatchSize()
|
||||
dSize := device.tun.device.BatchSize()
|
||||
if size < dSize {
|
||||
size = dSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
||||
device.peers.RLock()
|
||||
defer device.peers.RUnlock()
|
||||
@ -472,11 +488,13 @@ func (device *Device) BindUpdate() error {
|
||||
var err error
|
||||
var recvFns []conn.ReceiveFunc
|
||||
netc := &device.net
|
||||
|
||||
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
||||
if err != nil {
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
|
||||
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
||||
if err != nil {
|
||||
netc.bind.Close()
|
||||
@ -507,8 +525,9 @@ func (device *Device) BindUpdate() error {
|
||||
device.net.stopping.Add(len(recvFns))
|
||||
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||
batchSize := netc.bind.BatchSize()
|
||||
for _, fn := range recvFns {
|
||||
go device.RoutineReceiveIncoming(fn)
|
||||
go device.RoutineReceiveIncoming(batchSize, fn)
|
||||
}
|
||||
|
||||
device.log.Verbosef("UDP bind has been updated")
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
@ -21,6 +22,7 @@ import (
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/conn/bindtest"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||
)
|
||||
|
||||
@ -307,6 +309,17 @@ func TestConcurrencySafety(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Perform bind updates and keepalive sends concurrently with tunnel use.
|
||||
t.Run("bindUpdate and keepalive", func(t *testing.T) {
|
||||
const iters = 10
|
||||
for i := 0; i < iters; i++ {
|
||||
for _, peer := range pair {
|
||||
peer.dev.BindUpdate()
|
||||
peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
close(done)
|
||||
}
|
||||
|
||||
@ -405,3 +418,59 @@ func goroutineLeakCheck(t *testing.T) {
|
||||
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
|
||||
})
|
||||
}
|
||||
|
||||
type fakeBindSized struct {
|
||||
size int
|
||||
}
|
||||
|
||||
func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
func (b *fakeBindSized) Close() error { return nil }
|
||||
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
|
||||
func (b *fakeBindSized) Send(buffs [][]byte, ep conn.Endpoint) error { return nil }
|
||||
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
|
||||
func (b *fakeBindSized) BatchSize() int { return b.size }
|
||||
|
||||
type fakeTUNDeviceSized struct {
|
||||
size int
|
||||
}
|
||||
|
||||
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
|
||||
func (t *fakeTUNDeviceSized) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (t *fakeTUNDeviceSized) Write(buffs [][]byte, offset int) (int, error) { return 0, nil }
|
||||
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
|
||||
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
|
||||
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
|
||||
func (t *fakeTUNDeviceSized) Close() error { return nil }
|
||||
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
|
||||
|
||||
func TestBatchSize(t *testing.T) {
|
||||
d := Device{}
|
||||
|
||||
d.net.bind = &fakeBindSized{1}
|
||||
d.tun.device = &fakeTUNDeviceSized{1}
|
||||
if want, got := 1, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
|
||||
d.net.bind = &fakeBindSized{1}
|
||||
d.tun.device = &fakeTUNDeviceSized{128}
|
||||
if want, got := 128, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
|
||||
d.net.bind = &fakeBindSized{128}
|
||||
d.tun.device = &fakeTUNDeviceSized{1}
|
||||
if want, got := 128, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
|
||||
d.net.bind = &fakeBindSized{128}
|
||||
d.tun.device = &fakeTUNDeviceSized{128}
|
||||
if want, got := 128, d.BatchSize(); got != want {
|
||||
t.Errorf("expected batch size %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
@ -45,9 +45,9 @@ type Peer struct {
|
||||
}
|
||||
|
||||
queue struct {
|
||||
staged chan *QueueOutboundElement // staged packets before a handshake is available
|
||||
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
||||
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
||||
staged chan *[]*QueueOutboundElement // staged packets before a handshake is available
|
||||
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
||||
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
||||
}
|
||||
|
||||
cookieGenerator CookieGenerator
|
||||
@ -81,7 +81,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
peer.device = device
|
||||
peer.queue.outbound = newAutodrainingOutboundQueue(device)
|
||||
peer.queue.inbound = newAutodrainingInboundQueue(device)
|
||||
peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
|
||||
peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize)
|
||||
|
||||
// map public key
|
||||
_, ok := device.peers.keyMap[pk]
|
||||
@ -108,7 +108,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffer(buffer []byte) error {
|
||||
func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||
peer.device.net.RLock()
|
||||
defer peer.device.net.RUnlock()
|
||||
|
||||
@ -123,9 +123,13 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
|
||||
return errors.New("no known endpoint for peer")
|
||||
}
|
||||
|
||||
err := peer.device.net.bind.Send(buffer, peer.endpoint)
|
||||
err := peer.device.net.bind.Send(buffers, peer.endpoint)
|
||||
if err == nil {
|
||||
peer.txBytes.Add(uint64(len(buffer)))
|
||||
var totalLen uint64
|
||||
for _, b := range buffers {
|
||||
totalLen += uint64(len(b))
|
||||
}
|
||||
peer.txBytes.Add(totalLen)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -187,8 +191,12 @@ func (peer *Peer) Start() {
|
||||
|
||||
device.flushInboundQueue(peer.queue.inbound)
|
||||
device.flushOutboundQueue(peer.queue.outbound)
|
||||
go peer.RoutineSequentialSender()
|
||||
go peer.RoutineSequentialReceiver()
|
||||
|
||||
// Use the device batch size, not the bind batch size, as the device size is
|
||||
// the size of the batch pools.
|
||||
batchSize := peer.device.BatchSize()
|
||||
go peer.RoutineSequentialSender(batchSize)
|
||||
go peer.RoutineSequentialReceiver(batchSize)
|
||||
|
||||
peer.isRunning.Store(true)
|
||||
}
|
||||
|
@ -46,6 +46,14 @@ func (p *WaitPool) Put(x any) {
|
||||
}
|
||||
|
||||
func (device *Device) PopulatePools() {
|
||||
device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
s := make([]*QueueOutboundElement, 0, device.BatchSize())
|
||||
return &s
|
||||
})
|
||||
device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
s := make([]*QueueInboundElement, 0, device.BatchSize())
|
||||
return &s
|
||||
})
|
||||
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||
return new([MaxMessageSize]byte)
|
||||
})
|
||||
@ -57,6 +65,30 @@ func (device *Device) PopulatePools() {
|
||||
})
|
||||
}
|
||||
|
||||
func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement {
|
||||
return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement)
|
||||
}
|
||||
|
||||
func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) {
|
||||
for i := range *s {
|
||||
(*s)[i] = nil
|
||||
}
|
||||
*s = (*s)[:0]
|
||||
device.pool.outboundElementsSlice.Put(s)
|
||||
}
|
||||
|
||||
func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement {
|
||||
return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement)
|
||||
}
|
||||
|
||||
func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) {
|
||||
for i := range *s {
|
||||
(*s)[i] = nil
|
||||
}
|
||||
*s = (*s)[:0]
|
||||
device.pool.inboundElementsSlice.Put(s)
|
||||
}
|
||||
|
||||
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
||||
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
|
||||
}
|
||||
|
@ -89,3 +89,51 @@ func BenchmarkWaitPool(b *testing.B) {
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkWaitPoolEmpty(b *testing.B) {
|
||||
var wg sync.WaitGroup
|
||||
var trials atomic.Int32
|
||||
trials.Store(int32(b.N))
|
||||
workers := runtime.NumCPU() + 2
|
||||
if workers-4 <= 0 {
|
||||
b.Skip("Not enough cores")
|
||||
}
|
||||
p := NewWaitPool(0, func() any { return make([]byte, 16) })
|
||||
wg.Add(workers)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < workers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for trials.Add(-1) > 0 {
|
||||
x := p.Get()
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||
p.Put(x)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkSyncPool(b *testing.B) {
|
||||
var wg sync.WaitGroup
|
||||
var trials atomic.Int32
|
||||
trials.Store(int32(b.N))
|
||||
workers := runtime.NumCPU() + 2
|
||||
if workers-4 <= 0 {
|
||||
b.Skip("Not enough cores")
|
||||
}
|
||||
p := sync.Pool{New: func() any { return make([]byte, 16) }}
|
||||
wg.Add(workers)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < workers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for trials.Add(-1) > 0 {
|
||||
x := p.Get()
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||
p.Put(x)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
||||
* Every time the bind is updated a new routine is started for
|
||||
* IPv4 and IPv6 (separately)
|
||||
*/
|
||||
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
||||
func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
|
||||
recvName := recv.PrettyName()
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
|
||||
@ -79,20 +79,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
||||
|
||||
// receive datagrams until conn is closed
|
||||
|
||||
buffer := device.GetMessageBuffer()
|
||||
|
||||
var (
|
||||
buffsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
|
||||
buffs = make([][]byte, maxBatchSize)
|
||||
err error
|
||||
size int
|
||||
endpoint conn.Endpoint
|
||||
sizes = make([]int, maxBatchSize)
|
||||
count int
|
||||
endpoints = make([]conn.Endpoint, maxBatchSize)
|
||||
deathSpiral int
|
||||
elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
|
||||
)
|
||||
|
||||
for {
|
||||
size, endpoint, err = recv(buffer[:])
|
||||
for i := range buffsArrs {
|
||||
buffsArrs[i] = device.GetMessageBuffer()
|
||||
buffs[i] = buffsArrs[i][:]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for i := 0; i < maxBatchSize; i++ {
|
||||
if buffsArrs[i] != nil {
|
||||
device.PutMessageBuffer(buffsArrs[i])
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
count, err = recv(buffs, sizes, endpoints)
|
||||
if err != nil {
|
||||
device.PutMessageBuffer(buffer)
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
@ -103,101 +116,122 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
||||
if deathSpiral < 10 {
|
||||
deathSpiral++
|
||||
time.Sleep(time.Second / 3)
|
||||
buffer = device.GetMessageBuffer()
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
deathSpiral = 0
|
||||
|
||||
if size < MinMessageSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// check size of packet
|
||||
|
||||
packet := buffer[:size]
|
||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
||||
|
||||
var okay bool
|
||||
|
||||
switch msgType {
|
||||
|
||||
// check if transport
|
||||
|
||||
case MessageTransportType:
|
||||
|
||||
// check size
|
||||
|
||||
if len(packet) < MessageTransportSize {
|
||||
// handle each packet in the batch
|
||||
for i, size := range sizes[:count] {
|
||||
if size < MinMessageSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// lookup key pair
|
||||
// check size of packet
|
||||
|
||||
receiver := binary.LittleEndian.Uint32(
|
||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
||||
)
|
||||
value := device.indexTable.Lookup(receiver)
|
||||
keypair := value.keypair
|
||||
if keypair == nil {
|
||||
packet := buffsArrs[i][:size]
|
||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
||||
|
||||
switch msgType {
|
||||
|
||||
// check if transport
|
||||
|
||||
case MessageTransportType:
|
||||
|
||||
// check size
|
||||
|
||||
if len(packet) < MessageTransportSize {
|
||||
continue
|
||||
}
|
||||
|
||||
// lookup key pair
|
||||
|
||||
receiver := binary.LittleEndian.Uint32(
|
||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
||||
)
|
||||
value := device.indexTable.Lookup(receiver)
|
||||
keypair := value.keypair
|
||||
if keypair == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// check keypair expiry
|
||||
|
||||
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// create work element
|
||||
peer := value.peer
|
||||
elem := device.GetInboundElement()
|
||||
elem.packet = packet
|
||||
elem.buffer = buffsArrs[i]
|
||||
elem.keypair = keypair
|
||||
elem.endpoint = endpoints[i]
|
||||
elem.counter = 0
|
||||
elem.Mutex = sync.Mutex{}
|
||||
elem.Lock()
|
||||
|
||||
elemsForPeer, ok := elemsByPeer[peer]
|
||||
if !ok {
|
||||
elemsForPeer = device.GetInboundElementsSlice()
|
||||
elemsByPeer[peer] = elemsForPeer
|
||||
}
|
||||
*elemsForPeer = append(*elemsForPeer, elem)
|
||||
buffsArrs[i] = device.GetMessageBuffer()
|
||||
buffs[i] = buffsArrs[i][:]
|
||||
continue
|
||||
|
||||
// otherwise it is a fixed size & handshake related packet
|
||||
|
||||
case MessageInitiationType:
|
||||
if len(packet) != MessageInitiationSize {
|
||||
continue
|
||||
}
|
||||
|
||||
case MessageResponseType:
|
||||
if len(packet) != MessageResponseSize {
|
||||
continue
|
||||
}
|
||||
|
||||
case MessageCookieReplyType:
|
||||
if len(packet) != MessageCookieReplySize {
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received message with unknown type")
|
||||
continue
|
||||
}
|
||||
|
||||
// check keypair expiry
|
||||
|
||||
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// create work element
|
||||
peer := value.peer
|
||||
elem := device.GetInboundElement()
|
||||
elem.packet = packet
|
||||
elem.buffer = buffer
|
||||
elem.keypair = keypair
|
||||
elem.endpoint = endpoint
|
||||
elem.counter = 0
|
||||
elem.Mutex = sync.Mutex{}
|
||||
elem.Lock()
|
||||
|
||||
// add to decryption queues
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elem
|
||||
device.queue.decryption.c <- elem
|
||||
buffer = device.GetMessageBuffer()
|
||||
} else {
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
continue
|
||||
|
||||
// otherwise it is a fixed size & handshake related packet
|
||||
|
||||
case MessageInitiationType:
|
||||
okay = len(packet) == MessageInitiationSize
|
||||
|
||||
case MessageResponseType:
|
||||
okay = len(packet) == MessageResponseSize
|
||||
|
||||
case MessageCookieReplyType:
|
||||
okay = len(packet) == MessageCookieReplySize
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received message with unknown type")
|
||||
}
|
||||
|
||||
if okay {
|
||||
select {
|
||||
case device.queue.handshake.c <- QueueHandshakeElement{
|
||||
msgType: msgType,
|
||||
buffer: buffer,
|
||||
buffer: buffsArrs[i],
|
||||
packet: packet,
|
||||
endpoint: endpoint,
|
||||
endpoint: endpoints[i],
|
||||
}:
|
||||
buffer = device.GetMessageBuffer()
|
||||
buffsArrs[i] = device.GetMessageBuffer()
|
||||
buffs[i] = buffsArrs[i][:]
|
||||
default:
|
||||
}
|
||||
}
|
||||
for peer, elems := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elems
|
||||
for _, elem := range *elems {
|
||||
device.queue.decryption.c <- elem
|
||||
}
|
||||
} else {
|
||||
for _, elem := range *elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
device.PutInboundElementsSlice(elems)
|
||||
}
|
||||
delete(elemsByPeer, peer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -393,7 +427,7 @@ func (device *Device) RoutineHandshake(id int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) RoutineSequentialReceiver() {
|
||||
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
||||
device := peer.device
|
||||
defer func() {
|
||||
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
|
||||
@ -401,89 +435,91 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
}()
|
||||
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
|
||||
|
||||
for elem := range peer.queue.inbound.c {
|
||||
if elem == nil {
|
||||
buffs := make([][]byte, 0, maxBatchSize)
|
||||
|
||||
for elems := range peer.queue.inbound.c {
|
||||
if elems == nil {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
elem.Lock()
|
||||
if elem.packet == nil {
|
||||
// decryption failed
|
||||
goto skip
|
||||
}
|
||||
|
||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
||||
goto skip
|
||||
}
|
||||
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||
peer.timersHandshakeComplete()
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||
goto skip
|
||||
}
|
||||
peer.timersDataReceived()
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case ipv4.Version:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
goto skip
|
||||
}
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||
goto skip
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
||||
goto skip
|
||||
for _, elem := range *elems {
|
||||
elem.Lock()
|
||||
if elem.packet == nil {
|
||||
// decryption failed
|
||||
continue
|
||||
}
|
||||
|
||||
case ipv6.Version:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
goto skip
|
||||
}
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += ipv6.HeaderLen
|
||||
if int(length) > len(elem.packet) {
|
||||
goto skip
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
||||
goto skip
|
||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Packet with invalid IP version from %v", peer)
|
||||
goto skip
|
||||
}
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||
peer.timersHandshakeComplete()
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
|
||||
|
||||
_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
|
||||
if err != nil && !device.isClosed() {
|
||||
device.log.Errorf("Failed to write packet to TUN device: %v", err)
|
||||
if len(elem.packet) == 0 {
|
||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||
continue
|
||||
}
|
||||
peer.timersDataReceived()
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
case 6:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += ipv6.HeaderLen
|
||||
if int(length) > len(elem.packet) {
|
||||
continue
|
||||
}
|
||||
elem.packet = elem.packet[:length]
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.allowedips.Lookup(src) != peer {
|
||||
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Packet with invalid IP version from %v", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
buffs = append(buffs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
|
||||
}
|
||||
if len(peer.queue.inbound.c) == 0 {
|
||||
err = device.tun.device.Flush()
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("Unable to flush packets: %v", err)
|
||||
if len(buffs) > 0 {
|
||||
_, err := device.tun.device.Write(buffs, MessageTransportOffsetContent)
|
||||
if err != nil && !device.isClosed() {
|
||||
device.log.Errorf("Failed to write packets to TUN device: %v", err)
|
||||
}
|
||||
}
|
||||
skip:
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
for _, elem := range *elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutInboundElement(elem)
|
||||
}
|
||||
buffs = buffs[:0]
|
||||
device.PutInboundElementsSlice(elems)
|
||||
}
|
||||
}
|
||||
|
280
device/send.go
280
device/send.go
@ -17,6 +17,7 @@ import (
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
/* Outbound flow
|
||||
@ -77,12 +78,15 @@ func (elem *QueueOutboundElement) clearPointers() {
|
||||
func (peer *Peer) SendKeepalive() {
|
||||
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
|
||||
elem := peer.device.NewOutboundElement()
|
||||
elems := peer.device.GetOutboundElementsSlice()
|
||||
*elems = append(*elems, elem)
|
||||
select {
|
||||
case peer.queue.staged <- elem:
|
||||
case peer.queue.staged <- elems:
|
||||
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
|
||||
default:
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
peer.device.PutOutboundElementsSlice(elems)
|
||||
}
|
||||
}
|
||||
peer.SendStagedPackets()
|
||||
@ -125,7 +129,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err = peer.SendBuffer(packet)
|
||||
err = peer.SendBuffers([][]byte{packet})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||
}
|
||||
@ -163,7 +167,8 @@ func (peer *Peer) SendHandshakeResponse() error {
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err = peer.SendBuffer(packet)
|
||||
// TODO: allocation could be avoided
|
||||
err = peer.SendBuffers([][]byte{packet})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||
}
|
||||
@ -183,7 +188,8 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
|
||||
var buff [MessageCookieReplySize]byte
|
||||
writer := bytes.NewBuffer(buff[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
|
||||
// TODO: allocation could be avoided
|
||||
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -198,11 +204,6 @@ func (peer *Peer) keepKeyFreshSending() {
|
||||
}
|
||||
}
|
||||
|
||||
/* Reads packets from the TUN and inserts
|
||||
* into staged queue for peer
|
||||
*
|
||||
* Obs. Single instance per TUN device
|
||||
*/
|
||||
func (device *Device) RoutineReadFromTUN() {
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: TUN reader - stopped")
|
||||
@ -212,81 +213,123 @@ func (device *Device) RoutineReadFromTUN() {
|
||||
|
||||
device.log.Verbosef("Routine: TUN reader - started")
|
||||
|
||||
var elem *QueueOutboundElement
|
||||
var (
|
||||
batchSize = device.BatchSize()
|
||||
readErr error
|
||||
elems = make([]*QueueOutboundElement, batchSize)
|
||||
buffs = make([][]byte, batchSize)
|
||||
elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize)
|
||||
count = 0
|
||||
sizes = make([]int, batchSize)
|
||||
offset = MessageTransportHeaderSize
|
||||
)
|
||||
|
||||
for i := range elems {
|
||||
elems[i] = device.NewOutboundElement()
|
||||
buffs[i] = elems[i].buffer[:]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, elem := range elems {
|
||||
if elem != nil {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if elem != nil {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
// read packets
|
||||
count, readErr = device.tun.device.Read(buffs, sizes, offset)
|
||||
for i := 0; i < count; i++ {
|
||||
if sizes[i] < 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
elem := elems[i]
|
||||
elem.packet = buffs[i][offset : offset+sizes[i]]
|
||||
|
||||
// lookup peer
|
||||
var peer *Peer
|
||||
switch elem.packet[0] >> 4 {
|
||||
case 4:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
case 6:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received packet with unknown IP version")
|
||||
}
|
||||
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
elemsForPeer, ok := elemsByPeer[peer]
|
||||
if !ok {
|
||||
elemsForPeer = device.GetOutboundElementsSlice()
|
||||
elemsByPeer[peer] = elemsForPeer
|
||||
}
|
||||
*elemsForPeer = append(*elemsForPeer, elem)
|
||||
elems[i] = device.NewOutboundElement()
|
||||
buffs[i] = elems[i].buffer[:]
|
||||
}
|
||||
elem = device.NewOutboundElement()
|
||||
|
||||
// read packet
|
||||
for peer, elemsForPeer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.StagePackets(elemsForPeer)
|
||||
peer.SendStagedPackets()
|
||||
} else {
|
||||
for _, elem := range *elemsForPeer {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsSlice(elemsForPeer)
|
||||
}
|
||||
delete(elemsByPeer, peer)
|
||||
}
|
||||
|
||||
offset := MessageTransportHeaderSize
|
||||
size, err := device.tun.device.Read(elem.buffer[:], offset)
|
||||
if err != nil {
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, tun.ErrTooManySegments) {
|
||||
// TODO: record stat for this
|
||||
// This will happen if MSS is surprisingly small (< 576)
|
||||
// coincident with reasonably high throughput.
|
||||
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
|
||||
continue
|
||||
}
|
||||
if !device.isClosed() {
|
||||
if !errors.Is(err, os.ErrClosed) {
|
||||
device.log.Errorf("Failed to read packet from TUN device: %v", err)
|
||||
if !errors.Is(readErr, os.ErrClosed) {
|
||||
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
|
||||
}
|
||||
go device.Close()
|
||||
}
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
return
|
||||
}
|
||||
|
||||
if size == 0 || size > MaxContentSize {
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = elem.buffer[offset : offset+size]
|
||||
|
||||
// lookup peer
|
||||
|
||||
var peer *Peer
|
||||
switch elem.packet[0] >> 4 {
|
||||
case ipv4.Version:
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
case ipv6.Version:
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||
peer = device.allowedips.Lookup(dst)
|
||||
|
||||
default:
|
||||
device.log.Verbosef("Received packet with unknown IP version")
|
||||
}
|
||||
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
if peer.isRunning.Load() {
|
||||
peer.StagePacket(elem)
|
||||
elem = nil
|
||||
peer.SendStagedPackets()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
|
||||
func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
|
||||
for {
|
||||
select {
|
||||
case peer.queue.staged <- elem:
|
||||
case peer.queue.staged <- elems:
|
||||
return
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case tooOld := <-peer.queue.staged:
|
||||
peer.device.PutMessageBuffer(tooOld.buffer)
|
||||
peer.device.PutOutboundElement(tooOld)
|
||||
for _, elem := range *tooOld {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsSlice(tooOld)
|
||||
default:
|
||||
}
|
||||
}
|
||||
@ -305,26 +348,55 @@ top:
|
||||
}
|
||||
|
||||
for {
|
||||
var elemsOOO *[]*QueueOutboundElement
|
||||
select {
|
||||
case elem := <-peer.queue.staged:
|
||||
elem.peer = peer
|
||||
elem.nonce = keypair.sendNonce.Add(1) - 1
|
||||
if elem.nonce >= RejectAfterMessages {
|
||||
keypair.sendNonce.Store(RejectAfterMessages)
|
||||
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
|
||||
case elems := <-peer.queue.staged:
|
||||
i := 0
|
||||
for _, elem := range *elems {
|
||||
elem.peer = peer
|
||||
elem.nonce = keypair.sendNonce.Add(1) - 1
|
||||
if elem.nonce >= RejectAfterMessages {
|
||||
keypair.sendNonce.Store(RejectAfterMessages)
|
||||
if elemsOOO == nil {
|
||||
elemsOOO = peer.device.GetOutboundElementsSlice()
|
||||
}
|
||||
*elemsOOO = append(*elemsOOO, elem)
|
||||
continue
|
||||
} else {
|
||||
(*elems)[i] = elem
|
||||
i++
|
||||
}
|
||||
|
||||
elem.keypair = keypair
|
||||
elem.Lock()
|
||||
}
|
||||
*elems = (*elems)[:i]
|
||||
|
||||
if elemsOOO != nil {
|
||||
peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans
|
||||
}
|
||||
|
||||
if len(*elems) == 0 {
|
||||
peer.device.PutOutboundElementsSlice(elems)
|
||||
goto top
|
||||
}
|
||||
|
||||
elem.keypair = keypair
|
||||
elem.Lock()
|
||||
|
||||
// add to parallel and sequential queue
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.outbound.c <- elem
|
||||
peer.device.queue.encryption.c <- elem
|
||||
peer.queue.outbound.c <- elems
|
||||
for _, elem := range *elems {
|
||||
peer.device.queue.encryption.c <- elem
|
||||
}
|
||||
} else {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
for _, elem := range *elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsSlice(elems)
|
||||
}
|
||||
|
||||
if elemsOOO != nil {
|
||||
goto top
|
||||
}
|
||||
default:
|
||||
return
|
||||
@ -335,9 +407,12 @@ top:
|
||||
func (peer *Peer) FlushStagedPackets() {
|
||||
for {
|
||||
select {
|
||||
case elem := <-peer.queue.staged:
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
case elems := <-peer.queue.staged:
|
||||
for _, elem := range *elems {
|
||||
peer.device.PutMessageBuffer(elem.buffer)
|
||||
peer.device.PutOutboundElement(elem)
|
||||
}
|
||||
peer.device.PutOutboundElementsSlice(elems)
|
||||
default:
|
||||
return
|
||||
}
|
||||
@ -400,12 +475,7 @@ func (device *Device) RoutineEncryption(id int) {
|
||||
}
|
||||
}
|
||||
|
||||
/* Sequentially reads packets from queue and sends to endpoint
|
||||
*
|
||||
* Obs. Single instance per peer.
|
||||
* The routine terminates then the outbound queue is closed.
|
||||
*/
|
||||
func (peer *Peer) RoutineSequentialSender() {
|
||||
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||
device := peer.device
|
||||
defer func() {
|
||||
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
|
||||
@ -413,36 +483,50 @@ func (peer *Peer) RoutineSequentialSender() {
|
||||
}()
|
||||
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
|
||||
|
||||
for elem := range peer.queue.outbound.c {
|
||||
if elem == nil {
|
||||
buffs := make([][]byte, 0, maxBatchSize)
|
||||
|
||||
for elems := range peer.queue.outbound.c {
|
||||
buffs = buffs[:0]
|
||||
if elems == nil {
|
||||
return
|
||||
}
|
||||
elem.Lock()
|
||||
if !peer.isRunning.Load() {
|
||||
// peer has been stopped; return re-usable elems to the shared pool.
|
||||
// This is an optimization only. It is possible for the peer to be stopped
|
||||
// immediately after this check, in which case, elem will get processed.
|
||||
// The timers and SendBuffer code are resilient to a few stragglers.
|
||||
// The timers and SendBuffers code are resilient to a few stragglers.
|
||||
// TODO: rework peer shutdown order to ensure
|
||||
// that we never accidentally keep timers alive longer than necessary.
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
for _, elem := range *elems {
|
||||
elem.Lock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
continue
|
||||
}
|
||||
dataSent := false
|
||||
for _, elem := range *elems {
|
||||
elem.Lock()
|
||||
if len(elem.packet) != MessageKeepaliveSize {
|
||||
dataSent = true
|
||||
}
|
||||
buffs = append(buffs, elem.packet)
|
||||
}
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
// send message and return buffer to pool
|
||||
|
||||
err := peer.SendBuffer(elem.packet)
|
||||
if len(elem.packet) != MessageKeepaliveSize {
|
||||
err := peer.SendBuffers(buffs)
|
||||
if dataSent {
|
||||
peer.timersDataSent()
|
||||
}
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
for _, elem := range *elems {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
device.PutOutboundElement(elem)
|
||||
}
|
||||
device.PutOutboundElementsSlice(elems)
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to send data packet: %v", peer, err)
|
||||
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
14
main.go
14
main.go
@ -13,8 +13,8 @@ import (
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
@ -111,7 +111,7 @@ func main() {
|
||||
|
||||
// open TUN device (or use supplied fd)
|
||||
|
||||
tun, err := func() (tun.Device, error) {
|
||||
tdev, err := func() (tun.Device, error) {
|
||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||
if tunFdStr == "" {
|
||||
return tun.CreateTUN(interfaceName, device.DefaultMTU)
|
||||
@ -124,7 +124,7 @@ func main() {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = syscall.SetNonblock(int(fd), true)
|
||||
err = unix.SetNonblock(int(fd), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -134,7 +134,7 @@ func main() {
|
||||
}()
|
||||
|
||||
if err == nil {
|
||||
realInterfaceName, err2 := tun.Name()
|
||||
realInterfaceName, err2 := tdev.Name()
|
||||
if err2 == nil {
|
||||
interfaceName = realInterfaceName
|
||||
}
|
||||
@ -196,7 +196,7 @@ func main() {
|
||||
files[0], // stdin
|
||||
files[1], // stdout
|
||||
files[2], // stderr
|
||||
tun.File(),
|
||||
tdev.File(),
|
||||
fileUAPI,
|
||||
},
|
||||
Dir: ".",
|
||||
@ -222,7 +222,7 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
|
||||
device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
|
||||
|
||||
logger.Verbosef("Device started")
|
||||
|
||||
@ -250,7 +250,7 @@ func main() {
|
||||
|
||||
// wait for program to terminate
|
||||
|
||||
signal.Notify(term, syscall.SIGTERM)
|
||||
signal.Notify(term, unix.SIGTERM)
|
||||
signal.Notify(term, os.Interrupt)
|
||||
|
||||
select {
|
||||
|
@ -9,7 +9,8 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
@ -81,7 +82,7 @@ func main() {
|
||||
|
||||
signal.Notify(term, os.Interrupt)
|
||||
signal.Notify(term, os.Kill)
|
||||
signal.Notify(term, syscall.SIGTERM)
|
||||
signal.Notify(term, windows.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-term:
|
||||
|
60
tun/errors.go
Normal file
60
tun/errors.go
Normal file
@ -0,0 +1,60 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrTooManySegments is returned by Device.Read() when segmentation
|
||||
// overflows the length of supplied buffers. This error should not cause
|
||||
// reads to cease.
|
||||
ErrTooManySegments = errors.New("too many segments")
|
||||
)
|
||||
|
||||
type errorBatch []error
|
||||
|
||||
// ErrorBatch takes a possibly nil or empty list of errors, and if the list is
|
||||
// non-nil returns an error type that wraps all of the errors. Expected usage is
|
||||
// to append to an []errors and coerce the set to an error using this method.
|
||||
func ErrorBatch(errs []error) error {
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return errorBatch(errs)
|
||||
}
|
||||
|
||||
func (e errorBatch) Error() string {
|
||||
if len(e) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(e) == 1 {
|
||||
return e[0].Error()
|
||||
}
|
||||
return fmt.Sprintf("batch operation: %v (and %d more errors)", e[0], len(e)-1)
|
||||
}
|
||||
|
||||
func (e errorBatch) Is(target error) bool {
|
||||
for _, err := range e {
|
||||
if errors.Is(err, target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e errorBatch) As(target interface{}) bool {
|
||||
for _, err := range e {
|
||||
if errors.As(err, target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e errorBatch) Unwrap() error {
|
||||
if len(e) == 0 {
|
||||
return nil
|
||||
}
|
||||
return e[0]
|
||||
}
|
@ -19,6 +19,7 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@ -113,29 +114,37 @@ func (tun *netTun) Events() <-chan tun.Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *netTun) Read(buf []byte, offset int) (int, error) {
|
||||
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
|
||||
view, ok := <-tun.incomingPacket
|
||||
if !ok {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
return view.Read(buf[offset:])
|
||||
n, err := view.Read(buf[0][offset:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (tun *netTun) Write(buf []byte, offset int) (int, error) {
|
||||
packet := buf[offset:]
|
||||
if len(packet) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
|
||||
for _, buf := range buf {
|
||||
packet := buf[offset:]
|
||||
if len(packet) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
|
||||
switch packet[0] >> 4 {
|
||||
case 4:
|
||||
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||
case 6:
|
||||
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
|
||||
switch packet[0] >> 4 {
|
||||
case 4:
|
||||
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||
case 6:
|
||||
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
|
||||
default:
|
||||
return 0, syscall.EAFNOSUPPORT
|
||||
}
|
||||
}
|
||||
|
||||
return len(buf), nil
|
||||
}
|
||||
|
||||
@ -151,10 +160,6 @@ func (tun *netTun) WriteNotify() {
|
||||
tun.incomingPacket <- view
|
||||
}
|
||||
|
||||
func (tun *netTun) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *netTun) Close() error {
|
||||
tun.stack.RemoveNIC(1)
|
||||
|
||||
@ -175,6 +180,10 @@ func (tun *netTun) MTU() (int, error) {
|
||||
return tun.mtu, nil
|
||||
}
|
||||
|
||||
func (tun *netTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
||||
var protoNumber tcpip.NetworkProtocolNumber
|
||||
if endpoint.Addr().Is4() {
|
||||
|
40
tun/tun.go
40
tun/tun.go
@ -18,12 +18,36 @@ const (
|
||||
)
|
||||
|
||||
type Device interface {
|
||||
File() *os.File // returns the file descriptor of the device
|
||||
Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
|
||||
Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
|
||||
Flush() error // flush all previous writes to the device
|
||||
MTU() (int, error) // returns the MTU of the device
|
||||
Name() (string, error) // fetches and returns the current name
|
||||
Events() <-chan Event // returns a constant channel of events related to the device
|
||||
Close() error // stops the device and closes the event channel
|
||||
// File returns the file descriptor of the device.
|
||||
File() *os.File
|
||||
|
||||
// Read one or more packets from the Device (without any additional headers).
|
||||
// On a successful read it returns the number of packets read, and sets
|
||||
// packet lengths within the sizes slice. len(sizes) must be >= len(buffs).
|
||||
// A nonzero offset can be used to instruct the Device on where to begin
|
||||
// reading into each element of the buffs slice.
|
||||
Read(buffs [][]byte, sizes []int, offset int) (n int, err error)
|
||||
|
||||
// Write one or more packets to the device (without any additional headers).
|
||||
// On a successful write it returns the number of packets written. A nonzero
|
||||
// offset can be used to instruct the Device on where to begin writing from
|
||||
// each packet contained within the buffs slice.
|
||||
Write(buffs [][]byte, offset int) (int, error)
|
||||
|
||||
// MTU returns the MTU of the Device.
|
||||
MTU() (int, error)
|
||||
|
||||
// Name returns the current name of the Device.
|
||||
Name() (string, error)
|
||||
|
||||
// Events returns a channel of type Event, which is fed Device events.
|
||||
Events() <-chan Event
|
||||
|
||||
// Close stops the Device and closes the Event channel.
|
||||
Close() error
|
||||
|
||||
// BatchSize returns the preferred/max number of packets that can be read or
|
||||
// written in a single read/write call. BatchSize must not change over the
|
||||
// lifetime of a Device.
|
||||
BatchSize() int
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ package tun
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
@ -15,7 +16,6 @@ import (
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
@ -33,7 +33,7 @@ type NativeTun struct {
|
||||
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
|
||||
for i := 0; i < 20; i++ {
|
||||
iface, err = net.InterfaceByIndex(index)
|
||||
if err != nil && errors.Is(err, syscall.ENOMEM) {
|
||||
if err != nil && errors.Is(err, unix.ENOMEM) {
|
||||
time.Sleep(time.Duration(i) * time.Second / 3)
|
||||
continue
|
||||
}
|
||||
@ -55,7 +55,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
||||
retry:
|
||||
n, err := unix.Read(tun.routeSocket, data)
|
||||
if err != nil {
|
||||
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
|
||||
if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR {
|
||||
goto retry
|
||||
}
|
||||
tun.errors <- err
|
||||
@ -217,45 +217,46 @@ func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||
func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
|
||||
// TODO: the BSDs look very similar in Read() and Write(). They should be
|
||||
// collapsed, with platform-specific files containing the varying parts of
|
||||
// their implementations.
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
buff := buff[offset-4:]
|
||||
buff := buffs[0][offset-4:]
|
||||
n, err := tun.tunFile.Read(buff[:])
|
||||
if n < 4 {
|
||||
return 0, err
|
||||
}
|
||||
return n - 4, err
|
||||
sizes[0] = n - 4
|
||||
return 1, err
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
// reserve space for header
|
||||
|
||||
buff = buff[offset-4:]
|
||||
|
||||
// add packet information header
|
||||
|
||||
buff[0] = 0x00
|
||||
buff[1] = 0x00
|
||||
buff[2] = 0x00
|
||||
|
||||
if buff[4]>>4 == ipv6.Version {
|
||||
buff[3] = unix.AF_INET6
|
||||
} else {
|
||||
buff[3] = unix.AF_INET
|
||||
func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
|
||||
if offset < 4 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
// write
|
||||
|
||||
return tun.tunFile.Write(buff)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
for i, buf := range buffs {
|
||||
buf = buf[offset-4:]
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
buf[2] = 0x00
|
||||
switch buf[4] >> 4 {
|
||||
case 4:
|
||||
buf[3] = unix.AF_INET
|
||||
case 6:
|
||||
buf[3] = unix.AF_INET6
|
||||
default:
|
||||
return i, unix.EAFNOSUPPORT
|
||||
}
|
||||
if _, err := tun.tunFile.Write(buf); err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(buffs), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
@ -318,6 +319,10 @@ func (tun *NativeTun) MTU() (int, error) {
|
||||
return int(ifr.MTU), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func socketCloexec(family, sotype, proto int) (fd int, err error) {
|
||||
// See go/src/net/sys_cloexec.go for background.
|
||||
syscall.ForkLock.RLock()
|
||||
|
@ -333,45 +333,46 @@ func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||
func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
buff := buff[offset-4:]
|
||||
buff := buffs[0][offset-4:]
|
||||
n, err := tun.tunFile.Read(buff[:])
|
||||
if n < 4 {
|
||||
return 0, err
|
||||
}
|
||||
return n - 4, err
|
||||
sizes[0] = n - 4
|
||||
return 1, err
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
|
||||
func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
|
||||
if offset < 4 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
buf = buf[offset-4:]
|
||||
if len(buf) < 5 {
|
||||
return 0, io.ErrShortBuffer
|
||||
for i, buf := range buffs {
|
||||
buf = buf[offset-4:]
|
||||
if len(buf) < 5 {
|
||||
return i, io.ErrShortBuffer
|
||||
}
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
buf[2] = 0x00
|
||||
switch buf[4] >> 4 {
|
||||
case 4:
|
||||
buf[3] = unix.AF_INET
|
||||
case 6:
|
||||
buf[3] = unix.AF_INET6
|
||||
default:
|
||||
return i, unix.EAFNOSUPPORT
|
||||
}
|
||||
if _, err := tun.tunFile.Write(buf); err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
buf[2] = 0x00
|
||||
switch buf[4] >> 4 {
|
||||
case 4:
|
||||
buf[3] = unix.AF_INET
|
||||
case 6:
|
||||
buf[3] = unix.AF_INET6
|
||||
default:
|
||||
return 0, unix.EAFNOSUPPORT
|
||||
}
|
||||
return tun.tunFile.Write(buf)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
return len(buffs), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
@ -428,3 +429,7 @@ func (tun *NativeTun) MTU() (int, error) {
|
||||
}
|
||||
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
@ -323,12 +323,13 @@ func (tun *NativeTun) nameSlow() (string, error) {
|
||||
return unix.ByteSliceToString(ifr[:]), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
|
||||
func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) {
|
||||
var buf []byte
|
||||
if tun.nopi {
|
||||
buf = buf[offset:]
|
||||
buf = buffs[0][offset:]
|
||||
} else {
|
||||
// reserve space for header
|
||||
buf = buf[offset-4:]
|
||||
buf = buffs[0][offset-4:]
|
||||
|
||||
// add packet information header
|
||||
buf[0] = 0x00
|
||||
@ -342,34 +343,36 @@ func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
n, err := tun.tunFile.Write(buf)
|
||||
_, err = tun.tunFile.Write(buf)
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
} else if err == nil {
|
||||
n = 1
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) {
|
||||
func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
select {
|
||||
case err = <-tun.errors:
|
||||
default:
|
||||
if tun.nopi {
|
||||
n, err = tun.tunFile.Read(buf[offset:])
|
||||
sizes[0], err = tun.tunFile.Read(buffs[0][offset:])
|
||||
if err == nil {
|
||||
n = 1
|
||||
}
|
||||
} else {
|
||||
buff := buf[offset-4:]
|
||||
n, err = tun.tunFile.Read(buff[:])
|
||||
buff := buffs[0][offset-4:]
|
||||
sizes[0], err = tun.tunFile.Read(buff[:])
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
} else if err == nil {
|
||||
n = 1
|
||||
}
|
||||
if n < 4 {
|
||||
n = 0
|
||||
if sizes[0] < 4 {
|
||||
sizes[0] = 0
|
||||
} else {
|
||||
n -= 4
|
||||
sizes[0] -= 4
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -399,6 +402,10 @@ func (tun *NativeTun) Close() error {
|
||||
return err2
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func CreateTUN(name string, mtu int) (Device, error) {
|
||||
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
|
@ -8,13 +8,13 @@ package tun
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
@ -204,45 +204,43 @@ func (tun *NativeTun) Events() <-chan Event {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||
func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
buff := buff[offset-4:]
|
||||
buff := buffs[0][offset-4:]
|
||||
n, err := tun.tunFile.Read(buff[:])
|
||||
if n < 4 {
|
||||
return 0, err
|
||||
}
|
||||
return n - 4, err
|
||||
sizes[0] = n - 4
|
||||
return 1, err
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
// reserve space for header
|
||||
|
||||
buff = buff[offset-4:]
|
||||
|
||||
// add packet information header
|
||||
|
||||
buff[0] = 0x00
|
||||
buff[1] = 0x00
|
||||
buff[2] = 0x00
|
||||
|
||||
if buff[4]>>4 == ipv6.Version {
|
||||
buff[3] = unix.AF_INET6
|
||||
} else {
|
||||
buff[3] = unix.AF_INET
|
||||
func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
|
||||
if offset < 4 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
// write
|
||||
|
||||
return tun.tunFile.Write(buff)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
for i, buf := range buffs {
|
||||
buf = buf[offset-4:]
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
buf[2] = 0x00
|
||||
switch buf[4] >> 4 {
|
||||
case 4:
|
||||
buf[3] = unix.AF_INET
|
||||
case 6:
|
||||
buf[3] = unix.AF_INET6
|
||||
default:
|
||||
return i, unix.EAFNOSUPPORT
|
||||
}
|
||||
if _, err := tun.tunFile.Write(buf); err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(buffs), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
@ -329,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) {
|
||||
|
||||
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
_ "unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"golang.zx2c4.com/wintun"
|
||||
)
|
||||
|
||||
@ -44,6 +43,7 @@ type NativeTun struct {
|
||||
closeOnce sync.Once
|
||||
close atomic.Bool
|
||||
forcedMTU int
|
||||
outSizes []int
|
||||
}
|
||||
|
||||
var (
|
||||
@ -134,9 +134,14 @@ func (tun *NativeTun) ForceMTU(mtu int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
// TODO: implement batching with wintun
|
||||
return 1
|
||||
}
|
||||
|
||||
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
|
||||
|
||||
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||
func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
|
||||
tun.running.Add(1)
|
||||
defer tun.running.Done()
|
||||
retry:
|
||||
@ -153,10 +158,11 @@ retry:
|
||||
switch err {
|
||||
case nil:
|
||||
packetSize := len(packet)
|
||||
copy(buff[offset:], packet)
|
||||
copy(buffs[0][offset:], packet)
|
||||
sizes[0] = packetSize
|
||||
tun.session.ReleaseReceivePacket(packet)
|
||||
tun.rate.update(uint64(packetSize))
|
||||
return packetSize, nil
|
||||
return 1, nil
|
||||
case windows.ERROR_NO_MORE_ITEMS:
|
||||
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
|
||||
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
|
||||
@ -173,33 +179,33 @@ retry:
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
|
||||
tun.running.Add(1)
|
||||
defer tun.running.Done()
|
||||
if tun.close.Load() {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
packetSize := len(buff) - offset
|
||||
tun.rate.update(uint64(packetSize))
|
||||
for i, buff := range buffs {
|
||||
packetSize := len(buff) - offset
|
||||
tun.rate.update(uint64(packetSize))
|
||||
|
||||
packet, err := tun.session.AllocateSendPacket(packetSize)
|
||||
if err == nil {
|
||||
copy(packet, buff[offset:])
|
||||
tun.session.SendPacket(packet)
|
||||
return packetSize, nil
|
||||
packet, err := tun.session.AllocateSendPacket(packetSize)
|
||||
switch err {
|
||||
case nil:
|
||||
// TODO: Explore options to eliminate this copy.
|
||||
copy(packet, buff[offset:])
|
||||
tun.session.SendPacket(packet)
|
||||
continue
|
||||
case windows.ERROR_HANDLE_EOF:
|
||||
return i, os.ErrClosed
|
||||
case windows.ERROR_BUFFER_OVERFLOW:
|
||||
continue // Dropping when ring is full.
|
||||
default:
|
||||
return i, fmt.Errorf("Write failed: %w", err)
|
||||
}
|
||||
}
|
||||
switch err {
|
||||
case windows.ERROR_HANDLE_EOF:
|
||||
return 0, os.ErrClosed
|
||||
case windows.ERROR_BUFFER_OVERFLOW:
|
||||
return 0, nil // Dropping when ring is full.
|
||||
}
|
||||
return 0, fmt.Errorf("Write failed: %w", err)
|
||||
return len(buffs), nil
|
||||
}
|
||||
|
||||
// LUID returns Windows interface instance ID.
|
||||
|
@ -110,35 +110,42 @@ type chTun struct {
|
||||
|
||||
func (t *chTun) File() *os.File { return nil }
|
||||
|
||||
func (t *chTun) Read(data []byte, offset int) (int, error) {
|
||||
func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) {
|
||||
select {
|
||||
case <-t.c.closed:
|
||||
return 0, os.ErrClosed
|
||||
case msg := <-t.c.Outbound:
|
||||
return copy(data[offset:], msg), nil
|
||||
n := copy(packets[0][offset:], msg)
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write is called by the wireguard device to deliver a packet for routing.
|
||||
func (t *chTun) Write(data []byte, offset int) (int, error) {
|
||||
func (t *chTun) Write(packets [][]byte, offset int) (int, error) {
|
||||
if offset == -1 {
|
||||
close(t.c.closed)
|
||||
close(t.c.events)
|
||||
return 0, io.EOF
|
||||
}
|
||||
msg := make([]byte, len(data)-offset)
|
||||
copy(msg, data[offset:])
|
||||
select {
|
||||
case <-t.c.closed:
|
||||
return 0, os.ErrClosed
|
||||
case t.c.Inbound <- msg:
|
||||
return len(data) - offset, nil
|
||||
for i, data := range packets {
|
||||
msg := make([]byte, len(data)-offset)
|
||||
copy(msg, data[offset:])
|
||||
select {
|
||||
case <-t.c.closed:
|
||||
return i, os.ErrClosed
|
||||
case t.c.Inbound <- msg:
|
||||
}
|
||||
}
|
||||
return len(packets), nil
|
||||
}
|
||||
|
||||
func (t *chTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
const DefaultMTU = 1420
|
||||
|
||||
func (t *chTun) Flush() error { return nil }
|
||||
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
|
||||
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
|
||||
func (t *chTun) Events() <-chan tun.Event { return t.c.events }
|
||||
|
Loading…
Reference in New Issue
Block a user