diff --git a/conn.go b/conn.go index b8970e7..e38160a 100644 --- a/conn.go +++ b/conn.go @@ -20,8 +20,8 @@ const ( */ type Bind interface { SetMark(value uint32) error - ReceiveIPv6(buff []byte) (int, Endpoint, error) - ReceiveIPv4(buff []byte) (int, Endpoint, error) + ReceiveIPv6(buff []byte) (int, Endpoint, byte, error) + ReceiveIPv4(buff []byte) (int, Endpoint, byte, error) Send(buff []byte, end Endpoint, tos byte) error Close() error } diff --git a/conn_default.go b/conn_default.go index 6f17de5..1b25863 100644 --- a/conn_default.go +++ b/conn_default.go @@ -133,26 +133,29 @@ func (bind *NativeBind) Close() error { return err2 } -func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { +// TODO: implement TOS +func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, byte, error) { if bind.ipv4 == nil { - return 0, nil, syscall.EAFNOSUPPORT + return 0, nil, 0, syscall.EAFNOSUPPORT } n, endpoint, err := bind.ipv4.ReadFromUDP(buff) if endpoint != nil { endpoint.IP = endpoint.IP.To4() } - return n, (*NativeEndpoint)(endpoint), err + return n, (*NativeEndpoint)(endpoint), 0, err } -func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { +// TODO: implement TOS +func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, byte, error) { if bind.ipv6 == nil { - return 0, nil, syscall.EAFNOSUPPORT + return 0, nil, 0, syscall.EAFNOSUPPORT } n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err + return n, (*NativeEndpoint)(endpoint), 0, err } -func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error { +// TODO: implement TOS +func (bind *NativeBind) Send(buff []byte, endpoint Endpoint, tos byte) error { var err error nend := endpoint.(*NativeEndpoint) if nend.IP.To4() != nil { diff --git a/conn_linux.go b/conn_linux.go index 83cf1a2..cc1ce2e 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -232,30 +232,32 @@ func (bind *NativeBind) Close() error { return err3 } -func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { +func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, byte, error) { var end NativeEndpoint + var tos byte if bind.sock6 == -1 { - return 0, nil, syscall.EAFNOSUPPORT + return 0, nil, tos, syscall.EAFNOSUPPORT } - n, err := receive6( + n, tos, err := receive6( bind.sock6, buff, &end, ) - return n, &end, err + return n, &end, tos, err } -func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { +func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, byte, error) { var end NativeEndpoint + var tos byte if bind.sock4 == -1 { - return 0, nil, syscall.EAFNOSUPPORT + return 0, nil, tos, syscall.EAFNOSUPPORT } - n, err := receive4( + n, tos, err := receive4( bind.sock4, buff, &end, ) - return n, &end, err + return n, &end, tos, err } func (bind *NativeBind) Send(buff []byte, end Endpoint, tos byte) error { @@ -384,6 +386,15 @@ func create4(port uint16) (int, uint16, error) { return err } + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IP, + unix.IP_RECVTOS, + 1, + ); err != nil { + return err + } + return unix.Bind(fd, &addr) }(); err != nil { unix.Close(fd) @@ -442,6 +453,15 @@ func create6(port uint16) (int, uint16, error) { return err } + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_RECVTCLASS, + 1, + ); err != nil { + return err + } + return unix.Bind(fd, &addr) }(); err != nil { @@ -452,12 +472,13 @@ func create6(port uint16) (int, uint16, error) { return fd, uint16(addr.Port), err } +type ipTos struct { + tos byte +} + func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error { // construct message header - type ipTos struct { - tos byte - } cmsg := struct { cmsghdr unix.Cmsghdr @@ -505,9 +526,6 @@ func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error { func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error { // construct message header - type ipTos struct { - tos byte - } cmsg := struct { cmsghdr unix.Cmsghdr @@ -555,19 +573,21 @@ func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error { return err } -func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { +func receive4(sock int, buff []byte, end *NativeEndpoint) (int, byte, error) { // contruct message header var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo + cmsghdr unix.Cmsghdr + pktinfo unix.Inet4Pktinfo + cmsghdr2 unix.Cmsghdr + iptos ipTos } size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) if err != nil { - return 0, err + return 0, 0, err } end.isV6 = false @@ -576,7 +596,6 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { } // update source cache - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { @@ -584,22 +603,31 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { end.src4().ifindex = cmsg.pktinfo.Ifindex } - return size, nil + tos := byte(0) + if cmsg.cmsghdr2.Level == unix.IPPROTO_IP && + cmsg.cmsghdr2.Type == unix.IP_TOS && + cmsg.cmsghdr2.Len >= 1 { + tos = cmsg.iptos.tos + } + + return size, tos, nil } -func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { +func receive6(sock int, buff []byte, end *NativeEndpoint) (int, byte, error) { // contruct message header var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo + cmsghdr2 unix.Cmsghdr + iptos ipTos } size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) if err != nil { - return 0, err + return 0, 0, err } end.isV6 = true @@ -616,7 +644,14 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { end.dst6().ZoneId = cmsg.pktinfo.Ifindex } - return size, nil + tos := byte(0) + if cmsg.cmsghdr2.Level == unix.IPPROTO_IPV6 && + cmsg.cmsghdr2.Type == unix.IPV6_TCLASS && + cmsg.cmsghdr2.Len >= 1 { + tos = cmsg.iptos.tos + } + + return size, tos, nil } func (bind *NativeBind) routineRouteListener(device *Device) { diff --git a/misc.go b/misc.go index 6786cb5..e5688a5 100644 --- a/misc.go +++ b/misc.go @@ -46,3 +46,62 @@ func min(a, b uint) uint { } return a } + +// called from receive +func ecn_rfc6040_egress(inner byte, outer byte) (byte, bool) { + /* + +---------+------------------------------------------------+ + |Arriving | Arriving Outer Header | + | Inner +---------+------------+------------+------------+ + | Header | Not-ECT | ECT(0) | ECT(1) | CE | + +---------+---------+------------+------------+------------+ + | Not-ECT | Not-ECT |Not-ECT(!!!)|Not-ECT(!!!)| (!!!)| + | ECT(0) | ECT(0) | ECT(0) | ECT(1) | CE | + | ECT(1) | ECT(1) | ECT(1) (!) | ECT(1) | CE | + | CE | CE | CE | CE(!!!)| CE | + +---------+---------+------------+------------+------------+ + */ + innerECN := CongestionExperienced & inner + outerECN := CongestionExperienced & outer + + switch outerECN { + case CongestionExperienced: + switch innerECN { + case NotECNTransport: + return 0, true + } + return (inner & (CongestionExperienced ^ 255)) | CongestionExperienced, false + case ECNTransport1: + switch innerECN { + case ECNTransport0: + return (inner & (CongestionExperienced ^ 255)) | ECNTransport1, false + } + } + return inner, false +} + +// called from send +func ecn_rfc6040_ingress(inner byte, useNormalMode bool) byte { + /* + +-----------------+-------------------------------+ + | Incoming Header | Departing Outer Header | + | (also equal to +---------------+---------------+ + | departing Inner | Compatibility | Normal | + | Header) | Mode | Mode | + +-----------------+---------------+---------------+ + | Not-ECT | Not-ECT | Not-ECT | + | ECT(0) | Not-ECT | ECT(0) | + | ECT(1) | Not-ECT | ECT(1) | + | CE | Not-ECT | CE | + +-----------------+---------------+---------------+ + */ + if !useNormalMode { + inner &= (CongestionExperienced ^ 255) + } + + return inner +} + +func ecn_rfc6040_enabled(tos byte) bool { + return (CongestionExperienced & tos) == ECNTransport0 +} diff --git a/peer.go b/peer.go index 96cfa61..642a0ee 100644 --- a/peer.go +++ b/peer.go @@ -15,6 +15,14 @@ import ( const ( PeerRoutineNumber = 3 + + DiffServAF41 = 0x88 // AF41 + NotECNTransport = 0x00 // Not-ECT (Not ECN-Capable Transport) + ECNTransport1 = 0x01 // ECT(1) (ECN-Capable Transport(1)) + ECNTransport0 = 0x02 // ECT(0) (ECN-Capable Transport(0)) + CongestionExperienced = 0x03 // CE (Congestion Experienced) + + HandshakeDSCP = DiffServAF41 | ECNTransport0 // AF41, plus 10 ECN ) type Peer struct { @@ -25,6 +33,7 @@ type Peer struct { device *Device endpoint Endpoint persistentKeepaliveInterval uint16 + isECNConfirmed AtomicBool // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly stats struct { diff --git a/receive.go b/receive.go index fb848eb..03dbd4b 100644 --- a/receive.go +++ b/receive.go @@ -23,6 +23,7 @@ type QueueHandshakeElement struct { packet []byte endpoint Endpoint buffer *[MaxMessageSize]byte + isECNCompatible bool } type QueueInboundElement struct { @@ -33,6 +34,7 @@ type QueueInboundElement struct { counter uint64 keypair *Keypair endpoint Endpoint + tos byte } func (elem *QueueInboundElement) Drop() { @@ -108,6 +110,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { err error size int endpoint Endpoint + outerTOS byte ) for { @@ -116,9 +119,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { switch IP { case ipv4.Version: - size, endpoint, err = bind.ReceiveIPv4(buffer[:]) + size, endpoint, outerTOS, err = bind.ReceiveIPv4(buffer[:]) case ipv6.Version: - size, endpoint, err = bind.ReceiveIPv6(buffer[:]) + size, endpoint, outerTOS, err = bind.ReceiveIPv6(buffer[:]) default: panic("invalid IP version") } @@ -178,6 +181,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { elem.endpoint = endpoint elem.counter = 0 elem.Mutex = sync.Mutex{} + elem.tos = outerTOS elem.Lock() // add to decryption queues @@ -213,6 +217,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { buffer: buffer, packet: packet, endpoint: endpoint, + isECNCompatible: ecn_rfc6040_enabled(outerTOS), }, )) { buffer = device.GetMessageBuffer() @@ -426,7 +431,7 @@ func (device *Device) RoutineHandshake() { peer.SetEndpointFromPacket(elem.endpoint) logDebug.Println(peer, "- Received handshake initiation") - + peer.isECNConfirmed.Set(elem.isECNCompatible) peer.SendHandshakeResponse() case MessageResponseType: @@ -473,6 +478,7 @@ func (device *Device) RoutineHandshake() { peer.timersSessionDerived() peer.timersHandshakeComplete() + peer.isECNConfirmed.Set(elem.isECNCompatible) peer.SendKeepalive() select { case peer.signals.newKeypairArrived <- struct{}{}: @@ -565,6 +571,7 @@ func (peer *Peer) RoutineSequentialReceiver() { } peer.timersDataReceived() + var shouldDrop bool // verify source and strip padding switch elem.packet[0] >> 4 { @@ -595,6 +602,7 @@ func (peer *Peer) RoutineSequentialReceiver() { continue } + elem.tos, shouldDrop = ecn_rfc6040_egress(elem.packet[1], elem.tos) case ipv6.Version: // strip padding @@ -623,10 +631,15 @@ func (peer *Peer) RoutineSequentialReceiver() { continue } + elem.tos, shouldDrop = ecn_rfc6040_egress(elem.packet[1], elem.tos); default: logInfo.Println("Packet with invalid IP version from", peer) continue } + if shouldDrop { + logInfo.Println("ECN/Congestion detected, dropping packet from", peer) + continue + } // write to tun device diff --git a/send.go b/send.go index 57bb67b..f787027 100644 --- a/send.go +++ b/send.go @@ -41,10 +41,6 @@ import ( * (to allow the construction of transport messages in-place) */ -const ( - HandshakeDSCP = 0x88 // AF41, plus 00 ECN -) - type QueueOutboundElement struct { dropped int32 sync.Mutex @@ -299,14 +295,20 @@ func (device *Device) RoutineReadFromTUN() { } dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] peer = device.allowedips.LookupIPv4(dst) - elem.tos = elem.packet[1]; + if peer == nil { + continue + } + elem.tos = ecn_rfc6040_ingress(elem.packet[1], peer.isECNConfirmed.Get()) case ipv6.Version: if len(elem.packet) < ipv6.HeaderLen { continue } dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] peer = device.allowedips.LookupIPv6(dst) - elem.tos = elem.packet[1]; + if peer == nil { + continue + } + elem.tos = ecn_rfc6040_ingress(elem.packet[1], peer.isECNConfirmed.Get()) default: logDebug.Println("Received packet with unknown IP version") }