diff --git a/conn.go b/conn.go index b19a9c2..b8970e7 100644 --- a/conn.go +++ b/conn.go @@ -22,7 +22,7 @@ type Bind interface { SetMark(value uint32) error ReceiveIPv6(buff []byte) (int, Endpoint, error) ReceiveIPv4(buff []byte) (int, Endpoint, error) - Send(buff []byte, end Endpoint) error + Send(buff []byte, end Endpoint, tos byte) error Close() error } diff --git a/conn_linux.go b/conn_linux.go index 9ebbeb1..83cf1a2 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -258,18 +258,18 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { return n, &end, err } -func (bind *NativeBind) Send(buff []byte, end Endpoint) error { +func (bind *NativeBind) Send(buff []byte, end Endpoint, tos byte) error { nend := end.(*NativeEndpoint) if !nend.isV6 { if bind.sock4 == -1 { return syscall.EAFNOSUPPORT } - return send4(bind.sock4, nend, buff) + return send4(bind.sock4, nend, buff, tos) } else { if bind.sock6 == -1 { return syscall.EAFNOSUPPORT } - return send6(bind.sock6, nend, buff) + return send6(bind.sock6, nend, buff, tos) } } @@ -452,13 +452,18 @@ func create6(port uint16) (int, uint16, error) { return fd, uint16(addr.Port), err } -func send4(sock int, end *NativeEndpoint, buff []byte) error { +func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error { // construct message header + type ipTos struct { + tos byte + } cmsg := struct { cmsghdr unix.Cmsghdr pktinfo unix.Inet4Pktinfo + cmsghdr2 unix.Cmsghdr + iptos ipTos }{ unix.Cmsghdr{ Level: unix.IPPROTO_IP, @@ -469,6 +474,15 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error { Spec_dst: end.src4().src, Ifindex: end.src4().ifindex, }, + unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_TOS, + Len: 1 + unix.SizeofCmsghdr, + }, + ipTos{ + tos: tos, + }, + } _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) @@ -488,13 +502,18 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error { return err } -func send6(sock int, end *NativeEndpoint, buff []byte) error { +func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error { // construct message header + type ipTos struct { + tos byte + } cmsg := struct { cmsghdr unix.Cmsghdr pktinfo unix.Inet6Pktinfo + cmsghdr2 unix.Cmsghdr + tclass ipTos }{ unix.Cmsghdr{ Level: unix.IPPROTO_IPV6, @@ -505,6 +524,14 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error { Addr: end.src6().src, Ifindex: end.dst6().ZoneId, }, + unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_TCLASS, + Len: 1 + unix.SizeofCmsghdr, + }, + ipTos{ + tos: tos, + }, } if cmsg.pktinfo.Addr == [16]byte{} { diff --git a/peer.go b/peer.go index f021565..96cfa61 100644 --- a/peer.go +++ b/peer.go @@ -125,7 +125,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } -func (peer *Peer) SendBuffer(buffer []byte) error { +func (peer *Peer) SendBuffer(buffer []byte, tos byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() @@ -140,7 +140,7 @@ func (peer *Peer) SendBuffer(buffer []byte) error { return errors.New("no known endpoint for peer") } - return peer.device.net.bind.Send(buffer, peer.endpoint) + return peer.device.net.bind.Send(buffer, peer.endpoint, tos) } func (peer *Peer) String() string { diff --git a/send.go b/send.go index b7cac04..57bb67b 100644 --- a/send.go +++ b/send.go @@ -41,6 +41,10 @@ import ( * (to allow the construction of transport messages in-place) */ +const ( + HandshakeDSCP = 0x88 // AF41, plus 00 ECN +) + type QueueOutboundElement struct { dropped int32 sync.Mutex @@ -49,6 +53,7 @@ type QueueOutboundElement struct { nonce uint64 // nonce for encryption keypair *Keypair // keypair for encryption peer *Peer // related peer + tos byte // Type of Service (DSCP + ECN bits) } func (device *Device) NewOutboundElement() *QueueOutboundElement { @@ -159,7 +164,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + err = peer.SendBuffer(packet, HandshakeDSCP) if err != nil { peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err) } @@ -197,7 +202,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + err = peer.SendBuffer(packet, HandshakeDSCP) if err != nil { peer.device.log.Error.Println(peer, "- Failed to send handshake response", err) } @@ -218,7 +223,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) var buff [MessageCookieReplySize]byte writer := bytes.NewBuffer(buff[:0]) binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) + device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint, HandshakeDSCP) if err != nil { device.log.Error.Println("Failed to send cookie reply:", err) } @@ -294,14 +299,14 @@ func (device *Device) RoutineReadFromTUN() { } dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] peer = device.allowedips.LookupIPv4(dst) - + elem.tos = elem.packet[1]; case ipv6.Version: if len(elem.packet) < ipv6.HeaderLen { continue } dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] peer = device.allowedips.LookupIPv6(dst) - + elem.tos = elem.packet[1]; default: logDebug.Println("Received packet with unknown IP version") } @@ -600,7 +605,7 @@ func (peer *Peer) RoutineSequentialSender() { // send message and return buffer to pool length := uint64(len(elem.packet)) - err := peer.SendBuffer(elem.packet) + err := peer.SendBuffer(elem.packet, elem.tos) device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) if err != nil {