1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 09:15:14 +01:00

Definition of platform specific socket bind

This commit is contained in:
Mathias Hall-Andersen 2017-10-06 22:56:01 +02:00
parent 32d8932d1b
commit c70f0c5da2
4 changed files with 198 additions and 40 deletions

View File

@ -56,7 +56,7 @@ func updateUDPConn(device *Device) error {
// set fwmark // set fwmark
err = setMark(netc.conn, netc.fwmark) err = SetMark(netc.conn, netc.fwmark)
if err != nil { if err != nil {
return err return err
} }

View File

@ -6,6 +6,6 @@ import (
"net" "net"
) )
func setMark(conn *net.UDPConn, value uint32) error { func SetMark(conn *net.UDPConn, value uint32) error {
return nil return nil
} }

View File

@ -14,23 +14,30 @@ import (
"unsafe" "unsafe"
) )
import "fmt"
/* Supports source address caching /* Supports source address caching
*
* It is important that the endpoint is only updated after the packet content has been authenticated.
* *
* Currently there is no way to achieve this within the net package: * Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930 * See e.g. https://github.com/golang/go/issues/17930
* So this code is platform dependent.
*
* It is important that the endpoint is only updated after the packet content has been authenticated!
*/ */
type Endpoint struct { type Endpoint struct {
// source (selected based on dst type) // source (selected based on dst type)
// (could use RawSockaddrAny and unsafe) // (could use RawSockaddrAny and unsafe)
srcIPv6 unix.RawSockaddrInet6 src6 unix.RawSockaddrInet6
srcIPv4 unix.RawSockaddrInet4 src4 unix.RawSockaddrInet4
srcIf4 int32 src4if int32
dst unix.RawSockaddrAny dst unix.RawSockaddrAny
} }
type IPv4Socket int
type IPv6Socket int
func zoneToUint32(zone string) (uint32, error) { func zoneToUint32(zone string) (uint32, error) {
if zone == "" { if zone == "" {
return 0, nil return 0, nil
@ -42,10 +49,115 @@ func zoneToUint32(zone string) (uint32, error) {
return uint32(n), err return uint32(n), err
} }
func CreateIPv4Socket(port int) (IPv4Socket, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, err
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IP,
unix.IP_PKTINFO,
1,
); err != nil {
return err
}
addr := unix.SockaddrInet4{
Port: port,
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
}
return IPv4Socket(fd), err
}
func CreateIPv6Socket(port int) (IPv6Socket, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, err
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVPKTINFO,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_V6ONLY,
1,
); err != nil {
return err
}
addr := unix.SockaddrInet6{
Port: port,
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
}
return IPv6Socket(fd), err
}
func (end *Endpoint) ClearSrc() { func (end *Endpoint) ClearSrc() {
end.srcIf4 = 0 end.src4if = 0
end.srcIPv4 = unix.RawSockaddrInet4{} end.src4 = unix.RawSockaddrInet4{}
end.srcIPv6 = unix.RawSockaddrInet6{} end.src6 = unix.RawSockaddrInet6{}
} }
func (end *Endpoint) Set(s string) error { func (end *Endpoint) Set(s string) error {
@ -85,8 +197,10 @@ func (end *Endpoint) Set(s string) error {
} }
func send6(sock uintptr, end *Endpoint, buff []byte) error { func send6(sock uintptr, end *Endpoint, buff []byte) error {
var iovec unix.Iovec
// construct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff)) iovec.SetLen(len(buff))
@ -100,8 +214,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
Len: unix.SizeofInet6Pktinfo, Len: unix.SizeofInet6Pktinfo,
}, },
unix.Inet6Pktinfo{ unix.Inet6Pktinfo{
Addr: end.srcIPv6.Addr, Addr: end.src6.Addr,
Ifindex: end.srcIPv6.Scope_id, Ifindex: end.src6.Scope_id,
}, },
} }
@ -130,8 +244,10 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
} }
func send4(sock uintptr, end *Endpoint, buff []byte) error { func send4(sock uintptr, end *Endpoint, buff []byte) error {
var iovec unix.Iovec
// construct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff)) iovec.SetLen(len(buff))
@ -142,11 +258,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
unix.Cmsghdr{ unix.Cmsghdr{
Level: unix.IPPROTO_IP, Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO, Type: unix.IP_PKTINFO,
Len: unix.SizeofInet6Pktinfo, Len: unix.SizeofInet4Pktinfo,
}, },
unix.Inet4Pktinfo{ unix.Inet4Pktinfo{
Spec_dst: end.srcIPv4.Addr, Spec_dst: end.src4.Addr,
Ifindex: end.srcIf4, Ifindex: end.src4if,
}, },
} }
@ -174,7 +290,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
return errno return errno
} }
func send(c *net.UDPConn, end *Endpoint, buff []byte) error { func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error {
// extract underlying file descriptor // extract underlying file descriptor
@ -195,12 +311,9 @@ func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
return errors.New("Unknown address family of source") return errors.New("Unknown address family of source")
} }
func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) { func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
file, err := c.File() // contruct message header
if err != nil {
return err, nil, nil
}
var iovec unix.Iovec var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
@ -208,47 +321,92 @@ func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAdd
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo // big enough pktinfo unix.Inet4Pktinfo
}
var msghdr unix.Msghdr
msghdr.Iov = &iovec
msghdr.Iovlen = 1
msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
msghdr.Namelen = unix.SizeofSockaddrInet4
msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
// recvmsg(sock, &mskhdr, 0)
size, _, errno := unix.Syscall(
unix.SYS_RECVMSG,
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
if errno != 0 {
return 0, errno
}
fmt.Println(msghdr)
fmt.Println(cmsg)
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
end.src4.Addr = cmsg.pktinfo.Spec_dst
end.src4if = cmsg.pktinfo.Ifindex
}
return int(size), nil
}
func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
// contruct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
} }
var msg unix.Msghdr var msg unix.Msghdr
msg.Iov = &iovec msg.Iov = &iovec
msg.Iovlen = 1 msg.Iovlen = 1
msg.Name = (*byte)(unsafe.Pointer(&end.dst)) msg.Name = (*byte)(unsafe.Pointer(&end.dst))
msg.Namelen = uint32(unix.SizeofSockaddrAny) msg.Namelen = uint32(unix.SizeofSockaddrInet6)
msg.Control = (*byte)(unsafe.Pointer(&cmsg)) msg.Control = (*byte)(unsafe.Pointer(&cmsg))
msg.SetControllen(int(unsafe.Sizeof(cmsg))) msg.SetControllen(int(unsafe.Sizeof(cmsg)))
// recvmsg(sock, &mskhdr, 0)
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
unix.SYS_RECVMSG, unix.SYS_RECVMSG,
file.Fd(), uintptr(sock),
uintptr(unsafe.Pointer(&msg)), uintptr(unsafe.Pointer(&msg)),
0, 0,
) )
if errno != 0 { if errno != 0 {
return errno, nil, nil return errno
} }
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
end.src6.Addr = cmsg.pktinfo.Addr
end.src6.Scope_id = cmsg.pktinfo.Ifindex
} }
if cmsg.cmsghdr.Level == unix.IPPROTO_IP && return nil
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
println(info)
}
return nil, nil, nil
} }
func setMark(conn *net.UDPConn, value uint32) error { func SetMark(conn *net.UDPConn, value uint32) error {
if conn == nil { if conn == nil {
return nil return nil
} }

View File

@ -166,7 +166,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.net.mutex.Lock() device.net.mutex.Lock()
if fwmark > 0 || device.net.fwmark > 0 { if fwmark > 0 || device.net.fwmark > 0 {
device.net.fwmark = uint32(fwmark) device.net.fwmark = uint32(fwmark)
err := setMark( err := SetMark(
device.net.conn, device.net.conn,
device.net.fwmark, device.net.fwmark,
) )