1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00
wireguard-go/conn_linux.go

509 lines
9.4 KiB
Go
Raw Normal View History

2018-04-20 04:05:11 +02:00
/* Copyright 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
2018-04-20 04:05:11 +02:00
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/
2017-08-25 14:53:23 +02:00
package main
import (
"errors"
2017-08-25 14:53:23 +02:00
"golang.org/x/sys/unix"
"net"
"strconv"
"unsafe"
2017-08-25 14:53:23 +02:00
)
2018-04-20 04:05:11 +02:00
type IPv4Source struct {
src [4]byte
ifindex int32
}
type IPv6Source struct {
src [16]byte
//ifindex belongs in dst.ZoneId
}
type NativeEndpoint struct {
2018-04-20 04:05:11 +02:00
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(IPv6Source{})]byte
isV6 bool
2017-10-08 22:03:32 +02:00
}
2018-04-20 04:05:11 +02:00
func (endpoint *NativeEndpoint) src4() *IPv4Source {
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
2017-10-08 22:03:32 +02:00
}
2018-04-20 04:05:11 +02:00
func (endpoint *NativeEndpoint) src6() *IPv6Source {
return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
}
2018-04-20 04:05:11 +02:00
func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
}
2018-04-20 04:05:11 +02:00
func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
2017-11-11 15:43:55 +01:00
}
2018-04-20 04:05:11 +02:00
type NativeBind struct {
sock4 int
sock6 int
2017-11-11 15:43:55 +01:00
}
2018-04-20 04:05:11 +02:00
var _ Endpoint = (*NativeEndpoint)(nil)
var _ Bind = NativeBind{}
2017-11-19 00:21:58 +01:00
func CreateEndpoint(s string) (Endpoint, error) {
var end NativeEndpoint
addr, err := parseEndpoint(s)
if err != nil {
return nil, err
}
ipv4 := addr.IP.To4()
if ipv4 != nil {
2018-04-20 04:05:11 +02:00
dst := end.dst4()
end.isV6 = false
dst.Port = addr.Port
2017-11-19 00:21:58 +01:00
copy(dst.Addr[:], ipv4)
end.ClearSrc()
return &end, nil
}
ipv6 := addr.IP.To16()
if ipv6 != nil {
zone, err := zoneToUint32(addr.Zone)
if err != nil {
return nil, err
}
2018-04-20 04:05:11 +02:00
dst := end.dst6()
end.isV6 = true
dst.Port = addr.Port
dst.ZoneId = zone
2017-11-19 00:21:58 +01:00
copy(dst.Addr[:], ipv6[:])
end.ClearSrc()
return &end, nil
}
2018-04-20 04:05:11 +02:00
return nil, errors.New("Invalid IP address")
}
2017-11-19 00:21:58 +01:00
func CreateBind(port uint16) (Bind, uint16, error) {
2017-10-08 22:03:32 +02:00
var err error
var bind NativeBind
2017-10-08 22:03:32 +02:00
bind.sock6, port, err = create6(port)
if err != nil {
return nil, port, err
}
bind.sock4, port, err = create4(port)
if err != nil {
unix.Close(bind.sock6)
}
return bind, port, err
2017-10-08 22:03:32 +02:00
}
func (bind NativeBind) SetMark(value uint32) error {
2017-10-08 22:03:32 +02:00
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
return unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
}
2017-11-11 23:26:44 +01:00
func closeUnblock(fd int) error {
// shutdown to unblock readers
unix.Shutdown(fd, unix.SHUT_RD)
return unix.Close(fd)
}
func (bind NativeBind) Close() error {
2017-11-11 23:26:44 +01:00
err1 := closeUnblock(bind.sock6)
err2 := closeUnblock(bind.sock4)
2017-10-08 22:03:32 +02:00
if err1 != nil {
return err1
}
return err2
}
func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
n, err := receive6(
2017-10-08 22:03:32 +02:00
bind.sock6,
buff,
&end,
2017-10-08 22:03:32 +02:00
)
return n, &end, err
2017-10-08 22:03:32 +02:00
}
func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
n, err := receive4(
2017-10-08 22:03:32 +02:00
bind.sock4,
buff,
&end,
2017-10-08 22:03:32 +02:00
)
return n, &end, err
2017-10-08 22:03:32 +02:00
}
func (bind NativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint)
2018-04-20 04:05:11 +02:00
if !nend.isV6 {
return send4(bind.sock4, nend, buff)
2018-04-20 04:05:11 +02:00
} else {
return send6(bind.sock6, nend, buff)
2017-10-08 22:03:32 +02:00
}
}
2018-04-20 04:05:11 +02:00
func rawAddrToIP4(addr *unix.SockaddrInet4) net.IP {
return net.IPv4(
addr.Addr[0],
addr.Addr[1],
addr.Addr[2],
addr.Addr[3],
)
}
2017-10-08 22:03:32 +02:00
2018-04-20 04:05:11 +02:00
func rawAddrToIP6(addr *unix.SockaddrInet6) net.IP {
return addr.Addr[:]
2017-10-08 22:03:32 +02:00
}
2018-04-20 04:05:11 +02:00
func (end *NativeEndpoint) SrcIP() net.IP {
if !end.isV6 {
2017-10-08 22:03:32 +02:00
return net.IPv4(
2018-04-20 04:05:11 +02:00
end.src4().src[0],
end.src4().src[1],
end.src4().src[2],
end.src4().src[3],
2017-10-08 22:03:32 +02:00
)
2018-04-20 04:05:11 +02:00
} else {
return end.src6().src[:]
2017-10-08 22:03:32 +02:00
}
}
func (end *NativeEndpoint) DstIP() net.IP {
2018-04-20 04:05:11 +02:00
if !end.isV6 {
return net.IPv4(
end.dst4().Addr[0],
end.dst4().Addr[1],
end.dst4().Addr[2],
end.dst4().Addr[3],
)
} else {
return end.dst6().Addr[:]
}
}
func (end *NativeEndpoint) DstToBytes() []byte {
2018-04-20 04:05:11 +02:00
if !end.isV6 {
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
} else {
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
}
2017-10-08 22:03:32 +02:00
}
func (end *NativeEndpoint) SrcToString() string {
2018-04-20 04:05:11 +02:00
return end.SrcIP().String()
2017-10-08 22:03:32 +02:00
}
func (end *NativeEndpoint) DstToString() string {
2018-04-20 04:05:11 +02:00
var udpAddr net.UDPAddr
udpAddr.IP = end.DstIP()
if !end.isV6 {
udpAddr.Port = end.dst4().Port
} else {
udpAddr.Port = end.dst6().Port
}
return udpAddr.String()
2017-10-08 22:03:32 +02:00
}
func (end *NativeEndpoint) ClearDst() {
2018-04-20 04:05:11 +02:00
for i := range end.dst {
end.dst[i] = 0
}
}
func (end *NativeEndpoint) ClearSrc() {
2018-04-20 04:05:11 +02:00
for i := range end.src {
end.src[i] = 0
}
}
func zoneToUint32(zone string) (uint32, error) {
if zone == "" {
return 0, nil
}
if intr, err := net.InterfaceByName(zone); err == nil {
return uint32(intr.Index), nil
}
n, err := strconv.ParseUint(zone, 10, 32)
return uint32(n), err
}
2017-10-08 22:03:32 +02:00
func create4(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, 0, err
}
addr := unix.SockaddrInet4{
Port: int(port),
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IP,
unix.IP_PKTINFO,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
2018-04-20 04:05:11 +02:00
return -1, 0, err
}
2017-10-08 22:03:32 +02:00
return fd, uint16(addr.Port), err
}
2017-10-08 22:03:32 +02:00
func create6(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, 0, err
}
// set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVPKTINFO,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_V6ONLY,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
2018-04-20 04:05:11 +02:00
return -1, 0, err
}
2017-10-08 22:03:32 +02:00
return fd, uint16(addr.Port), err
}
2018-04-20 04:05:11 +02:00
func send4(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
2018-04-20 04:05:11 +02:00
pktinfo unix.Inet4Pktinfo
}{
unix.Cmsghdr{
2018-04-20 04:05:11 +02:00
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
2018-04-20 04:05:11 +02:00
unix.Inet4Pktinfo{
Spec_dst: end.src4().src,
Ifindex: end.src4().ifindex,
},
}
2018-04-20 04:05:11 +02:00
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
2018-04-20 04:05:11 +02:00
if err == nil {
return nil
}
// clear src and retry
2018-04-20 04:05:11 +02:00
if err == unix.EINVAL {
end.ClearSrc()
2018-04-20 04:05:11 +02:00
cmsg.pktinfo = unix.Inet4Pktinfo{}
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
}
2018-04-20 04:05:11 +02:00
return err
}
2018-04-20 04:05:11 +02:00
func send6(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
2018-04-20 04:05:11 +02:00
pktinfo unix.Inet6Pktinfo
}{
unix.Cmsghdr{
2018-04-20 04:05:11 +02:00
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
2018-04-20 04:05:11 +02:00
unix.Inet6Pktinfo{
Addr: end.src6().src,
Ifindex: end.dst6().ZoneId,
},
}
2018-04-20 04:05:11 +02:00
if cmsg.pktinfo.Addr == [16]byte{} {
cmsg.pktinfo.Ifindex = 0
}
2018-04-20 04:05:11 +02:00
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
2018-04-20 04:05:11 +02:00
if err == nil {
return nil
}
2018-04-20 04:05:11 +02:00
// clear src and retry
2017-11-11 15:43:55 +01:00
2018-04-20 04:05:11 +02:00
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
2017-11-11 15:43:55 +01:00
}
2018-04-20 04:05:11 +02:00
return err
}
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}
2018-04-20 04:05:11 +02:00
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
2018-04-20 04:05:11 +02:00
if err != nil {
return 0, err
}
end.isV6 = false
2018-04-20 04:05:11 +02:00
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
*end.dst4() = *newDst4
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
2018-04-20 04:05:11 +02:00
end.src4().src = cmsg.pktinfo.Spec_dst
end.src4().ifindex = cmsg.pktinfo.Ifindex
}
2018-04-20 04:05:11 +02:00
return size, nil
}
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}
2018-04-20 04:05:11 +02:00
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
2018-04-20 04:05:11 +02:00
if err != nil {
return 0, err
}
end.isV6 = true
2018-04-20 04:05:11 +02:00
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
*end.dst6() = *newDst6
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
2018-04-20 04:05:11 +02:00
end.src6().src = cmsg.pktinfo.Addr
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
}
2018-04-20 04:05:11 +02:00
return size, nil
}