diff --git a/src/conn.go b/src/conn.go index 60cd789..61be3bf 100644 --- a/src/conn.go +++ b/src/conn.go @@ -3,7 +3,6 @@ package main import ( "errors" "net" - "time" ) func parseEndpoint(s string) (*net.UDPAddr, error) { @@ -27,63 +26,96 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { return addr, err } -func updateUDPConn(device *Device) error { +func ListenerClose(l *Listener) (err error) { + if l.active { + err = CloseIPv4Socket(l.sock) + l.active = false + } + return +} + +func (l *Listener) Init() { + l.update = make(chan struct{}, 1) + ListenerClose(l) +} + +func ListeningUpdate(device *Device) error { netc := &device.net netc.mutex.Lock() defer netc.mutex.Unlock() - // close existing connection + // close existing sockets - if netc.conn != nil { - netc.conn.Close() - netc.conn = nil - - // We need for that fd to be closed in all other go routines, which - // means we have to wait. TODO: find less horrible way of doing this. - time.Sleep(time.Second / 2) + if err := ListenerClose(&netc.ipv4); err != nil { + return err } - // open new connection + if err := ListenerClose(&netc.ipv6); err != nil { + return err + } + + // open new sockets if device.tun.isUp.Get() { - // listen on new address + // listen on IPv4 - conn, err := net.ListenUDP("udp", netc.addr) - if err != nil { - return err + { + list := &netc.ipv6 + sock, port, err := CreateIPv4Socket(netc.port) + if err != nil { + return err + } + netc.port = port + list.sock = sock + list.active = true + + if err := SetMark(list.sock, netc.fwmark); err != nil { + ListenerClose(list) + return err + } + signalSend(list.update) } - // set fwmark + // listen on IPv6 - err = SetMark(netc.conn, netc.fwmark) - if err != nil { - return err + { + list := &netc.ipv6 + sock, port, err := CreateIPv6Socket(netc.port) + if err != nil { + return err + } + netc.port = port + list.sock = sock + list.active = true + + if err := SetMark(list.sock, netc.fwmark); err != nil { + ListenerClose(list) + return err + } + signalSend(list.update) } - // retrieve port (may have been chosen by kernel) - - addr := conn.LocalAddr() - netc.conn = conn - netc.addr, _ = net.ResolveUDPAddr( - addr.Network(), - addr.String(), - ) - - // notify goroutines - - signalSend(device.signal.newUDPConn) + // TODO: clear endpoint caches } return nil } -func closeUDPConn(device *Device) { +func ListeningClose(device *Device) error { netc := &device.net netc.mutex.Lock() - if netc.conn != nil { - netc.conn.Close() + defer netc.mutex.Unlock() + + if err := ListenerClose(&netc.ipv4); err != nil { + return err } - netc.mutex.Unlock() - signalSend(device.signal.newUDPConn) + signalSend(netc.ipv4.update) + + if err := ListenerClose(&netc.ipv6); err != nil { + return err + } + signalSend(netc.ipv6.update) + + return nil } diff --git a/src/conn_linux.go b/src/conn_linux.go index 64447a5..034fb8b 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -28,6 +28,7 @@ import "fmt" type Endpoint struct { // source (selected based on dst type) // (could use RawSockaddrAny and unsafe) + // TODO: Merge src6 unix.RawSockaddrInet6 src4 unix.RawSockaddrInet4 src4if int32 @@ -35,8 +36,14 @@ type Endpoint struct { dst unix.RawSockaddrAny } -type IPv4Socket int -type IPv6Socket int +type Socket int + +/* Returns a byte representation of the source field(s) + * for use in "under load" cookie computations. + */ +func (endpoint *Endpoint) Source() []byte { + return nil +} func zoneToUint32(zone string) (uint32, error) { if zone == "" { @@ -49,7 +56,7 @@ func zoneToUint32(zone string) (uint32, error) { return uint32(n), err } -func CreateIPv4Socket(port int) (IPv4Socket, error) { +func CreateIPv4Socket(port uint16) (Socket, uint16, error) { // create socket @@ -60,13 +67,16 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) { ) if err != nil { - return -1, err + return -1, 0, err + } + + addr := unix.SockaddrInet4{ + Port: int(port), } // set sockopts and bind if err := func() error { - if err := unix.SetsockoptInt( fd, unix.SOL_SOCKET, @@ -85,19 +95,23 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) { return err } - addr := unix.SockaddrInet4{ - Port: port, - } return unix.Bind(fd, &addr) - }(); err != nil { unix.Close(fd) } - return IPv4Socket(fd), err + return Socket(fd), uint16(addr.Port), err } -func CreateIPv6Socket(port int) (IPv6Socket, error) { +func CloseIPv4Socket(sock Socket) error { + return unix.Close(int(sock)) +} + +func CloseIPv6Socket(sock Socket) error { + return unix.Close(int(sock)) +} + +func CreateIPv6Socket(port uint16) (Socket, uint16, error) { // create socket @@ -108,11 +122,15 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) { ) if err != nil { - return -1, err + return -1, 0, err } // set sockopts and bind + addr := unix.SockaddrInet6{ + Port: int(port), + } + if err := func() error { if err := unix.SetsockoptInt( @@ -142,16 +160,13 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) { return err } - addr := unix.SockaddrInet6{ - Port: port, - } return unix.Bind(fd, &addr) }(); err != nil { unix.Close(fd) } - return IPv6Socket(fd), err + return Socket(fd), uint16(addr.Port), err } func (end *Endpoint) ClearSrc() { @@ -311,7 +326,7 @@ func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error { return errors.New("Unknown address family of source") } -func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) { +func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) { // contruct message header @@ -360,7 +375,7 @@ func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) { return int(size), nil } -func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error { +func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) { // contruct message header @@ -383,7 +398,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error { // recvmsg(sock, &mskhdr, 0) - _, _, errno := unix.Syscall( + size, _, errno := unix.Syscall( unix.SYS_RECVMSG, uintptr(sock), uintptr(unsafe.Pointer(&msg)), @@ -391,7 +406,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error { ) if errno != 0 { - return errno + return 0, errno } // update source cache @@ -403,21 +418,12 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error { end.src6.Scope_id = cmsg.pktinfo.Ifindex } - return nil + return int(size), nil } -func SetMark(conn *net.UDPConn, value uint32) error { - if conn == nil { - return nil - } - - file, err := conn.File() - if err != nil { - return err - } - +func SetMark(sock Socket, value uint32) error { return unix.SetsockoptInt( - int(file.Fd()), + int(sock), unix.SOL_SOCKET, unix.SO_MARK, int(value), diff --git a/src/device.go b/src/device.go index 61c87bc..509e6a7 100644 --- a/src/device.go +++ b/src/device.go @@ -1,13 +1,18 @@ package main import ( - "net" "runtime" "sync" "sync/atomic" "time" ) +type Listener struct { + sock Socket + active bool + update chan struct{} +} + type Device struct { log *Logger // collection of loggers for levels idCounter uint // for assigning debug ids to peers @@ -22,8 +27,9 @@ type Device struct { } net struct { mutex sync.RWMutex - addr *net.UDPAddr // UDP source address - conn *net.UDPConn // UDP "connection" + ipv4 Listener + ipv6 Listener + port uint16 fwmark uint32 } mutex sync.RWMutex @@ -37,8 +43,9 @@ type Device struct { handshake chan QueueHandshakeElement } signal struct { - stop chan struct{} // halts all go routines - newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine) + stop chan struct{} // halts all go routines + updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine) + updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine) } underLoadUntil atomic.Value ratelimiter Ratelimiter @@ -137,12 +144,16 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.log = NewLogger(logLevel, "("+tun.Name()+") ") device.peers = make(map[NoisePublicKey]*Peer) device.tun.device = tun + device.indices.Init() + device.net.ipv4.Init() + device.net.ipv6.Init() device.ratelimiter.Init() + device.routingTable.Reset() device.underLoadUntil.Store(time.Time{}) - // setup pools + // setup buffer pool device.pool.messageBuffers = sync.Pool{ New: func() interface{} { @@ -159,7 +170,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { // prepare signals device.signal.stop = make(chan struct{}) - device.signal.newUDPConn = make(chan struct{}, 1) // start workers @@ -168,12 +178,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { go device.RoutineDecryption() go device.RoutineHandshake() } - + go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) - go device.RoutineReadFromTUN() - go device.RoutineReceiveIncomming() - + go device.RoutineReceiveIncomming(&device.net.ipv4) + go device.RoutineReceiveIncomming(&device.net.ipv6) return device } @@ -204,7 +213,7 @@ func (device *Device) RemoveAllPeers() { func (device *Device) Close() { device.RemoveAllPeers() close(device.signal.stop) - closeUDPConn(device) + ListeningClose(device) } func (device *Device) WaitChannel() chan struct{} { diff --git a/src/main.go b/src/main.go index 196a4c6..a05dbba 100644 --- a/src/main.go +++ b/src/main.go @@ -14,6 +14,7 @@ func printUsage() { } func main() { + test() // parse arguments diff --git a/src/receive.go b/src/receive.go index 52c2718..60c0f2c 100644 --- a/src/receive.go +++ b/src/receive.go @@ -13,10 +13,10 @@ import ( ) type QueueHandshakeElement struct { - msgType uint32 - packet []byte - buffer *[MaxMessageSize]byte - source *net.UDPAddr + msgType uint32 + packet []byte + endpoint Endpoint + buffer *[MaxMessageSize]byte } type QueueInboundElement struct { @@ -92,11 +92,22 @@ func (device *Device) addToHandshakeQueue( } } -func (device *Device) RoutineReceiveIncomming() { +func (device *Device) RoutineReceiveIncomming(IPVersion int) { logDebug := device.log.Debug logDebug.Println("Routine, receive incomming, started") + var listener *Listener + + switch IPVersion { + case ipv4.Version: + listener = &device.net.ipv4 + case ipv6.Version: + listener = &device.net.ipv6 + default: + return + } + for { // wait for new conn @@ -107,14 +118,15 @@ func (device *Device) RoutineReceiveIncomming() { case <-device.signal.stop: return - case <-device.signal.newUDPConn: + case <-listener.update: - // fetch connection + // fetch new socket device.net.mutex.RLock() - conn := device.net.conn + sock := listener.sock + okay := listener.active device.net.mutex.RUnlock() - if conn == nil { + if !okay { continue } @@ -124,11 +136,20 @@ func (device *Device) RoutineReceiveIncomming() { buffer := device.GetMessageBuffer() + var size int + var err error + for { // read next datagram - size, raddr, err := conn.ReadFromUDP(buffer[:]) + var endpoint Endpoint + + if IPVersion == ipv6.Version { + size, err = endpoint.ReceiveIPv4(sock, buffer[:]) + } else { + size, err = endpoint.ReceiveIPv6(sock, buffer[:]) + } if err != nil { break @@ -192,7 +213,7 @@ func (device *Device) RoutineReceiveIncomming() { buffer = device.GetMessageBuffer() continue - // otherwise it is a handshake related packet + // otherwise it is a fixed size & handshake related packet case MessageInitiationType: okay = len(packet) == MessageInitiationSize @@ -208,10 +229,10 @@ func (device *Device) RoutineReceiveIncomming() { device.addToHandshakeQueue( device.queue.handshake, QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - source: raddr, + msgType: msgType, + buffer: buffer, + packet: packet, + endpoint: endpoint, }, ) buffer = device.GetMessageBuffer() @@ -293,8 +314,6 @@ func (device *Device) RoutineHandshake() { // unmarshal packet - logDebug.Println("Process cookie reply from:", elem.source.String()) - var reply MessageCookieReply reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &reply)