1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 09:15:14 +01:00

receive: implement flush semantics

This commit is contained in:
Jason A. Donenfeld 2019-03-21 14:43:04 -06:00
parent 49ea0c9b1a
commit 6440f010ee
12 changed files with 165 additions and 116 deletions

View File

@ -482,6 +482,33 @@ func (device *Device) RoutineHandshake() {
} }
} }
func (peer *Peer) elementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, elem *QueueInboundElement) {
if !*shouldFlush {
select {
case <-peer.routines.stop:
stop = true
return
case elem, elemOk = <-peer.queue.inbound:
return
}
} else {
select {
case <-peer.routines.stop:
stop = true
return
case elem, elemOk = <-peer.queue.inbound:
return
default:
*shouldFlush = false
err := peer.device.tun.device.Flush()
if err != nil {
peer.device.log.Error.Printf("Unable to flush packets: %v", err)
}
return peer.elementStopOrFlush(shouldFlush)
}
}
}
func (peer *Peer) RoutineSequentialReceiver() { func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device device := peer.device
@ -491,6 +518,9 @@ func (peer *Peer) RoutineSequentialReceiver() {
var elem *QueueInboundElement var elem *QueueInboundElement
var ok bool var ok bool
var stop bool
shouldFlush := false
defer func() { defer func() {
logDebug.Println(peer, "- Routine: sequential receiver - stopped") logDebug.Println(peer, "- Routine: sequential receiver - stopped")
@ -516,126 +546,122 @@ func (peer *Peer) RoutineSequentialReceiver() {
elem = nil elem = nil
} }
select { stop, ok, elem = peer.elementStopOrFlush(&shouldFlush)
if stop || !ok {
case <-peer.routines.stop:
return return
}
case elem, ok = <-peer.queue.inbound: // wait for decryption
if !ok { elem.Lock()
return
}
// wait for decryption if elem.IsDropped() {
continue
}
elem.Lock() // check for replay
if elem.IsDropped() { if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue continue
} }
// check for replay // update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue
}
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
// check if using new keypair
if peer.ReceivedWithKeypair(elem.keypair) {
peer.timersHandshakeComplete()
select {
case peer.signals.newKeypairArrived <- struct{}{}:
default:
}
}
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// check for keepalive
if len(elem.packet) == 0 {
logDebug.Println(peer, "- Receiving keepalive packet")
continue
}
peer.timersDataReceived()
// verify source and strip padding
switch elem.packet[0] >> 4 {
case ipv4.Version:
// strip padding
if len(elem.packet) < ipv4.HeaderLen {
continue
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
}
elem.packet = elem.packet[:length]
// verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.LookupIPv4(src) != peer {
logInfo.Println(
"IPv4 packet with disallowed source address from",
peer,
)
continue
}
case ipv6.Version:
// strip padding
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.LookupIPv6(src) != peer {
logInfo.Println(
peer,
"sent packet with disallowed IPv6 source",
)
continue
}
// check if using new keypair
if peer.ReceivedWithKeypair(elem.keypair) {
peer.timersHandshakeComplete()
select {
case peer.signals.newKeypairArrived <- struct{}{}:
default: default:
logInfo.Println("Packet with invalid IP version from", peer) }
}
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// check for keepalive
if len(elem.packet) == 0 {
logDebug.Println(peer, "- Receiving keepalive packet")
continue
}
peer.timersDataReceived()
// verify source and strip padding
switch elem.packet[0] >> 4 {
case ipv4.Version:
// strip padding
if len(elem.packet) < ipv4.HeaderLen {
continue continue
} }
// write to tun device field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
offset := MessageTransportOffsetContent if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) continue
_, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
if err != nil && !device.isClosed.Get() {
logError.Println("Failed to write packet to TUN device:", err)
} }
elem.packet = elem.packet[:length]
// verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.LookupIPv4(src) != peer {
logInfo.Println(
"IPv4 packet with disallowed source address from",
peer,
)
continue
}
case ipv6.Version:
// strip padding
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.LookupIPv6(src) != peer {
logInfo.Println(
peer,
"sent packet with disallowed IPv6 source",
)
continue
}
default:
logInfo.Println("Packet with invalid IP version from", peer)
continue
}
// write to tun device
offset := MessageTransportOffsetContent
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
if err == nil {
shouldFlush = true
}
if err != nil && !device.isClosed.Get() {
logError.Println("Failed to write packet to TUN device:", err)
} }
} }
} }

View File

@ -21,6 +21,7 @@ type TUNDevice interface {
File() *os.File // returns the file descriptor of the device File() *os.File // returns the file descriptor of the device
Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
Flush() error // flush all previous writes to the device
MTU() (int, error) // returns the MTU of the device MTU() (int, error) // returns the MTU of the device
Name() (string, error) // fetches and returns the current name Name() (string, error) // fetches and returns the current name
Events() chan TUNEvent // returns a constant channel of events related to the device Events() chan TUNEvent // returns a constant channel of events related to the device

View File

@ -281,6 +281,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return tun.tunFile.Write(buff) return tun.tunFile.Write(buff)
} }
func (tun *NativeTun) Flush() error {
//TODO: can flushing be implemented by buffering and using sendmmsg?
return nil
}
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err2 error var err2 error
err1 := tun.tunFile.Close() err1 := tun.tunFile.Close()

View File

@ -406,6 +406,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return tun.tunFile.Write(buff) return tun.tunFile.Write(buff)
} }
func (tun *NativeTun) Flush() error {
//TODO: can flushing be implemented by buffering and using sendmmsg?
return nil
}
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err3 error var err3 error
err1 := tun.tunFile.Close() err1 := tun.tunFile.Close()

View File

@ -318,6 +318,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return tun.tunFile.Write(buff) return tun.tunFile.Write(buff)
} }
func (tun *NativeTun) Flush() error {
//TODO: can flushing be implemented by buffering and using sendmmsg?
return nil
}
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
select { select {
case err := <-tun.errors: case err := <-tun.errors:

View File

@ -237,6 +237,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return tun.tunFile.Write(buff) return tun.tunFile.Write(buff)
} }
func (tun *NativeTun) Flush() error {
//TODO: can flushing be implemented by buffering and using sendmmsg?
return nil
}
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err2 error var err2 error
err1 := tun.tunFile.Close() err1 := tun.tunFile.Close()

View File

@ -281,7 +281,11 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
// Note: flush() and putTunPacket() assume the caller comes only from a single thread; there's no locking. // Note: flush() and putTunPacket() assume the caller comes only from a single thread; there's no locking.
func (tun *NativeTun) flush() error { func (tun *NativeTun) Flush() error {
if tun.wrBuff.offset == 0 {
return nil
}
// Get TUN data pipe. // Get TUN data pipe.
file, err := tun.getTUN() file, err := tun.getTUN()
if err != nil { if err != nil {
@ -322,7 +326,7 @@ func (tun *NativeTun) putTunPacket(buff []byte) error {
if tun.wrBuff.packetNum >= packetExchangeMax || tun.wrBuff.offset+pSize >= packetExchangeSize { if tun.wrBuff.packetNum >= packetExchangeMax || tun.wrBuff.offset+pSize >= packetExchangeSize {
// Exchange buffer is full -> flush first. // Exchange buffer is full -> flush first.
err := tun.flush() err := tun.Flush()
if err != nil { if err != nil {
return err return err
} }
@ -345,9 +349,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
return len(buff) - offset, nil
// Flush write buffer.
return len(buff) - offset, tun.flush()
} }
// //