1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 09:15:14 +01:00

Begin incorporating new src cache into receive

This commit is contained in:
Mathias Hall-Andersen 2017-10-07 22:35:23 +02:00
parent c70f0c5da2
commit 2d856045a0
5 changed files with 164 additions and 97 deletions

View File

@ -3,7 +3,6 @@ package main
import ( import (
"errors" "errors"
"net" "net"
"time"
) )
func parseEndpoint(s string) (*net.UDPAddr, error) { func parseEndpoint(s string) (*net.UDPAddr, error) {
@ -27,63 +26,96 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
return addr, err return addr, err
} }
func updateUDPConn(device *Device) error { func ListenerClose(l *Listener) (err error) {
if l.active {
err = CloseIPv4Socket(l.sock)
l.active = false
}
return
}
func (l *Listener) Init() {
l.update = make(chan struct{}, 1)
ListenerClose(l)
}
func ListeningUpdate(device *Device) error {
netc := &device.net netc := &device.net
netc.mutex.Lock() netc.mutex.Lock()
defer netc.mutex.Unlock() defer netc.mutex.Unlock()
// close existing connection // close existing sockets
if netc.conn != nil { if err := ListenerClose(&netc.ipv4); err != nil {
netc.conn.Close() return err
netc.conn = nil
// We need for that fd to be closed in all other go routines, which
// means we have to wait. TODO: find less horrible way of doing this.
time.Sleep(time.Second / 2)
} }
// open new connection if err := ListenerClose(&netc.ipv6); err != nil {
return err
}
// open new sockets
if device.tun.isUp.Get() { if device.tun.isUp.Get() {
// listen on new address // listen on IPv4
conn, err := net.ListenUDP("udp", netc.addr) {
list := &netc.ipv6
sock, port, err := CreateIPv4Socket(netc.port)
if err != nil { if err != nil {
return err return err
} }
netc.port = port
list.sock = sock
list.active = true
// set fwmark if err := SetMark(list.sock, netc.fwmark); err != nil {
ListenerClose(list)
return err
}
signalSend(list.update)
}
err = SetMark(netc.conn, netc.fwmark) // listen on IPv6
{
list := &netc.ipv6
sock, port, err := CreateIPv6Socket(netc.port)
if err != nil { if err != nil {
return err return err
} }
netc.port = port
list.sock = sock
list.active = true
// retrieve port (may have been chosen by kernel) if err := SetMark(list.sock, netc.fwmark); err != nil {
ListenerClose(list)
return err
}
signalSend(list.update)
}
addr := conn.LocalAddr() // TODO: clear endpoint caches
netc.conn = conn
netc.addr, _ = net.ResolveUDPAddr(
addr.Network(),
addr.String(),
)
// notify goroutines
signalSend(device.signal.newUDPConn)
} }
return nil return nil
} }
func closeUDPConn(device *Device) { func ListeningClose(device *Device) error {
netc := &device.net netc := &device.net
netc.mutex.Lock() netc.mutex.Lock()
if netc.conn != nil { defer netc.mutex.Unlock()
netc.conn.Close()
if err := ListenerClose(&netc.ipv4); err != nil {
return err
} }
netc.mutex.Unlock() signalSend(netc.ipv4.update)
signalSend(device.signal.newUDPConn)
if err := ListenerClose(&netc.ipv6); err != nil {
return err
}
signalSend(netc.ipv6.update)
return nil
} }

View File

@ -28,6 +28,7 @@ import "fmt"
type Endpoint struct { type Endpoint struct {
// source (selected based on dst type) // source (selected based on dst type)
// (could use RawSockaddrAny and unsafe) // (could use RawSockaddrAny and unsafe)
// TODO: Merge
src6 unix.RawSockaddrInet6 src6 unix.RawSockaddrInet6
src4 unix.RawSockaddrInet4 src4 unix.RawSockaddrInet4
src4if int32 src4if int32
@ -35,8 +36,14 @@ type Endpoint struct {
dst unix.RawSockaddrAny dst unix.RawSockaddrAny
} }
type IPv4Socket int type Socket int
type IPv6Socket int
/* Returns a byte representation of the source field(s)
* for use in "under load" cookie computations.
*/
func (endpoint *Endpoint) Source() []byte {
return nil
}
func zoneToUint32(zone string) (uint32, error) { func zoneToUint32(zone string) (uint32, error) {
if zone == "" { if zone == "" {
@ -49,7 +56,7 @@ func zoneToUint32(zone string) (uint32, error) {
return uint32(n), err return uint32(n), err
} }
func CreateIPv4Socket(port int) (IPv4Socket, error) { func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
// create socket // create socket
@ -60,13 +67,16 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) {
) )
if err != nil { if err != nil {
return -1, err return -1, 0, err
}
addr := unix.SockaddrInet4{
Port: int(port),
} }
// set sockopts and bind // set sockopts and bind
if err := func() error { if err := func() error {
if err := unix.SetsockoptInt( if err := unix.SetsockoptInt(
fd, fd,
unix.SOL_SOCKET, unix.SOL_SOCKET,
@ -85,19 +95,23 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) {
return err return err
} }
addr := unix.SockaddrInet4{
Port: port,
}
return unix.Bind(fd, &addr) return unix.Bind(fd, &addr)
}(); err != nil { }(); err != nil {
unix.Close(fd) unix.Close(fd)
} }
return IPv4Socket(fd), err return Socket(fd), uint16(addr.Port), err
} }
func CreateIPv6Socket(port int) (IPv6Socket, error) { func CloseIPv4Socket(sock Socket) error {
return unix.Close(int(sock))
}
func CloseIPv6Socket(sock Socket) error {
return unix.Close(int(sock))
}
func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
// create socket // create socket
@ -108,11 +122,15 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) {
) )
if err != nil { if err != nil {
return -1, err return -1, 0, err
} }
// set sockopts and bind // set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error { if err := func() error {
if err := unix.SetsockoptInt( if err := unix.SetsockoptInt(
@ -142,16 +160,13 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) {
return err return err
} }
addr := unix.SockaddrInet6{
Port: port,
}
return unix.Bind(fd, &addr) return unix.Bind(fd, &addr)
}(); err != nil { }(); err != nil {
unix.Close(fd) unix.Close(fd)
} }
return IPv6Socket(fd), err return Socket(fd), uint16(addr.Port), err
} }
func (end *Endpoint) ClearSrc() { func (end *Endpoint) ClearSrc() {
@ -311,7 +326,7 @@ func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error {
return errors.New("Unknown address family of source") return errors.New("Unknown address family of source")
} }
func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) { func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
// contruct message header // contruct message header
@ -360,7 +375,7 @@ func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
return int(size), nil return int(size), nil
} }
func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error { func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
// contruct message header // contruct message header
@ -383,7 +398,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
// recvmsg(sock, &mskhdr, 0) // recvmsg(sock, &mskhdr, 0)
_, _, errno := unix.Syscall( size, _, errno := unix.Syscall(
unix.SYS_RECVMSG, unix.SYS_RECVMSG,
uintptr(sock), uintptr(sock),
uintptr(unsafe.Pointer(&msg)), uintptr(unsafe.Pointer(&msg)),
@ -391,7 +406,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
) )
if errno != 0 { if errno != 0 {
return errno return 0, errno
} }
// update source cache // update source cache
@ -403,21 +418,12 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
end.src6.Scope_id = cmsg.pktinfo.Ifindex end.src6.Scope_id = cmsg.pktinfo.Ifindex
} }
return nil return int(size), nil
} }
func SetMark(conn *net.UDPConn, value uint32) error { func SetMark(sock Socket, value uint32) error {
if conn == nil {
return nil
}
file, err := conn.File()
if err != nil {
return err
}
return unix.SetsockoptInt( return unix.SetsockoptInt(
int(file.Fd()), int(sock),
unix.SOL_SOCKET, unix.SOL_SOCKET,
unix.SO_MARK, unix.SO_MARK,
int(value), int(value),

View File

@ -1,13 +1,18 @@
package main package main
import ( import (
"net"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
) )
type Listener struct {
sock Socket
active bool
update chan struct{}
}
type Device struct { type Device struct {
log *Logger // collection of loggers for levels log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers idCounter uint // for assigning debug ids to peers
@ -22,8 +27,9 @@ type Device struct {
} }
net struct { net struct {
mutex sync.RWMutex mutex sync.RWMutex
addr *net.UDPAddr // UDP source address ipv4 Listener
conn *net.UDPConn // UDP "connection" ipv6 Listener
port uint16
fwmark uint32 fwmark uint32
} }
mutex sync.RWMutex mutex sync.RWMutex
@ -38,7 +44,8 @@ type Device struct {
} }
signal struct { signal struct {
stop chan struct{} // halts all go routines stop chan struct{} // halts all go routines
newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine) updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
} }
underLoadUntil atomic.Value underLoadUntil atomic.Value
ratelimiter Ratelimiter ratelimiter Ratelimiter
@ -137,12 +144,16 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.log = NewLogger(logLevel, "("+tun.Name()+") ") device.log = NewLogger(logLevel, "("+tun.Name()+") ")
device.peers = make(map[NoisePublicKey]*Peer) device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun device.tun.device = tun
device.indices.Init() device.indices.Init()
device.net.ipv4.Init()
device.net.ipv6.Init()
device.ratelimiter.Init() device.ratelimiter.Init()
device.routingTable.Reset() device.routingTable.Reset()
device.underLoadUntil.Store(time.Time{}) device.underLoadUntil.Store(time.Time{})
// setup pools // setup buffer pool
device.pool.messageBuffers = sync.Pool{ device.pool.messageBuffers = sync.Pool{
New: func() interface{} { New: func() interface{} {
@ -159,7 +170,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
// prepare signals // prepare signals
device.signal.stop = make(chan struct{}) device.signal.stop = make(chan struct{})
device.signal.newUDPConn = make(chan struct{}, 1)
// start workers // start workers
@ -168,12 +178,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
go device.RoutineDecryption() go device.RoutineDecryption()
go device.RoutineHandshake() go device.RoutineHandshake()
} }
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader() go device.RoutineTUNEventReader()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
go device.RoutineReadFromTUN() go device.RoutineReceiveIncomming(&device.net.ipv4)
go device.RoutineReceiveIncomming() go device.RoutineReceiveIncomming(&device.net.ipv6)
return device return device
} }
@ -204,7 +213,7 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() { func (device *Device) Close() {
device.RemoveAllPeers() device.RemoveAllPeers()
close(device.signal.stop) close(device.signal.stop)
closeUDPConn(device) ListeningClose(device)
} }
func (device *Device) WaitChannel() chan struct{} { func (device *Device) WaitChannel() chan struct{} {

View File

@ -14,6 +14,7 @@ func printUsage() {
} }
func main() { func main() {
test()
// parse arguments // parse arguments

View File

@ -15,8 +15,8 @@ import (
type QueueHandshakeElement struct { type QueueHandshakeElement struct {
msgType uint32 msgType uint32
packet []byte packet []byte
endpoint Endpoint
buffer *[MaxMessageSize]byte buffer *[MaxMessageSize]byte
source *net.UDPAddr
} }
type QueueInboundElement struct { type QueueInboundElement struct {
@ -92,11 +92,22 @@ func (device *Device) addToHandshakeQueue(
} }
} }
func (device *Device) RoutineReceiveIncomming() { func (device *Device) RoutineReceiveIncomming(IPVersion int) {
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, started") logDebug.Println("Routine, receive incomming, started")
var listener *Listener
switch IPVersion {
case ipv4.Version:
listener = &device.net.ipv4
case ipv6.Version:
listener = &device.net.ipv6
default:
return
}
for { for {
// wait for new conn // wait for new conn
@ -107,14 +118,15 @@ func (device *Device) RoutineReceiveIncomming() {
case <-device.signal.stop: case <-device.signal.stop:
return return
case <-device.signal.newUDPConn: case <-listener.update:
// fetch connection // fetch new socket
device.net.mutex.RLock() device.net.mutex.RLock()
conn := device.net.conn sock := listener.sock
okay := listener.active
device.net.mutex.RUnlock() device.net.mutex.RUnlock()
if conn == nil { if !okay {
continue continue
} }
@ -124,11 +136,20 @@ func (device *Device) RoutineReceiveIncomming() {
buffer := device.GetMessageBuffer() buffer := device.GetMessageBuffer()
var size int
var err error
for { for {
// read next datagram // read next datagram
size, raddr, err := conn.ReadFromUDP(buffer[:]) var endpoint Endpoint
if IPVersion == ipv6.Version {
size, err = endpoint.ReceiveIPv4(sock, buffer[:])
} else {
size, err = endpoint.ReceiveIPv6(sock, buffer[:])
}
if err != nil { if err != nil {
break break
@ -192,7 +213,7 @@ func (device *Device) RoutineReceiveIncomming() {
buffer = device.GetMessageBuffer() buffer = device.GetMessageBuffer()
continue continue
// otherwise it is a handshake related packet // otherwise it is a fixed size & handshake related packet
case MessageInitiationType: case MessageInitiationType:
okay = len(packet) == MessageInitiationSize okay = len(packet) == MessageInitiationSize
@ -211,7 +232,7 @@ func (device *Device) RoutineReceiveIncomming() {
msgType: msgType, msgType: msgType,
buffer: buffer, buffer: buffer,
packet: packet, packet: packet,
source: raddr, endpoint: endpoint,
}, },
) )
buffer = device.GetMessageBuffer() buffer = device.GetMessageBuffer()
@ -293,8 +314,6 @@ func (device *Device) RoutineHandshake() {
// unmarshal packet // unmarshal packet
logDebug.Println("Process cookie reply from:", elem.source.String())
var reply MessageCookieReply var reply MessageCookieReply
reader := bytes.NewReader(elem.packet) reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply) err := binary.Read(reader, binary.LittleEndian, &reply)