mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18 comes out with support for it in the "net" package. Also, allowedips still uses slices internally, which might be suboptimal. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
de7c702ace
commit
ef8d6804d7
@ -14,6 +14,7 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ipv4Source struct {
|
type ipv4Source struct {
|
||||||
@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil)
|
|||||||
|
|
||||||
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
|
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||||
var end LinuxSocketEndpoint
|
var end LinuxSocketEndpoint
|
||||||
addr, err := parseEndpoint(s)
|
e, err := netip.ParseAddrPort(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ipv4 := addr.IP.To4()
|
if e.Addr().Is4() {
|
||||||
if ipv4 != nil {
|
|
||||||
dst := end.dst4()
|
dst := end.dst4()
|
||||||
end.isV6 = false
|
end.isV6 = false
|
||||||
dst.Port = addr.Port
|
dst.Port = int(e.Port())
|
||||||
copy(dst.Addr[:], ipv4)
|
dst.Addr = e.Addr().As4()
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
return &end, nil
|
return &end, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipv6 := addr.IP.To16()
|
if e.Addr().Is6() {
|
||||||
if ipv6 != nil {
|
zone, err := zoneToUint32(e.Addr().Zone())
|
||||||
zone, err := zoneToUint32(addr.Zone)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dst := end.dst6()
|
dst := end.dst6()
|
||||||
end.isV6 = true
|
end.isV6 = true
|
||||||
dst.Port = addr.Port
|
dst.Port = int(e.Port())
|
||||||
dst.ZoneId = zone
|
dst.ZoneId = zone
|
||||||
copy(dst.Addr[:], ipv6[:])
|
dst.Addr = e.Addr().As16()
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
return &end, nil
|
return &end, nil
|
||||||
}
|
}
|
||||||
@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (end *LinuxSocketEndpoint) SrcIP() net.IP {
|
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
|
||||||
if !end.isV6 {
|
if !end.isV6 {
|
||||||
return net.IPv4(
|
return netip.AddrFrom4(end.src4().Src)
|
||||||
end.src4().Src[0],
|
|
||||||
end.src4().Src[1],
|
|
||||||
end.src4().Src[2],
|
|
||||||
end.src4().Src[3],
|
|
||||||
)
|
|
||||||
} else {
|
} else {
|
||||||
return end.src6().src[:]
|
return netip.AddrFrom16(end.src6().src)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (end *LinuxSocketEndpoint) DstIP() net.IP {
|
func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
|
||||||
if !end.isV6 {
|
if !end.isV6 {
|
||||||
return net.IPv4(
|
return netip.AddrFrom4(end.dst4().Addr)
|
||||||
end.dst4().Addr[0],
|
|
||||||
end.dst4().Addr[1],
|
|
||||||
end.dst4().Addr[2],
|
|
||||||
end.dst4().Addr[3],
|
|
||||||
)
|
|
||||||
} else {
|
} else {
|
||||||
return end.dst6().Addr[:]
|
return netip.AddrFrom16(end.dst6().Addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (end *LinuxSocketEndpoint) DstToString() string {
|
func (end *LinuxSocketEndpoint) DstToString() string {
|
||||||
var udpAddr net.UDPAddr
|
var port int
|
||||||
udpAddr.IP = end.DstIP()
|
|
||||||
if !end.isV6 {
|
if !end.isV6 {
|
||||||
udpAddr.Port = end.dst4().Port
|
port = end.dst4().Port
|
||||||
} else {
|
} else {
|
||||||
udpAddr.Port = end.dst6().Port
|
port = end.dst6().Port
|
||||||
}
|
}
|
||||||
return udpAddr.String()
|
return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (end *LinuxSocketEndpoint) ClearDst() {
|
func (end *LinuxSocketEndpoint) ClearDst() {
|
||||||
|
@ -10,6 +10,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StdNetBind is meant to be a temporary solution on platforms for which
|
// StdNetBind is meant to be a temporary solution on platforms for which
|
||||||
@ -32,18 +34,23 @@ var _ Bind = (*StdNetBind)(nil)
|
|||||||
var _ Endpoint = (*StdNetEndpoint)(nil)
|
var _ Endpoint = (*StdNetEndpoint)(nil)
|
||||||
|
|
||||||
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||||
addr, err := parseEndpoint(s)
|
e, err := netip.ParseAddrPort(s)
|
||||||
return (*StdNetEndpoint)(addr), err
|
return (*StdNetEndpoint)(&net.UDPAddr{
|
||||||
|
IP: e.Addr().AsSlice(),
|
||||||
|
Port: int(e.Port()),
|
||||||
|
Zone: e.Addr().Zone(),
|
||||||
|
}), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*StdNetEndpoint) ClearSrc() {}
|
func (*StdNetEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
func (e *StdNetEndpoint) DstIP() net.IP {
|
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||||
return (*net.UDPAddr)(e).IP
|
a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP)
|
||||||
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcIP() net.IP {
|
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||||
return nil // not supported
|
return netip.Addr{} // not supported
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *StdNetEndpoint) DstToBytes() []byte {
|
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn/winrio"
|
"golang.zx2c4.com/wireguard/conn/winrio"
|
||||||
)
|
)
|
||||||
@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
|
|||||||
|
|
||||||
func (*WinRingEndpoint) ClearSrc() {}
|
func (*WinRingEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
func (e *WinRingEndpoint) DstIP() net.IP {
|
func (e *WinRingEndpoint) DstIP() netip.Addr {
|
||||||
switch e.family {
|
switch e.family {
|
||||||
case windows.AF_INET:
|
case windows.AF_INET:
|
||||||
return append([]byte{}, e.data[2:6]...)
|
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
|
||||||
case windows.AF_INET6:
|
case windows.AF_INET6:
|
||||||
return append([]byte{}, e.data[6:22]...)
|
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
|
||||||
}
|
}
|
||||||
return nil
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *WinRingEndpoint) SrcIP() net.IP {
|
func (e *WinRingEndpoint) SrcIP() netip.Addr {
|
||||||
return nil // not supported
|
return netip.Addr{} // not supported
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *WinRingEndpoint) DstToBytes() []byte {
|
func (e *WinRingEndpoint) DstToBytes() []byte {
|
||||||
@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
|
|||||||
func (e *WinRingEndpoint) DstToString() string {
|
func (e *WinRingEndpoint) DstToString() string {
|
||||||
switch e.family {
|
switch e.family {
|
||||||
case windows.AF_INET:
|
case windows.AF_INET:
|
||||||
addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
|
netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||||
return addr.String()
|
|
||||||
case windows.AF_INET6:
|
case windows.AF_INET6:
|
||||||
var zone string
|
var zone string
|
||||||
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
|
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
|
||||||
zone = strconv.FormatUint(uint64(scope), 10)
|
zone = strconv.FormatUint(uint64(scope), 10)
|
||||||
}
|
}
|
||||||
addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
|
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||||
return addr.String()
|
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -10,8 +10,8 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
|
|||||||
|
|
||||||
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
|
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
|
||||||
|
|
||||||
func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
|
func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
|
||||||
|
|
||||||
func (c ChannelEndpoint) SrcIP() net.IP { return nil }
|
func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
|
||||||
|
|
||||||
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||||||
c.closeSignal = make(chan bool)
|
c.closeSignal = make(chan bool)
|
||||||
@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||||
_, port, err := net.SplitHostPort(s)
|
addr, err := netip.ParseAddrPort(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
i, err := strconv.ParseUint(port, 10, 16)
|
return ChannelEndpoint(addr.Port()), nil
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ChannelEndpoint(i), nil
|
|
||||||
}
|
}
|
||||||
|
37
conn/conn.go
37
conn/conn.go
@ -9,10 +9,11 @@ package conn
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A ReceiveFunc receives a single inbound packet from the network.
|
// A ReceiveFunc receives a single inbound packet from the network.
|
||||||
@ -68,8 +69,8 @@ type Endpoint interface {
|
|||||||
SrcToString() string // returns the local source address (ip:port)
|
SrcToString() string // returns the local source address (ip:port)
|
||||||
DstToString() string // returns the destination address (ip:port)
|
DstToString() string // returns the destination address (ip:port)
|
||||||
DstToBytes() []byte // used for mac2 cookie calculations
|
DstToBytes() []byte // used for mac2 cookie calculations
|
||||||
DstIP() net.IP
|
DstIP() netip.Addr
|
||||||
SrcIP() net.IP
|
SrcIP() netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string {
|
|||||||
}
|
}
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
|
||||||
// ensure that the host is an IP address
|
|
||||||
|
|
||||||
host, _, err := net.SplitHostPort(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
|
|
||||||
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
|
|
||||||
// trying to make sure with a small sanity test that this is a real IP address and
|
|
||||||
// not something that's likely to incur DNS lookups.
|
|
||||||
host = host[:i]
|
|
||||||
}
|
|
||||||
if ip := net.ParseIP(host); ip == nil {
|
|
||||||
return nil, errors.New("Failed to parse IP address: " + host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse address and port
|
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ip4 := addr.IP.To4()
|
|
||||||
if ip4 != nil {
|
|
||||||
addr.IP = ip4
|
|
||||||
}
|
|
||||||
return addr, err
|
|
||||||
}
|
|
||||||
|
@ -12,6 +12,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type parentIndirection struct {
|
type parentIndirection struct {
|
||||||
@ -26,7 +28,7 @@ type trieEntry struct {
|
|||||||
cidr uint8
|
cidr uint8
|
||||||
bitAtByte uint8
|
bitAtByte uint8
|
||||||
bitAtShift uint8
|
bitAtShift uint8
|
||||||
bits net.IP
|
bits []byte
|
||||||
perPeerElem *list.Element
|
perPeerElem *list.Element
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 {
|
|||||||
return bits.ReverseBytes64(i)
|
return bits.ReverseBytes64(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
|
func commonBits(ip1, ip2 []byte) uint8 {
|
||||||
size := len(ip1)
|
size := len(ip1)
|
||||||
if size == net.IPv4len {
|
if size == net.IPv4len {
|
||||||
a := (*uint32)(unsafe.Pointer(&ip1[0]))
|
a := (*uint32)(unsafe.Pointer(&ip1[0]))
|
||||||
@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) choose(ip net.IP) byte {
|
func (node *trieEntry) choose(ip []byte) byte {
|
||||||
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
|
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() {
|
|||||||
node.parent.parentBit = nil
|
node.parent.parentBit = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
|
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
|
||||||
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
|
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
|
||||||
parent = node
|
parent = node
|
||||||
if parent.cidr == cidr {
|
if parent.cidr == cidr {
|
||||||
@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
|
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
|
||||||
if *trie.parentBit == nil {
|
if *trie.parentBit == nil {
|
||||||
node := &trieEntry{
|
node := &trieEntry{
|
||||||
peer: peer,
|
peer: peer,
|
||||||
@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
func (node *trieEntry) lookup(ip []byte) *Peer {
|
||||||
var found *Peer
|
var found *Peer
|
||||||
size := uint8(len(ip))
|
size := uint8(len(ip))
|
||||||
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
||||||
@ -229,13 +231,14 @@ type AllowedIPs struct {
|
|||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
|
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
|
||||||
table.mutex.RLock()
|
table.mutex.RLock()
|
||||||
defer table.mutex.RUnlock()
|
defer table.mutex.RUnlock()
|
||||||
|
|
||||||
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
|
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
|
||||||
node := elem.Value.(*trieEntry)
|
node := elem.Value.(*trieEntry)
|
||||||
if !cb(node.bits, node.cidr) {
|
a, _ := netip.AddrFromSlice(node.bits)
|
||||||
|
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
|
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
defer table.mutex.Unlock()
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
switch len(ip) {
|
if prefix.Addr().Is6() {
|
||||||
case net.IPv6len:
|
ip := prefix.Addr().As16()
|
||||||
parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
|
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||||
case net.IPv4len:
|
} else if prefix.Addr().Is4() {
|
||||||
parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
|
ip := prefix.Addr().As4()
|
||||||
default:
|
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||||
|
} else {
|
||||||
panic(errors.New("inserting unknown address type"))
|
panic(errors.New("inserting unknown address type"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) Lookup(address []byte) *Peer {
|
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
|
||||||
table.mutex.RLock()
|
table.mutex.RLock()
|
||||||
defer table.mutex.RUnlock()
|
defer table.mutex.RUnlock()
|
||||||
switch len(address) {
|
switch len(ip) {
|
||||||
case net.IPv6len:
|
case net.IPv6len:
|
||||||
return table.IPv6.lookup(address)
|
return table.IPv6.lookup(ip)
|
||||||
case net.IPv4len:
|
case net.IPv4len:
|
||||||
return table.IPv4.lookup(address)
|
return table.IPv4.lookup(ip)
|
||||||
default:
|
default:
|
||||||
panic(errors.New("looking up unknown address type"))
|
panic(errors.New("looking up unknown address type"))
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) {
|
|||||||
rand.Read(addr4[:])
|
rand.Read(addr4[:])
|
||||||
cidr := uint8(rand.Intn(32) + 1)
|
cidr := uint8(rand.Intn(32) + 1)
|
||||||
index := rand.Intn(NumberOfPeers)
|
index := rand.Intn(NumberOfPeers)
|
||||||
allowedIPs.Insert(addr4[:], cidr, peers[index])
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
|
||||||
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
|
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
|
||||||
|
|
||||||
var addr6 [16]byte
|
var addr6 [16]byte
|
||||||
rand.Read(addr6[:])
|
rand.Read(addr6[:])
|
||||||
cidr = uint8(rand.Intn(128) + 1)
|
cidr = uint8(rand.Intn(128) + 1)
|
||||||
index = rand.Intn(NumberOfPeers)
|
index = rand.Intn(NumberOfPeers)
|
||||||
allowedIPs.Insert(addr6[:], cidr, peers[index])
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
|
||||||
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
|
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,6 +9,8 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testPairCommonBits struct {
|
type testPairCommonBits struct {
|
||||||
@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) {
|
|||||||
var allowedIPs AllowedIPs
|
var allowedIPs AllowedIPs
|
||||||
|
|
||||||
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
||||||
allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) {
|
|||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
addr = append(addr, expand(c)...)
|
addr = append(addr, expand(c)...)
|
||||||
addr = append(addr, expand(d)...)
|
addr = append(addr, expand(d)...)
|
||||||
allowedIPs.Insert(addr, cidr, peer)
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"sync"
|
"sync"
|
||||||
@ -19,6 +18,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/conn/bindtest"
|
"golang.zx2c4.com/wireguard/conn/bindtest"
|
||||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||||
@ -96,7 +96,7 @@ type testPair [2]testPeer
|
|||||||
type testPeer struct {
|
type testPeer struct {
|
||||||
tun *tuntest.ChannelTUN
|
tun *tuntest.ChannelTUN
|
||||||
dev *Device
|
dev *Device
|
||||||
ip net.IP
|
ip netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
type SendDirection bool
|
type SendDirection bool
|
||||||
@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
|
|||||||
for i := range pair {
|
for i := range pair {
|
||||||
p := &pair[i]
|
p := &pair[i]
|
||||||
p.tun = tuntest.NewChannelTUN()
|
p.tun = tuntest.NewChannelTUN()
|
||||||
p.ip = net.IPv4(1, 0, 0, byte(i+1))
|
p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
|
||||||
level := LogLevelVerbose
|
level := LogLevelVerbose
|
||||||
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
||||||
level = LogLevelError
|
level = LogLevelError
|
||||||
|
@ -7,47 +7,44 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DummyEndpoint struct {
|
type DummyEndpoint struct {
|
||||||
src [16]byte
|
src, dst netip.Addr
|
||||||
dst [16]byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateDummyEndpoint() (*DummyEndpoint, error) {
|
func CreateDummyEndpoint() (*DummyEndpoint, error) {
|
||||||
var end DummyEndpoint
|
var src, dst [16]byte
|
||||||
if _, err := rand.Read(end.src[:]); err != nil {
|
if _, err := rand.Read(src[:]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
_, err := rand.Read(end.dst[:])
|
_, err := rand.Read(dst[:])
|
||||||
return &end, err
|
return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) ClearSrc() {}
|
func (e *DummyEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
func (e *DummyEndpoint) SrcToString() string {
|
func (e *DummyEndpoint) SrcToString() string {
|
||||||
var addr net.UDPAddr
|
return netip.AddrPortFrom(e.SrcIP(), 1000).String()
|
||||||
addr.IP = e.SrcIP()
|
|
||||||
addr.Port = 1000
|
|
||||||
return addr.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) DstToString() string {
|
func (e *DummyEndpoint) DstToString() string {
|
||||||
var addr net.UDPAddr
|
return netip.AddrPortFrom(e.DstIP(), 1000).String()
|
||||||
addr.IP = e.DstIP()
|
|
||||||
addr.Port = 1000
|
|
||||||
return addr.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) SrcToBytes() []byte {
|
func (e *DummyEndpoint) DstToBytes() []byte {
|
||||||
return e.src[:]
|
out := e.DstIP().AsSlice()
|
||||||
|
out = append(out, byte(1000&0xff))
|
||||||
|
out = append(out, byte((1000>>8)&0xff))
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) DstIP() net.IP {
|
func (e *DummyEndpoint) DstIP() netip.Addr {
|
||||||
return e.dst[:]
|
return e.dst
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) SrcIP() net.IP {
|
func (e *DummyEndpoint) SrcIP() netip.Addr {
|
||||||
return e.src[:]
|
return e.src
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,6 @@ import (
|
|||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||||||
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
|
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
|
||||||
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
|
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
|
||||||
|
|
||||||
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
|
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||||
sendf("allowed_ip=%s/%d", ip.String(), cidr)
|
sendf("allowed_ip=%s", prefix.String())
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -374,16 +375,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
|||||||
|
|
||||||
case "allowed_ip":
|
case "allowed_ip":
|
||||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
||||||
|
prefix, err := netip.ParsePrefix(value)
|
||||||
_, network, err := net.ParseCIDR(value)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||||
}
|
}
|
||||||
if peer.dummy {
|
if peer.dummy {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
ones, _ := network.Mask.Size()
|
device.allowedips.Insert(prefix, peer.Peer)
|
||||||
device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
|
|
||||||
|
|
||||||
case "protocol_version":
|
case "protocol_version":
|
||||||
if value != "1" {
|
if value != "1" {
|
||||||
|
7
go.mod
7
go.mod
@ -3,8 +3,9 @@ module golang.zx2c4.com/wireguard
|
|||||||
go 1.17
|
go 1.17
|
||||||
|
|
||||||
require (
|
require (
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
|
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa
|
||||||
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3
|
golang.org/x/net v0.0.0-20211111083644-e5c967477495
|
||||||
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b
|
golang.org/x/sys v0.0.0-20211110154304-99a53858aa08
|
||||||
|
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d
|
||||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
|
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
|
||||||
)
|
)
|
||||||
|
13
go.sum
13
go.sum
@ -1,16 +1,19 @@
|
|||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
|
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa h1:idItI2DDfCokpg0N51B2VtiLdJ4vAuXC9fnCb2gACo4=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 h1:VrJZAjbekhoRn7n5FBujY31gboH+iB3pdLxn3gE9FjU=
|
golang.org/x/net v0.0.0-20211111083644-e5c967477495 h1:cjxxlQm6d4kYbhpZ2ghvmI8xnq0AG+jXmzrhzfkyu5A=
|
||||||
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20211111083644-e5c967477495/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4=
|
|
||||||
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20211110154304-99a53858aa08 h1:WecRHqgE09JBkh/584XIE6PMz5KKE/vER4izNUi30AQ=
|
||||||
|
golang.org/x/sys v0.0.0-20211110154304-99a53858aa08/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I=
|
||||||
|
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
|
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
|
@ -6,9 +6,10 @@
|
|||||||
package ratelimiter
|
package ratelimiter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -30,8 +31,7 @@ type Ratelimiter struct {
|
|||||||
timeNow func() time.Time
|
timeNow func() time.Time
|
||||||
|
|
||||||
stopReset chan struct{} // send to reset, close to stop
|
stopReset chan struct{} // send to reset, close to stop
|
||||||
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
|
table map[netip.Addr]*RatelimiterEntry
|
||||||
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) Close() {
|
func (rate *Ratelimiter) Close() {
|
||||||
@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rate.stopReset = make(chan struct{})
|
rate.stopReset = make(chan struct{})
|
||||||
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
|
rate.table = make(map[netip.Addr]*RatelimiterEntry)
|
||||||
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
|
|
||||||
|
|
||||||
stopReset := rate.stopReset // store in case Init is called again.
|
stopReset := rate.stopReset // store in case Init is called again.
|
||||||
|
|
||||||
@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
|
|||||||
rate.mu.Lock()
|
rate.mu.Lock()
|
||||||
defer rate.mu.Unlock()
|
defer rate.mu.Unlock()
|
||||||
|
|
||||||
for key, entry := range rate.tableIPv4 {
|
for key, entry := range rate.table {
|
||||||
entry.mu.Lock()
|
entry.mu.Lock()
|
||||||
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
||||||
delete(rate.tableIPv4, key)
|
delete(rate.table, key)
|
||||||
}
|
}
|
||||||
entry.mu.Unlock()
|
entry.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, entry := range rate.tableIPv6 {
|
return len(rate.table) == 0
|
||||||
entry.mu.Lock()
|
|
||||||
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
|
||||||
delete(rate.tableIPv6, key)
|
|
||||||
}
|
|
||||||
entry.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
|
func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
|
||||||
}
|
|
||||||
|
|
||||||
func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
|
||||||
var entry *RatelimiterEntry
|
var entry *RatelimiterEntry
|
||||||
var keyIPv4 [net.IPv4len]byte
|
|
||||||
var keyIPv6 [net.IPv6len]byte
|
|
||||||
|
|
||||||
// lookup entry
|
// lookup entry
|
||||||
|
|
||||||
IPv4 := ip.To4()
|
|
||||||
IPv6 := ip.To16()
|
|
||||||
|
|
||||||
rate.mu.RLock()
|
rate.mu.RLock()
|
||||||
|
entry = rate.table[ip]
|
||||||
if IPv4 != nil {
|
|
||||||
copy(keyIPv4[:], IPv4)
|
|
||||||
entry = rate.tableIPv4[keyIPv4]
|
|
||||||
} else {
|
|
||||||
copy(keyIPv6[:], IPv6)
|
|
||||||
entry = rate.tableIPv6[keyIPv6]
|
|
||||||
}
|
|
||||||
|
|
||||||
rate.mu.RUnlock()
|
rate.mu.RUnlock()
|
||||||
|
|
||||||
// make new entry if not found
|
// make new entry if not found
|
||||||
|
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
entry = new(RatelimiterEntry)
|
entry = new(RatelimiterEntry)
|
||||||
entry.tokens = maxTokens - packetCost
|
entry.tokens = maxTokens - packetCost
|
||||||
entry.lastTime = rate.timeNow()
|
entry.lastTime = rate.timeNow()
|
||||||
rate.mu.Lock()
|
rate.mu.Lock()
|
||||||
if IPv4 != nil {
|
rate.table[ip] = entry
|
||||||
rate.tableIPv4[keyIPv4] = entry
|
if len(rate.table) == 1 {
|
||||||
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
|
|
||||||
rate.stopReset <- struct{}{}
|
rate.stopReset <- struct{}{}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
rate.tableIPv6[keyIPv6] = entry
|
|
||||||
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
|
|
||||||
rate.stopReset <- struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rate.mu.Unlock()
|
rate.mu.Unlock()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// add tokens to entry
|
// add tokens to entry
|
||||||
|
|
||||||
entry.mu.Lock()
|
entry.mu.Lock()
|
||||||
now := rate.timeNow()
|
now := rate.timeNow()
|
||||||
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
||||||
@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// subtract cost of packet
|
// subtract cost of packet
|
||||||
|
|
||||||
if entry.tokens > packetCost {
|
if entry.tokens > packetCost {
|
||||||
entry.tokens -= packetCost
|
entry.tokens -= packetCost
|
||||||
entry.mu.Unlock()
|
entry.mu.Unlock()
|
||||||
|
@ -6,9 +6,10 @@
|
|||||||
package ratelimiter
|
package ratelimiter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type result struct {
|
type result struct {
|
||||||
@ -71,21 +72,21 @@ func TestRatelimiter(t *testing.T) {
|
|||||||
text: "packet following 2 packet burst",
|
text: "packet following 2 packet burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
ips := []net.IP{
|
ips := []netip.Addr{
|
||||||
net.ParseIP("127.0.0.1"),
|
netip.MustParseAddr("127.0.0.1"),
|
||||||
net.ParseIP("192.168.1.1"),
|
netip.MustParseAddr("192.168.1.1"),
|
||||||
net.ParseIP("172.167.2.3"),
|
netip.MustParseAddr("172.167.2.3"),
|
||||||
net.ParseIP("97.231.252.215"),
|
netip.MustParseAddr("97.231.252.215"),
|
||||||
net.ParseIP("248.97.91.167"),
|
netip.MustParseAddr("248.97.91.167"),
|
||||||
net.ParseIP("188.208.233.47"),
|
netip.MustParseAddr("188.208.233.47"),
|
||||||
net.ParseIP("104.2.183.179"),
|
netip.MustParseAddr("104.2.183.179"),
|
||||||
net.ParseIP("72.129.46.120"),
|
netip.MustParseAddr("72.129.46.120"),
|
||||||
net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
|
netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
|
||||||
net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
|
netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
|
||||||
net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
|
netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
|
||||||
net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
|
netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
|
||||||
net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
|
netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
|
||||||
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
|
netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
//go:build ignore
|
//go:build ignore
|
||||||
|
// +build ignore
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
@ -10,9 +11,9 @@ package main
|
|||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
@ -20,8 +21,8 @@ import (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
tun, tnet, err := netstack.CreateNetTUN(
|
tun, tnet, err := netstack.CreateNetTUN(
|
||||||
[]net.IP{net.ParseIP("192.168.4.29")},
|
[]netip.Addr{netip.MustParseAddr("192.168.4.29")},
|
||||||
[]net.IP{net.ParseIP("8.8.8.8")},
|
[]netip.Addr{netip.MustParseAddr("8.8.8.8")},
|
||||||
1420)
|
1420)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic(err)
|
log.Panic(err)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
//go:build ignore
|
//go:build ignore
|
||||||
|
// +build ignore
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
@ -13,6 +14,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
@ -20,8 +22,8 @@ import (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
tun, tnet, err := netstack.CreateNetTUN(
|
tun, tnet, err := netstack.CreateNetTUN(
|
||||||
[]net.IP{net.ParseIP("192.168.4.29")},
|
[]netip.Addr{netip.MustParseAddr("192.168.4.29")},
|
||||||
[]net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")},
|
[]netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
|
||||||
1420,
|
1420,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -6,6 +6,7 @@ require (
|
|||||||
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
|
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
|
||||||
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect
|
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect
|
||||||
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
|
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
|
||||||
|
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22
|
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22
|
||||||
gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6
|
gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6
|
||||||
)
|
)
|
||||||
|
@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
|
|||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg=
|
||||||
|
golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
|
||||||
|
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E=
|
||||||
|
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw=
|
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
|
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
|
||||||
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
|
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
|
||||||
|
@ -18,6 +18,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
|
||||||
"golang.org/x/net/dns/dnsmessage"
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
@ -38,7 +39,7 @@ type netTun struct {
|
|||||||
events chan tun.Event
|
events chan tun.Event
|
||||||
incomingPacket chan buffer.VectorisedView
|
incomingPacket chan buffer.VectorisedView
|
||||||
mtu int
|
mtu int
|
||||||
dnsServers []net.IP
|
dnsServers []netip.Addr
|
||||||
hasV4, hasV6 bool
|
hasV4, hasV6 bool
|
||||||
}
|
}
|
||||||
type endpoint netTun
|
type endpoint netTun
|
||||||
@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
|
|||||||
func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
|
func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
|
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
|
||||||
opts := stack.Options{
|
opts := stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
|
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
|
||||||
@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
|
|||||||
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
||||||
}
|
}
|
||||||
for _, ip := range localAddresses {
|
for _, ip := range localAddresses {
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
var protoNumber tcpip.NetworkProtocolNumber
|
||||||
protoAddr := tcpip.ProtocolAddress{
|
if ip.Is4() {
|
||||||
Protocol: ipv4.ProtocolNumber,
|
protoNumber = ipv4.ProtocolNumber
|
||||||
AddressWithPrefix: tcpip.Address(ip4).WithPrefix(),
|
} else if ip.Is6() {
|
||||||
|
protoNumber = ipv6.ProtocolNumber
|
||||||
}
|
}
|
||||||
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
|
||||||
if tcpipErr != nil {
|
|
||||||
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr)
|
|
||||||
}
|
|
||||||
dev.hasV4 = true
|
|
||||||
} else {
|
|
||||||
protoAddr := tcpip.ProtocolAddress{
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
Protocol: ipv6.ProtocolNumber,
|
Protocol: protoNumber,
|
||||||
AddressWithPrefix: tcpip.Address(ip).WithPrefix(),
|
AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
|
||||||
}
|
}
|
||||||
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
||||||
if tcpipErr != nil {
|
if tcpipErr != nil {
|
||||||
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
|
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
|
||||||
}
|
}
|
||||||
|
if ip.Is4() {
|
||||||
|
dev.hasV4 = true
|
||||||
|
} else if ip.Is6() {
|
||||||
dev.hasV6 = true
|
dev.hasV6 = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) {
|
|||||||
return tun.mtu, nil
|
return tun.mtu, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
var protoNumber tcpip.NetworkProtocolNumber
|
||||||
return tcpip.FullAddress{
|
if endpoint.Addr().Is4() {
|
||||||
NIC: 1,
|
protoNumber = ipv4.ProtocolNumber
|
||||||
Addr: tcpip.Address(ip4),
|
|
||||||
Port: uint16(port),
|
|
||||||
}, ipv4.ProtocolNumber
|
|
||||||
} else {
|
} else {
|
||||||
|
protoNumber = ipv6.ProtocolNumber
|
||||||
|
}
|
||||||
return tcpip.FullAddress{
|
return tcpip.FullAddress{
|
||||||
NIC: 1,
|
NIC: 1,
|
||||||
Addr: tcpip.Address(ip),
|
Addr: tcpip.Address(endpoint.Addr().AsSlice()),
|
||||||
Port: uint16(port),
|
Port: endpoint.Port(),
|
||||||
}, ipv6.ProtocolNumber
|
}, protoNumber
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
|
||||||
|
fa, pn := convertToFullAddr(addr)
|
||||||
|
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
||||||
if addr == nil {
|
if addr == nil {
|
||||||
panic("todo: deal with auto addr semantics for nil addr")
|
return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
|
||||||
}
|
}
|
||||||
fa, pn := convertToFullAddr(addr.IP, addr.Port)
|
return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
|
||||||
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
|
}
|
||||||
|
|
||||||
|
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
|
||||||
|
fa, pn := convertToFullAddr(addr)
|
||||||
|
return gonet.DialTCP(net.stack, fa, pn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
||||||
if addr == nil {
|
if addr == nil {
|
||||||
panic("todo: deal with auto addr semantics for nil addr")
|
return net.DialTCPAddrPort(netip.AddrPort{})
|
||||||
}
|
}
|
||||||
fa, pn := convertToFullAddr(addr.IP, addr.Port)
|
return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
|
||||||
return gonet.DialTCP(net.stack, fa, pn)
|
}
|
||||||
|
|
||||||
|
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
|
||||||
|
fa, pn := convertToFullAddr(addr)
|
||||||
|
return gonet.ListenTCP(net.stack, fa, pn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
|
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
|
||||||
if addr == nil {
|
if addr == nil {
|
||||||
panic("todo: deal with auto addr semantics for nil addr")
|
return net.ListenTCPAddrPort(netip.AddrPort{})
|
||||||
}
|
}
|
||||||
fa, pn := convertToFullAddr(addr.IP, addr.Port)
|
return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
|
||||||
return gonet.ListenTCP(net.stack, fa, pn)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
|
||||||
var lfa, rfa *tcpip.FullAddress
|
var lfa, rfa *tcpip.FullAddress
|
||||||
var pn tcpip.NetworkProtocolNumber
|
var pn tcpip.NetworkProtocolNumber
|
||||||
if laddr != nil {
|
if laddr.IsValid() || laddr.Port() > 0 {
|
||||||
var addr tcpip.FullAddress
|
var addr tcpip.FullAddress
|
||||||
addr, pn = convertToFullAddr(laddr.IP, laddr.Port)
|
addr, pn = convertToFullAddr(laddr)
|
||||||
lfa = &addr
|
lfa = &addr
|
||||||
}
|
}
|
||||||
if raddr != nil {
|
if raddr.IsValid() || raddr.Port() > 0 {
|
||||||
var addr tcpip.FullAddress
|
var addr tcpip.FullAddress
|
||||||
addr, pn = convertToFullAddr(raddr.IP, raddr.Port)
|
addr, pn = convertToFullAddr(raddr)
|
||||||
rfa = &addr
|
rfa = &addr
|
||||||
}
|
}
|
||||||
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
||||||
|
var la, ra netip.AddrPort
|
||||||
|
if laddr != nil {
|
||||||
|
la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
|
||||||
|
}
|
||||||
|
if raddr != nil {
|
||||||
|
ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
|
||||||
|
}
|
||||||
|
return net.DialUDPAddrPort(la, ra)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errNoSuchHost = errors.New("no such host")
|
errNoSuchHost = errors.New("no such host")
|
||||||
errLameReferral = errors.New("lame referral")
|
errLameReferral = errors.New("lame referral")
|
||||||
@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
|
|||||||
return p, h, nil
|
return p, h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
|
func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
|
||||||
q.Class = dnsmessage.ClassINET
|
q.Class = dnsmessage.ClassINET
|
||||||
id, udpReq, tcpReq, err := newRequest(q)
|
id, udpReq, tcpReq, err := newRequest(q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
|
|||||||
var c net.Conn
|
var c net.Conn
|
||||||
var err error
|
var err error
|
||||||
if useUDP {
|
if useUDP {
|
||||||
c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53})
|
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
|
||||||
} else {
|
} else {
|
||||||
c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53})
|
c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
|
|||||||
zlen = zidx
|
zlen = zidx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ip := net.ParseIP(host[:zlen]); ip != nil {
|
if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
|
||||||
return []string{host[:zlen]}, nil
|
return []string{ip.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isDomainName(host) {
|
if !isDomainName(host) {
|
||||||
@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
|
|||||||
server string
|
server string
|
||||||
error
|
error
|
||||||
}
|
}
|
||||||
var addrsV4, addrsV6 []net.IP
|
var addrsV4, addrsV6 []netip.Addr
|
||||||
lanes := 0
|
lanes := 0
|
||||||
if tnet.hasV4 {
|
if tnet.hasV4 {
|
||||||
lanes++
|
lanes++
|
||||||
@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
|
|||||||
}
|
}
|
||||||
break loop
|
break loop
|
||||||
}
|
}
|
||||||
addrsV4 = append(addrsV4, net.IP(a.A[:]))
|
addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
|
||||||
|
|
||||||
case dnsmessage.TypeAAAA:
|
case dnsmessage.TypeAAAA:
|
||||||
aaaa, err := result.p.AAAAResource()
|
aaaa, err := result.p.AAAAResource()
|
||||||
@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
|
|||||||
}
|
}
|
||||||
break loop
|
break loop
|
||||||
}
|
}
|
||||||
addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:]))
|
addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
|
||||||
|
|
||||||
default:
|
default:
|
||||||
if err := result.p.SkipAnswer(); err != nil {
|
if err := result.p.SkipAnswer(); err != nil {
|
||||||
@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
|
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
|
||||||
var addrs []net.IP
|
var addrs []netip.Addr
|
||||||
if tnet.hasV6 {
|
if tnet.hasV6 {
|
||||||
addrs = append(addrsV6, addrsV4...)
|
addrs = append(addrsV6, addrsV4...)
|
||||||
} else {
|
} else {
|
||||||
@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &net.OpError{Op: "dial", Err: err}
|
return nil, &net.OpError{Op: "dial", Err: err}
|
||||||
}
|
}
|
||||||
var addrs []net.IP
|
var addrs []netip.AddrPort
|
||||||
for _, addr := range allAddr {
|
for _, addr := range allAddr {
|
||||||
if strings.IndexByte(addr, ':') != -1 && acceptV6 {
|
ip, err := netip.ParseAddr(addr)
|
||||||
addrs = append(addrs, net.ParseIP(addr))
|
if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
|
||||||
} else if strings.IndexByte(addr, '.') != -1 && acceptV4 {
|
addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
|
||||||
addrs = append(addrs, net.ParseIP(addr))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(addrs) == 0 && len(allAddr) != 0 {
|
if len(addrs) == 0 && len(allAddr) != 0 {
|
||||||
@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
|
|||||||
|
|
||||||
var c net.Conn
|
var c net.Conn
|
||||||
if useUDP {
|
if useUDP {
|
||||||
c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port})
|
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
|
||||||
} else {
|
} else {
|
||||||
c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port})
|
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
|
||||||
}
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return c, nil
|
return c, nil
|
||||||
|
@ -8,13 +8,13 @@ package tuntest
|
|||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/go118/netip"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Ping(dst, src net.IP) []byte {
|
func Ping(dst, src netip.Addr) []byte {
|
||||||
localPort := uint16(1337)
|
localPort := uint16(1337)
|
||||||
seq := uint16(0)
|
seq := uint16(0)
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 {
|
|||||||
return ^uint16(v)
|
return ^uint16(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func genICMPv4(payload []byte, dst, src net.IP) []byte {
|
func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
|
||||||
const (
|
const (
|
||||||
icmpv4ProtocolNumber = 1
|
icmpv4ProtocolNumber = 1
|
||||||
icmpv4Echo = 8
|
icmpv4Echo = 8
|
||||||
@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
|
|||||||
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
|
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
|
||||||
ip[8] = ttl
|
ip[8] = ttl
|
||||||
ip[9] = icmpv4ProtocolNumber
|
ip[9] = icmpv4ProtocolNumber
|
||||||
copy(ip[12:], src.To4())
|
copy(ip[12:], src.AsSlice())
|
||||||
copy(ip[16:], dst.To4())
|
copy(ip[16:], dst.AsSlice())
|
||||||
chksum = ^checksum(ip[:], 0)
|
chksum = ^checksum(ip[:], 0)
|
||||||
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
|
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user