mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
Added new UDPBind interface
This commit is contained in:
parent
2d856045a0
commit
a72b0f7ae5
75
src/conn.go
75
src/conn.go
@ -5,6 +5,14 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type UDPBind interface {
|
||||||
|
SetMark(value uint32) error
|
||||||
|
ReceiveIPv6(buff []byte, end *Endpoint) (int, error)
|
||||||
|
ReceiveIPv4(buff []byte, end *Endpoint) (int, error)
|
||||||
|
Send(buff []byte, end *Endpoint) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||||
|
|
||||||
// ensure that the host is an IP address
|
// ensure that the host is an IP address
|
||||||
@ -26,19 +34,6 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
|
|||||||
return addr, err
|
return addr, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenerClose(l *Listener) (err error) {
|
|
||||||
if l.active {
|
|
||||||
err = CloseIPv4Socket(l.sock)
|
|
||||||
l.active = false
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Listener) Init() {
|
|
||||||
l.update = make(chan struct{}, 1)
|
|
||||||
ListenerClose(l)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ListeningUpdate(device *Device) error {
|
func ListeningUpdate(device *Device) error {
|
||||||
netc := &device.net
|
netc := &device.net
|
||||||
netc.mutex.Lock()
|
netc.mutex.Lock()
|
||||||
@ -46,11 +41,7 @@ func ListeningUpdate(device *Device) error {
|
|||||||
|
|
||||||
// close existing sockets
|
// close existing sockets
|
||||||
|
|
||||||
if err := ListenerClose(&netc.ipv4); err != nil {
|
if err := device.net.bind.Close(); err != nil {
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ListenerClose(&netc.ipv6); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,45 +49,22 @@ func ListeningUpdate(device *Device) error {
|
|||||||
|
|
||||||
if device.tun.isUp.Get() {
|
if device.tun.isUp.Get() {
|
||||||
|
|
||||||
// listen on IPv4
|
// bind to new port
|
||||||
|
|
||||||
{
|
var err error
|
||||||
list := &netc.ipv6
|
netc.bind, netc.port, err = CreateUDPBind(netc.port)
|
||||||
sock, port, err := CreateIPv4Socket(netc.port)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
netc.port = port
|
|
||||||
list.sock = sock
|
|
||||||
list.active = true
|
|
||||||
|
|
||||||
if err := SetMark(list.sock, netc.fwmark); err != nil {
|
// set mark
|
||||||
ListenerClose(list)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
signalSend(list.update)
|
|
||||||
}
|
|
||||||
|
|
||||||
// listen on IPv6
|
err = netc.bind.SetMark(netc.fwmark)
|
||||||
|
|
||||||
{
|
|
||||||
list := &netc.ipv6
|
|
||||||
sock, port, err := CreateIPv6Socket(netc.port)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
netc.port = port
|
|
||||||
list.sock = sock
|
|
||||||
list.active = true
|
|
||||||
|
|
||||||
if err := SetMark(list.sock, netc.fwmark); err != nil {
|
// TODO: clear endpoint (src) caches
|
||||||
ListenerClose(list)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
signalSend(list.update)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: clear endpoint caches
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -106,16 +74,5 @@ func ListeningClose(device *Device) error {
|
|||||||
netc := &device.net
|
netc := &device.net
|
||||||
netc.mutex.Lock()
|
netc.mutex.Lock()
|
||||||
defer netc.mutex.Unlock()
|
defer netc.mutex.Unlock()
|
||||||
|
return netc.bind.Close()
|
||||||
if err := ListenerClose(&netc.ipv4); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
signalSend(netc.ipv4.update)
|
|
||||||
|
|
||||||
if err := ListenerClose(&netc.ipv6); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
signalSend(netc.ipv6.update)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -14,35 +14,158 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
/* Supports source address caching
|
/* Supports source address caching
|
||||||
*
|
*
|
||||||
* Currently there is no way to achieve this within the net package:
|
* Currently there is no way to achieve this within the net package:
|
||||||
* See e.g. https://github.com/golang/go/issues/17930
|
* See e.g. https://github.com/golang/go/issues/17930
|
||||||
* So this code is platform dependent.
|
* So this code is remains platform dependent.
|
||||||
*
|
|
||||||
* It is important that the endpoint is only updated after the packet content has been authenticated!
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type Endpoint struct {
|
type Endpoint struct {
|
||||||
// source (selected based on dst type)
|
src unix.RawSockaddrInet6
|
||||||
// (could use RawSockaddrAny and unsafe)
|
dst unix.RawSockaddrInet6
|
||||||
// TODO: Merge
|
|
||||||
src6 unix.RawSockaddrInet6
|
|
||||||
src4 unix.RawSockaddrInet4
|
|
||||||
src4if int32
|
|
||||||
|
|
||||||
dst unix.RawSockaddrAny
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Socket int
|
type IPv4Source struct {
|
||||||
|
src unix.RawSockaddrInet4
|
||||||
|
Ifindex int32
|
||||||
|
}
|
||||||
|
|
||||||
/* Returns a byte representation of the source field(s)
|
type Bind struct {
|
||||||
* for use in "under load" cookie computations.
|
sock4 int
|
||||||
*/
|
sock6 int
|
||||||
func (endpoint *Endpoint) Source() []byte {
|
}
|
||||||
|
|
||||||
|
func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
|
||||||
|
var err error
|
||||||
|
var bind Bind
|
||||||
|
|
||||||
|
bind.sock6, port, err = create6(port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, port, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.sock4, port, err = create4(port)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(bind.sock6)
|
||||||
|
}
|
||||||
|
return &bind, port, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *Bind) SetMark(value uint32) error {
|
||||||
|
err := unix.SetsockoptInt(
|
||||||
|
bind.sock6,
|
||||||
|
unix.SOL_SOCKET,
|
||||||
|
unix.SO_MARK,
|
||||||
|
int(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return unix.SetsockoptInt(
|
||||||
|
bind.sock4,
|
||||||
|
unix.SOL_SOCKET,
|
||||||
|
unix.SO_MARK,
|
||||||
|
int(value),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *Bind) Close() error {
|
||||||
|
err1 := unix.Close(bind.sock6)
|
||||||
|
err2 := unix.Close(bind.sock4)
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *Bind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) {
|
||||||
|
return receive6(
|
||||||
|
bind.sock6,
|
||||||
|
buff,
|
||||||
|
end,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *Bind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) {
|
||||||
|
return receive4(
|
||||||
|
bind.sock4,
|
||||||
|
buff,
|
||||||
|
end,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *Bind) Send(buff []byte, end *Endpoint) error {
|
||||||
|
switch end.src.Family {
|
||||||
|
case unix.AF_INET6:
|
||||||
|
return send6(bind.sock6, end, buff)
|
||||||
|
case unix.AF_INET:
|
||||||
|
return send4(bind.sock4, end, buff)
|
||||||
|
default:
|
||||||
|
return errors.New("Unknown address family of source")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sockaddrToString(addr unix.RawSockaddrInet6) string {
|
||||||
|
var udpAddr net.UDPAddr
|
||||||
|
|
||||||
|
switch addr.Family {
|
||||||
|
case unix.AF_INET6:
|
||||||
|
udpAddr.Port = int(addr.Port)
|
||||||
|
udpAddr.IP = addr.Addr[:]
|
||||||
|
return udpAddr.String()
|
||||||
|
|
||||||
|
case unix.AF_INET:
|
||||||
|
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
|
||||||
|
udpAddr.Port = int(ptr.Port)
|
||||||
|
udpAddr.IP = net.IPv4(
|
||||||
|
ptr.Addr[0],
|
||||||
|
ptr.Addr[1],
|
||||||
|
ptr.Addr[2],
|
||||||
|
ptr.Addr[3],
|
||||||
|
)
|
||||||
|
return udpAddr.String()
|
||||||
|
|
||||||
|
default:
|
||||||
|
return "<unknown address family>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *Endpoint) DestinationIP() net.IP {
|
||||||
|
switch end.dst.Family {
|
||||||
|
case unix.AF_INET6:
|
||||||
|
return end.dst.Addr[:]
|
||||||
|
case unix.AF_INET:
|
||||||
|
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
|
||||||
|
return net.IPv4(
|
||||||
|
ptr.Addr[0],
|
||||||
|
ptr.Addr[1],
|
||||||
|
ptr.Addr[2],
|
||||||
|
ptr.Addr[3],
|
||||||
|
)
|
||||||
|
default:
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *Endpoint) SourceToBytes() []byte {
|
||||||
|
ptr := unsafe.Pointer(&end.src)
|
||||||
|
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
|
||||||
|
return arr[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *Endpoint) SourceToString() string {
|
||||||
|
return sockaddrToString(end.src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *Endpoint) DestinationToString() string {
|
||||||
|
return sockaddrToString(end.dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *Endpoint) ClearSrc() {
|
||||||
|
end.src = unix.RawSockaddrInet6{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func zoneToUint32(zone string) (uint32, error) {
|
func zoneToUint32(zone string) (uint32, error) {
|
||||||
@ -56,7 +179,7 @@ func zoneToUint32(zone string) (uint32, error) {
|
|||||||
return uint32(n), err
|
return uint32(n), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
|
func create4(port uint16) (int, uint16, error) {
|
||||||
|
|
||||||
// create socket
|
// create socket
|
||||||
|
|
||||||
@ -100,18 +223,10 @@ func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
|
|||||||
unix.Close(fd)
|
unix.Close(fd)
|
||||||
}
|
}
|
||||||
|
|
||||||
return Socket(fd), uint16(addr.Port), err
|
return fd, uint16(addr.Port), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CloseIPv4Socket(sock Socket) error {
|
func create6(port uint16) (int, uint16, error) {
|
||||||
return unix.Close(int(sock))
|
|
||||||
}
|
|
||||||
|
|
||||||
func CloseIPv6Socket(sock Socket) error {
|
|
||||||
return unix.Close(int(sock))
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
|
|
||||||
|
|
||||||
// create socket
|
// create socket
|
||||||
|
|
||||||
@ -166,13 +281,7 @@ func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
|
|||||||
unix.Close(fd)
|
unix.Close(fd)
|
||||||
}
|
}
|
||||||
|
|
||||||
return Socket(fd), uint16(addr.Port), err
|
return fd, uint16(addr.Port), err
|
||||||
}
|
|
||||||
|
|
||||||
func (end *Endpoint) ClearSrc() {
|
|
||||||
end.src4if = 0
|
|
||||||
end.src4 = unix.RawSockaddrInet4{}
|
|
||||||
end.src6 = unix.RawSockaddrInet6{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (end *Endpoint) Set(s string) error {
|
func (end *Endpoint) Set(s string) error {
|
||||||
@ -187,23 +296,23 @@ func (end *Endpoint) Set(s string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
|
dst := &end.dst
|
||||||
ptr.Family = unix.AF_INET6
|
dst.Family = unix.AF_INET6
|
||||||
ptr.Port = uint16(addr.Port)
|
dst.Port = uint16(addr.Port)
|
||||||
ptr.Flowinfo = 0
|
dst.Flowinfo = 0
|
||||||
ptr.Scope_id = zone
|
dst.Scope_id = zone
|
||||||
copy(ptr.Addr[:], ipv6[:])
|
copy(dst.Addr[:], ipv6[:])
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipv4 := addr.IP.To4()
|
ipv4 := addr.IP.To4()
|
||||||
if ipv4 != nil {
|
if ipv4 != nil {
|
||||||
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
|
dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
|
||||||
ptr.Family = unix.AF_INET
|
dst.Family = unix.AF_INET
|
||||||
ptr.Port = uint16(addr.Port)
|
dst.Port = uint16(addr.Port)
|
||||||
ptr.Zero = [8]byte{}
|
dst.Zero = [8]byte{}
|
||||||
copy(ptr.Addr[:], ipv4)
|
copy(dst.Addr[:], ipv4)
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -211,7 +320,7 @@ func (end *Endpoint) Set(s string) error {
|
|||||||
return errors.New("Failed to recognize IP address format")
|
return errors.New("Failed to recognize IP address format")
|
||||||
}
|
}
|
||||||
|
|
||||||
func send6(sock uintptr, end *Endpoint, buff []byte) error {
|
func send6(sock int, end *Endpoint, buff []byte) error {
|
||||||
|
|
||||||
// construct message header
|
// construct message header
|
||||||
|
|
||||||
@ -229,8 +338,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
|
|||||||
Len: unix.SizeofInet6Pktinfo,
|
Len: unix.SizeofInet6Pktinfo,
|
||||||
},
|
},
|
||||||
unix.Inet6Pktinfo{
|
unix.Inet6Pktinfo{
|
||||||
Addr: end.src6.Addr,
|
Addr: end.src.Addr,
|
||||||
Ifindex: end.src6.Scope_id,
|
Ifindex: end.src.Scope_id,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,7 +357,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
|
|||||||
|
|
||||||
_, _, errno := unix.Syscall(
|
_, _, errno := unix.Syscall(
|
||||||
unix.SYS_SENDMSG,
|
unix.SYS_SENDMSG,
|
||||||
sock,
|
uintptr(sock),
|
||||||
uintptr(unsafe.Pointer(&msghdr)),
|
uintptr(unsafe.Pointer(&msghdr)),
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
@ -258,7 +367,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
|
|||||||
return errno
|
return errno
|
||||||
}
|
}
|
||||||
|
|
||||||
func send4(sock uintptr, end *Endpoint, buff []byte) error {
|
func send4(sock int, end *Endpoint, buff []byte) error {
|
||||||
|
|
||||||
// construct message header
|
// construct message header
|
||||||
|
|
||||||
@ -266,6 +375,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
|
|||||||
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
|
||||||
iovec.SetLen(len(buff))
|
iovec.SetLen(len(buff))
|
||||||
|
|
||||||
|
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
|
||||||
|
|
||||||
cmsg := struct {
|
cmsg := struct {
|
||||||
cmsghdr unix.Cmsghdr
|
cmsghdr unix.Cmsghdr
|
||||||
pktinfo unix.Inet4Pktinfo
|
pktinfo unix.Inet4Pktinfo
|
||||||
@ -276,8 +387,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
|
|||||||
Len: unix.SizeofInet4Pktinfo,
|
Len: unix.SizeofInet4Pktinfo,
|
||||||
},
|
},
|
||||||
unix.Inet4Pktinfo{
|
unix.Inet4Pktinfo{
|
||||||
Spec_dst: end.src4.Addr,
|
Spec_dst: src4.src.Addr,
|
||||||
Ifindex: end.src4if,
|
Ifindex: src4.Ifindex,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,7 +406,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
|
|||||||
|
|
||||||
_, _, errno := unix.Syscall(
|
_, _, errno := unix.Syscall(
|
||||||
unix.SYS_SENDMSG,
|
unix.SYS_SENDMSG,
|
||||||
sock,
|
uintptr(sock),
|
||||||
uintptr(unsafe.Pointer(&msghdr)),
|
uintptr(unsafe.Pointer(&msghdr)),
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
@ -305,28 +416,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
|
|||||||
return errno
|
return errno
|
||||||
}
|
}
|
||||||
|
|
||||||
func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error {
|
func receive4(sock int, buff []byte, end *Endpoint) (int, error) {
|
||||||
|
|
||||||
// extract underlying file descriptor
|
|
||||||
|
|
||||||
file, err := c.File()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
sock := file.Fd()
|
|
||||||
|
|
||||||
// send depending on address family of dst
|
|
||||||
|
|
||||||
family := *((*uint16)(unsafe.Pointer(&end.dst)))
|
|
||||||
if family == unix.AF_INET {
|
|
||||||
return send4(sock, end, buff)
|
|
||||||
} else if family == unix.AF_INET6 {
|
|
||||||
return send6(sock, end, buff)
|
|
||||||
}
|
|
||||||
return errors.New("Unknown address family of source")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
|
|
||||||
|
|
||||||
// contruct message header
|
// contruct message header
|
||||||
|
|
||||||
@ -360,22 +450,21 @@ func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
|
|||||||
return 0, errno
|
return 0, errno
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(msghdr)
|
|
||||||
fmt.Println(cmsg)
|
|
||||||
|
|
||||||
// update source cache
|
// update source cache
|
||||||
|
|
||||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
||||||
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
||||||
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
||||||
end.src4.Addr = cmsg.pktinfo.Spec_dst
|
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
|
||||||
end.src4if = cmsg.pktinfo.Ifindex
|
src4.src.Family = unix.AF_INET
|
||||||
|
src4.src.Addr = cmsg.pktinfo.Spec_dst
|
||||||
|
src4.Ifindex = cmsg.pktinfo.Ifindex
|
||||||
}
|
}
|
||||||
|
|
||||||
return int(size), nil
|
return int(size), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
|
func receive6(sock int, buff []byte, end *Endpoint) (int, error) {
|
||||||
|
|
||||||
// contruct message header
|
// contruct message header
|
||||||
|
|
||||||
@ -414,18 +503,10 @@ func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
|
|||||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
|
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
|
||||||
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
|
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
|
||||||
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
|
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
|
||||||
end.src6.Addr = cmsg.pktinfo.Addr
|
end.src.Family = unix.AF_INET6
|
||||||
end.src6.Scope_id = cmsg.pktinfo.Ifindex
|
end.src.Addr = cmsg.pktinfo.Addr
|
||||||
|
end.src.Scope_id = cmsg.pktinfo.Ifindex
|
||||||
}
|
}
|
||||||
|
|
||||||
return int(size), nil
|
return int(size), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetMark(sock Socket, value uint32) error {
|
|
||||||
return unix.SetsockoptInt(
|
|
||||||
int(sock),
|
|
||||||
unix.SOL_SOCKET,
|
|
||||||
unix.SO_MARK,
|
|
||||||
int(value),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
@ -5,10 +5,8 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type CookieChecker struct {
|
type CookieChecker struct {
|
||||||
@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
|
|||||||
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
|
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
|
||||||
st.mutex.RLock()
|
st.mutex.RLock()
|
||||||
defer st.mutex.RUnlock()
|
defer st.mutex.RUnlock()
|
||||||
|
|
||||||
@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
|
|||||||
var cookie [blake2s.Size128]byte
|
var cookie [blake2s.Size128]byte
|
||||||
func() {
|
func() {
|
||||||
mac, _ := blake2s.New128(st.mac2.secret[:])
|
mac, _ := blake2s.New128(st.mac2.secret[:])
|
||||||
mac.Write(src.IP)
|
mac.Write(src)
|
||||||
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
|
|
||||||
mac.Sum(cookie[:0])
|
mac.Sum(cookie[:0])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
|
|||||||
func (st *CookieChecker) CreateReply(
|
func (st *CookieChecker) CreateReply(
|
||||||
msg []byte,
|
msg []byte,
|
||||||
recv uint32,
|
recv uint32,
|
||||||
src *net.UDPAddr,
|
src []byte,
|
||||||
) (*MessageCookieReply, error) {
|
) (*MessageCookieReply, error) {
|
||||||
|
|
||||||
st.mutex.RLock()
|
st.mutex.RLock()
|
||||||
@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply(
|
|||||||
var cookie [blake2s.Size128]byte
|
var cookie [blake2s.Size128]byte
|
||||||
func() {
|
func() {
|
||||||
mac, _ := blake2s.New128(st.mac2.secret[:])
|
mac, _ := blake2s.New128(st.mac2.secret[:])
|
||||||
mac.Write(src.IP)
|
mac.Write(src)
|
||||||
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
|
|
||||||
mac.Sum(cookie[:0])
|
mac.Sum(cookie[:0])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -1,18 +1,14 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Listener struct {
|
|
||||||
sock Socket
|
|
||||||
active bool
|
|
||||||
update chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
log *Logger // collection of loggers for levels
|
log *Logger // collection of loggers for levels
|
||||||
idCounter uint // for assigning debug ids to peers
|
idCounter uint // for assigning debug ids to peers
|
||||||
@ -27,8 +23,7 @@ type Device struct {
|
|||||||
}
|
}
|
||||||
net struct {
|
net struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
ipv4 Listener
|
bind UDPBind
|
||||||
ipv6 Listener
|
|
||||||
port uint16
|
port uint16
|
||||||
fwmark uint32
|
fwmark uint32
|
||||||
}
|
}
|
||||||
@ -43,9 +38,8 @@ type Device struct {
|
|||||||
handshake chan QueueHandshakeElement
|
handshake chan QueueHandshakeElement
|
||||||
}
|
}
|
||||||
signal struct {
|
signal struct {
|
||||||
stop chan struct{} // halts all go routines
|
stop chan struct{}
|
||||||
updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
|
updateBind chan struct{}
|
||||||
updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
|
|
||||||
}
|
}
|
||||||
underLoadUntil atomic.Value
|
underLoadUntil atomic.Value
|
||||||
ratelimiter Ratelimiter
|
ratelimiter Ratelimiter
|
||||||
@ -146,8 +140,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
|||||||
device.tun.device = tun
|
device.tun.device = tun
|
||||||
|
|
||||||
device.indices.Init()
|
device.indices.Init()
|
||||||
device.net.ipv4.Init()
|
|
||||||
device.net.ipv6.Init()
|
|
||||||
device.ratelimiter.Init()
|
device.ratelimiter.Init()
|
||||||
|
|
||||||
device.routingTable.Reset()
|
device.routingTable.Reset()
|
||||||
@ -181,8 +173,8 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
|||||||
go device.RoutineReadFromTUN()
|
go device.RoutineReadFromTUN()
|
||||||
go device.RoutineTUNEventReader()
|
go device.RoutineTUNEventReader()
|
||||||
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
||||||
go device.RoutineReceiveIncomming(&device.net.ipv4)
|
go device.RoutineReceiveIncomming(ipv4.Version)
|
||||||
go device.RoutineReceiveIncomming(&device.net.ipv6)
|
go device.RoutineReceiveIncomming(ipv6.Version)
|
||||||
return device
|
return device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -15,8 +14,8 @@ type Peer struct {
|
|||||||
persistentKeepaliveInterval uint64
|
persistentKeepaliveInterval uint64
|
||||||
keyPairs KeyPairs
|
keyPairs KeyPairs
|
||||||
handshake Handshake
|
handshake Handshake
|
||||||
|
endpoint Endpoint
|
||||||
device *Device
|
device *Device
|
||||||
endpoint *net.UDPAddr
|
|
||||||
stats struct {
|
stats struct {
|
||||||
txBytes uint64 // bytes send to peer (endpoint)
|
txBytes uint64 // bytes send to peer (endpoint)
|
||||||
rxBytes uint64 // bytes received from peer
|
rxBytes uint64 // bytes received from peer
|
||||||
@ -134,7 +133,7 @@ func (peer *Peer) String() string {
|
|||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"peer(%d %s %s)",
|
"peer(%d %s %s)",
|
||||||
peer.id,
|
peer.id,
|
||||||
peer.endpoint.String(),
|
peer.endpoint.DestinationToString(),
|
||||||
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
|
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -97,17 +97,6 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
|
|||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
logDebug.Println("Routine, receive incomming, started")
|
logDebug.Println("Routine, receive incomming, started")
|
||||||
|
|
||||||
var listener *Listener
|
|
||||||
|
|
||||||
switch IPVersion {
|
|
||||||
case ipv4.Version:
|
|
||||||
listener = &device.net.ipv4
|
|
||||||
case ipv6.Version:
|
|
||||||
listener = &device.net.ipv6
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
||||||
// wait for new conn
|
// wait for new conn
|
||||||
@ -118,15 +107,14 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
|
|||||||
case <-device.signal.stop:
|
case <-device.signal.stop:
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-listener.update:
|
case <-device.signal.updateBind:
|
||||||
|
|
||||||
// fetch new socket
|
// fetch new socket
|
||||||
|
|
||||||
device.net.mutex.RLock()
|
device.net.mutex.RLock()
|
||||||
sock := listener.sock
|
bind := device.net.bind
|
||||||
okay := listener.active
|
|
||||||
device.net.mutex.RUnlock()
|
device.net.mutex.RUnlock()
|
||||||
if !okay {
|
if bind == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,10 +133,13 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
|
|||||||
|
|
||||||
var endpoint Endpoint
|
var endpoint Endpoint
|
||||||
|
|
||||||
if IPVersion == ipv6.Version {
|
switch IPVersion {
|
||||||
size, err = endpoint.ReceiveIPv4(sock, buffer[:])
|
case ipv4.Version:
|
||||||
} else {
|
size, err = bind.ReceiveIPv4(buffer[:], &endpoint)
|
||||||
size, err = endpoint.ReceiveIPv6(sock, buffer[:])
|
case ipv6.Version:
|
||||||
|
size, err = bind.ReceiveIPv6(buffer[:], &endpoint)
|
||||||
|
default:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -340,15 +331,19 @@ func (device *Device) RoutineHandshake() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
srcBytes := elem.endpoint.SourceToBytes()
|
||||||
if device.IsUnderLoad() {
|
if device.IsUnderLoad() {
|
||||||
if !device.mac.CheckMAC2(elem.packet, elem.source) {
|
|
||||||
|
// verify MAC2 field
|
||||||
|
|
||||||
|
if !device.mac.CheckMAC2(elem.packet, srcBytes) {
|
||||||
|
|
||||||
// construct cookie reply
|
// construct cookie reply
|
||||||
|
|
||||||
logDebug.Println("Sending cookie reply to:", elem.source.String())
|
logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString())
|
||||||
|
|
||||||
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
|
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
|
||||||
reply, err := device.mac.CreateReply(elem.packet, sender, elem.source)
|
reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to create cookie reply:", err)
|
logError.Println("Failed to create cookie reply:", err)
|
||||||
return
|
return
|
||||||
@ -358,9 +353,9 @@ func (device *Device) RoutineHandshake() {
|
|||||||
|
|
||||||
writer := bytes.NewBuffer(temp[:0])
|
writer := bytes.NewBuffer(temp[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, reply)
|
binary.Write(writer, binary.LittleEndian, reply)
|
||||||
_, err = device.net.conn.WriteToUDP(
|
device.net.bind.Send(
|
||||||
writer.Bytes(),
|
writer.Bytes(),
|
||||||
elem.source,
|
&elem.endpoint,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logDebug.Println("Failed to send cookie reply:", err)
|
logDebug.Println("Failed to send cookie reply:", err)
|
||||||
@ -368,7 +363,11 @@ func (device *Device) RoutineHandshake() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !device.ratelimiter.Allow(elem.source.IP) {
|
// check ratelimiter
|
||||||
|
|
||||||
|
if !device.ratelimiter.Allow(
|
||||||
|
elem.endpoint.DestinationIP(),
|
||||||
|
) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -399,8 +398,7 @@ func (device *Device) RoutineHandshake() {
|
|||||||
if peer == nil {
|
if peer == nil {
|
||||||
logInfo.Println(
|
logInfo.Println(
|
||||||
"Recieved invalid initiation message from",
|
"Recieved invalid initiation message from",
|
||||||
elem.source.IP.String(),
|
elem.endpoint.DestinationToString(),
|
||||||
elem.source.Port,
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -414,7 +412,7 @@ func (device *Device) RoutineHandshake() {
|
|||||||
// TODO: Discover destination address also, only update on change
|
// TODO: Discover destination address also, only update on change
|
||||||
|
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
peer.endpoint = elem.source
|
peer.endpoint = elem.endpoint
|
||||||
peer.mutex.Unlock()
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
// create response
|
// create response
|
||||||
@ -460,8 +458,7 @@ func (device *Device) RoutineHandshake() {
|
|||||||
if peer == nil {
|
if peer == nil {
|
||||||
logInfo.Println(
|
logInfo.Println(
|
||||||
"Recieved invalid response message from",
|
"Recieved invalid response message from",
|
||||||
elem.source.IP.String(),
|
elem.endpoint.DestinationToString(),
|
||||||
elem.source.Port,
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user