1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00
wireguard-go/conn/bind_std.go

197 lines
4.3 KiB
Go
Raw Normal View History

2019-01-02 01:55:51 +01:00
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package conn
2017-08-25 14:53:23 +02:00
import (
"errors"
2017-08-25 14:53:23 +02:00
"net"
"net/netip"
"sync"
2018-06-11 19:04:38 +02:00
"syscall"
2017-08-25 14:53:23 +02:00
)
// StdNetBind is meant to be a temporary solution on platforms for which
// the sticky socket / source caching behavior has not yet been implemented.
// It uses the Go's net package to implement networking.
// See LinuxSocketBind for a proper implementation on the Linux platform.
type StdNetBind struct {
mu sync.Mutex // protects following fields
2019-10-21 13:29:57 +02:00
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
2017-11-19 00:21:58 +01:00
}
func NewStdNetBind() Bind { return &StdNetBind{} }
2017-11-19 00:21:58 +01:00
type StdNetEndpoint netip.AddrPort
2017-11-19 00:21:58 +01:00
var _ Bind = (*StdNetBind)(nil)
var _ Endpoint = (*StdNetEndpoint)(nil)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s)
return (*StdNetEndpoint)(&e), err
2017-11-19 00:21:58 +01:00
}
func (*StdNetEndpoint) ClearSrc() {}
2017-11-19 00:21:58 +01:00
func (e *StdNetEndpoint) DstIP() netip.Addr {
return (*netip.AddrPort)(e).Addr()
2017-11-19 00:21:58 +01:00
}
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
2017-11-19 00:21:58 +01:00
}
func (e *StdNetEndpoint) DstToBytes() []byte {
addr := (*netip.AddrPort)(e)
out := addr.Addr().AsSlice()
out = append(out, byte(addr.Port()&0xff))
out = append(out, byte((addr.Port()>>8)&0xff))
2017-11-19 00:21:58 +01:00
return out
}
func (e *StdNetEndpoint) DstToString() string {
return (*netip.AddrPort)(e).String()
2017-11-19 00:21:58 +01:00
}
func (e *StdNetEndpoint) SrcToString() string {
2017-11-19 00:21:58 +01:00
return ""
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
2017-11-19 00:21:58 +01:00
if err != nil {
return nil, 0, err
}
// Retrieve port.
2017-11-19 00:21:58 +01:00
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
2017-11-19 00:21:58 +01:00
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
2017-11-19 00:21:58 +01:00
return conn, uaddr.Port, nil
}
func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
var err error
var tries int
if bind.ipv4 != nil || bind.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
// Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var ipv4, ipv6 *net.UDPConn
ipv4, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
ipv6, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
ipv4.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
ipv4.Close()
return nil, 0, err
}
var fns []ReceiveFunc
if ipv4 != nil {
fns = append(fns, bind.makeReceiveIPv4(ipv4))
bind.ipv4 = ipv4
}
if ipv6 != nil {
fns = append(fns, bind.makeReceiveIPv6(ipv6))
bind.ipv6 = ipv6
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, uint16(port), nil
}
func (bind *StdNetBind) Close() error {
bind.mu.Lock()
defer bind.mu.Unlock()
2018-06-11 19:04:38 +02:00
var err1, err2 error
if bind.ipv4 != nil {
err1 = bind.ipv4.Close()
bind.ipv4 = nil
2018-06-11 19:04:38 +02:00
}
if bind.ipv6 != nil {
err2 = bind.ipv6.Close()
bind.ipv6 = nil
2018-06-11 19:04:38 +02:00
}
bind.blackhole4 = false
bind.blackhole6 = false
if err1 != nil {
return err1
}
return err2
}
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
return n, (*StdNetEndpoint)(&endpoint), err
2018-06-11 19:04:38 +02:00
}
}
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
return n, (*StdNetEndpoint)(&endpoint), err
}
}
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
var err error
nend, ok := endpoint.(*StdNetEndpoint)
if !ok {
return ErrWrongEndpointType
}
addr := (*netip.AddrPort)(nend)
bind.mu.Lock()
var (
blackhole bool
conn *net.UDPConn
)
if addr.Addr().Is4() {
blackhole = bind.blackhole4
conn = bind.ipv4
} else if addr.Addr().Is6() {
blackhole = bind.blackhole6
conn = bind.ipv6
}
bind.mu.Unlock()
if blackhole {
return nil
}
if conn == nil {
return syscall.EAFNOSUPPORT
}
_, err = conn.WriteToUDPAddrPort(buff, *addr)
return err
}