mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
device: test up/down using virtual conn
This prevents port clashing bugs. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
6603c05a4a
commit
9a29ae267c
136
conn/bindtest/bindtest.go
Normal file
136
conn/bindtest/bindtest.go
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package bindtest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChannelBind struct {
|
||||||
|
rx4, tx4 *chan []byte
|
||||||
|
rx6, tx6 *chan []byte
|
||||||
|
closeSignal chan bool
|
||||||
|
source4, source6 ChannelEndpoint
|
||||||
|
target4, target6 ChannelEndpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChannelEndpoint uint16
|
||||||
|
|
||||||
|
var _ conn.Bind = (*ChannelBind)(nil)
|
||||||
|
var _ conn.Endpoint = (*ChannelEndpoint)(nil)
|
||||||
|
|
||||||
|
func NewChannelBinds() [2]conn.Bind {
|
||||||
|
arx4 := make(chan []byte, 8192)
|
||||||
|
brx4 := make(chan []byte, 8192)
|
||||||
|
arx6 := make(chan []byte, 8192)
|
||||||
|
brx6 := make(chan []byte, 8192)
|
||||||
|
var binds [2]ChannelBind
|
||||||
|
binds[0].rx4 = &arx4
|
||||||
|
binds[0].tx4 = &brx4
|
||||||
|
binds[1].rx4 = &brx4
|
||||||
|
binds[1].tx4 = &arx4
|
||||||
|
binds[0].rx6 = &arx6
|
||||||
|
binds[0].tx6 = &brx6
|
||||||
|
binds[1].rx6 = &brx6
|
||||||
|
binds[1].tx6 = &arx6
|
||||||
|
binds[0].target4 = ChannelEndpoint(1)
|
||||||
|
binds[1].target4 = ChannelEndpoint(2)
|
||||||
|
binds[0].target6 = ChannelEndpoint(3)
|
||||||
|
binds[1].target6 = ChannelEndpoint(4)
|
||||||
|
binds[0].source4 = binds[1].target4
|
||||||
|
binds[0].source6 = binds[1].target6
|
||||||
|
binds[1].source4 = binds[0].target4
|
||||||
|
binds[1].source6 = binds[0].target6
|
||||||
|
return [2]conn.Bind{&binds[0], &binds[1]}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) SrcToString() string { return "" }
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) SrcIP() net.IP { return nil }
|
||||||
|
|
||||||
|
func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) {
|
||||||
|
c.closeSignal = make(chan bool)
|
||||||
|
if rand.Uint32()&1 == 0 {
|
||||||
|
return uint16(c.source4), nil
|
||||||
|
} else {
|
||||||
|
return uint16(c.source6), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelBind) Close() error {
|
||||||
|
if c.closeSignal != nil {
|
||||||
|
select {
|
||||||
|
case <-c.closeSignal:
|
||||||
|
default:
|
||||||
|
close(c.closeSignal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
||||||
|
|
||||||
|
func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||||
|
select {
|
||||||
|
case <-c.closeSignal:
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
case rx := <-*c.rx6:
|
||||||
|
return copy(b, rx), c.target6, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||||
|
select {
|
||||||
|
case <-c.closeSignal:
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
case rx := <-*c.rx4:
|
||||||
|
return copy(b, rx), c.target4, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
|
||||||
|
select {
|
||||||
|
case <-c.closeSignal:
|
||||||
|
return net.ErrClosed
|
||||||
|
default:
|
||||||
|
bc := make([]byte, len(b))
|
||||||
|
copy(bc, b)
|
||||||
|
if ep.(ChannelEndpoint) == c.target4 {
|
||||||
|
*c.tx4 <- bc
|
||||||
|
} else if ep.(ChannelEndpoint) == c.target6 {
|
||||||
|
*c.tx6 <- bc
|
||||||
|
} else {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||||
|
_, port, err := net.SplitHostPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
i, err := strconv.ParseUint(port, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ChannelEndpoint(i), nil
|
||||||
|
}
|
@ -8,7 +8,6 @@ package device
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@ -17,11 +16,11 @@ import (
|
|||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
"golang.zx2c4.com/wireguard/conn/bindtest"
|
||||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -148,8 +147,14 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}
|
|||||||
}
|
}
|
||||||
|
|
||||||
// genTestPair creates a testPair.
|
// genTestPair creates a testPair.
|
||||||
func genTestPair(tb testing.TB) (pair testPair) {
|
func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
|
||||||
cfg, endpointCfg := genConfigs(tb)
|
cfg, endpointCfg := genConfigs(tb)
|
||||||
|
var binds [2]conn.Bind
|
||||||
|
if realSocket {
|
||||||
|
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
|
||||||
|
} else {
|
||||||
|
binds = bindtest.NewChannelBinds()
|
||||||
|
}
|
||||||
// Bring up a ChannelTun for each config.
|
// Bring up a ChannelTun for each config.
|
||||||
for i := range pair {
|
for i := range pair {
|
||||||
p := &pair[i]
|
p := &pair[i]
|
||||||
@ -159,7 +164,7 @@ func genTestPair(tb testing.TB) (pair testPair) {
|
|||||||
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
||||||
level = LogLevelError
|
level = LogLevelError
|
||||||
}
|
}
|
||||||
p.dev = NewDevice(p.tun.TUN(), conn.NewDefaultBind(), NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
||||||
if err := p.dev.IpcSet(cfg[i]); err != nil {
|
if err := p.dev.IpcSet(cfg[i]); err != nil {
|
||||||
tb.Errorf("failed to configure device %d: %v", i, err)
|
tb.Errorf("failed to configure device %d: %v", i, err)
|
||||||
p.dev.Close()
|
p.dev.Close()
|
||||||
@ -187,7 +192,7 @@ func genTestPair(tb testing.TB) (pair testPair) {
|
|||||||
|
|
||||||
func TestTwoDevicePing(t *testing.T) {
|
func TestTwoDevicePing(t *testing.T) {
|
||||||
goroutineLeakCheck(t)
|
goroutineLeakCheck(t)
|
||||||
pair := genTestPair(t)
|
pair := genTestPair(t, true)
|
||||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
pair.Send(t, Ping, nil)
|
pair.Send(t, Ping, nil)
|
||||||
})
|
})
|
||||||
@ -198,11 +203,11 @@ func TestTwoDevicePing(t *testing.T) {
|
|||||||
|
|
||||||
func TestUpDown(t *testing.T) {
|
func TestUpDown(t *testing.T) {
|
||||||
goroutineLeakCheck(t)
|
goroutineLeakCheck(t)
|
||||||
const itrials = 20
|
const itrials = 50
|
||||||
const otrials = 1
|
const otrials = 10
|
||||||
|
|
||||||
for n := 0; n < otrials; n++ {
|
for n := 0; n < otrials; n++ {
|
||||||
pair := genTestPair(t)
|
pair := genTestPair(t, false)
|
||||||
for i := range pair {
|
for i := range pair {
|
||||||
for k := range pair[i].dev.peers.keyMap {
|
for k := range pair[i].dev.peers.keyMap {
|
||||||
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
||||||
@ -214,17 +219,8 @@ func TestUpDown(t *testing.T) {
|
|||||||
go func(d *Device) {
|
go func(d *Device) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for i := 0; i < itrials; i++ {
|
for i := 0; i < itrials; i++ {
|
||||||
start := time.Now()
|
if err := d.Up(); err != nil {
|
||||||
for {
|
t.Errorf("failed up bring up device: %v", err)
|
||||||
if err := d.Up(); err != nil {
|
|
||||||
if errors.Is(err, syscall.EADDRINUSE) && time.Now().Sub(start) < time.Second*4 {
|
|
||||||
// Some other test process is racing with us, so try again.
|
|
||||||
time.Sleep(time.Millisecond * 10)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
t.Errorf("failed up bring up device: %v", err)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
|
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
|
||||||
if err := d.Down(); err != nil {
|
if err := d.Down(); err != nil {
|
||||||
@ -245,7 +241,7 @@ func TestUpDown(t *testing.T) {
|
|||||||
// TestConcurrencySafety does other things concurrently with tunnel use.
|
// TestConcurrencySafety does other things concurrently with tunnel use.
|
||||||
// It is intended to be used with the race detector to catch data races.
|
// It is intended to be used with the race detector to catch data races.
|
||||||
func TestConcurrencySafety(t *testing.T) {
|
func TestConcurrencySafety(t *testing.T) {
|
||||||
pair := genTestPair(t)
|
pair := genTestPair(t, true)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
||||||
const warmupIters = 10
|
const warmupIters = 10
|
||||||
@ -315,7 +311,7 @@ func TestConcurrencySafety(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLatency(b *testing.B) {
|
func BenchmarkLatency(b *testing.B) {
|
||||||
pair := genTestPair(b)
|
pair := genTestPair(b, true)
|
||||||
|
|
||||||
// Establish a connection.
|
// Establish a connection.
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
@ -329,7 +325,7 @@ func BenchmarkLatency(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkThroughput(b *testing.B) {
|
func BenchmarkThroughput(b *testing.B) {
|
||||||
pair := genTestPair(b)
|
pair := genTestPair(b, true)
|
||||||
|
|
||||||
// Establish a connection.
|
// Establish a connection.
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
@ -373,7 +369,7 @@ func BenchmarkThroughput(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkUAPIGet(b *testing.B) {
|
func BenchmarkUAPIGet(b *testing.B) {
|
||||||
pair := genTestPair(b)
|
pair := genTestPair(b, true)
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
pair.Send(b, Pong, nil)
|
pair.Send(b, Pong, nil)
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
@ -79,7 +79,6 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
|
|||||||
return pkt
|
return pkt
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(crawshaw): find a reusable home for this. package devicetest?
|
|
||||||
type ChannelTUN struct {
|
type ChannelTUN struct {
|
||||||
Inbound chan []byte // incoming packets, closed on TUN close
|
Inbound chan []byte // incoming packets, closed on TUN close
|
||||||
Outbound chan []byte // outbound packets, blocks forever on TUN close
|
Outbound chan []byte // outbound packets, blocks forever on TUN close
|
||||||
|
Loading…
Reference in New Issue
Block a user