1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

net: implement ECN handling, rfc6040 style

To decide whether we should use the compatibility mode or the normal
mode with a peer, we use the handshake messages as a signaling channel.

If we receive the expected ECN bits, it most likely means they're
running a compatible version.

Signed-off-by: Florent Daigniere <nextgens@freenetproject.org>
This commit is contained in:
Florent Daigniere 2019-02-23 21:50:04 +01:00
parent 9e686cd714
commit 0c2d06d8a5
No known key found for this signature in database
GPG Key ID: EAC5EBF07AA9C2A3
7 changed files with 164 additions and 43 deletions

View File

@ -20,8 +20,8 @@ const (
*/ */
type Bind interface { type Bind interface {
SetMark(value uint32) error SetMark(value uint32) error
ReceiveIPv6(buff []byte) (int, Endpoint, error) ReceiveIPv6(buff []byte) (int, Endpoint, byte, error)
ReceiveIPv4(buff []byte) (int, Endpoint, error) ReceiveIPv4(buff []byte) (int, Endpoint, byte, error)
Send(buff []byte, end Endpoint, tos byte) error Send(buff []byte, end Endpoint, tos byte) error
Close() error Close() error
} }

View File

@ -133,26 +133,29 @@ func (bind *NativeBind) Close() error {
return err2 return err2
} }
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { // TODO: implement TOS
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, byte, error) {
if bind.ipv4 == nil { if bind.ipv4 == nil {
return 0, nil, syscall.EAFNOSUPPORT return 0, nil, 0, syscall.EAFNOSUPPORT
} }
n, endpoint, err := bind.ipv4.ReadFromUDP(buff) n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
if endpoint != nil { if endpoint != nil {
endpoint.IP = endpoint.IP.To4() endpoint.IP = endpoint.IP.To4()
} }
return n, (*NativeEndpoint)(endpoint), err return n, (*NativeEndpoint)(endpoint), 0, err
} }
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { // TODO: implement TOS
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, byte, error) {
if bind.ipv6 == nil { if bind.ipv6 == nil {
return 0, nil, syscall.EAFNOSUPPORT return 0, nil, 0, syscall.EAFNOSUPPORT
} }
n, endpoint, err := bind.ipv6.ReadFromUDP(buff) n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
return n, (*NativeEndpoint)(endpoint), err return n, (*NativeEndpoint)(endpoint), 0, err
} }
func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error { // TODO: implement TOS
func (bind *NativeBind) Send(buff []byte, endpoint Endpoint, tos byte) error {
var err error var err error
nend := endpoint.(*NativeEndpoint) nend := endpoint.(*NativeEndpoint)
if nend.IP.To4() != nil { if nend.IP.To4() != nil {

View File

@ -232,30 +232,32 @@ func (bind *NativeBind) Close() error {
return err3 return err3
} }
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, byte, error) {
var end NativeEndpoint var end NativeEndpoint
var tos byte
if bind.sock6 == -1 { if bind.sock6 == -1 {
return 0, nil, syscall.EAFNOSUPPORT return 0, nil, tos, syscall.EAFNOSUPPORT
} }
n, err := receive6( n, tos, err := receive6(
bind.sock6, bind.sock6,
buff, buff,
&end, &end,
) )
return n, &end, err return n, &end, tos, err
} }
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, byte, error) {
var end NativeEndpoint var end NativeEndpoint
var tos byte
if bind.sock4 == -1 { if bind.sock4 == -1 {
return 0, nil, syscall.EAFNOSUPPORT return 0, nil, tos, syscall.EAFNOSUPPORT
} }
n, err := receive4( n, tos, err := receive4(
bind.sock4, bind.sock4,
buff, buff,
&end, &end,
) )
return n, &end, err return n, &end, tos, err
} }
func (bind *NativeBind) Send(buff []byte, end Endpoint, tos byte) error { func (bind *NativeBind) Send(buff []byte, end Endpoint, tos byte) error {
@ -384,6 +386,15 @@ func create4(port uint16) (int, uint16, error) {
return err return err
} }
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IP,
unix.IP_RECVTOS,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr) return unix.Bind(fd, &addr)
}(); err != nil { }(); err != nil {
unix.Close(fd) unix.Close(fd)
@ -442,6 +453,15 @@ func create6(port uint16) (int, uint16, error) {
return err return err
} }
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVTCLASS,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr) return unix.Bind(fd, &addr)
}(); err != nil { }(); err != nil {
@ -452,12 +472,13 @@ func create6(port uint16) (int, uint16, error) {
return fd, uint16(addr.Port), err return fd, uint16(addr.Port), err
} }
type ipTos struct {
tos byte
}
func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error { func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
// construct message header // construct message header
type ipTos struct {
tos byte
}
cmsg := struct { cmsg := struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
@ -505,9 +526,6 @@ func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error { func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
// construct message header // construct message header
type ipTos struct {
tos byte
}
cmsg := struct { cmsg := struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
@ -555,19 +573,21 @@ func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
return err return err
} }
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive4(sock int, buff []byte, end *NativeEndpoint) (int, byte, error) {
// contruct message header // contruct message header
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo pktinfo unix.Inet4Pktinfo
cmsghdr2 unix.Cmsghdr
iptos ipTos
} }
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil { if err != nil {
return 0, err return 0, 0, err
} }
end.isV6 = false end.isV6 = false
@ -576,7 +596,6 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
} }
// update source cache // update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP && if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
@ -584,22 +603,31 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
end.src4().ifindex = cmsg.pktinfo.Ifindex end.src4().ifindex = cmsg.pktinfo.Ifindex
} }
return size, nil tos := byte(0)
if cmsg.cmsghdr2.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr2.Type == unix.IP_TOS &&
cmsg.cmsghdr2.Len >= 1 {
tos = cmsg.iptos.tos
}
return size, tos, nil
} }
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive6(sock int, buff []byte, end *NativeEndpoint) (int, byte, error) {
// contruct message header // contruct message header
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo pktinfo unix.Inet6Pktinfo
cmsghdr2 unix.Cmsghdr
iptos ipTos
} }
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil { if err != nil {
return 0, err return 0, 0, err
} }
end.isV6 = true end.isV6 = true
@ -616,7 +644,14 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
end.dst6().ZoneId = cmsg.pktinfo.Ifindex end.dst6().ZoneId = cmsg.pktinfo.Ifindex
} }
return size, nil tos := byte(0)
if cmsg.cmsghdr2.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr2.Type == unix.IPV6_TCLASS &&
cmsg.cmsghdr2.Len >= 1 {
tos = cmsg.iptos.tos
}
return size, tos, nil
} }
func (bind *NativeBind) routineRouteListener(device *Device) { func (bind *NativeBind) routineRouteListener(device *Device) {

59
misc.go
View File

@ -46,3 +46,62 @@ func min(a, b uint) uint {
} }
return a return a
} }
// called from receive
func ecn_rfc6040_egress(inner byte, outer byte) (byte, bool) {
/*
+---------+------------------------------------------------+
|Arriving | Arriving Outer Header |
| Inner +---------+------------+------------+------------+
| Header | Not-ECT | ECT(0) | ECT(1) | CE |
+---------+---------+------------+------------+------------+
| Not-ECT | Not-ECT |Not-ECT(!!!)|Not-ECT(!!!)| <drop>(!!!)|
| ECT(0) | ECT(0) | ECT(0) | ECT(1) | CE |
| ECT(1) | ECT(1) | ECT(1) (!) | ECT(1) | CE |
| CE | CE | CE | CE(!!!)| CE |
+---------+---------+------------+------------+------------+
*/
innerECN := CongestionExperienced & inner
outerECN := CongestionExperienced & outer
switch outerECN {
case CongestionExperienced:
switch innerECN {
case NotECNTransport:
return 0, true
}
return (inner & (CongestionExperienced ^ 255)) | CongestionExperienced, false
case ECNTransport1:
switch innerECN {
case ECNTransport0:
return (inner & (CongestionExperienced ^ 255)) | ECNTransport1, false
}
}
return inner, false
}
// called from send
func ecn_rfc6040_ingress(inner byte, useNormalMode bool) byte {
/*
+-----------------+-------------------------------+
| Incoming Header | Departing Outer Header |
| (also equal to +---------------+---------------+
| departing Inner | Compatibility | Normal |
| Header) | Mode | Mode |
+-----------------+---------------+---------------+
| Not-ECT | Not-ECT | Not-ECT |
| ECT(0) | Not-ECT | ECT(0) |
| ECT(1) | Not-ECT | ECT(1) |
| CE | Not-ECT | CE |
+-----------------+---------------+---------------+
*/
if !useNormalMode {
inner &= (CongestionExperienced ^ 255)
}
return inner
}
func ecn_rfc6040_enabled(tos byte) bool {
return (CongestionExperienced & tos) == ECNTransport0
}

View File

@ -15,6 +15,14 @@ import (
const ( const (
PeerRoutineNumber = 3 PeerRoutineNumber = 3
DiffServAF41 = 0x88 // AF41
NotECNTransport = 0x00 // Not-ECT (Not ECN-Capable Transport)
ECNTransport1 = 0x01 // ECT(1) (ECN-Capable Transport(1))
ECNTransport0 = 0x02 // ECT(0) (ECN-Capable Transport(0))
CongestionExperienced = 0x03 // CE (Congestion Experienced)
HandshakeDSCP = DiffServAF41 | ECNTransport0 // AF41, plus 10 ECN
) )
type Peer struct { type Peer struct {
@ -25,6 +33,7 @@ type Peer struct {
device *Device device *Device
endpoint Endpoint endpoint Endpoint
persistentKeepaliveInterval uint16 persistentKeepaliveInterval uint16
isECNConfirmed AtomicBool
// This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
stats struct { stats struct {

View File

@ -23,6 +23,7 @@ type QueueHandshakeElement struct {
packet []byte packet []byte
endpoint Endpoint endpoint Endpoint
buffer *[MaxMessageSize]byte buffer *[MaxMessageSize]byte
isECNCompatible bool
} }
type QueueInboundElement struct { type QueueInboundElement struct {
@ -33,6 +34,7 @@ type QueueInboundElement struct {
counter uint64 counter uint64
keypair *Keypair keypair *Keypair
endpoint Endpoint endpoint Endpoint
tos byte
} }
func (elem *QueueInboundElement) Drop() { func (elem *QueueInboundElement) Drop() {
@ -108,6 +110,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
err error err error
size int size int
endpoint Endpoint endpoint Endpoint
outerTOS byte
) )
for { for {
@ -116,9 +119,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
switch IP { switch IP {
case ipv4.Version: case ipv4.Version:
size, endpoint, err = bind.ReceiveIPv4(buffer[:]) size, endpoint, outerTOS, err = bind.ReceiveIPv4(buffer[:])
case ipv6.Version: case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:]) size, endpoint, outerTOS, err = bind.ReceiveIPv6(buffer[:])
default: default:
panic("invalid IP version") panic("invalid IP version")
} }
@ -178,6 +181,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
elem.endpoint = endpoint elem.endpoint = endpoint
elem.counter = 0 elem.counter = 0
elem.Mutex = sync.Mutex{} elem.Mutex = sync.Mutex{}
elem.tos = outerTOS
elem.Lock() elem.Lock()
// add to decryption queues // add to decryption queues
@ -213,6 +217,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
buffer: buffer, buffer: buffer,
packet: packet, packet: packet,
endpoint: endpoint, endpoint: endpoint,
isECNCompatible: ecn_rfc6040_enabled(outerTOS),
}, },
)) { )) {
buffer = device.GetMessageBuffer() buffer = device.GetMessageBuffer()
@ -426,7 +431,7 @@ func (device *Device) RoutineHandshake() {
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
logDebug.Println(peer, "- Received handshake initiation") logDebug.Println(peer, "- Received handshake initiation")
peer.isECNConfirmed.Set(elem.isECNCompatible)
peer.SendHandshakeResponse() peer.SendHandshakeResponse()
case MessageResponseType: case MessageResponseType:
@ -473,6 +478,7 @@ func (device *Device) RoutineHandshake() {
peer.timersSessionDerived() peer.timersSessionDerived()
peer.timersHandshakeComplete() peer.timersHandshakeComplete()
peer.isECNConfirmed.Set(elem.isECNCompatible)
peer.SendKeepalive() peer.SendKeepalive()
select { select {
case peer.signals.newKeypairArrived <- struct{}{}: case peer.signals.newKeypairArrived <- struct{}{}:
@ -565,6 +571,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
} }
peer.timersDataReceived() peer.timersDataReceived()
var shouldDrop bool
// verify source and strip padding // verify source and strip padding
switch elem.packet[0] >> 4 { switch elem.packet[0] >> 4 {
@ -595,6 +602,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue continue
} }
elem.tos, shouldDrop = ecn_rfc6040_egress(elem.packet[1], elem.tos)
case ipv6.Version: case ipv6.Version:
// strip padding // strip padding
@ -623,10 +631,15 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue continue
} }
elem.tos, shouldDrop = ecn_rfc6040_egress(elem.packet[1], elem.tos);
default: default:
logInfo.Println("Packet with invalid IP version from", peer) logInfo.Println("Packet with invalid IP version from", peer)
continue continue
} }
if shouldDrop {
logInfo.Println("ECN/Congestion detected, dropping packet from", peer)
continue
}
// write to tun device // write to tun device

14
send.go
View File

@ -41,10 +41,6 @@ import (
* (to allow the construction of transport messages in-place) * (to allow the construction of transport messages in-place)
*/ */
const (
HandshakeDSCP = 0x88 // AF41, plus 00 ECN
)
type QueueOutboundElement struct { type QueueOutboundElement struct {
dropped int32 dropped int32
sync.Mutex sync.Mutex
@ -299,14 +295,20 @@ func (device *Device) RoutineReadFromTUN() {
} }
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.LookupIPv4(dst) peer = device.allowedips.LookupIPv4(dst)
elem.tos = elem.packet[1]; if peer == nil {
continue
}
elem.tos = ecn_rfc6040_ingress(elem.packet[1], peer.isECNConfirmed.Get())
case ipv6.Version: case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen { if len(elem.packet) < ipv6.HeaderLen {
continue continue
} }
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.LookupIPv6(dst) peer = device.allowedips.LookupIPv6(dst)
elem.tos = elem.packet[1]; if peer == nil {
continue
}
elem.tos = ecn_rfc6040_ingress(elem.packet[1], peer.isECNConfirmed.Get())
default: default:
logDebug.Println("Received packet with unknown IP version") logDebug.Println("Received packet with unknown IP version")
} }