1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

conn: reconstruct v4 vs v6 receive function based on symtab

This is kind of gross but it's better than the alternatives.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-04-09 17:21:35 -06:00
parent d2fd0c0cc0
commit 54dbe2471f
4 changed files with 69 additions and 15 deletions

View File

@ -148,11 +148,11 @@ again:
var fns []ReceiveFunc var fns []ReceiveFunc
if sock4 != -1 { if sock4 != -1 {
fns = append(fns, makeReceiveIPv4(sock4)) fns = append(fns, bind.makeReceiveIPv4(sock4))
bind.sock4 = sock4 bind.sock4 = sock4
} }
if sock6 != -1 { if sock6 != -1 {
fns = append(fns, makeReceiveIPv6(sock6)) fns = append(fns, bind.makeReceiveIPv6(sock6))
bind.sock6 = sock6 bind.sock6 = sock6
} }
if len(fns) == 0 { if len(fns) == 0 {
@ -224,7 +224,7 @@ func (bind *LinuxSocketBind) Close() error {
return err2 return err2
} }
func makeReceiveIPv6(sock int) ReceiveFunc { func (*LinuxSocketBind) makeReceiveIPv6(sock int) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) { return func(buff []byte) (int, Endpoint, error) {
var end LinuxSocketEndpoint var end LinuxSocketEndpoint
n, err := receive6(sock, buff, &end) n, err := receive6(sock, buff, &end)
@ -232,7 +232,7 @@ func makeReceiveIPv6(sock int) ReceiveFunc {
} }
} }
func makeReceiveIPv4(sock int) ReceiveFunc { func (*LinuxSocketBind) makeReceiveIPv4(sock int) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) { return func(buff []byte) (int, Endpoint, error) {
var end LinuxSocketEndpoint var end LinuxSocketEndpoint
n, err := receive4(sock, buff, &end) n, err := receive4(sock, buff, &end)

View File

@ -118,11 +118,11 @@ again:
} }
var fns []ReceiveFunc var fns []ReceiveFunc
if ipv4 != nil { if ipv4 != nil {
fns = append(fns, makeReceiveFunc(ipv4, true)) fns = append(fns, bind.makeReceiveIPv4(ipv4))
bind.ipv4 = ipv4 bind.ipv4 = ipv4
} }
if ipv6 != nil { if ipv6 != nil {
fns = append(fns, makeReceiveFunc(ipv6, false)) fns = append(fns, bind.makeReceiveIPv6(ipv6))
bind.ipv6 = ipv6 bind.ipv6 = ipv6
} }
if len(fns) == 0 { if len(fns) == 0 {
@ -152,16 +152,23 @@ func (bind *StdNetBind) Close() error {
return err2 return err2
} }
func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc { func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) { return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDP(buff) n, endpoint, err := conn.ReadFromUDP(buff)
if isIPv4 && endpoint != nil { if endpoint != nil {
endpoint.IP = endpoint.IP.To4() endpoint.IP = endpoint.IP.To4()
} }
return n, (*StdNetEndpoint)(endpoint), err return n, (*StdNetEndpoint)(endpoint), err
} }
} }
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDP(buff)
return n, (*StdNetEndpoint)(endpoint), err
}
}
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
var err error var err error
nend, ok := endpoint.(*StdNetEndpoint) nend, ok := endpoint.(*StdNetEndpoint)

View File

@ -8,7 +8,10 @@ package conn
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"reflect"
"runtime"
"strings" "strings"
) )
@ -69,6 +72,54 @@ type Endpoint interface {
SrcIP() net.IP SrcIP() net.IP
} }
var (
ErrBindAlreadyOpen = errors.New("bind is already open")
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
)
func (fn ReceiveFunc) PrettyName() string {
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
// 0. cheese/taco.beansIPv6.func12.func21218-fm
name = strings.TrimSuffix(name, "-fm")
// 1. cheese/taco.beansIPv6.func12.func21218
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
name = name[idx+1:]
// 2. taco.beansIPv6.func12.func21218
}
for {
var idx int
for idx = len(name) - 1; idx >= 0; idx-- {
if name[idx] < '0' || name[idx] > '9' {
break
}
}
if idx == len(name)-1 {
break
}
const dotFunc = ".func"
if !strings.HasSuffix(name[:idx+1], dotFunc) {
break
}
name = name[:idx+1-len(dotFunc)]
// 3. taco.beansIPv6.func12
// 4. taco.beansIPv6
}
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
name = name[idx+1:]
// 5. beansIPv6
}
if name == "" {
return fmt.Sprintf("%p", fn)
}
if strings.HasSuffix(name, "IPv4") {
return "v4"
}
if strings.HasSuffix(name, "IPv6") {
return "v6"
}
return name
}
func parseEndpoint(s string) (*net.UDPAddr, error) { func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address // ensure that the host is an IP address
@ -98,8 +149,3 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
} }
return addr, err return addr, err
} }
var (
ErrBindAlreadyOpen = errors.New("bind is already open")
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
)

View File

@ -69,14 +69,15 @@ func (peer *Peer) keepKeyFreshReceiving() {
* IPv4 and IPv6 (separately) * IPv4 and IPv6 (separately)
*/ */
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() { defer func() {
device.log.Verbosef("Routine: receive incoming %p - stopped", recv) device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
device.queue.decryption.wg.Done() device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done() device.queue.handshake.wg.Done()
device.net.stopping.Done() device.net.stopping.Done()
}() }()
device.log.Verbosef("Routine: receive incoming %p - started", recv) device.log.Verbosef("Routine: receive incoming %s - started", recvName)
// receive datagrams until conn is closed // receive datagrams until conn is closed