mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
receive: implement flush semantics
This commit is contained in:
parent
49ea0c9b1a
commit
6440f010ee
@ -41,4 +41,4 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -53,4 +53,4 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -177,4 +177,4 @@ func (device *Device) BindClose() error {
|
||||
err := unsafeCloseBind(device)
|
||||
device.net.Unlock()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -13,4 +13,4 @@ const (
|
||||
QueueHandshakeSize = 1024
|
||||
MaxSegmentSize = 2200
|
||||
PreallocatedBuffersPerPool = 4096
|
||||
)
|
||||
)
|
||||
|
@ -482,6 +482,33 @@ func (device *Device) RoutineHandshake() {
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) elementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, elem *QueueInboundElement) {
|
||||
if !*shouldFlush {
|
||||
select {
|
||||
case <-peer.routines.stop:
|
||||
stop = true
|
||||
return
|
||||
case elem, elemOk = <-peer.queue.inbound:
|
||||
return
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-peer.routines.stop:
|
||||
stop = true
|
||||
return
|
||||
case elem, elemOk = <-peer.queue.inbound:
|
||||
return
|
||||
default:
|
||||
*shouldFlush = false
|
||||
err := peer.device.tun.device.Flush()
|
||||
if err != nil {
|
||||
peer.device.log.Error.Printf("Unable to flush packets: %v", err)
|
||||
}
|
||||
return peer.elementStopOrFlush(shouldFlush)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) RoutineSequentialReceiver() {
|
||||
|
||||
device := peer.device
|
||||
@ -491,6 +518,9 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
|
||||
var elem *QueueInboundElement
|
||||
var ok bool
|
||||
var stop bool
|
||||
|
||||
shouldFlush := false
|
||||
|
||||
defer func() {
|
||||
logDebug.Println(peer, "- Routine: sequential receiver - stopped")
|
||||
@ -516,126 +546,122 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
elem = nil
|
||||
}
|
||||
|
||||
select {
|
||||
|
||||
case <-peer.routines.stop:
|
||||
stop, ok, elem = peer.elementStopOrFlush(&shouldFlush)
|
||||
if stop || !ok {
|
||||
return
|
||||
}
|
||||
|
||||
case elem, ok = <-peer.queue.inbound:
|
||||
// wait for decryption
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
elem.Lock()
|
||||
|
||||
// wait for decryption
|
||||
if elem.IsDropped() {
|
||||
continue
|
||||
}
|
||||
|
||||
elem.Lock()
|
||||
// check for replay
|
||||
|
||||
if elem.IsDropped() {
|
||||
continue
|
||||
}
|
||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
||||
continue
|
||||
}
|
||||
|
||||
// check for replay
|
||||
|
||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
||||
continue
|
||||
}
|
||||
|
||||
// update endpoint
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
|
||||
// check if using new keypair
|
||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||
peer.timersHandshakeComplete()
|
||||
select {
|
||||
case peer.signals.newKeypairArrived <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
|
||||
// check for keepalive
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
logDebug.Println(peer, "- Receiving keepalive packet")
|
||||
continue
|
||||
}
|
||||
peer.timersDataReceived()
|
||||
|
||||
// verify source and strip padding
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case ipv4.Version:
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = elem.packet[:length]
|
||||
|
||||
// verify IPv4 source
|
||||
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.allowedips.LookupIPv4(src) != peer {
|
||||
logInfo.Println(
|
||||
"IPv4 packet with disallowed source address from",
|
||||
peer,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
case ipv6.Version:
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += ipv6.HeaderLen
|
||||
if int(length) > len(elem.packet) {
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = elem.packet[:length]
|
||||
|
||||
// verify IPv6 source
|
||||
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.allowedips.LookupIPv6(src) != peer {
|
||||
logInfo.Println(
|
||||
peer,
|
||||
"sent packet with disallowed IPv6 source",
|
||||
)
|
||||
continue
|
||||
}
|
||||
// update endpoint
|
||||
peer.SetEndpointFromPacket(elem.endpoint)
|
||||
|
||||
// check if using new keypair
|
||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||
peer.timersHandshakeComplete()
|
||||
select {
|
||||
case peer.signals.newKeypairArrived <- struct{}{}:
|
||||
default:
|
||||
logInfo.Println("Packet with invalid IP version from", peer)
|
||||
}
|
||||
}
|
||||
|
||||
peer.keepKeyFreshReceiving()
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketReceived()
|
||||
|
||||
// check for keepalive
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
logDebug.Println(peer, "- Receiving keepalive packet")
|
||||
continue
|
||||
}
|
||||
peer.timersDataReceived()
|
||||
|
||||
// verify source and strip padding
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case ipv4.Version:
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
|
||||
// write to tun device
|
||||
|
||||
offset := MessageTransportOffsetContent
|
||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
||||
_, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
|
||||
if err != nil && !device.isClosed.Get() {
|
||||
logError.Println("Failed to write packet to TUN device:", err)
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = elem.packet[:length]
|
||||
|
||||
// verify IPv4 source
|
||||
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.allowedips.LookupIPv4(src) != peer {
|
||||
logInfo.Println(
|
||||
"IPv4 packet with disallowed source address from",
|
||||
peer,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
case ipv6.Version:
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += ipv6.HeaderLen
|
||||
if int(length) > len(elem.packet) {
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = elem.packet[:length]
|
||||
|
||||
// verify IPv6 source
|
||||
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.allowedips.LookupIPv6(src) != peer {
|
||||
logInfo.Println(
|
||||
peer,
|
||||
"sent packet with disallowed IPv6 source",
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
logInfo.Println("Packet with invalid IP version from", peer)
|
||||
continue
|
||||
}
|
||||
|
||||
// write to tun device
|
||||
|
||||
offset := MessageTransportOffsetContent
|
||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
||||
_, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
|
||||
if err == nil {
|
||||
shouldFlush = true
|
||||
}
|
||||
if err != nil && !device.isClosed.Get() {
|
||||
logError.Println("Failed to write packet to TUN device:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ type TUNDevice interface {
|
||||
File() *os.File // returns the file descriptor of the device
|
||||
Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
|
||||
Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
|
||||
Flush() error // flush all previous writes to the device
|
||||
MTU() (int, error) // returns the MTU of the device
|
||||
Name() (string, error) // fetches and returns the current name
|
||||
Events() chan TUNEvent // returns a constant channel of events related to the device
|
||||
|
@ -281,6 +281,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
return tun.tunFile.Write(buff)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
//TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
var err2 error
|
||||
err1 := tun.tunFile.Close()
|
||||
|
@ -406,6 +406,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
return tun.tunFile.Write(buff)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
//TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
var err3 error
|
||||
err1 := tun.tunFile.Close()
|
||||
|
@ -318,6 +318,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
return tun.tunFile.Write(buff)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
//TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
|
@ -237,6 +237,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
return tun.tunFile.Write(buff)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
//TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
var err2 error
|
||||
err1 := tun.tunFile.Close()
|
||||
|
@ -281,7 +281,11 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||
|
||||
// Note: flush() and putTunPacket() assume the caller comes only from a single thread; there's no locking.
|
||||
|
||||
func (tun *NativeTun) flush() error {
|
||||
func (tun *NativeTun) Flush() error {
|
||||
if tun.wrBuff.offset == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get TUN data pipe.
|
||||
file, err := tun.getTUN()
|
||||
if err != nil {
|
||||
@ -322,7 +326,7 @@ func (tun *NativeTun) putTunPacket(buff []byte) error {
|
||||
|
||||
if tun.wrBuff.packetNum >= packetExchangeMax || tun.wrBuff.offset+pSize >= packetExchangeSize {
|
||||
// Exchange buffer is full -> flush first.
|
||||
err := tun.flush()
|
||||
err := tun.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -345,9 +349,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Flush write buffer.
|
||||
return len(buff) - offset, tun.flush()
|
||||
return len(buff) - offset, nil
|
||||
}
|
||||
|
||||
//
|
||||
|
Loading…
Reference in New Issue
Block a user