diff --git a/src/conn.go b/src/conn.go index 61be3bf..db4020d 100644 --- a/src/conn.go +++ b/src/conn.go @@ -5,6 +5,14 @@ import ( "net" ) +type UDPBind interface { + SetMark(value uint32) error + ReceiveIPv6(buff []byte, end *Endpoint) (int, error) + ReceiveIPv4(buff []byte, end *Endpoint) (int, error) + Send(buff []byte, end *Endpoint) error + Close() error +} + func parseEndpoint(s string) (*net.UDPAddr, error) { // ensure that the host is an IP address @@ -26,19 +34,6 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { return addr, err } -func ListenerClose(l *Listener) (err error) { - if l.active { - err = CloseIPv4Socket(l.sock) - l.active = false - } - return -} - -func (l *Listener) Init() { - l.update = make(chan struct{}, 1) - ListenerClose(l) -} - func ListeningUpdate(device *Device) error { netc := &device.net netc.mutex.Lock() @@ -46,11 +41,7 @@ func ListeningUpdate(device *Device) error { // close existing sockets - if err := ListenerClose(&netc.ipv4); err != nil { - return err - } - - if err := ListenerClose(&netc.ipv6); err != nil { + if err := device.net.bind.Close(); err != nil { return err } @@ -58,45 +49,22 @@ func ListeningUpdate(device *Device) error { if device.tun.isUp.Get() { - // listen on IPv4 + // bind to new port - { - list := &netc.ipv6 - sock, port, err := CreateIPv4Socket(netc.port) - if err != nil { - return err - } - netc.port = port - list.sock = sock - list.active = true - - if err := SetMark(list.sock, netc.fwmark); err != nil { - ListenerClose(list) - return err - } - signalSend(list.update) + var err error + netc.bind, netc.port, err = CreateUDPBind(netc.port) + if err != nil { + return err } - // listen on IPv6 + // set mark - { - list := &netc.ipv6 - sock, port, err := CreateIPv6Socket(netc.port) - if err != nil { - return err - } - netc.port = port - list.sock = sock - list.active = true - - if err := SetMark(list.sock, netc.fwmark); err != nil { - ListenerClose(list) - return err - } - signalSend(list.update) + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err } - // TODO: clear endpoint caches + // TODO: clear endpoint (src) caches } return nil @@ -106,16 +74,5 @@ func ListeningClose(device *Device) error { netc := &device.net netc.mutex.Lock() defer netc.mutex.Unlock() - - if err := ListenerClose(&netc.ipv4); err != nil { - return err - } - signalSend(netc.ipv4.update) - - if err := ListenerClose(&netc.ipv6); err != nil { - return err - } - signalSend(netc.ipv6.update) - - return nil + return netc.bind.Close() } diff --git a/src/conn_linux.go b/src/conn_linux.go index 034fb8b..8942b03 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -14,35 +14,158 @@ import ( "unsafe" ) -import "fmt" - /* Supports source address caching * * Currently there is no way to achieve this within the net package: * See e.g. https://github.com/golang/go/issues/17930 - * So this code is platform dependent. - * - * It is important that the endpoint is only updated after the packet content has been authenticated! + * So this code is remains platform dependent. */ type Endpoint struct { - // source (selected based on dst type) - // (could use RawSockaddrAny and unsafe) - // TODO: Merge - src6 unix.RawSockaddrInet6 - src4 unix.RawSockaddrInet4 - src4if int32 - - dst unix.RawSockaddrAny + src unix.RawSockaddrInet6 + dst unix.RawSockaddrInet6 } -type Socket int +type IPv4Source struct { + src unix.RawSockaddrInet4 + Ifindex int32 +} -/* Returns a byte representation of the source field(s) - * for use in "under load" cookie computations. - */ -func (endpoint *Endpoint) Source() []byte { - return nil +type Bind struct { + sock4 int + sock6 int +} + +func CreateUDPBind(port uint16) (UDPBind, uint16, error) { + var err error + var bind Bind + + bind.sock6, port, err = create6(port) + if err != nil { + return nil, port, err + } + + bind.sock4, port, err = create4(port) + if err != nil { + unix.Close(bind.sock6) + } + return &bind, port, err +} + +func (bind *Bind) SetMark(value uint32) error { + err := unix.SetsockoptInt( + bind.sock6, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) + + if err != nil { + return err + } + + return unix.SetsockoptInt( + bind.sock4, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) +} + +func (bind *Bind) Close() error { + err1 := unix.Close(bind.sock6) + err2 := unix.Close(bind.sock4) + if err1 != nil { + return err1 + } + return err2 +} + +func (bind *Bind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { + return receive6( + bind.sock6, + buff, + end, + ) +} + +func (bind *Bind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { + return receive4( + bind.sock4, + buff, + end, + ) +} + +func (bind *Bind) Send(buff []byte, end *Endpoint) error { + switch end.src.Family { + case unix.AF_INET6: + return send6(bind.sock6, end, buff) + case unix.AF_INET: + return send4(bind.sock4, end, buff) + default: + return errors.New("Unknown address family of source") + } +} + +func sockaddrToString(addr unix.RawSockaddrInet6) string { + var udpAddr net.UDPAddr + + switch addr.Family { + case unix.AF_INET6: + udpAddr.Port = int(addr.Port) + udpAddr.IP = addr.Addr[:] + return udpAddr.String() + + case unix.AF_INET: + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) + udpAddr.Port = int(ptr.Port) + udpAddr.IP = net.IPv4( + ptr.Addr[0], + ptr.Addr[1], + ptr.Addr[2], + ptr.Addr[3], + ) + return udpAddr.String() + + default: + return "" + } +} + +func (end *Endpoint) DestinationIP() net.IP { + switch end.dst.Family { + case unix.AF_INET6: + return end.dst.Addr[:] + case unix.AF_INET: + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + return net.IPv4( + ptr.Addr[0], + ptr.Addr[1], + ptr.Addr[2], + ptr.Addr[3], + ) + default: + return nil + } +} + +func (end *Endpoint) SourceToBytes() []byte { + ptr := unsafe.Pointer(&end.src) + arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) + return arr[:] +} + +func (end *Endpoint) SourceToString() string { + return sockaddrToString(end.src) +} + +func (end *Endpoint) DestinationToString() string { + return sockaddrToString(end.dst) +} + +func (end *Endpoint) ClearSrc() { + end.src = unix.RawSockaddrInet6{} } func zoneToUint32(zone string) (uint32, error) { @@ -56,7 +179,7 @@ func zoneToUint32(zone string) (uint32, error) { return uint32(n), err } -func CreateIPv4Socket(port uint16) (Socket, uint16, error) { +func create4(port uint16) (int, uint16, error) { // create socket @@ -100,18 +223,10 @@ func CreateIPv4Socket(port uint16) (Socket, uint16, error) { unix.Close(fd) } - return Socket(fd), uint16(addr.Port), err + return fd, uint16(addr.Port), err } -func CloseIPv4Socket(sock Socket) error { - return unix.Close(int(sock)) -} - -func CloseIPv6Socket(sock Socket) error { - return unix.Close(int(sock)) -} - -func CreateIPv6Socket(port uint16) (Socket, uint16, error) { +func create6(port uint16) (int, uint16, error) { // create socket @@ -166,13 +281,7 @@ func CreateIPv6Socket(port uint16) (Socket, uint16, error) { unix.Close(fd) } - return Socket(fd), uint16(addr.Port), err -} - -func (end *Endpoint) ClearSrc() { - end.src4if = 0 - end.src4 = unix.RawSockaddrInet4{} - end.src6 = unix.RawSockaddrInet6{} + return fd, uint16(addr.Port), err } func (end *Endpoint) Set(s string) error { @@ -187,23 +296,23 @@ func (end *Endpoint) Set(s string) error { if err != nil { return err } - ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst)) - ptr.Family = unix.AF_INET6 - ptr.Port = uint16(addr.Port) - ptr.Flowinfo = 0 - ptr.Scope_id = zone - copy(ptr.Addr[:], ipv6[:]) + dst := &end.dst + dst.Family = unix.AF_INET6 + dst.Port = uint16(addr.Port) + dst.Flowinfo = 0 + dst.Scope_id = zone + copy(dst.Addr[:], ipv6[:]) end.ClearSrc() return nil } ipv4 := addr.IP.To4() if ipv4 != nil { - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) - ptr.Family = unix.AF_INET - ptr.Port = uint16(addr.Port) - ptr.Zero = [8]byte{} - copy(ptr.Addr[:], ipv4) + dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + dst.Family = unix.AF_INET + dst.Port = uint16(addr.Port) + dst.Zero = [8]byte{} + copy(dst.Addr[:], ipv4) end.ClearSrc() return nil } @@ -211,7 +320,7 @@ func (end *Endpoint) Set(s string) error { return errors.New("Failed to recognize IP address format") } -func send6(sock uintptr, end *Endpoint, buff []byte) error { +func send6(sock int, end *Endpoint, buff []byte) error { // construct message header @@ -229,8 +338,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { Len: unix.SizeofInet6Pktinfo, }, unix.Inet6Pktinfo{ - Addr: end.src6.Addr, - Ifindex: end.src6.Scope_id, + Addr: end.src.Addr, + Ifindex: end.src.Scope_id, }, } @@ -248,7 +357,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { _, _, errno := unix.Syscall( unix.SYS_SENDMSG, - sock, + uintptr(sock), uintptr(unsafe.Pointer(&msghdr)), 0, ) @@ -258,7 +367,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { return errno } -func send4(sock uintptr, end *Endpoint, buff []byte) error { +func send4(sock int, end *Endpoint, buff []byte) error { // construct message header @@ -266,6 +375,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.SetLen(len(buff)) + src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) + cmsg := struct { cmsghdr unix.Cmsghdr pktinfo unix.Inet4Pktinfo @@ -276,8 +387,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { Len: unix.SizeofInet4Pktinfo, }, unix.Inet4Pktinfo{ - Spec_dst: end.src4.Addr, - Ifindex: end.src4if, + Spec_dst: src4.src.Addr, + Ifindex: src4.Ifindex, }, } @@ -295,7 +406,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { _, _, errno := unix.Syscall( unix.SYS_SENDMSG, - sock, + uintptr(sock), uintptr(unsafe.Pointer(&msghdr)), 0, ) @@ -305,28 +416,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { return errno } -func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error { - - // extract underlying file descriptor - - file, err := c.File() - if err != nil { - return err - } - sock := file.Fd() - - // send depending on address family of dst - - family := *((*uint16)(unsafe.Pointer(&end.dst))) - if family == unix.AF_INET { - return send4(sock, end, buff) - } else if family == unix.AF_INET6 { - return send6(sock, end, buff) - } - return errors.New("Unknown address family of source") -} - -func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) { +func receive4(sock int, buff []byte, end *Endpoint) (int, error) { // contruct message header @@ -360,22 +450,21 @@ func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) { return 0, errno } - fmt.Println(msghdr) - fmt.Println(cmsg) - // update source cache if cmsg.cmsghdr.Level == unix.IPPROTO_IP && cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4.Addr = cmsg.pktinfo.Spec_dst - end.src4if = cmsg.pktinfo.Ifindex + src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) + src4.src.Family = unix.AF_INET + src4.src.Addr = cmsg.pktinfo.Spec_dst + src4.Ifindex = cmsg.pktinfo.Ifindex } return int(size), nil } -func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) { +func receive6(sock int, buff []byte, end *Endpoint) (int, error) { // contruct message header @@ -414,18 +503,10 @@ func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) { if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src6.Addr = cmsg.pktinfo.Addr - end.src6.Scope_id = cmsg.pktinfo.Ifindex + end.src.Family = unix.AF_INET6 + end.src.Addr = cmsg.pktinfo.Addr + end.src.Scope_id = cmsg.pktinfo.Ifindex } return int(size), nil } - -func SetMark(sock Socket, value uint32) error { - return unix.SetsockoptInt( - int(sock), - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) -} diff --git a/src/cookie.go b/src/cookie.go index a81819b..a13ad49 100644 --- a/src/cookie.go +++ b/src/cookie.go @@ -5,10 +5,8 @@ import ( "crypto/rand" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" - "net" "sync" "time" - "unsafe" ) type CookieChecker struct { @@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool { return hmac.Equal(mac1[:], msg[smac1:smac2]) } -func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool { +func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { st.mutex.RLock() defer st.mutex.RUnlock() @@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool { var cookie [blake2s.Size128]byte func() { mac, _ := blake2s.New128(st.mac2.secret[:]) - mac.Write(src.IP) - mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:]) + mac.Write(src) mac.Sum(cookie[:0]) }() @@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool { func (st *CookieChecker) CreateReply( msg []byte, recv uint32, - src *net.UDPAddr, + src []byte, ) (*MessageCookieReply, error) { st.mutex.RLock() @@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply( var cookie [blake2s.Size128]byte func() { mac, _ := blake2s.New128(st.mac2.secret[:]) - mac.Write(src.IP) - mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:]) + mac.Write(src) mac.Sum(cookie[:0]) }() diff --git a/src/device.go b/src/device.go index 509e6a7..d1e0685 100644 --- a/src/device.go +++ b/src/device.go @@ -1,18 +1,14 @@ package main import ( + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "runtime" "sync" "sync/atomic" "time" ) -type Listener struct { - sock Socket - active bool - update chan struct{} -} - type Device struct { log *Logger // collection of loggers for levels idCounter uint // for assigning debug ids to peers @@ -27,8 +23,7 @@ type Device struct { } net struct { mutex sync.RWMutex - ipv4 Listener - ipv6 Listener + bind UDPBind port uint16 fwmark uint32 } @@ -43,9 +38,8 @@ type Device struct { handshake chan QueueHandshakeElement } signal struct { - stop chan struct{} // halts all go routines - updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine) - updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine) + stop chan struct{} + updateBind chan struct{} } underLoadUntil atomic.Value ratelimiter Ratelimiter @@ -146,8 +140,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.tun.device = tun device.indices.Init() - device.net.ipv4.Init() - device.net.ipv6.Init() device.ratelimiter.Init() device.routingTable.Reset() @@ -181,8 +173,8 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) - go device.RoutineReceiveIncomming(&device.net.ipv4) - go device.RoutineReceiveIncomming(&device.net.ipv6) + go device.RoutineReceiveIncomming(ipv4.Version) + go device.RoutineReceiveIncomming(ipv6.Version) return device } diff --git a/src/peer.go b/src/peer.go index 6fea829..791c091 100644 --- a/src/peer.go +++ b/src/peer.go @@ -4,7 +4,6 @@ import ( "encoding/base64" "errors" "fmt" - "net" "sync" "time" ) @@ -15,8 +14,8 @@ type Peer struct { persistentKeepaliveInterval uint64 keyPairs KeyPairs handshake Handshake + endpoint Endpoint device *Device - endpoint *net.UDPAddr stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer @@ -134,7 +133,7 @@ func (peer *Peer) String() string { return fmt.Sprintf( "peer(%d %s %s)", peer.id, - peer.endpoint.String(), + peer.endpoint.DestinationToString(), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } diff --git a/src/receive.go b/src/receive.go index 60c0f2c..664f1ba 100644 --- a/src/receive.go +++ b/src/receive.go @@ -97,17 +97,6 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) { logDebug := device.log.Debug logDebug.Println("Routine, receive incomming, started") - var listener *Listener - - switch IPVersion { - case ipv4.Version: - listener = &device.net.ipv4 - case ipv6.Version: - listener = &device.net.ipv6 - default: - return - } - for { // wait for new conn @@ -118,15 +107,14 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) { case <-device.signal.stop: return - case <-listener.update: + case <-device.signal.updateBind: // fetch new socket device.net.mutex.RLock() - sock := listener.sock - okay := listener.active + bind := device.net.bind device.net.mutex.RUnlock() - if !okay { + if bind == nil { continue } @@ -145,10 +133,13 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) { var endpoint Endpoint - if IPVersion == ipv6.Version { - size, err = endpoint.ReceiveIPv4(sock, buffer[:]) - } else { - size, err = endpoint.ReceiveIPv6(sock, buffer[:]) + switch IPVersion { + case ipv4.Version: + size, err = bind.ReceiveIPv4(buffer[:], &endpoint) + case ipv6.Version: + size, err = bind.ReceiveIPv6(buffer[:], &endpoint) + default: + return } if err != nil { @@ -340,15 +331,19 @@ func (device *Device) RoutineHandshake() { return } + srcBytes := elem.endpoint.SourceToBytes() if device.IsUnderLoad() { - if !device.mac.CheckMAC2(elem.packet, elem.source) { + + // verify MAC2 field + + if !device.mac.CheckMAC2(elem.packet, srcBytes) { // construct cookie reply - logDebug.Println("Sending cookie reply to:", elem.source.String()) + logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString()) sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" - reply, err := device.mac.CreateReply(elem.packet, sender, elem.source) + reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) if err != nil { logError.Println("Failed to create cookie reply:", err) return @@ -358,9 +353,9 @@ func (device *Device) RoutineHandshake() { writer := bytes.NewBuffer(temp[:0]) binary.Write(writer, binary.LittleEndian, reply) - _, err = device.net.conn.WriteToUDP( + device.net.bind.Send( writer.Bytes(), - elem.source, + &elem.endpoint, ) if err != nil { logDebug.Println("Failed to send cookie reply:", err) @@ -368,7 +363,11 @@ func (device *Device) RoutineHandshake() { continue } - if !device.ratelimiter.Allow(elem.source.IP) { + // check ratelimiter + + if !device.ratelimiter.Allow( + elem.endpoint.DestinationIP(), + ) { continue } } @@ -399,8 +398,7 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid initiation message from", - elem.source.IP.String(), - elem.source.Port, + elem.endpoint.DestinationToString(), ) continue } @@ -414,7 +412,7 @@ func (device *Device) RoutineHandshake() { // TODO: Discover destination address also, only update on change peer.mutex.Lock() - peer.endpoint = elem.source + peer.endpoint = elem.endpoint peer.mutex.Unlock() // create response @@ -460,8 +458,7 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid response message from", - elem.source.IP.String(), - elem.source.Port, + elem.endpoint.DestinationToString(), ) continue }