1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2025-09-18 20:57:50 +02:00
wireguard-go/tun/netstack/tun.go
tqbf 90300178a6 Handle whole ICMP headers
Theoretically, this change would allow us to send ICMP replies from
these sockets. In practice, netstack filters anything but EchoRequest,
so this change mostly just adds annoyance right now. But in the
future, netstack might fix this, and this API would then work for
more than one use case.

Signed-off-by: Thomas Ptacek <thomas@sockpuppet.org>
2022-02-02 13:38:49 -06:00

1063 lines
26 KiB
Go

/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package netstack
import (
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"regexp"
"strconv"
"strings"
"sync"
"time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
type netTun struct {
stack *stack.Stack
dispatcher stack.NetworkDispatcher
events chan tun.Event
incomingPacket chan buffer.VectorisedView
mtu int
dnsServers []netip.Addr
hasV4, hasV6 bool
}
type (
endpoint netTun
Net netTun
)
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *endpoint) MTU() uint32 {
mtu, err := (*netTun)(e).MTU()
if err != nil {
panic(err)
}
return uint32(mtu)
}
func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
}
func (*endpoint) MaxHeaderLength() uint16 {
return 0
}
func (*endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (*endpoint) Wait() {}
func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.incomingPacket <- buffer.NewVectorisedView(pkt.Size(), pkt.Views())
return nil
}
func (e *endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented")
}
func (e *endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
panic("not implemented")
}
func (*endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
HandleLocal: true,
}
dev := &netTun{
stack: stack.New(opts),
events: make(chan tun.Event, 10),
incomingPacket: make(chan buffer.VectorisedView),
dnsServers: dnsServers,
mtu: mtu,
}
tcpipErr := dev.stack.CreateNIC(1, (*endpoint)(dev))
if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
for _, ip := range localAddresses {
var protoNumber tcpip.NetworkProtocolNumber
if ip.Is4() {
protoNumber = ipv4.ProtocolNumber
} else if ip.Is6() {
protoNumber = ipv6.ProtocolNumber
}
protoAddr := tcpip.ProtocolAddress{
Protocol: protoNumber,
AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
}
if ip.Is4() {
dev.hasV4 = true
} else if ip.Is6() {
dev.hasV6 = true
}
}
if dev.hasV4 {
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
}
if dev.hasV6 {
dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
}
dev.events <- tun.EventUp
return dev, (*Net)(dev), nil
}
func (tun *netTun) Name() (string, error) {
return "go", nil
}
func (tun *netTun) File() *os.File {
return nil
}
func (tun *netTun) Events() chan tun.Event {
return tun.events
}
func (tun *netTun) Read(buf []byte, offset int) (int, error) {
view, ok := <-tun.incomingPacket
if !ok {
return 0, os.ErrClosed
}
return view.Read(buf[offset:])
}
func (tun *netTun) Write(buf []byte, offset int) (int, error) {
packet := buf[offset:]
if len(packet) == 0 {
return 0, nil
}
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(len(packet), []buffer.View{buffer.NewViewFromBytes(packet)})})
switch packet[0] >> 4 {
case 4:
tun.dispatcher.DeliverNetworkPacket("", "", ipv4.ProtocolNumber, pkb)
case 6:
tun.dispatcher.DeliverNetworkPacket("", "", ipv6.ProtocolNumber, pkb)
}
return len(buf), nil
}
func (tun *netTun) Flush() error {
return nil
}
func (tun *netTun) Close() error {
tun.stack.RemoveNIC(1)
if tun.events != nil {
close(tun.events)
}
if tun.incomingPacket != nil {
close(tun.incomingPacket)
}
return nil
}
func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
protoNumber = ipv4.ProtocolNumber
} else {
protoNumber = ipv6.ProtocolNumber
}
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(endpoint.Addr().AsSlice()),
Port: endpoint.Port(),
}, protoNumber
}
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
}
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
}
return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialTCP(net.stack, fa, pn)
}
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
return net.DialTCPAddrPort(netip.AddrPort{})
}
return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
fa, pn := convertToFullAddr(addr)
return gonet.ListenTCP(net.stack, fa, pn)
}
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil {
return net.ListenTCPAddrPort(netip.AddrPort{})
}
return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber
if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(laddr)
lfa = &addr
}
if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(raddr)
rfa = &addr
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
var la, ra netip.AddrPort
if laddr != nil {
la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
}
if raddr != nil {
ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
}
return net.DialUDPAddrPort(la, ra)
}
type PingConn struct {
laddr PingAddr
raddr PingAddr
wq waiter.Queue
ep tcpip.Endpoint
mu sync.RWMutex
deadline time.Time
deadlineBreaker chan struct{}
}
type PingAddr struct{ addr netip.Addr }
func (ia PingAddr) String() string {
return ia.addr.String()
}
func (ia PingAddr) Network() string {
if ia.addr.Is4() {
return "ping4"
} else if ia.addr.Is6() {
return "ping6"
}
return "ping"
}
func PingAddrFromAddr(addr net.Addr) (PingAddr, error) {
switch v := addr.(type) {
case PingAddr:
return v, nil
case *net.IPAddr:
nip := netip.AddrFromSlice(v.IP)
return PingAddr{nip}, nil
default:
return PingAddr{}, errors.New("wrong address format")
}
}
func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
v6 := laddr.Is6() || raddr.Is6()
bind := laddr.IsValid()
if !bind {
if v6 {
laddr = netip.IPv6Unspecified()
} else {
laddr = netip.IPv4Unspecified()
}
}
tn := icmp.ProtocolNumber4
pn := ipv4.ProtocolNumber
if v6 {
tn = icmp.ProtocolNumber6
pn = ipv6.ProtocolNumber
}
pc := &PingConn{
laddr: PingAddr{laddr},
deadlineBreaker: make(chan struct{}, 1),
}
ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
if tcpipErr != nil {
return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
}
pc.ep = ep
if bind {
fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
return nil, fmt.Errorf("ping bind: %s", tcpipErr)
}
}
if raddr.IsValid() {
pc.raddr = PingAddr{raddr}
fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
return nil, fmt.Errorf("ping connect: %s", tcpipErr)
}
}
return pc, nil
}
func (pc *PingConn) LocalAddr() net.Addr {
return pc.laddr
}
func (pc *PingConn) RemoteAddr() net.Addr {
return pc.raddr
}
func (pc *PingConn) Close() error {
close(pc.deadlineBreaker)
pc.ep.Close()
return nil
}
func (pc *PingConn) SetWriteDeadline(t time.Time) error {
return errors.New("not implemented")
}
func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
ia, err := PingAddrFromAddr(addr)
if err != nil {
return 0, fmt.Errorf("ping write: %w", err)
}
if !((ia.addr.Is4() && pc.laddr.addr.Is4()) || (ia.addr.Is6() && pc.laddr.addr.Is6())) {
return 0, fmt.Errorf("ping write: mismatched protocols")
}
buf := buffer.NewViewFromBytes(p)
rdr := buf.Reader()
rfa, _ := convertToFullAddr(netip.AddrPortFrom(ia.addr, 0))
// won't block, no deadlines
n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{
To: &rfa,
})
if tcpipErr != nil {
return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
}
return int(n64), nil
}
func (pc *PingConn) Write(p []byte) (n int, err error) {
return pc.WriteTo(p, pc.raddr)
}
func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
e, notifyCh := waiter.NewChannelEntry(nil)
pc.wq.EventRegister(&e, waiter.EventIn)
defer pc.wq.EventUnregister(&e)
ready := false
for !ready {
pc.mu.RLock()
deadlineBreaker := pc.deadlineBreaker
deadline := pc.deadline
pc.mu.RUnlock()
if deadline.IsZero() {
select {
case <-deadlineBreaker:
case <-notifyCh:
ready = true
}
} else {
t := time.NewTimer(deadline.Sub(time.Now()))
defer t.Stop()
select {
case <-t.C:
return 0, nil, os.ErrDeadlineExceeded
case <-deadlineBreaker:
case <-notifyCh:
ready = true
}
}
}
w := tcpip.SliceWriter(p)
res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
NeedRemoteAddr: true,
})
if tcpipErr != nil {
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
}
addr = PingAddr{netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))}
return res.Count, addr, nil
}
func (pc *PingConn) Read(p []byte) (n int, err error) {
n, _, err = pc.ReadFrom(p)
return
}
func (pc *PingConn) SetDeadline(t time.Time) error {
// pc.SetWriteDeadline is unimplemented
return pc.SetReadDeadline(t)
}
func (pc *PingConn) SetReadDeadline(t time.Time) error {
pc.mu.Lock()
defer pc.mu.Unlock()
close(pc.deadlineBreaker)
pc.deadlineBreaker = make(chan struct{}, 1)
pc.deadline = t
return nil
}
var (
errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral")
errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
errServerMisbehaving = errors.New("server misbehaving")
errInvalidDNSResponse = errors.New("invalid DNS response")
errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
errServerTemporarilyMisbehaving = errors.New("server misbehaving")
errCanceled = errors.New("operation was canceled")
errTimeout = errors.New("i/o timeout")
errNumericPort = errors.New("port must be numeric")
errNoSuitableAddress = errors.New("no suitable address found")
errMissingAddress = errors.New("missing address")
)
func (net *Net) LookupHost(host string) (addrs []string, err error) {
return net.LookupContextHost(context.Background(), host)
}
func isDomainName(s string) bool {
l := len(s)
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
return false
}
last := byte('.')
nonNumeric := false
partlen := 0
for i := 0; i < len(s); i++ {
c := s[i]
switch {
default:
return false
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
nonNumeric = true
partlen++
case '0' <= c && c <= '9':
partlen++
case c == '-':
if last == '.' {
return false
}
partlen++
nonNumeric = true
case c == '.':
if last == '.' || last == '-' {
return false
}
if partlen > 63 || partlen == 0 {
return false
}
partlen = 0
}
last = c
}
if last == '-' || partlen > 63 {
return false
}
return nonNumeric
}
func randU16() uint16 {
var b [2]byte
_, err := rand.Read(b[:])
if err != nil {
panic(err)
}
return binary.LittleEndian.Uint16(b[:])
}
func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
id = randU16()
b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
b.EnableCompression()
if err := b.StartQuestions(); err != nil {
return 0, nil, nil, err
}
if err := b.Question(q); err != nil {
return 0, nil, nil, err
}
tcpReq, err = b.Finish()
udpReq = tcpReq[2:]
l := len(tcpReq) - 2
tcpReq[0] = byte(l >> 8)
tcpReq[1] = byte(l)
return id, udpReq, tcpReq, err
}
func equalASCIIName(x, y dnsmessage.Name) bool {
if x.Length != y.Length {
return false
}
for i := 0; i < int(x.Length); i++ {
a := x.Data[i]
b := y.Data[i]
if 'A' <= a && a <= 'Z' {
a += 0x20
}
if 'A' <= b && b <= 'Z' {
b += 0x20
}
if a != b {
return false
}
}
return true
}
func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
if !respHdr.Response {
return false
}
if reqID != respHdr.ID {
return false
}
if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
return false
}
return true
}
func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
if _, err := c.Write(b); err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
b = make([]byte, 512)
for {
n, err := c.Read(b)
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
var p dnsmessage.Parser
h, err := p.Start(b[:n])
if err != nil {
continue
}
q, err := p.Question()
if err != nil || !checkResponse(id, query, h, q) {
continue
}
return p, h, nil
}
}
func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
if _, err := c.Write(b); err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
b = make([]byte, 1280)
if _, err := io.ReadFull(c, b[:2]); err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
l := int(b[0])<<8 | int(b[1])
if l > len(b) {
b = make([]byte, l)
}
n, err := io.ReadFull(c, b[:l])
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
var p dnsmessage.Parser
h, err := p.Start(b[:n])
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
}
q, err := p.Question()
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
}
if !checkResponse(id, query, h, q) {
return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
}
return p, h, nil
}
func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q)
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
}
for _, useUDP := range []bool{true, false} {
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
var c net.Conn
var err error
if useUDP {
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
} else {
c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
}
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
if d, ok := ctx.Deadline(); ok && !d.IsZero() {
c.SetDeadline(d)
}
var p dnsmessage.Parser
var h dnsmessage.Header
if useUDP {
p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
} else {
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
}
c.Close()
if err != nil {
if err == context.Canceled {
err = errCanceled
} else if err == context.DeadlineExceeded {
err = errTimeout
}
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
}
if h.Truncated {
continue
}
return p, h, nil
}
return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
}
func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
if h.RCode == dnsmessage.RCodeNameError {
return errNoSuchHost
}
_, err := p.AnswerHeader()
if err != nil && err != dnsmessage.ErrSectionDone {
return errCannotUnmarshalDNSMessage
}
if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
return errLameReferral
}
if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
if h.RCode == dnsmessage.RCodeServerFailure {
return errServerTemporarilyMisbehaving
}
return errServerMisbehaving
}
return nil
}
func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
return errNoSuchHost
}
if err != nil {
return errCannotUnmarshalDNSMessage
}
if h.Type == qtype {
return nil
}
if err := p.SkipAnswer(); err != nil {
return errCannotUnmarshalDNSMessage
}
}
}
func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
var lastErr error
n, err := dnsmessage.NewName(name)
if err != nil {
return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
}
q := dnsmessage.Question{
Name: n,
Type: qtype,
Class: dnsmessage.ClassINET,
}
for i := 0; i < 2; i++ {
for _, server := range tnet.dnsServers {
p, h, err := tnet.exchange(ctx, server, q, time.Second*5)
if err != nil {
dnsErr := &net.DNSError{
Err: err.Error(),
Name: name,
Server: server.String(),
}
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
dnsErr.IsTimeout = true
}
if _, ok := err.(*net.OpError); ok {
dnsErr.IsTemporary = true
}
lastErr = dnsErr
continue
}
if err := checkHeader(&p, h); err != nil {
dnsErr := &net.DNSError{
Err: err.Error(),
Name: name,
Server: server.String(),
}
if err == errServerTemporarilyMisbehaving {
dnsErr.IsTemporary = true
}
if err == errNoSuchHost {
dnsErr.IsNotFound = true
return p, server.String(), dnsErr
}
lastErr = dnsErr
continue
}
err = skipToAnswer(&p, qtype)
if err == nil {
return p, server.String(), nil
}
lastErr = &net.DNSError{
Err: err.Error(),
Name: name,
Server: server.String(),
}
if err == errNoSuchHost {
lastErr.(*net.DNSError).IsNotFound = true
return p, server.String(), lastErr
}
}
}
return dnsmessage.Parser{}, "", lastErr
}
func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) {
if host == "" || (!tnet.hasV6 && !tnet.hasV4) {
return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
}
zlen := len(host)
if strings.IndexByte(host, ':') != -1 {
if zidx := strings.LastIndexByte(host, '%'); zidx != -1 {
zlen = zidx
}
}
if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
return []string{ip.String()}, nil
}
if !isDomainName(host) {
return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
}
type result struct {
p dnsmessage.Parser
server string
error
}
var addrsV4, addrsV6 []netip.Addr
lanes := 0
if tnet.hasV4 {
lanes++
}
if tnet.hasV6 {
lanes++
}
lane := make(chan result, lanes)
var lastErr error
if tnet.hasV4 {
go func() {
p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA)
lane <- result{p, server, err}
}()
}
if tnet.hasV6 {
go func() {
p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA)
lane <- result{p, server, err}
}()
}
for l := 0; l < lanes; l++ {
result := <-lane
if result.error != nil {
if lastErr == nil {
lastErr = result.error
}
continue
}
loop:
for {
h, err := result.p.AnswerHeader()
if err != nil && err != dnsmessage.ErrSectionDone {
lastErr = &net.DNSError{
Err: errCannotMarshalDNSMessage.Error(),
Name: host,
Server: result.server,
}
}
if err != nil {
break
}
switch h.Type {
case dnsmessage.TypeA:
a, err := result.p.AResource()
if err != nil {
lastErr = &net.DNSError{
Err: errCannotMarshalDNSMessage.Error(),
Name: host,
Server: result.server,
}
break loop
}
addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
case dnsmessage.TypeAAAA:
aaaa, err := result.p.AAAAResource()
if err != nil {
lastErr = &net.DNSError{
Err: errCannotMarshalDNSMessage.Error(),
Name: host,
Server: result.server,
}
break loop
}
addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
default:
if err := result.p.SkipAnswer(); err != nil {
lastErr = &net.DNSError{
Err: errCannotMarshalDNSMessage.Error(),
Name: host,
Server: result.server,
}
break loop
}
continue
}
}
}
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
var addrs []netip.Addr
if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...)
} else {
addrs = append(addrsV4, addrsV6...)
}
if len(addrs) == 0 && lastErr != nil {
return nil, lastErr
}
saddrs := make([]string, 0, len(addrs))
for _, ip := range addrs {
saddrs = append(saddrs, ip.String())
}
return saddrs, nil
}
func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
if deadline.IsZero() {
return deadline, nil
}
timeRemaining := deadline.Sub(now)
if timeRemaining <= 0 {
return time.Time{}, errTimeout
}
timeout := timeRemaining / time.Duration(addrsRemaining)
const saneMinimum = 2 * time.Second
if timeout < saneMinimum {
if timeRemaining < saneMinimum {
timeout = timeRemaining
} else {
timeout = saneMinimum
}
}
return now.Add(timeout), nil
}
var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if ctx == nil {
panic("nil context")
}
var acceptV4, acceptV6 bool
matches := protoSplitter.FindStringSubmatch(network)
if matches == nil {
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
} else if len(matches[2]) == 0 {
acceptV4 = true
acceptV6 = true
} else {
acceptV4 = matches[2][0] == '4'
acceptV6 = !acceptV4
}
var host string
var port int
if matches[1] == "ping" {
host = address
} else {
var sport string
var err error
host, sport, err = net.SplitHostPort(address)
if err != nil {
return nil, &net.OpError{Op: "dial", Err: err}
}
port, err = strconv.Atoi(sport)
if err != nil || port < 0 || port > 65535 {
return nil, &net.OpError{Op: "dial", Err: errNumericPort}
}
}
allAddr, err := tnet.LookupContextHost(ctx, host)
if err != nil {
return nil, &net.OpError{Op: "dial", Err: err}
}
var addrs []netip.AddrPort
for _, addr := range allAddr {
ip, err := netip.ParseAddr(addr)
if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
}
}
if len(addrs) == 0 && len(allAddr) != 0 {
return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
}
var firstErr error
for i, addr := range addrs {
select {
case <-ctx.Done():
err := ctx.Err()
if err == context.Canceled {
err = errCanceled
} else if err == context.DeadlineExceeded {
err = errTimeout
}
return nil, &net.OpError{Op: "dial", Err: err}
default:
}
dialCtx := ctx
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
if err != nil {
if firstErr == nil {
firstErr = &net.OpError{Op: "dial", Err: err}
}
break
}
if partialDeadline.Before(deadline) {
var cancel context.CancelFunc
dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
defer cancel()
}
}
var c net.Conn
switch matches[1] {
case "tcp":
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
case "udp":
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
case "ping":
c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
}
if err == nil {
return c, nil
}
if firstErr == nil {
firstErr = err
}
}
if firstErr == nil {
firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
}
return nil, firstErr
}
func (tnet *Net) Dial(network, address string) (net.Conn, error) {
return tnet.DialContext(context.Background(), network, address)
}