mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
conn: reconstruct v4 vs v6 receive function based on symtab
This is kind of gross but it's better than the alternatives. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
d2fd0c0cc0
commit
54dbe2471f
@ -148,11 +148,11 @@ again:
|
|||||||
|
|
||||||
var fns []ReceiveFunc
|
var fns []ReceiveFunc
|
||||||
if sock4 != -1 {
|
if sock4 != -1 {
|
||||||
fns = append(fns, makeReceiveIPv4(sock4))
|
fns = append(fns, bind.makeReceiveIPv4(sock4))
|
||||||
bind.sock4 = sock4
|
bind.sock4 = sock4
|
||||||
}
|
}
|
||||||
if sock6 != -1 {
|
if sock6 != -1 {
|
||||||
fns = append(fns, makeReceiveIPv6(sock6))
|
fns = append(fns, bind.makeReceiveIPv6(sock6))
|
||||||
bind.sock6 = sock6
|
bind.sock6 = sock6
|
||||||
}
|
}
|
||||||
if len(fns) == 0 {
|
if len(fns) == 0 {
|
||||||
@ -224,7 +224,7 @@ func (bind *LinuxSocketBind) Close() error {
|
|||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeReceiveIPv6(sock int) ReceiveFunc {
|
func (*LinuxSocketBind) makeReceiveIPv6(sock int) ReceiveFunc {
|
||||||
return func(buff []byte) (int, Endpoint, error) {
|
return func(buff []byte) (int, Endpoint, error) {
|
||||||
var end LinuxSocketEndpoint
|
var end LinuxSocketEndpoint
|
||||||
n, err := receive6(sock, buff, &end)
|
n, err := receive6(sock, buff, &end)
|
||||||
@ -232,7 +232,7 @@ func makeReceiveIPv6(sock int) ReceiveFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeReceiveIPv4(sock int) ReceiveFunc {
|
func (*LinuxSocketBind) makeReceiveIPv4(sock int) ReceiveFunc {
|
||||||
return func(buff []byte) (int, Endpoint, error) {
|
return func(buff []byte) (int, Endpoint, error) {
|
||||||
var end LinuxSocketEndpoint
|
var end LinuxSocketEndpoint
|
||||||
n, err := receive4(sock, buff, &end)
|
n, err := receive4(sock, buff, &end)
|
||||||
|
@ -118,11 +118,11 @@ again:
|
|||||||
}
|
}
|
||||||
var fns []ReceiveFunc
|
var fns []ReceiveFunc
|
||||||
if ipv4 != nil {
|
if ipv4 != nil {
|
||||||
fns = append(fns, makeReceiveFunc(ipv4, true))
|
fns = append(fns, bind.makeReceiveIPv4(ipv4))
|
||||||
bind.ipv4 = ipv4
|
bind.ipv4 = ipv4
|
||||||
}
|
}
|
||||||
if ipv6 != nil {
|
if ipv6 != nil {
|
||||||
fns = append(fns, makeReceiveFunc(ipv6, false))
|
fns = append(fns, bind.makeReceiveIPv6(ipv6))
|
||||||
bind.ipv6 = ipv6
|
bind.ipv6 = ipv6
|
||||||
}
|
}
|
||||||
if len(fns) == 0 {
|
if len(fns) == 0 {
|
||||||
@ -152,16 +152,23 @@ func (bind *StdNetBind) Close() error {
|
|||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc {
|
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
|
||||||
return func(buff []byte) (int, Endpoint, error) {
|
return func(buff []byte) (int, Endpoint, error) {
|
||||||
n, endpoint, err := conn.ReadFromUDP(buff)
|
n, endpoint, err := conn.ReadFromUDP(buff)
|
||||||
if isIPv4 && endpoint != nil {
|
if endpoint != nil {
|
||||||
endpoint.IP = endpoint.IP.To4()
|
endpoint.IP = endpoint.IP.To4()
|
||||||
}
|
}
|
||||||
return n, (*StdNetEndpoint)(endpoint), err
|
return n, (*StdNetEndpoint)(endpoint), err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
|
||||||
|
return func(buff []byte) (int, Endpoint, error) {
|
||||||
|
n, endpoint, err := conn.ReadFromUDP(buff)
|
||||||
|
return n, (*StdNetEndpoint)(endpoint), err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
||||||
var err error
|
var err error
|
||||||
nend, ok := endpoint.(*StdNetEndpoint)
|
nend, ok := endpoint.(*StdNetEndpoint)
|
||||||
|
56
conn/conn.go
56
conn/conn.go
@ -8,7 +8,10 @@ package conn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -69,6 +72,54 @@ type Endpoint interface {
|
|||||||
SrcIP() net.IP
|
SrcIP() net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBindAlreadyOpen = errors.New("bind is already open")
|
||||||
|
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (fn ReceiveFunc) PrettyName() string {
|
||||||
|
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
|
||||||
|
// 0. cheese/taco.beansIPv6.func12.func21218-fm
|
||||||
|
name = strings.TrimSuffix(name, "-fm")
|
||||||
|
// 1. cheese/taco.beansIPv6.func12.func21218
|
||||||
|
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
|
||||||
|
name = name[idx+1:]
|
||||||
|
// 2. taco.beansIPv6.func12.func21218
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
var idx int
|
||||||
|
for idx = len(name) - 1; idx >= 0; idx-- {
|
||||||
|
if name[idx] < '0' || name[idx] > '9' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx == len(name)-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
const dotFunc = ".func"
|
||||||
|
if !strings.HasSuffix(name[:idx+1], dotFunc) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
name = name[:idx+1-len(dotFunc)]
|
||||||
|
// 3. taco.beansIPv6.func12
|
||||||
|
// 4. taco.beansIPv6
|
||||||
|
}
|
||||||
|
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
|
||||||
|
name = name[idx+1:]
|
||||||
|
// 5. beansIPv6
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
return fmt.Sprintf("%p", fn)
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(name, "IPv4") {
|
||||||
|
return "v4"
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(name, "IPv6") {
|
||||||
|
return "v6"
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||||
// ensure that the host is an IP address
|
// ensure that the host is an IP address
|
||||||
|
|
||||||
@ -98,8 +149,3 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
|
|||||||
}
|
}
|
||||||
return addr, err
|
return addr, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrBindAlreadyOpen = errors.New("bind is already open")
|
|
||||||
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
|
|
||||||
)
|
|
||||||
|
@ -69,14 +69,15 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
|||||||
* IPv4 and IPv6 (separately)
|
* IPv4 and IPv6 (separately)
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
||||||
|
recvName := recv.PrettyName()
|
||||||
defer func() {
|
defer func() {
|
||||||
device.log.Verbosef("Routine: receive incoming %p - stopped", recv)
|
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
|
||||||
device.queue.decryption.wg.Done()
|
device.queue.decryption.wg.Done()
|
||||||
device.queue.handshake.wg.Done()
|
device.queue.handshake.wg.Done()
|
||||||
device.net.stopping.Done()
|
device.net.stopping.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
device.log.Verbosef("Routine: receive incoming %p - started", recv)
|
device.log.Verbosef("Routine: receive incoming %s - started", recvName)
|
||||||
|
|
||||||
// receive datagrams until conn is closed
|
// receive datagrams until conn is closed
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user