1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

Begin incorporating new src cache into receive

This commit is contained in:
Mathias Hall-Andersen 2017-10-07 22:35:23 +02:00
parent c70f0c5da2
commit 2d856045a0
5 changed files with 164 additions and 97 deletions

View File

@ -3,7 +3,6 @@ package main
import (
"errors"
"net"
"time"
)
func parseEndpoint(s string) (*net.UDPAddr, error) {
@ -27,63 +26,96 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
return addr, err
}
func updateUDPConn(device *Device) error {
func ListenerClose(l *Listener) (err error) {
if l.active {
err = CloseIPv4Socket(l.sock)
l.active = false
}
return
}
func (l *Listener) Init() {
l.update = make(chan struct{}, 1)
ListenerClose(l)
}
func ListeningUpdate(device *Device) error {
netc := &device.net
netc.mutex.Lock()
defer netc.mutex.Unlock()
// close existing connection
// close existing sockets
if netc.conn != nil {
netc.conn.Close()
netc.conn = nil
// We need for that fd to be closed in all other go routines, which
// means we have to wait. TODO: find less horrible way of doing this.
time.Sleep(time.Second / 2)
if err := ListenerClose(&netc.ipv4); err != nil {
return err
}
// open new connection
if err := ListenerClose(&netc.ipv6); err != nil {
return err
}
// open new sockets
if device.tun.isUp.Get() {
// listen on new address
// listen on IPv4
conn, err := net.ListenUDP("udp", netc.addr)
if err != nil {
return err
{
list := &netc.ipv6
sock, port, err := CreateIPv4Socket(netc.port)
if err != nil {
return err
}
netc.port = port
list.sock = sock
list.active = true
if err := SetMark(list.sock, netc.fwmark); err != nil {
ListenerClose(list)
return err
}
signalSend(list.update)
}
// set fwmark
// listen on IPv6
err = SetMark(netc.conn, netc.fwmark)
if err != nil {
return err
{
list := &netc.ipv6
sock, port, err := CreateIPv6Socket(netc.port)
if err != nil {
return err
}
netc.port = port
list.sock = sock
list.active = true
if err := SetMark(list.sock, netc.fwmark); err != nil {
ListenerClose(list)
return err
}
signalSend(list.update)
}
// retrieve port (may have been chosen by kernel)
addr := conn.LocalAddr()
netc.conn = conn
netc.addr, _ = net.ResolveUDPAddr(
addr.Network(),
addr.String(),
)
// notify goroutines
signalSend(device.signal.newUDPConn)
// TODO: clear endpoint caches
}
return nil
}
func closeUDPConn(device *Device) {
func ListeningClose(device *Device) error {
netc := &device.net
netc.mutex.Lock()
if netc.conn != nil {
netc.conn.Close()
defer netc.mutex.Unlock()
if err := ListenerClose(&netc.ipv4); err != nil {
return err
}
netc.mutex.Unlock()
signalSend(device.signal.newUDPConn)
signalSend(netc.ipv4.update)
if err := ListenerClose(&netc.ipv6); err != nil {
return err
}
signalSend(netc.ipv6.update)
return nil
}

View File

@ -28,6 +28,7 @@ import "fmt"
type Endpoint struct {
// source (selected based on dst type)
// (could use RawSockaddrAny and unsafe)
// TODO: Merge
src6 unix.RawSockaddrInet6
src4 unix.RawSockaddrInet4
src4if int32
@ -35,8 +36,14 @@ type Endpoint struct {
dst unix.RawSockaddrAny
}
type IPv4Socket int
type IPv6Socket int
type Socket int
/* Returns a byte representation of the source field(s)
* for use in "under load" cookie computations.
*/
func (endpoint *Endpoint) Source() []byte {
return nil
}
func zoneToUint32(zone string) (uint32, error) {
if zone == "" {
@ -49,7 +56,7 @@ func zoneToUint32(zone string) (uint32, error) {
return uint32(n), err
}
func CreateIPv4Socket(port int) (IPv4Socket, error) {
func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
// create socket
@ -60,13 +67,16 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) {
)
if err != nil {
return -1, err
return -1, 0, err
}
addr := unix.SockaddrInet4{
Port: int(port),
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
@ -85,19 +95,23 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) {
return err
}
addr := unix.SockaddrInet4{
Port: port,
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
}
return IPv4Socket(fd), err
return Socket(fd), uint16(addr.Port), err
}
func CreateIPv6Socket(port int) (IPv6Socket, error) {
func CloseIPv4Socket(sock Socket) error {
return unix.Close(int(sock))
}
func CloseIPv6Socket(sock Socket) error {
return unix.Close(int(sock))
}
func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
// create socket
@ -108,11 +122,15 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) {
)
if err != nil {
return -1, err
return -1, 0, err
}
// set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error {
if err := unix.SetsockoptInt(
@ -142,16 +160,13 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) {
return err
}
addr := unix.SockaddrInet6{
Port: port,
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
}
return IPv6Socket(fd), err
return Socket(fd), uint16(addr.Port), err
}
func (end *Endpoint) ClearSrc() {
@ -311,7 +326,7 @@ func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error {
return errors.New("Unknown address family of source")
}
func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
// contruct message header
@ -360,7 +375,7 @@ func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
return int(size), nil
}
func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
// contruct message header
@ -383,7 +398,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
// recvmsg(sock, &mskhdr, 0)
_, _, errno := unix.Syscall(
size, _, errno := unix.Syscall(
unix.SYS_RECVMSG,
uintptr(sock),
uintptr(unsafe.Pointer(&msg)),
@ -391,7 +406,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
)
if errno != 0 {
return errno
return 0, errno
}
// update source cache
@ -403,21 +418,12 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
end.src6.Scope_id = cmsg.pktinfo.Ifindex
}
return nil
return int(size), nil
}
func SetMark(conn *net.UDPConn, value uint32) error {
if conn == nil {
return nil
}
file, err := conn.File()
if err != nil {
return err
}
func SetMark(sock Socket, value uint32) error {
return unix.SetsockoptInt(
int(file.Fd()),
int(sock),
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),

View File

@ -1,13 +1,18 @@
package main
import (
"net"
"runtime"
"sync"
"sync/atomic"
"time"
)
type Listener struct {
sock Socket
active bool
update chan struct{}
}
type Device struct {
log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
@ -22,8 +27,9 @@ type Device struct {
}
net struct {
mutex sync.RWMutex
addr *net.UDPAddr // UDP source address
conn *net.UDPConn // UDP "connection"
ipv4 Listener
ipv6 Listener
port uint16
fwmark uint32
}
mutex sync.RWMutex
@ -37,8 +43,9 @@ type Device struct {
handshake chan QueueHandshakeElement
}
signal struct {
stop chan struct{} // halts all go routines
newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine)
stop chan struct{} // halts all go routines
updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
}
underLoadUntil atomic.Value
ratelimiter Ratelimiter
@ -137,12 +144,16 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.log = NewLogger(logLevel, "("+tun.Name()+") ")
device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun
device.indices.Init()
device.net.ipv4.Init()
device.net.ipv6.Init()
device.ratelimiter.Init()
device.routingTable.Reset()
device.underLoadUntil.Store(time.Time{})
// setup pools
// setup buffer pool
device.pool.messageBuffers = sync.Pool{
New: func() interface{} {
@ -159,7 +170,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
// prepare signals
device.signal.stop = make(chan struct{})
device.signal.newUDPConn = make(chan struct{}, 1)
// start workers
@ -168,12 +178,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
go device.RoutineDecryption()
go device.RoutineHandshake()
}
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
go device.RoutineReadFromTUN()
go device.RoutineReceiveIncomming()
go device.RoutineReceiveIncomming(&device.net.ipv4)
go device.RoutineReceiveIncomming(&device.net.ipv6)
return device
}
@ -204,7 +213,7 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() {
device.RemoveAllPeers()
close(device.signal.stop)
closeUDPConn(device)
ListeningClose(device)
}
func (device *Device) WaitChannel() chan struct{} {

View File

@ -14,6 +14,7 @@ func printUsage() {
}
func main() {
test()
// parse arguments

View File

@ -13,10 +13,10 @@ import (
)
type QueueHandshakeElement struct {
msgType uint32
packet []byte
buffer *[MaxMessageSize]byte
source *net.UDPAddr
msgType uint32
packet []byte
endpoint Endpoint
buffer *[MaxMessageSize]byte
}
type QueueInboundElement struct {
@ -92,11 +92,22 @@ func (device *Device) addToHandshakeQueue(
}
}
func (device *Device) RoutineReceiveIncomming() {
func (device *Device) RoutineReceiveIncomming(IPVersion int) {
logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, started")
var listener *Listener
switch IPVersion {
case ipv4.Version:
listener = &device.net.ipv4
case ipv6.Version:
listener = &device.net.ipv6
default:
return
}
for {
// wait for new conn
@ -107,14 +118,15 @@ func (device *Device) RoutineReceiveIncomming() {
case <-device.signal.stop:
return
case <-device.signal.newUDPConn:
case <-listener.update:
// fetch connection
// fetch new socket
device.net.mutex.RLock()
conn := device.net.conn
sock := listener.sock
okay := listener.active
device.net.mutex.RUnlock()
if conn == nil {
if !okay {
continue
}
@ -124,11 +136,20 @@ func (device *Device) RoutineReceiveIncomming() {
buffer := device.GetMessageBuffer()
var size int
var err error
for {
// read next datagram
size, raddr, err := conn.ReadFromUDP(buffer[:])
var endpoint Endpoint
if IPVersion == ipv6.Version {
size, err = endpoint.ReceiveIPv4(sock, buffer[:])
} else {
size, err = endpoint.ReceiveIPv6(sock, buffer[:])
}
if err != nil {
break
@ -192,7 +213,7 @@ func (device *Device) RoutineReceiveIncomming() {
buffer = device.GetMessageBuffer()
continue
// otherwise it is a handshake related packet
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
@ -208,10 +229,10 @@ func (device *Device) RoutineReceiveIncomming() {
device.addToHandshakeQueue(
device.queue.handshake,
QueueHandshakeElement{
msgType: msgType,
buffer: buffer,
packet: packet,
source: raddr,
msgType: msgType,
buffer: buffer,
packet: packet,
endpoint: endpoint,
},
)
buffer = device.GetMessageBuffer()
@ -293,8 +314,6 @@ func (device *Device) RoutineHandshake() {
// unmarshal packet
logDebug.Println("Process cookie reply from:", elem.source.String())
var reply MessageCookieReply
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply)