1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Added new UDPBind interface

This commit is contained in:
Mathias Hall-Andersen 2017-10-08 22:03:32 +02:00
parent 2d856045a0
commit a72b0f7ae5
6 changed files with 236 additions and 214 deletions

View File

@ -5,6 +5,14 @@ import (
"net" "net"
) )
type UDPBind interface {
SetMark(value uint32) error
ReceiveIPv6(buff []byte, end *Endpoint) (int, error)
ReceiveIPv4(buff []byte, end *Endpoint) (int, error)
Send(buff []byte, end *Endpoint) error
Close() error
}
func parseEndpoint(s string) (*net.UDPAddr, error) { func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address // ensure that the host is an IP address
@ -26,19 +34,6 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
return addr, err return addr, err
} }
func ListenerClose(l *Listener) (err error) {
if l.active {
err = CloseIPv4Socket(l.sock)
l.active = false
}
return
}
func (l *Listener) Init() {
l.update = make(chan struct{}, 1)
ListenerClose(l)
}
func ListeningUpdate(device *Device) error { func ListeningUpdate(device *Device) error {
netc := &device.net netc := &device.net
netc.mutex.Lock() netc.mutex.Lock()
@ -46,11 +41,7 @@ func ListeningUpdate(device *Device) error {
// close existing sockets // close existing sockets
if err := ListenerClose(&netc.ipv4); err != nil { if err := device.net.bind.Close(); err != nil {
return err
}
if err := ListenerClose(&netc.ipv6); err != nil {
return err return err
} }
@ -58,45 +49,22 @@ func ListeningUpdate(device *Device) error {
if device.tun.isUp.Get() { if device.tun.isUp.Get() {
// listen on IPv4 // bind to new port
{ var err error
list := &netc.ipv6 netc.bind, netc.port, err = CreateUDPBind(netc.port)
sock, port, err := CreateIPv4Socket(netc.port)
if err != nil { if err != nil {
return err return err
} }
netc.port = port
list.sock = sock
list.active = true
if err := SetMark(list.sock, netc.fwmark); err != nil { // set mark
ListenerClose(list)
return err
}
signalSend(list.update)
}
// listen on IPv6 err = netc.bind.SetMark(netc.fwmark)
{
list := &netc.ipv6
sock, port, err := CreateIPv6Socket(netc.port)
if err != nil { if err != nil {
return err return err
} }
netc.port = port
list.sock = sock
list.active = true
if err := SetMark(list.sock, netc.fwmark); err != nil { // TODO: clear endpoint (src) caches
ListenerClose(list)
return err
}
signalSend(list.update)
}
// TODO: clear endpoint caches
} }
return nil return nil
@ -106,16 +74,5 @@ func ListeningClose(device *Device) error {
netc := &device.net netc := &device.net
netc.mutex.Lock() netc.mutex.Lock()
defer netc.mutex.Unlock() defer netc.mutex.Unlock()
return netc.bind.Close()
if err := ListenerClose(&netc.ipv4); err != nil {
return err
}
signalSend(netc.ipv4.update)
if err := ListenerClose(&netc.ipv6); err != nil {
return err
}
signalSend(netc.ipv6.update)
return nil
} }

View File

@ -14,35 +14,158 @@ import (
"unsafe" "unsafe"
) )
import "fmt"
/* Supports source address caching /* Supports source address caching
* *
* Currently there is no way to achieve this within the net package: * Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930 * See e.g. https://github.com/golang/go/issues/17930
* So this code is platform dependent. * So this code is remains platform dependent.
*
* It is important that the endpoint is only updated after the packet content has been authenticated!
*/ */
type Endpoint struct { type Endpoint struct {
// source (selected based on dst type) src unix.RawSockaddrInet6
// (could use RawSockaddrAny and unsafe) dst unix.RawSockaddrInet6
// TODO: Merge
src6 unix.RawSockaddrInet6
src4 unix.RawSockaddrInet4
src4if int32
dst unix.RawSockaddrAny
} }
type Socket int type IPv4Source struct {
src unix.RawSockaddrInet4
Ifindex int32
}
/* Returns a byte representation of the source field(s) type Bind struct {
* for use in "under load" cookie computations. sock4 int
*/ sock6 int
func (endpoint *Endpoint) Source() []byte { }
func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
var err error
var bind Bind
bind.sock6, port, err = create6(port)
if err != nil {
return nil, port, err
}
bind.sock4, port, err = create4(port)
if err != nil {
unix.Close(bind.sock6)
}
return &bind, port, err
}
func (bind *Bind) SetMark(value uint32) error {
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
return unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
}
func (bind *Bind) Close() error {
err1 := unix.Close(bind.sock6)
err2 := unix.Close(bind.sock4)
if err1 != nil {
return err1
}
return err2
}
func (bind *Bind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) {
return receive6(
bind.sock6,
buff,
end,
)
}
func (bind *Bind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) {
return receive4(
bind.sock4,
buff,
end,
)
}
func (bind *Bind) Send(buff []byte, end *Endpoint) error {
switch end.src.Family {
case unix.AF_INET6:
return send6(bind.sock6, end, buff)
case unix.AF_INET:
return send4(bind.sock4, end, buff)
default:
return errors.New("Unknown address family of source")
}
}
func sockaddrToString(addr unix.RawSockaddrInet6) string {
var udpAddr net.UDPAddr
switch addr.Family {
case unix.AF_INET6:
udpAddr.Port = int(addr.Port)
udpAddr.IP = addr.Addr[:]
return udpAddr.String()
case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
udpAddr.Port = int(ptr.Port)
udpAddr.IP = net.IPv4(
ptr.Addr[0],
ptr.Addr[1],
ptr.Addr[2],
ptr.Addr[3],
)
return udpAddr.String()
default:
return "<unknown address family>"
}
}
func (end *Endpoint) DestinationIP() net.IP {
switch end.dst.Family {
case unix.AF_INET6:
return end.dst.Addr[:]
case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
return net.IPv4(
ptr.Addr[0],
ptr.Addr[1],
ptr.Addr[2],
ptr.Addr[3],
)
default:
return nil return nil
}
}
func (end *Endpoint) SourceToBytes() []byte {
ptr := unsafe.Pointer(&end.src)
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
return arr[:]
}
func (end *Endpoint) SourceToString() string {
return sockaddrToString(end.src)
}
func (end *Endpoint) DestinationToString() string {
return sockaddrToString(end.dst)
}
func (end *Endpoint) ClearSrc() {
end.src = unix.RawSockaddrInet6{}
} }
func zoneToUint32(zone string) (uint32, error) { func zoneToUint32(zone string) (uint32, error) {
@ -56,7 +179,7 @@ func zoneToUint32(zone string) (uint32, error) {
return uint32(n), err return uint32(n), err
} }
func CreateIPv4Socket(port uint16) (Socket, uint16, error) { func create4(port uint16) (int, uint16, error) {
// create socket // create socket
@ -100,18 +223,10 @@ func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
unix.Close(fd) unix.Close(fd)
} }
return Socket(fd), uint16(addr.Port), err return fd, uint16(addr.Port), err
} }
func CloseIPv4Socket(sock Socket) error { func create6(port uint16) (int, uint16, error) {
return unix.Close(int(sock))
}
func CloseIPv6Socket(sock Socket) error {
return unix.Close(int(sock))
}
func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
// create socket // create socket
@ -166,13 +281,7 @@ func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
unix.Close(fd) unix.Close(fd)
} }
return Socket(fd), uint16(addr.Port), err return fd, uint16(addr.Port), err
}
func (end *Endpoint) ClearSrc() {
end.src4if = 0
end.src4 = unix.RawSockaddrInet4{}
end.src6 = unix.RawSockaddrInet6{}
} }
func (end *Endpoint) Set(s string) error { func (end *Endpoint) Set(s string) error {
@ -187,23 +296,23 @@ func (end *Endpoint) Set(s string) error {
if err != nil { if err != nil {
return err return err
} }
ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst)) dst := &end.dst
ptr.Family = unix.AF_INET6 dst.Family = unix.AF_INET6
ptr.Port = uint16(addr.Port) dst.Port = uint16(addr.Port)
ptr.Flowinfo = 0 dst.Flowinfo = 0
ptr.Scope_id = zone dst.Scope_id = zone
copy(ptr.Addr[:], ipv6[:]) copy(dst.Addr[:], ipv6[:])
end.ClearSrc() end.ClearSrc()
return nil return nil
} }
ipv4 := addr.IP.To4() ipv4 := addr.IP.To4()
if ipv4 != nil { if ipv4 != nil {
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
ptr.Family = unix.AF_INET dst.Family = unix.AF_INET
ptr.Port = uint16(addr.Port) dst.Port = uint16(addr.Port)
ptr.Zero = [8]byte{} dst.Zero = [8]byte{}
copy(ptr.Addr[:], ipv4) copy(dst.Addr[:], ipv4)
end.ClearSrc() end.ClearSrc()
return nil return nil
} }
@ -211,7 +320,7 @@ func (end *Endpoint) Set(s string) error {
return errors.New("Failed to recognize IP address format") return errors.New("Failed to recognize IP address format")
} }
func send6(sock uintptr, end *Endpoint, buff []byte) error { func send6(sock int, end *Endpoint, buff []byte) error {
// construct message header // construct message header
@ -229,8 +338,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
Len: unix.SizeofInet6Pktinfo, Len: unix.SizeofInet6Pktinfo,
}, },
unix.Inet6Pktinfo{ unix.Inet6Pktinfo{
Addr: end.src6.Addr, Addr: end.src.Addr,
Ifindex: end.src6.Scope_id, Ifindex: end.src.Scope_id,
}, },
} }
@ -248,7 +357,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
unix.SYS_SENDMSG, unix.SYS_SENDMSG,
sock, uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)), uintptr(unsafe.Pointer(&msghdr)),
0, 0,
) )
@ -258,7 +367,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
return errno return errno
} }
func send4(sock uintptr, end *Endpoint, buff []byte) error { func send4(sock int, end *Endpoint, buff []byte) error {
// construct message header // construct message header
@ -266,6 +375,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff)) iovec.SetLen(len(buff))
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
cmsg := struct { cmsg := struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo pktinfo unix.Inet4Pktinfo
@ -276,8 +387,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
Len: unix.SizeofInet4Pktinfo, Len: unix.SizeofInet4Pktinfo,
}, },
unix.Inet4Pktinfo{ unix.Inet4Pktinfo{
Spec_dst: end.src4.Addr, Spec_dst: src4.src.Addr,
Ifindex: end.src4if, Ifindex: src4.Ifindex,
}, },
} }
@ -295,7 +406,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
unix.SYS_SENDMSG, unix.SYS_SENDMSG,
sock, uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)), uintptr(unsafe.Pointer(&msghdr)),
0, 0,
) )
@ -305,28 +416,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
return errno return errno
} }
func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error { func receive4(sock int, buff []byte, end *Endpoint) (int, error) {
// extract underlying file descriptor
file, err := c.File()
if err != nil {
return err
}
sock := file.Fd()
// send depending on address family of dst
family := *((*uint16)(unsafe.Pointer(&end.dst)))
if family == unix.AF_INET {
return send4(sock, end, buff)
} else if family == unix.AF_INET6 {
return send6(sock, end, buff)
}
return errors.New("Unknown address family of source")
}
func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
// contruct message header // contruct message header
@ -360,22 +450,21 @@ func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
return 0, errno return 0, errno
} }
fmt.Println(msghdr)
fmt.Println(cmsg)
// update source cache // update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP && if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
end.src4.Addr = cmsg.pktinfo.Spec_dst src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
end.src4if = cmsg.pktinfo.Ifindex src4.src.Family = unix.AF_INET
src4.src.Addr = cmsg.pktinfo.Spec_dst
src4.Ifindex = cmsg.pktinfo.Ifindex
} }
return int(size), nil return int(size), nil
} }
func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) { func receive6(sock int, buff []byte, end *Endpoint) (int, error) {
// contruct message header // contruct message header
@ -414,18 +503,10 @@ func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
end.src6.Addr = cmsg.pktinfo.Addr end.src.Family = unix.AF_INET6
end.src6.Scope_id = cmsg.pktinfo.Ifindex end.src.Addr = cmsg.pktinfo.Addr
end.src.Scope_id = cmsg.pktinfo.Ifindex
} }
return int(size), nil return int(size), nil
} }
func SetMark(sock Socket, value uint32) error {
return unix.SetsockoptInt(
int(sock),
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
}

View File

@ -5,10 +5,8 @@ import (
"crypto/rand" "crypto/rand"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"net"
"sync" "sync"
"time" "time"
"unsafe"
) )
type CookieChecker struct { type CookieChecker struct {
@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2]) return hmac.Equal(mac1[:], msg[smac1:smac2])
} }
func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool { func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
st.mutex.RLock() st.mutex.RLock()
defer st.mutex.RUnlock() defer st.mutex.RUnlock()
@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
var cookie [blake2s.Size128]byte var cookie [blake2s.Size128]byte
func() { func() {
mac, _ := blake2s.New128(st.mac2.secret[:]) mac, _ := blake2s.New128(st.mac2.secret[:])
mac.Write(src.IP) mac.Write(src)
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
mac.Sum(cookie[:0]) mac.Sum(cookie[:0])
}() }()
@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
func (st *CookieChecker) CreateReply( func (st *CookieChecker) CreateReply(
msg []byte, msg []byte,
recv uint32, recv uint32,
src *net.UDPAddr, src []byte,
) (*MessageCookieReply, error) { ) (*MessageCookieReply, error) {
st.mutex.RLock() st.mutex.RLock()
@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply(
var cookie [blake2s.Size128]byte var cookie [blake2s.Size128]byte
func() { func() {
mac, _ := blake2s.New128(st.mac2.secret[:]) mac, _ := blake2s.New128(st.mac2.secret[:])
mac.Write(src.IP) mac.Write(src)
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
mac.Sum(cookie[:0]) mac.Sum(cookie[:0])
}() }()

View File

@ -1,18 +1,14 @@
package main package main
import ( import (
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
) )
type Listener struct {
sock Socket
active bool
update chan struct{}
}
type Device struct { type Device struct {
log *Logger // collection of loggers for levels log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers idCounter uint // for assigning debug ids to peers
@ -27,8 +23,7 @@ type Device struct {
} }
net struct { net struct {
mutex sync.RWMutex mutex sync.RWMutex
ipv4 Listener bind UDPBind
ipv6 Listener
port uint16 port uint16
fwmark uint32 fwmark uint32
} }
@ -43,9 +38,8 @@ type Device struct {
handshake chan QueueHandshakeElement handshake chan QueueHandshakeElement
} }
signal struct { signal struct {
stop chan struct{} // halts all go routines stop chan struct{}
updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine) updateBind chan struct{}
updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
} }
underLoadUntil atomic.Value underLoadUntil atomic.Value
ratelimiter Ratelimiter ratelimiter Ratelimiter
@ -146,8 +140,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.tun.device = tun device.tun.device = tun
device.indices.Init() device.indices.Init()
device.net.ipv4.Init()
device.net.ipv6.Init()
device.ratelimiter.Init() device.ratelimiter.Init()
device.routingTable.Reset() device.routingTable.Reset()
@ -181,8 +173,8 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
go device.RoutineReadFromTUN() go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader() go device.RoutineTUNEventReader()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
go device.RoutineReceiveIncomming(&device.net.ipv4) go device.RoutineReceiveIncomming(ipv4.Version)
go device.RoutineReceiveIncomming(&device.net.ipv6) go device.RoutineReceiveIncomming(ipv6.Version)
return device return device
} }

View File

@ -4,7 +4,6 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"net"
"sync" "sync"
"time" "time"
) )
@ -15,8 +14,8 @@ type Peer struct {
persistentKeepaliveInterval uint64 persistentKeepaliveInterval uint64
keyPairs KeyPairs keyPairs KeyPairs
handshake Handshake handshake Handshake
endpoint Endpoint
device *Device device *Device
endpoint *net.UDPAddr
stats struct { stats struct {
txBytes uint64 // bytes send to peer (endpoint) txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer rxBytes uint64 // bytes received from peer
@ -134,7 +133,7 @@ func (peer *Peer) String() string {
return fmt.Sprintf( return fmt.Sprintf(
"peer(%d %s %s)", "peer(%d %s %s)",
peer.id, peer.id,
peer.endpoint.String(), peer.endpoint.DestinationToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
) )
} }

View File

@ -97,17 +97,6 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, started") logDebug.Println("Routine, receive incomming, started")
var listener *Listener
switch IPVersion {
case ipv4.Version:
listener = &device.net.ipv4
case ipv6.Version:
listener = &device.net.ipv6
default:
return
}
for { for {
// wait for new conn // wait for new conn
@ -118,15 +107,14 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
case <-device.signal.stop: case <-device.signal.stop:
return return
case <-listener.update: case <-device.signal.updateBind:
// fetch new socket // fetch new socket
device.net.mutex.RLock() device.net.mutex.RLock()
sock := listener.sock bind := device.net.bind
okay := listener.active
device.net.mutex.RUnlock() device.net.mutex.RUnlock()
if !okay { if bind == nil {
continue continue
} }
@ -145,10 +133,13 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
var endpoint Endpoint var endpoint Endpoint
if IPVersion == ipv6.Version { switch IPVersion {
size, err = endpoint.ReceiveIPv4(sock, buffer[:]) case ipv4.Version:
} else { size, err = bind.ReceiveIPv4(buffer[:], &endpoint)
size, err = endpoint.ReceiveIPv6(sock, buffer[:]) case ipv6.Version:
size, err = bind.ReceiveIPv6(buffer[:], &endpoint)
default:
return
} }
if err != nil { if err != nil {
@ -340,15 +331,19 @@ func (device *Device) RoutineHandshake() {
return return
} }
srcBytes := elem.endpoint.SourceToBytes()
if device.IsUnderLoad() { if device.IsUnderLoad() {
if !device.mac.CheckMAC2(elem.packet, elem.source) {
// verify MAC2 field
if !device.mac.CheckMAC2(elem.packet, srcBytes) {
// construct cookie reply // construct cookie reply
logDebug.Println("Sending cookie reply to:", elem.source.String()) logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString())
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
reply, err := device.mac.CreateReply(elem.packet, sender, elem.source) reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
if err != nil { if err != nil {
logError.Println("Failed to create cookie reply:", err) logError.Println("Failed to create cookie reply:", err)
return return
@ -358,9 +353,9 @@ func (device *Device) RoutineHandshake() {
writer := bytes.NewBuffer(temp[:0]) writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, reply) binary.Write(writer, binary.LittleEndian, reply)
_, err = device.net.conn.WriteToUDP( device.net.bind.Send(
writer.Bytes(), writer.Bytes(),
elem.source, &elem.endpoint,
) )
if err != nil { if err != nil {
logDebug.Println("Failed to send cookie reply:", err) logDebug.Println("Failed to send cookie reply:", err)
@ -368,7 +363,11 @@ func (device *Device) RoutineHandshake() {
continue continue
} }
if !device.ratelimiter.Allow(elem.source.IP) { // check ratelimiter
if !device.ratelimiter.Allow(
elem.endpoint.DestinationIP(),
) {
continue continue
} }
} }
@ -399,8 +398,7 @@ func (device *Device) RoutineHandshake() {
if peer == nil { if peer == nil {
logInfo.Println( logInfo.Println(
"Recieved invalid initiation message from", "Recieved invalid initiation message from",
elem.source.IP.String(), elem.endpoint.DestinationToString(),
elem.source.Port,
) )
continue continue
} }
@ -414,7 +412,7 @@ func (device *Device) RoutineHandshake() {
// TODO: Discover destination address also, only update on change // TODO: Discover destination address also, only update on change
peer.mutex.Lock() peer.mutex.Lock()
peer.endpoint = elem.source peer.endpoint = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
// create response // create response
@ -460,8 +458,7 @@ func (device *Device) RoutineHandshake() {
if peer == nil { if peer == nil {
logInfo.Println( logInfo.Println(
"Recieved invalid response message from", "Recieved invalid response message from",
elem.source.IP.String(), elem.endpoint.DestinationToString(),
elem.source.Port,
) )
continue continue
} }