1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2025-09-18 20:57:50 +02:00

tun/netstack: implement ICMP ping

Provide a PacketConn interface for netstack's ICMP endpoint; netstack
currently only provides EchoRequest/EchoResponse ICMP support, so this
code exposes only an interface for doing ping.

Currently is missing:
- Write deadlines
- Context support

Signed-off-by: Thomas Ptacek <thomas@sockpuppet.org>
[Jason: rework structure, match std go interfaces, add example code]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Thomas H. Ptacek 2022-01-31 16:55:36 -06:00 committed by Jason A. Donenfeld
parent e0b8f11489
commit a702597e22
2 changed files with 264 additions and 24 deletions

View File

@ -0,0 +1,57 @@
//go:build ignore
// +build ignore
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"log"
"time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
)
func main() {
tun, tnet, err := netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr("192.168.4.29")},
[]netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
log.Panic(err)
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
endpoint=163.172.161.0:12912
allowed_ip=0.0.0.0/0
`)
err = dev.Up()
if err != nil {
log.Panic(err)
}
socket, err := tnet.Dial("ping4", "zx2c4.com")
if err != nil {
log.Panic(err)
}
const payload = "gopher burrow"
socket.SetReadDeadline(time.Now().Add(time.Second * 10))
start := time.Now()
_, err = socket.Write([]byte(payload))
if err != nil {
log.Panic(err)
}
var reply [len(payload)]byte
n, err := socket.Read(reply[:])
if err != nil || string(reply[:n]) != payload {
log.Panic(err)
}
log.Printf("Ping latency: %v", time.Since(start))
}

View File

@ -14,6 +14,7 @@ import (
"io"
"net"
"os"
"regexp"
"strconv"
"strings"
"time"
@ -29,8 +30,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
type netTun struct {
@ -101,7 +104,7 @@ func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.Network
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
HandleLocal: true,
}
dev := &netTun{
@ -281,6 +284,178 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
return net.DialUDPAddrPort(la, ra)
}
type PingConn struct {
laddr PingAddr
raddr PingAddr
wq waiter.Queue
ep tcpip.Endpoint
deadline time.Time
}
type PingAddr struct{ addr netip.Addr }
func (ia PingAddr) String() string {
return ia.addr.String()
}
func (ia PingAddr) Network() string {
if ia.addr.Is4() {
return "ping4"
} else if ia.addr.Is6() {
return "ping6"
}
return "ping"
}
func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
v6 := laddr.Is6() || raddr.Is6()
bind := laddr.IsValid()
if !bind {
if v6 {
laddr = netip.IPv6Unspecified()
} else {
laddr = netip.IPv4Unspecified()
}
}
tn := icmp.ProtocolNumber4
pn := ipv4.ProtocolNumber
if v6 {
tn = icmp.ProtocolNumber6
pn = ipv6.ProtocolNumber
}
pc := &PingConn{laddr: PingAddr{laddr}}
ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
if tcpipErr != nil {
return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
}
pc.ep = ep
if bind {
fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
return nil, fmt.Errorf("ping bind: %s", tcpipErr)
}
}
if raddr.IsValid() {
pc.raddr = PingAddr{raddr}
fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
return nil, fmt.Errorf("ping connect: %s", tcpipErr)
}
}
return pc, nil
}
func (pc *PingConn) LocalAddr() net.Addr {
return pc.laddr
}
func (pc *PingConn) RemoteAddr() net.Addr {
return pc.raddr
}
func (pc *PingConn) Close() error {
pc.ep.Close()
return nil
}
func (pc *PingConn) SetWriteDeadline(t time.Time) error {
return errors.New("not implemented")
}
func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
ia, ok := addr.(PingAddr)
if !ok || !((ia.addr.Is4() && pc.laddr.addr.Is4()) || (ia.addr.Is6() && pc.laddr.addr.Is6())) {
return 0, fmt.Errorf("ping write: mismatched protocols")
}
var buf buffer.View
if ia.addr.Is4() {
buf = buffer.NewView(header.ICMPv4MinimumSize + len(p))
copy(buf[header.ICMPv4MinimumSize:], p)
icmp := header.ICMPv4(buf)
icmp.SetType(header.ICMPv4Echo)
} else if ia.addr.Is6() {
buf = buffer.NewView(header.ICMPv6MinimumSize + len(p))
copy(buf[header.ICMPv6MinimumSize:], p)
icmp := header.ICMPv6(buf)
icmp.SetType(header.ICMPv6EchoRequest)
}
rdr := buf.Reader()
rfa, _ := convertToFullAddr(netip.AddrPortFrom(ia.addr, 0))
// won't block, no deadlines
n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{
To: &rfa,
})
if tcpipErr != nil {
return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
}
return int(n64), nil
}
func (pc *PingConn) Write(p []byte) (n int, err error) {
return pc.WriteTo(p, pc.raddr)
}
func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
e, notifyCh := waiter.NewChannelEntry(nil)
pc.wq.EventRegister(&e, waiter.EventIn)
defer pc.wq.EventUnregister(&e)
deadline := pc.deadline
if deadline.IsZero() {
<-notifyCh
} else {
select {
case <-time.NewTimer(deadline.Sub(time.Now())).C:
return 0, nil, os.ErrDeadlineExceeded
case <-notifyCh:
}
}
min := header.ICMPv6MinimumSize
if pc.laddr.addr.Is4() {
min = header.ICMPv4MinimumSize
}
reply := make([]byte, min+len(p))
w := tcpip.SliceWriter(reply)
res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
NeedRemoteAddr: true,
})
if tcpipErr != nil {
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
}
addr = PingAddr{netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))}
copy(p, reply[min:res.Count])
return res.Count - min, addr, nil
}
func (pc *PingConn) Read(p []byte) (n int, err error) {
n, _, err = pc.ReadFrom(p)
return
}
func (pc *PingConn) SetDeadline(t time.Time) error {
// pc.SetWriteDeadline is unimplemented
return pc.SetReadDeadline(t)
}
func (pc *PingConn) SetReadDeadline(t time.Time) error {
pc.deadline = t
return nil
}
var (
errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral")
@ -755,34 +930,39 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er
return now.Add(timeout), nil
}
var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if ctx == nil {
panic("nil context")
}
var acceptV4, acceptV6, useUDP bool
if len(network) == 3 {
var acceptV4, acceptV6 bool
matches := protoSplitter.FindStringSubmatch(network)
if matches == nil {
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
} else if len(matches[2]) == 0 {
acceptV4 = true
acceptV6 = true
} else if len(network) == 4 {
acceptV4 = network[3] == '4'
acceptV6 = network[3] == '6'
} else {
acceptV4 = matches[2][0] == '4'
acceptV6 = !acceptV4
}
if !acceptV4 && !acceptV6 {
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
}
if network[:3] == "udp" {
useUDP = true
} else if network[:3] != "tcp" {
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
}
host, sport, err := net.SplitHostPort(address)
var host string
var port int
if matches[1] == "ping" {
host = address
} else {
var sport string
var err error
host, sport, err = net.SplitHostPort(address)
if err != nil {
return nil, &net.OpError{Op: "dial", Err: err}
}
port, err := strconv.Atoi(sport)
port, err = strconv.Atoi(sport)
if err != nil || port < 0 || port > 65535 {
return nil, &net.OpError{Op: "dial", Err: errNumericPort}
}
}
allAddr, err := tnet.LookupContextHost(ctx, host)
if err != nil {
return nil, &net.OpError{Op: "dial", Err: err}
@ -829,10 +1009,13 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
}
var c net.Conn
if useUDP {
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
} else {
switch matches[1] {
case "tcp":
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
case "udp":
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
case "ping":
c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
}
if err == nil {
return c, nil