mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
namedpipe: rename from winpipe to keep in sync with CL299009
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
eb6302c7eb
commit
c07dd60cdb
@ -1,12 +1,12 @@
|
|||||||
|
// Copyright 2021 The Go Authors. All rights reserved.
|
||||||
|
// Copyright 2015 Microsoft
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
//go:build windows
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
package namedpipe
|
||||||
*
|
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package winpipe
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
@ -1,13 +1,13 @@
|
|||||||
|
// Copyright 2021 The Go Authors. All rights reserved.
|
||||||
|
// Copyright 2015 Microsoft
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
//go:build windows
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
|
||||||
*
|
package namedpipe
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Package winpipe implements a net.Conn and net.Listener around Windows named pipes.
|
|
||||||
package winpipe
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -15,6 +15,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
@ -28,7 +29,7 @@ type pipe struct {
|
|||||||
|
|
||||||
type messageBytePipe struct {
|
type messageBytePipe struct {
|
||||||
pipe
|
pipe
|
||||||
writeClosed bool
|
writeClosed int32
|
||||||
readEOF bool
|
readEOF bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,25 +51,26 @@ func (f *pipe) SetDeadline(t time.Time) error {
|
|||||||
|
|
||||||
// CloseWrite closes the write side of a message pipe in byte mode.
|
// CloseWrite closes the write side of a message pipe in byte mode.
|
||||||
func (f *messageBytePipe) CloseWrite() error {
|
func (f *messageBytePipe) CloseWrite() error {
|
||||||
if f.writeClosed {
|
if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) {
|
||||||
return io.ErrClosedPipe
|
return io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
err := f.file.Flush()
|
err := f.file.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
atomic.StoreInt32(&f.writeClosed, 0)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = f.file.Write(nil)
|
_, err = f.file.Write(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
atomic.StoreInt32(&f.writeClosed, 0)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
f.writeClosed = true
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
||||||
// they are used to implement CloseWrite.
|
// they are used to implement CloseWrite.
|
||||||
func (f *messageBytePipe) Write(b []byte) (int, error) {
|
func (f *messageBytePipe) Write(b []byte) (int, error) {
|
||||||
if f.writeClosed {
|
if atomic.LoadInt32(&f.writeClosed) != 0 {
|
||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
@ -142,30 +144,24 @@ type DialConfig struct {
|
|||||||
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
|
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial connects to the specified named pipe by path, timing out if the connection
|
// DialTimeout connects to the specified named pipe by path, timing out if the
|
||||||
// takes longer than the specified duration. If timeout is nil, then we use
|
// connection takes longer than the specified duration. If timeout is zero, then
|
||||||
// a default timeout of 2 seconds.
|
// we use a default timeout of 2 seconds.
|
||||||
func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) {
|
func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
|
||||||
var absTimeout time.Time
|
if timeout == 0 {
|
||||||
if timeout != nil {
|
timeout = time.Second * 2
|
||||||
absTimeout = time.Now().Add(*timeout)
|
|
||||||
} else {
|
|
||||||
absTimeout = time.Now().Add(2 * time.Second)
|
|
||||||
}
|
}
|
||||||
|
absTimeout := time.Now().Add(timeout)
|
||||||
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
||||||
conn, err := DialContext(ctx, path, config)
|
conn, err := config.DialContext(ctx, path)
|
||||||
if err == context.DeadlineExceeded {
|
if err == context.DeadlineExceeded {
|
||||||
return nil, os.ErrDeadlineExceeded
|
return nil, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
return conn, err
|
return conn, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContext attempts to connect to the specified named pipe by path
|
// DialContext attempts to connect to the specified named pipe by path.
|
||||||
// cancellation or timeout.
|
func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
|
||||||
func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) {
|
|
||||||
if config == nil {
|
|
||||||
config = &DialConfig{}
|
|
||||||
}
|
|
||||||
var err error
|
var err error
|
||||||
var h windows.Handle
|
var h windows.Handle
|
||||||
h, err = tryDialPipe(ctx, &path)
|
h, err = tryDialPipe(ctx, &path)
|
||||||
@ -213,6 +209,18 @@ func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn
|
|||||||
return &pipe{file: f, path: path}, nil
|
return &pipe{file: f, path: path}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultDialer DialConfig
|
||||||
|
|
||||||
|
// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
|
||||||
|
func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
return defaultDialer.DialTimeout(path, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext calls DialConfig.DialContext using an empty configuration.
|
||||||
|
func DialContext(ctx context.Context, path string) (net.Conn, error) {
|
||||||
|
return defaultDialer.DialContext(ctx, path)
|
||||||
|
}
|
||||||
|
|
||||||
type acceptResponse struct {
|
type acceptResponse struct {
|
||||||
f *file
|
f *file
|
||||||
err error
|
err error
|
||||||
@ -222,12 +230,12 @@ type pipeListener struct {
|
|||||||
firstHandle windows.Handle
|
firstHandle windows.Handle
|
||||||
path string
|
path string
|
||||||
config ListenConfig
|
config ListenConfig
|
||||||
acceptCh chan (chan acceptResponse)
|
acceptCh chan chan acceptResponse
|
||||||
closeCh chan int
|
closeCh chan int
|
||||||
doneCh chan int
|
doneCh chan int
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) {
|
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
|
||||||
path16, err := windows.UTF16PtrFromString(path)
|
path16, err := windows.UTF16PtrFromString(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
@ -247,7 +255,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
|
|||||||
oa.ObjectName = &ntPath
|
oa.ObjectName = &ntPath
|
||||||
|
|
||||||
// The security descriptor is only needed for the first pipe.
|
// The security descriptor is only needed for the first pipe.
|
||||||
if first {
|
if isFirstPipe {
|
||||||
if sd != nil {
|
if sd != nil {
|
||||||
oa.SecurityDescriptor = sd
|
oa.SecurityDescriptor = sd
|
||||||
} else {
|
} else {
|
||||||
@ -257,7 +265,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
|
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
|
||||||
sd, err := windows.NewSecurityDescriptor()
|
sd, err = windows.NewSecurityDescriptor()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -275,11 +283,11 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
|
|||||||
|
|
||||||
disposition := uint32(windows.FILE_OPEN)
|
disposition := uint32(windows.FILE_OPEN)
|
||||||
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
|
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
|
||||||
if first {
|
if isFirstPipe {
|
||||||
disposition = windows.FILE_CREATE
|
disposition = windows.FILE_CREATE
|
||||||
// By not asking for read or write access, the named pipe file system
|
// By not asking for read or write access, the named pipe file system
|
||||||
// will put this pipe into an initially disconnected state, blocking
|
// will put this pipe into an initially disconnected state, blocking
|
||||||
// client connections until the next call with first == false.
|
// client connections until the next call with isFirstPipe == false.
|
||||||
access = windows.SYNCHRONIZE
|
access = windows.SYNCHRONIZE
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -395,10 +403,7 @@ type ListenConfig struct {
|
|||||||
|
|
||||||
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
|
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
|
||||||
// The pipe must not already exist.
|
// The pipe must not already exist.
|
||||||
func Listen(path string, c *ListenConfig) (net.Listener, error) {
|
func (c *ListenConfig) Listen(path string) (net.Listener, error) {
|
||||||
if c == nil {
|
|
||||||
c = &ListenConfig{}
|
|
||||||
}
|
|
||||||
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
|
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -407,12 +412,12 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
|
|||||||
firstHandle: h,
|
firstHandle: h,
|
||||||
path: path,
|
path: path,
|
||||||
config: *c,
|
config: *c,
|
||||||
acceptCh: make(chan (chan acceptResponse)),
|
acceptCh: make(chan chan acceptResponse),
|
||||||
closeCh: make(chan int),
|
closeCh: make(chan int),
|
||||||
doneCh: make(chan int),
|
doneCh: make(chan int),
|
||||||
}
|
}
|
||||||
// The first connection is swallowed on Windows 7 & 8, so synthesize it.
|
// The first connection is swallowed on Windows 7 & 8, so synthesize it.
|
||||||
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
|
if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
|
||||||
path16, err := windows.UTF16PtrFromString(path)
|
path16, err := windows.UTF16PtrFromString(path)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
|
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
|
||||||
@ -425,6 +430,13 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
|
|||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultListener ListenConfig
|
||||||
|
|
||||||
|
// Listen calls ListenConfig.Listen using an empty configuration.
|
||||||
|
func Listen(path string) (net.Listener, error) {
|
||||||
|
return defaultListener.Listen(path)
|
||||||
|
}
|
||||||
|
|
||||||
func connectPipe(p *file) error {
|
func connectPipe(p *file) error {
|
||||||
c, err := p.prepareIo()
|
c, err := p.prepareIo()
|
||||||
if err != nil {
|
if err != nil {
|
@ -1,12 +1,12 @@
|
|||||||
|
// Copyright 2021 The Go Authors. All rights reserved.
|
||||||
|
// Copyright 2015 Microsoft
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
//go:build windows
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
package namedpipe_test
|
||||||
*
|
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package winpipe_test
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@ -22,7 +22,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"golang.zx2c4.com/wireguard/ipc/winpipe"
|
"golang.zx2c4.com/wireguard/ipc/namedpipe"
|
||||||
)
|
)
|
||||||
|
|
||||||
func randomPipePath() string {
|
func randomPipePath() string {
|
||||||
@ -30,7 +30,7 @@ func randomPipePath() string {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return `\\.\PIPE\go-winpipe-test-` + guid.String()
|
return `\\.\PIPE\go-namedpipe-test-` + guid.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPingPong(t *testing.T) {
|
func TestPingPong(t *testing.T) {
|
||||||
@ -39,7 +39,7 @@ func TestPingPong(t *testing.T) {
|
|||||||
pong = 24
|
pong = 24
|
||||||
)
|
)
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
listener, err := winpipe.Listen(pipePath, nil)
|
listener, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to listen on pipe: %v", err)
|
t.Fatalf("unable to listen on pipe: %v", err)
|
||||||
}
|
}
|
||||||
@ -64,11 +64,12 @@ func TestPingPong(t *testing.T) {
|
|||||||
t.Fatalf("unable to write pong to pipe: %v", err)
|
t.Fatalf("unable to write pong to pipe: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
client, err := winpipe.Dial(pipePath, nil, nil)
|
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to dial pipe: %v", err)
|
t.Fatalf("unable to dial pipe: %v", err)
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
client.SetDeadline(time.Now().Add(time.Second * 5))
|
||||||
var data [1]byte
|
var data [1]byte
|
||||||
data[0] = ping
|
data[0] = ping
|
||||||
_, err = client.Write(data[:])
|
_, err = client.Write(data[:])
|
||||||
@ -85,7 +86,7 @@ func TestPingPong(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDialUnknownFailsImmediately(t *testing.T) {
|
func TestDialUnknownFailsImmediately(t *testing.T) {
|
||||||
_, err := winpipe.Dial(randomPipePath(), nil, nil)
|
_, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
|
||||||
if !errors.Is(err, syscall.ENOENT) {
|
if !errors.Is(err, syscall.ENOENT) {
|
||||||
t.Fatalf("expected ENOENT got %v", err)
|
t.Fatalf("expected ENOENT got %v", err)
|
||||||
}
|
}
|
||||||
@ -93,13 +94,15 @@ func TestDialUnknownFailsImmediately(t *testing.T) {
|
|||||||
|
|
||||||
func TestDialListenerTimesOut(t *testing.T) {
|
func TestDialListenerTimesOut(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
d := 10 * time.Millisecond
|
pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
|
||||||
_, err = winpipe.Dial(pipePath, &d, nil)
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
if err != os.ErrDeadlineExceeded {
|
if err != os.ErrDeadlineExceeded {
|
||||||
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
}
|
}
|
||||||
@ -107,14 +110,17 @@ func TestDialListenerTimesOut(t *testing.T) {
|
|||||||
|
|
||||||
func TestDialContextListenerTimesOut(t *testing.T) {
|
func TestDialContextListenerTimesOut(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
d := 10 * time.Millisecond
|
d := 10 * time.Millisecond
|
||||||
ctx, _ := context.WithTimeout(context.Background(), d)
|
ctx, _ := context.WithTimeout(context.Background(), d)
|
||||||
_, err = winpipe.DialContext(ctx, pipePath, nil)
|
pipe, err := namedpipe.DialContext(ctx, pipePath)
|
||||||
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
if err != context.DeadlineExceeded {
|
if err != context.DeadlineExceeded {
|
||||||
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
|
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
|
||||||
}
|
}
|
||||||
@ -123,14 +129,14 @@ func TestDialContextListenerTimesOut(t *testing.T) {
|
|||||||
func TestDialListenerGetsCancelled(t *testing.T) {
|
func TestDialListenerGetsCancelled(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
ch := make(chan error)
|
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
ch := make(chan error)
|
||||||
go func(ctx context.Context, ch chan error) {
|
go func(ctx context.Context, ch chan error) {
|
||||||
_, err := winpipe.DialContext(ctx, pipePath, nil)
|
_, err := namedpipe.DialContext(ctx, pipePath)
|
||||||
ch <- err
|
ch <- err
|
||||||
}(ctx, ch)
|
}(ctx, ch)
|
||||||
time.Sleep(time.Millisecond * 30)
|
time.Sleep(time.Millisecond * 30)
|
||||||
@ -147,23 +153,28 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
|
|||||||
}
|
}
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
sd, _ := windows.SecurityDescriptorFromString("D:")
|
sd, _ := windows.SecurityDescriptorFromString("D:")
|
||||||
c := winpipe.ListenConfig{
|
l, err := (&namedpipe.ListenConfig{
|
||||||
SecurityDescriptor: sd,
|
SecurityDescriptor: sd,
|
||||||
}
|
}).Listen(pipePath)
|
||||||
l, err := winpipe.Listen(pipePath, &c)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
_, err = winpipe.Dial(pipePath, nil, nil)
|
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
|
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
|
||||||
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
|
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) {
|
func getConnection(cfg *namedpipe.ListenConfig) (client net.Conn, server net.Conn, err error) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, cfg)
|
if cfg == nil {
|
||||||
|
cfg = &namedpipe.ListenConfig{}
|
||||||
|
}
|
||||||
|
l, err := cfg.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -179,7 +190,7 @@ func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn,
|
|||||||
ch <- response{c, err}
|
ch <- response{c, err}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c, err := winpipe.Dial(pipePath, nil, nil)
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -236,7 +247,7 @@ func server(l net.Listener, ch chan int) {
|
|||||||
|
|
||||||
func TestFullListenDialReadWrite(t *testing.T) {
|
func TestFullListenDialReadWrite(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -245,7 +256,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
|
|||||||
ch := make(chan int)
|
ch := make(chan int)
|
||||||
go server(l, ch)
|
go server(l, ch)
|
||||||
|
|
||||||
c, err := winpipe.Dial(pipePath, nil, nil)
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -275,7 +286,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
|
|||||||
|
|
||||||
func TestCloseAbortsListen(t *testing.T) {
|
func TestCloseAbortsListen(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -328,7 +339,7 @@ func TestCloseServerEOFClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCloseWriteEOF(t *testing.T) {
|
func TestCloseWriteEOF(t *testing.T) {
|
||||||
cfg := &winpipe.ListenConfig{
|
cfg := &namedpipe.ListenConfig{
|
||||||
MessageMode: true,
|
MessageMode: true,
|
||||||
}
|
}
|
||||||
c, s, err := getConnection(cfg)
|
c, s, err := getConnection(cfg)
|
||||||
@ -356,7 +367,7 @@ func TestCloseWriteEOF(t *testing.T) {
|
|||||||
|
|
||||||
func TestAcceptAfterCloseFails(t *testing.T) {
|
func TestAcceptAfterCloseFails(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -369,12 +380,15 @@ func TestAcceptAfterCloseFails(t *testing.T) {
|
|||||||
|
|
||||||
func TestDialTimesOutByDefault(t *testing.T) {
|
func TestDialTimesOutByDefault(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
_, err = winpipe.Dial(pipePath, nil, nil)
|
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
|
||||||
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
if err != os.ErrDeadlineExceeded {
|
if err != os.ErrDeadlineExceeded {
|
||||||
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
}
|
}
|
||||||
@ -382,7 +396,7 @@ func TestDialTimesOutByDefault(t *testing.T) {
|
|||||||
|
|
||||||
func TestTimeoutPendingRead(t *testing.T) {
|
func TestTimeoutPendingRead(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -400,7 +414,7 @@ func TestTimeoutPendingRead(t *testing.T) {
|
|||||||
close(serverDone)
|
close(serverDone)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := winpipe.Dial(pipePath, nil, nil)
|
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -430,7 +444,7 @@ func TestTimeoutPendingRead(t *testing.T) {
|
|||||||
|
|
||||||
func TestTimeoutPendingWrite(t *testing.T) {
|
func TestTimeoutPendingWrite(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -448,7 +462,7 @@ func TestTimeoutPendingWrite(t *testing.T) {
|
|||||||
close(serverDone)
|
close(serverDone)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := winpipe.Dial(pipePath, nil, nil)
|
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -480,13 +494,12 @@ type CloseWriter interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEchoWithMessaging(t *testing.T) {
|
func TestEchoWithMessaging(t *testing.T) {
|
||||||
c := winpipe.ListenConfig{
|
pipePath := randomPipePath()
|
||||||
|
l, err := (&namedpipe.ListenConfig{
|
||||||
MessageMode: true, // Use message mode so that CloseWrite() is supported
|
MessageMode: true, // Use message mode so that CloseWrite() is supported
|
||||||
InputBufferSize: 65536, // Use 64KB buffers to improve performance
|
InputBufferSize: 65536, // Use 64KB buffers to improve performance
|
||||||
OutputBufferSize: 65536,
|
OutputBufferSize: 65536,
|
||||||
}
|
}).Listen(pipePath)
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := winpipe.Listen(pipePath, &c)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -496,19 +509,21 @@ func TestEchoWithMessaging(t *testing.T) {
|
|||||||
clientDone := make(chan bool)
|
clientDone := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
// server echo
|
// server echo
|
||||||
conn, e := l.Accept()
|
conn, err := l.Accept()
|
||||||
if e != nil {
|
if err != nil {
|
||||||
t.Fatal(e)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
|
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
|
||||||
io.Copy(conn, conn)
|
_, err = io.Copy(conn, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
conn.(CloseWriter).CloseWrite()
|
conn.(CloseWriter).CloseWrite()
|
||||||
close(listenerDone)
|
close(listenerDone)
|
||||||
}()
|
}()
|
||||||
timeout := 1 * time.Second
|
client, err := namedpipe.DialTimeout(pipePath, time.Second)
|
||||||
client, err := winpipe.Dial(pipePath, &timeout, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -521,7 +536,7 @@ func TestEchoWithMessaging(t *testing.T) {
|
|||||||
if e != nil {
|
if e != nil {
|
||||||
t.Fatal(e)
|
t.Fatal(e)
|
||||||
}
|
}
|
||||||
if n != 2 {
|
if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
|
||||||
t.Fatalf("expected 2 bytes, got %v", n)
|
t.Fatalf("expected 2 bytes, got %v", n)
|
||||||
}
|
}
|
||||||
close(clientDone)
|
close(clientDone)
|
||||||
@ -545,7 +560,7 @@ func TestEchoWithMessaging(t *testing.T) {
|
|||||||
|
|
||||||
func TestConnectRace(t *testing.T) {
|
func TestConnectRace(t *testing.T) {
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, nil)
|
l, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -565,7 +580,7 @@ func TestConnectRace(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
c, err := winpipe.Dial(pipePath, nil, nil)
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -580,7 +595,7 @@ func TestMessageReadMode(t *testing.T) {
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
defer wg.Wait()
|
defer wg.Wait()
|
||||||
pipePath := randomPipePath()
|
pipePath := randomPipePath()
|
||||||
l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true})
|
l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -602,7 +617,7 @@ func TestMessageReadMode(t *testing.T) {
|
|||||||
s.Close()
|
s.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c, err := winpipe.Dial(pipePath, nil, nil)
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -643,13 +658,13 @@ func TestListenConnectRace(t *testing.T) {
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
c, err := winpipe.Dial(pipePath, nil, nil)
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
s, err := winpipe.Listen(pipePath, nil)
|
s, err := namedpipe.Listen(pipePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(i, err)
|
t.Error(i, err)
|
||||||
} else {
|
} else {
|
@ -9,8 +9,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
"golang.zx2c4.com/wireguard/ipc/namedpipe"
|
||||||
"golang.zx2c4.com/wireguard/ipc/winpipe"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: replace these with actual standard windows error numbers from the win package
|
// TODO: replace these with actual standard windows error numbers from the win package
|
||||||
@ -61,10 +60,9 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UAPIListen(name string) (net.Listener, error) {
|
func UAPIListen(name string) (net.Listener, error) {
|
||||||
config := winpipe.ListenConfig{
|
listener, err := (&namedpipe.ListenConfig{
|
||||||
SecurityDescriptor: UAPISecurityDescriptor,
|
SecurityDescriptor: UAPISecurityDescriptor,
|
||||||
}
|
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
|
||||||
listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1,128 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package wintun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
|
|
||||||
return &lazyDLL{Name: name, onLoad: onLoad}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *lazyDLL) NewProc(name string) *lazyProc {
|
|
||||||
return &lazyProc{dll: d, Name: name}
|
|
||||||
}
|
|
||||||
|
|
||||||
type lazyProc struct {
|
|
||||||
Name string
|
|
||||||
mu sync.Mutex
|
|
||||||
dll *lazyDLL
|
|
||||||
addr uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *lazyProc) Find() error {
|
|
||||||
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
if p.addr != 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := p.dll.Load()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
|
|
||||||
}
|
|
||||||
addr, err := p.nameToAddr()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Error getting %v address: %w", p.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *lazyProc) Addr() uintptr {
|
|
||||||
err := p.Find()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return p.addr
|
|
||||||
}
|
|
||||||
|
|
||||||
type lazyDLL struct {
|
|
||||||
Name string
|
|
||||||
mu sync.Mutex
|
|
||||||
module windows.Handle
|
|
||||||
onLoad func(d *lazyDLL)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *lazyDLL) Load() error {
|
|
||||||
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
d.mu.Lock()
|
|
||||||
defer d.mu.Unlock()
|
|
||||||
if d.module != 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
|
|
||||||
LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
|
|
||||||
)
|
|
||||||
module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Unable to load library: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
|
|
||||||
if d.onLoad != nil {
|
|
||||||
d.onLoad(d)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *lazyProc) nameToAddr() (uintptr, error) {
|
|
||||||
return windows.GetProcAddress(p.dll.module, p.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Version returns the version of the Wintun DLL.
|
|
||||||
func Version() string {
|
|
||||||
if modwintun.Load() != nil {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
resInfo, err := windows.FindResource(modwintun.module, windows.ResourceID(1), windows.RT_VERSION)
|
|
||||||
if err != nil {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
data, err := windows.LoadResourceData(modwintun.module, resInfo)
|
|
||||||
if err != nil {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
|
|
||||||
var fixedInfo *windows.VS_FIXEDFILEINFO
|
|
||||||
fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo))
|
|
||||||
err = windows.VerQueryValue(unsafe.Pointer(&data[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen)
|
|
||||||
if err != nil {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
version := fmt.Sprintf("%d.%d", (fixedInfo.FileVersionMS>>16)&0xff, (fixedInfo.FileVersionMS>>0)&0xff)
|
|
||||||
if nextNibble := (fixedInfo.FileVersionLS >> 16) & 0xff; nextNibble != 0 {
|
|
||||||
version += fmt.Sprintf(".%d", nextNibble)
|
|
||||||
}
|
|
||||||
if nextNibble := (fixedInfo.FileVersionLS >> 0) & 0xff; nextNibble != 0 {
|
|
||||||
version += fmt.Sprintf(".%d", nextNibble)
|
|
||||||
}
|
|
||||||
return version
|
|
||||||
}
|
|
@ -1,90 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package wintun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Session struct {
|
|
||||||
handle uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
PacketSizeMax = 0xffff // Maximum packet size
|
|
||||||
RingCapacityMin = 0x20000 // Minimum ring capacity (128 kiB)
|
|
||||||
RingCapacityMax = 0x4000000 // Maximum ring capacity (64 MiB)
|
|
||||||
)
|
|
||||||
|
|
||||||
// Packet with data
|
|
||||||
type Packet struct {
|
|
||||||
Next *Packet // Pointer to next packet in queue
|
|
||||||
Size uint32 // Size of packet (max WINTUN_MAX_IP_PACKET_SIZE)
|
|
||||||
Data *[PacketSizeMax]byte // Pointer to layer 3 IPv4 or IPv6 packet
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
procWintunAllocateSendPacket = modwintun.NewProc("WintunAllocateSendPacket")
|
|
||||||
procWintunEndSession = modwintun.NewProc("WintunEndSession")
|
|
||||||
procWintunGetReadWaitEvent = modwintun.NewProc("WintunGetReadWaitEvent")
|
|
||||||
procWintunReceivePacket = modwintun.NewProc("WintunReceivePacket")
|
|
||||||
procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket")
|
|
||||||
procWintunSendPacket = modwintun.NewProc("WintunSendPacket")
|
|
||||||
procWintunStartSession = modwintun.NewProc("WintunStartSession")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (wintun *Adapter) StartSession(capacity uint32) (session Session, err error) {
|
|
||||||
r0, _, e1 := syscall.Syscall(procWintunStartSession.Addr(), 2, uintptr(wintun.handle), uintptr(capacity), 0)
|
|
||||||
if r0 == 0 {
|
|
||||||
err = e1
|
|
||||||
} else {
|
|
||||||
session = Session{r0}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session Session) End() {
|
|
||||||
syscall.Syscall(procWintunEndSession.Addr(), 1, session.handle, 0, 0)
|
|
||||||
session.handle = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session Session) ReadWaitEvent() (handle windows.Handle) {
|
|
||||||
r0, _, _ := syscall.Syscall(procWintunGetReadWaitEvent.Addr(), 1, session.handle, 0, 0)
|
|
||||||
handle = windows.Handle(r0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session Session) ReceivePacket() (packet []byte, err error) {
|
|
||||||
var packetSize uint32
|
|
||||||
r0, _, e1 := syscall.Syscall(procWintunReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0)
|
|
||||||
if r0 == 0 {
|
|
||||||
err = e1
|
|
||||||
return
|
|
||||||
}
|
|
||||||
packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session Session) ReleaseReceivePacket(packet []byte) {
|
|
||||||
syscall.Syscall(procWintunReleaseReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) {
|
|
||||||
r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket.Addr(), 2, session.handle, uintptr(packetSize), 0)
|
|
||||||
if r0 == 0 {
|
|
||||||
err = e1
|
|
||||||
return
|
|
||||||
}
|
|
||||||
packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session Session) SendPacket(packet []byte) {
|
|
||||||
syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
|
|
||||||
}
|
|
@ -1,150 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package wintun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"runtime"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
type loggerLevel int
|
|
||||||
|
|
||||||
const (
|
|
||||||
logInfo loggerLevel = iota
|
|
||||||
logWarn
|
|
||||||
logErr
|
|
||||||
)
|
|
||||||
|
|
||||||
const AdapterNameMax = 128
|
|
||||||
|
|
||||||
type Adapter struct {
|
|
||||||
handle uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
modwintun = newLazyDLL("wintun.dll", setupLogger)
|
|
||||||
procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter")
|
|
||||||
procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter")
|
|
||||||
procWintunCloseAdapter = modwintun.NewProc("WintunCloseAdapter")
|
|
||||||
procWintunDeleteDriver = modwintun.NewProc("WintunDeleteDriver")
|
|
||||||
procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID")
|
|
||||||
procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
|
|
||||||
)
|
|
||||||
|
|
||||||
type TimestampedWriter interface {
|
|
||||||
WriteWithTimestamp(p []byte, ts int64) (n int, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int {
|
|
||||||
if tw, ok := log.Default().Writer().(TimestampedWriter); ok {
|
|
||||||
tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100)
|
|
||||||
} else {
|
|
||||||
log.Println(windows.UTF16PtrToString(msg))
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupLogger(dll *lazyDLL) {
|
|
||||||
var callback uintptr
|
|
||||||
if runtime.GOARCH == "386" {
|
|
||||||
callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int {
|
|
||||||
return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
|
|
||||||
})
|
|
||||||
} else if runtime.GOARCH == "arm" {
|
|
||||||
callback = windows.NewCallback(func(level loggerLevel, _, timestampLow, timestampHigh uint32, msg *uint16) int {
|
|
||||||
return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
|
|
||||||
})
|
|
||||||
} else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" {
|
|
||||||
callback = windows.NewCallback(logMessage)
|
|
||||||
}
|
|
||||||
syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, callback, 0, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeAdapter(wintun *Adapter) {
|
|
||||||
syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter.
|
|
||||||
// tunnelType represents the type of adapter and should be "Wintun". requestedGUID is
|
|
||||||
// the GUID of the created network adapter, which then influences NLA generation
|
|
||||||
// deterministically. If it is set to nil, the GUID is chosen by the system at random,
|
|
||||||
// and hence a new NLA entry is created for each new adapter.
|
|
||||||
func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
|
|
||||||
var name16 *uint16
|
|
||||||
name16, err = windows.UTF16PtrFromString(name)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var tunnelType16 *uint16
|
|
||||||
tunnelType16, err = windows.UTF16PtrFromString(tunnelType)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID)))
|
|
||||||
if r0 == 0 {
|
|
||||||
err = e1
|
|
||||||
return
|
|
||||||
}
|
|
||||||
wintun = &Adapter{handle: r0}
|
|
||||||
runtime.SetFinalizer(wintun, closeAdapter)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenAdapter opens an existing Wintun adapter by name.
|
|
||||||
func OpenAdapter(name string) (wintun *Adapter, err error) {
|
|
||||||
var name16 *uint16
|
|
||||||
name16, err = windows.UTF16PtrFromString(name)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0)
|
|
||||||
if r0 == 0 {
|
|
||||||
err = e1
|
|
||||||
return
|
|
||||||
}
|
|
||||||
wintun = &Adapter{handle: r0}
|
|
||||||
runtime.SetFinalizer(wintun, closeAdapter)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes a Wintun adapter.
|
|
||||||
func (wintun *Adapter) Close() (err error) {
|
|
||||||
runtime.SetFinalizer(wintun, nil)
|
|
||||||
r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
err = e1
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uninstall removes the driver from the system if no drivers are currently in use.
|
|
||||||
func Uninstall() (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
err = e1
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunningVersion returns the version of the loaded driver.
|
|
||||||
func RunningVersion() (version uint32, err error) {
|
|
||||||
r0, _, e1 := syscall.Syscall(procWintunGetRunningDriverVersion.Addr(), 0, 0, 0, 0)
|
|
||||||
version = uint32(r0)
|
|
||||||
if version == 0 {
|
|
||||||
err = e1
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// LUID returns the LUID of the adapter.
|
|
||||||
func (wintun *Adapter) LUID() (luid uint64) {
|
|
||||||
syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0)
|
|
||||||
return
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user