mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
Begin incorporating new src cache into receive
This commit is contained in:
parent
c70f0c5da2
commit
2d856045a0
104
src/conn.go
104
src/conn.go
@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||
@ -27,63 +26,96 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||
return addr, err
|
||||
}
|
||||
|
||||
func updateUDPConn(device *Device) error {
|
||||
func ListenerClose(l *Listener) (err error) {
|
||||
if l.active {
|
||||
err = CloseIPv4Socket(l.sock)
|
||||
l.active = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Init() {
|
||||
l.update = make(chan struct{}, 1)
|
||||
ListenerClose(l)
|
||||
}
|
||||
|
||||
func ListeningUpdate(device *Device) error {
|
||||
netc := &device.net
|
||||
netc.mutex.Lock()
|
||||
defer netc.mutex.Unlock()
|
||||
|
||||
// close existing connection
|
||||
// close existing sockets
|
||||
|
||||
if netc.conn != nil {
|
||||
netc.conn.Close()
|
||||
netc.conn = nil
|
||||
|
||||
// We need for that fd to be closed in all other go routines, which
|
||||
// means we have to wait. TODO: find less horrible way of doing this.
|
||||
time.Sleep(time.Second / 2)
|
||||
if err := ListenerClose(&netc.ipv4); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// open new connection
|
||||
if err := ListenerClose(&netc.ipv6); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// open new sockets
|
||||
|
||||
if device.tun.isUp.Get() {
|
||||
|
||||
// listen on new address
|
||||
// listen on IPv4
|
||||
|
||||
conn, err := net.ListenUDP("udp", netc.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
{
|
||||
list := &netc.ipv6
|
||||
sock, port, err := CreateIPv4Socket(netc.port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
netc.port = port
|
||||
list.sock = sock
|
||||
list.active = true
|
||||
|
||||
if err := SetMark(list.sock, netc.fwmark); err != nil {
|
||||
ListenerClose(list)
|
||||
return err
|
||||
}
|
||||
signalSend(list.update)
|
||||
}
|
||||
|
||||
// set fwmark
|
||||
// listen on IPv6
|
||||
|
||||
err = SetMark(netc.conn, netc.fwmark)
|
||||
if err != nil {
|
||||
return err
|
||||
{
|
||||
list := &netc.ipv6
|
||||
sock, port, err := CreateIPv6Socket(netc.port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
netc.port = port
|
||||
list.sock = sock
|
||||
list.active = true
|
||||
|
||||
if err := SetMark(list.sock, netc.fwmark); err != nil {
|
||||
ListenerClose(list)
|
||||
return err
|
||||
}
|
||||
signalSend(list.update)
|
||||
}
|
||||
|
||||
// retrieve port (may have been chosen by kernel)
|
||||
|
||||
addr := conn.LocalAddr()
|
||||
netc.conn = conn
|
||||
netc.addr, _ = net.ResolveUDPAddr(
|
||||
addr.Network(),
|
||||
addr.String(),
|
||||
)
|
||||
|
||||
// notify goroutines
|
||||
|
||||
signalSend(device.signal.newUDPConn)
|
||||
// TODO: clear endpoint caches
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeUDPConn(device *Device) {
|
||||
func ListeningClose(device *Device) error {
|
||||
netc := &device.net
|
||||
netc.mutex.Lock()
|
||||
if netc.conn != nil {
|
||||
netc.conn.Close()
|
||||
defer netc.mutex.Unlock()
|
||||
|
||||
if err := ListenerClose(&netc.ipv4); err != nil {
|
||||
return err
|
||||
}
|
||||
netc.mutex.Unlock()
|
||||
signalSend(device.signal.newUDPConn)
|
||||
signalSend(netc.ipv4.update)
|
||||
|
||||
if err := ListenerClose(&netc.ipv6); err != nil {
|
||||
return err
|
||||
}
|
||||
signalSend(netc.ipv6.update)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -28,6 +28,7 @@ import "fmt"
|
||||
type Endpoint struct {
|
||||
// source (selected based on dst type)
|
||||
// (could use RawSockaddrAny and unsafe)
|
||||
// TODO: Merge
|
||||
src6 unix.RawSockaddrInet6
|
||||
src4 unix.RawSockaddrInet4
|
||||
src4if int32
|
||||
@ -35,8 +36,14 @@ type Endpoint struct {
|
||||
dst unix.RawSockaddrAny
|
||||
}
|
||||
|
||||
type IPv4Socket int
|
||||
type IPv6Socket int
|
||||
type Socket int
|
||||
|
||||
/* Returns a byte representation of the source field(s)
|
||||
* for use in "under load" cookie computations.
|
||||
*/
|
||||
func (endpoint *Endpoint) Source() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func zoneToUint32(zone string) (uint32, error) {
|
||||
if zone == "" {
|
||||
@ -49,7 +56,7 @@ func zoneToUint32(zone string) (uint32, error) {
|
||||
return uint32(n), err
|
||||
}
|
||||
|
||||
func CreateIPv4Socket(port int) (IPv4Socket, error) {
|
||||
func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
|
||||
|
||||
// create socket
|
||||
|
||||
@ -60,13 +67,16 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) {
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return -1, err
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
addr := unix.SockaddrInet4{
|
||||
Port: int(port),
|
||||
}
|
||||
|
||||
// set sockopts and bind
|
||||
|
||||
if err := func() error {
|
||||
|
||||
if err := unix.SetsockoptInt(
|
||||
fd,
|
||||
unix.SOL_SOCKET,
|
||||
@ -85,19 +95,23 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) {
|
||||
return err
|
||||
}
|
||||
|
||||
addr := unix.SockaddrInet4{
|
||||
Port: port,
|
||||
}
|
||||
return unix.Bind(fd, &addr)
|
||||
|
||||
}(); err != nil {
|
||||
unix.Close(fd)
|
||||
}
|
||||
|
||||
return IPv4Socket(fd), err
|
||||
return Socket(fd), uint16(addr.Port), err
|
||||
}
|
||||
|
||||
func CreateIPv6Socket(port int) (IPv6Socket, error) {
|
||||
func CloseIPv4Socket(sock Socket) error {
|
||||
return unix.Close(int(sock))
|
||||
}
|
||||
|
||||
func CloseIPv6Socket(sock Socket) error {
|
||||
return unix.Close(int(sock))
|
||||
}
|
||||
|
||||
func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
|
||||
|
||||
// create socket
|
||||
|
||||
@ -108,11 +122,15 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) {
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return -1, err
|
||||
return -1, 0, err
|
||||
}
|
||||
|
||||
// set sockopts and bind
|
||||
|
||||
addr := unix.SockaddrInet6{
|
||||
Port: int(port),
|
||||
}
|
||||
|
||||
if err := func() error {
|
||||
|
||||
if err := unix.SetsockoptInt(
|
||||
@ -142,16 +160,13 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) {
|
||||
return err
|
||||
}
|
||||
|
||||
addr := unix.SockaddrInet6{
|
||||
Port: port,
|
||||
}
|
||||
return unix.Bind(fd, &addr)
|
||||
|
||||
}(); err != nil {
|
||||
unix.Close(fd)
|
||||
}
|
||||
|
||||
return IPv6Socket(fd), err
|
||||
return Socket(fd), uint16(addr.Port), err
|
||||
}
|
||||
|
||||
func (end *Endpoint) ClearSrc() {
|
||||
@ -311,7 +326,7 @@ func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error {
|
||||
return errors.New("Unknown address family of source")
|
||||
}
|
||||
|
||||
func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
|
||||
func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
|
||||
|
||||
// contruct message header
|
||||
|
||||
@ -360,7 +375,7 @@ func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
|
||||
return int(size), nil
|
||||
}
|
||||
|
||||
func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
|
||||
func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
|
||||
|
||||
// contruct message header
|
||||
|
||||
@ -383,7 +398,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
|
||||
|
||||
// recvmsg(sock, &mskhdr, 0)
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
size, _, errno := unix.Syscall(
|
||||
unix.SYS_RECVMSG,
|
||||
uintptr(sock),
|
||||
uintptr(unsafe.Pointer(&msg)),
|
||||
@ -391,7 +406,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return errno
|
||||
return 0, errno
|
||||
}
|
||||
|
||||
// update source cache
|
||||
@ -403,21 +418,12 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
|
||||
end.src6.Scope_id = cmsg.pktinfo.Ifindex
|
||||
}
|
||||
|
||||
return nil
|
||||
return int(size), nil
|
||||
}
|
||||
|
||||
func SetMark(conn *net.UDPConn, value uint32) error {
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
file, err := conn.File()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func SetMark(sock Socket, value uint32) error {
|
||||
return unix.SetsockoptInt(
|
||||
int(file.Fd()),
|
||||
int(sock),
|
||||
unix.SOL_SOCKET,
|
||||
unix.SO_MARK,
|
||||
int(value),
|
||||
|
@ -1,13 +1,18 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
sock Socket
|
||||
active bool
|
||||
update chan struct{}
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
log *Logger // collection of loggers for levels
|
||||
idCounter uint // for assigning debug ids to peers
|
||||
@ -22,8 +27,9 @@ type Device struct {
|
||||
}
|
||||
net struct {
|
||||
mutex sync.RWMutex
|
||||
addr *net.UDPAddr // UDP source address
|
||||
conn *net.UDPConn // UDP "connection"
|
||||
ipv4 Listener
|
||||
ipv6 Listener
|
||||
port uint16
|
||||
fwmark uint32
|
||||
}
|
||||
mutex sync.RWMutex
|
||||
@ -37,8 +43,9 @@ type Device struct {
|
||||
handshake chan QueueHandshakeElement
|
||||
}
|
||||
signal struct {
|
||||
stop chan struct{} // halts all go routines
|
||||
newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine)
|
||||
stop chan struct{} // halts all go routines
|
||||
updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
|
||||
updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
|
||||
}
|
||||
underLoadUntil atomic.Value
|
||||
ratelimiter Ratelimiter
|
||||
@ -137,12 +144,16 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
||||
device.log = NewLogger(logLevel, "("+tun.Name()+") ")
|
||||
device.peers = make(map[NoisePublicKey]*Peer)
|
||||
device.tun.device = tun
|
||||
|
||||
device.indices.Init()
|
||||
device.net.ipv4.Init()
|
||||
device.net.ipv6.Init()
|
||||
device.ratelimiter.Init()
|
||||
|
||||
device.routingTable.Reset()
|
||||
device.underLoadUntil.Store(time.Time{})
|
||||
|
||||
// setup pools
|
||||
// setup buffer pool
|
||||
|
||||
device.pool.messageBuffers = sync.Pool{
|
||||
New: func() interface{} {
|
||||
@ -159,7 +170,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
||||
// prepare signals
|
||||
|
||||
device.signal.stop = make(chan struct{})
|
||||
device.signal.newUDPConn = make(chan struct{}, 1)
|
||||
|
||||
// start workers
|
||||
|
||||
@ -168,12 +178,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
||||
go device.RoutineDecryption()
|
||||
go device.RoutineHandshake()
|
||||
}
|
||||
|
||||
go device.RoutineReadFromTUN()
|
||||
go device.RoutineTUNEventReader()
|
||||
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
||||
go device.RoutineReadFromTUN()
|
||||
go device.RoutineReceiveIncomming()
|
||||
|
||||
go device.RoutineReceiveIncomming(&device.net.ipv4)
|
||||
go device.RoutineReceiveIncomming(&device.net.ipv6)
|
||||
return device
|
||||
}
|
||||
|
||||
@ -204,7 +213,7 @@ func (device *Device) RemoveAllPeers() {
|
||||
func (device *Device) Close() {
|
||||
device.RemoveAllPeers()
|
||||
close(device.signal.stop)
|
||||
closeUDPConn(device)
|
||||
ListeningClose(device)
|
||||
}
|
||||
|
||||
func (device *Device) WaitChannel() chan struct{} {
|
||||
|
@ -14,6 +14,7 @@ func printUsage() {
|
||||
}
|
||||
|
||||
func main() {
|
||||
test()
|
||||
|
||||
// parse arguments
|
||||
|
||||
|
@ -13,10 +13,10 @@ import (
|
||||
)
|
||||
|
||||
type QueueHandshakeElement struct {
|
||||
msgType uint32
|
||||
packet []byte
|
||||
buffer *[MaxMessageSize]byte
|
||||
source *net.UDPAddr
|
||||
msgType uint32
|
||||
packet []byte
|
||||
endpoint Endpoint
|
||||
buffer *[MaxMessageSize]byte
|
||||
}
|
||||
|
||||
type QueueInboundElement struct {
|
||||
@ -92,11 +92,22 @@ func (device *Device) addToHandshakeQueue(
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) RoutineReceiveIncomming() {
|
||||
func (device *Device) RoutineReceiveIncomming(IPVersion int) {
|
||||
|
||||
logDebug := device.log.Debug
|
||||
logDebug.Println("Routine, receive incomming, started")
|
||||
|
||||
var listener *Listener
|
||||
|
||||
switch IPVersion {
|
||||
case ipv4.Version:
|
||||
listener = &device.net.ipv4
|
||||
case ipv6.Version:
|
||||
listener = &device.net.ipv6
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
|
||||
// wait for new conn
|
||||
@ -107,14 +118,15 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
case <-device.signal.stop:
|
||||
return
|
||||
|
||||
case <-device.signal.newUDPConn:
|
||||
case <-listener.update:
|
||||
|
||||
// fetch connection
|
||||
// fetch new socket
|
||||
|
||||
device.net.mutex.RLock()
|
||||
conn := device.net.conn
|
||||
sock := listener.sock
|
||||
okay := listener.active
|
||||
device.net.mutex.RUnlock()
|
||||
if conn == nil {
|
||||
if !okay {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -124,11 +136,20 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
|
||||
buffer := device.GetMessageBuffer()
|
||||
|
||||
var size int
|
||||
var err error
|
||||
|
||||
for {
|
||||
|
||||
// read next datagram
|
||||
|
||||
size, raddr, err := conn.ReadFromUDP(buffer[:])
|
||||
var endpoint Endpoint
|
||||
|
||||
if IPVersion == ipv6.Version {
|
||||
size, err = endpoint.ReceiveIPv4(sock, buffer[:])
|
||||
} else {
|
||||
size, err = endpoint.ReceiveIPv6(sock, buffer[:])
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
@ -192,7 +213,7 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
buffer = device.GetMessageBuffer()
|
||||
continue
|
||||
|
||||
// otherwise it is a handshake related packet
|
||||
// otherwise it is a fixed size & handshake related packet
|
||||
|
||||
case MessageInitiationType:
|
||||
okay = len(packet) == MessageInitiationSize
|
||||
@ -208,10 +229,10 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
device.addToHandshakeQueue(
|
||||
device.queue.handshake,
|
||||
QueueHandshakeElement{
|
||||
msgType: msgType,
|
||||
buffer: buffer,
|
||||
packet: packet,
|
||||
source: raddr,
|
||||
msgType: msgType,
|
||||
buffer: buffer,
|
||||
packet: packet,
|
||||
endpoint: endpoint,
|
||||
},
|
||||
)
|
||||
buffer = device.GetMessageBuffer()
|
||||
@ -293,8 +314,6 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
// unmarshal packet
|
||||
|
||||
logDebug.Println("Process cookie reply from:", elem.source.String())
|
||||
|
||||
var reply MessageCookieReply
|
||||
reader := bytes.NewReader(elem.packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
||||
|
Loading…
Reference in New Issue
Block a user