1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

namedpipe: rename from winpipe to keep in sync with CL299009

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-10-30 02:39:56 +02:00
parent eb6302c7eb
commit c07dd60cdb
7 changed files with 132 additions and 475 deletions

View File

@ -1,12 +1,12 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows //go:build windows
// +build windows
/* SPDX-License-Identifier: MIT package namedpipe
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import ( import (
"io" "io"

View File

@ -1,13 +1,13 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows //go:build windows
// +build windows
/* SPDX-License-Identifier: MIT // Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
* package namedpipe
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
// Package winpipe implements a net.Conn and net.Listener around Windows named pipes.
package winpipe
import ( import (
"context" "context"
@ -15,6 +15,7 @@ import (
"net" "net"
"os" "os"
"runtime" "runtime"
"sync/atomic"
"time" "time"
"unsafe" "unsafe"
@ -28,7 +29,7 @@ type pipe struct {
type messageBytePipe struct { type messageBytePipe struct {
pipe pipe
writeClosed bool writeClosed int32
readEOF bool readEOF bool
} }
@ -50,25 +51,26 @@ func (f *pipe) SetDeadline(t time.Time) error {
// CloseWrite closes the write side of a message pipe in byte mode. // CloseWrite closes the write side of a message pipe in byte mode.
func (f *messageBytePipe) CloseWrite() error { func (f *messageBytePipe) CloseWrite() error {
if f.writeClosed { if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
err := f.file.Flush() err := f.file.Flush()
if err != nil { if err != nil {
atomic.StoreInt32(&f.writeClosed, 0)
return err return err
} }
_, err = f.file.Write(nil) _, err = f.file.Write(nil)
if err != nil { if err != nil {
atomic.StoreInt32(&f.writeClosed, 0)
return err return err
} }
f.writeClosed = true
return nil return nil
} }
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since // Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite. // they are used to implement CloseWrite.
func (f *messageBytePipe) Write(b []byte) (int, error) { func (f *messageBytePipe) Write(b []byte) (int, error) {
if f.writeClosed { if atomic.LoadInt32(&f.writeClosed) != 0 {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
if len(b) == 0 { if len(b) == 0 {
@ -142,30 +144,24 @@ type DialConfig struct {
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
} }
// Dial connects to the specified named pipe by path, timing out if the connection // DialTimeout connects to the specified named pipe by path, timing out if the
// takes longer than the specified duration. If timeout is nil, then we use // connection takes longer than the specified duration. If timeout is zero, then
// a default timeout of 2 seconds. // we use a default timeout of 2 seconds.
func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) { func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
var absTimeout time.Time if timeout == 0 {
if timeout != nil { timeout = time.Second * 2
absTimeout = time.Now().Add(*timeout)
} else {
absTimeout = time.Now().Add(2 * time.Second)
} }
absTimeout := time.Now().Add(timeout)
ctx, _ := context.WithDeadline(context.Background(), absTimeout) ctx, _ := context.WithDeadline(context.Background(), absTimeout)
conn, err := DialContext(ctx, path, config) conn, err := config.DialContext(ctx, path)
if err == context.DeadlineExceeded { if err == context.DeadlineExceeded {
return nil, os.ErrDeadlineExceeded return nil, os.ErrDeadlineExceeded
} }
return conn, err return conn, err
} }
// DialContext attempts to connect to the specified named pipe by path // DialContext attempts to connect to the specified named pipe by path.
// cancellation or timeout. func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) {
if config == nil {
config = &DialConfig{}
}
var err error var err error
var h windows.Handle var h windows.Handle
h, err = tryDialPipe(ctx, &path) h, err = tryDialPipe(ctx, &path)
@ -213,6 +209,18 @@ func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn
return &pipe{file: f, path: path}, nil return &pipe{file: f, path: path}, nil
} }
var defaultDialer DialConfig
// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
return defaultDialer.DialTimeout(path, timeout)
}
// DialContext calls DialConfig.DialContext using an empty configuration.
func DialContext(ctx context.Context, path string) (net.Conn, error) {
return defaultDialer.DialContext(ctx, path)
}
type acceptResponse struct { type acceptResponse struct {
f *file f *file
err error err error
@ -222,12 +230,12 @@ type pipeListener struct {
firstHandle windows.Handle firstHandle windows.Handle
path string path string
config ListenConfig config ListenConfig
acceptCh chan (chan acceptResponse) acceptCh chan chan acceptResponse
closeCh chan int closeCh chan int
doneCh chan int doneCh chan int
} }
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) { func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
path16, err := windows.UTF16PtrFromString(path) path16, err := windows.UTF16PtrFromString(path)
if err != nil { if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
@ -247,7 +255,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
oa.ObjectName = &ntPath oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe. // The security descriptor is only needed for the first pipe.
if first { if isFirstPipe {
if sd != nil { if sd != nil {
oa.SecurityDescriptor = sd oa.SecurityDescriptor = sd
} else { } else {
@ -257,7 +265,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
return 0, err return 0, err
} }
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
sd, err := windows.NewSecurityDescriptor() sd, err = windows.NewSecurityDescriptor()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -275,11 +283,11 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste
disposition := uint32(windows.FILE_OPEN) disposition := uint32(windows.FILE_OPEN)
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
if first { if isFirstPipe {
disposition = windows.FILE_CREATE disposition = windows.FILE_CREATE
// By not asking for read or write access, the named pipe file system // By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking // will put this pipe into an initially disconnected state, blocking
// client connections until the next call with first == false. // client connections until the next call with isFirstPipe == false.
access = windows.SYNCHRONIZE access = windows.SYNCHRONIZE
} }
@ -395,10 +403,7 @@ type ListenConfig struct {
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe. // Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
// The pipe must not already exist. // The pipe must not already exist.
func Listen(path string, c *ListenConfig) (net.Listener, error) { func (c *ListenConfig) Listen(path string) (net.Listener, error) {
if c == nil {
c = &ListenConfig{}
}
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
if err != nil { if err != nil {
return nil, err return nil, err
@ -407,12 +412,12 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
firstHandle: h, firstHandle: h,
path: path, path: path,
config: *c, config: *c,
acceptCh: make(chan (chan acceptResponse)), acceptCh: make(chan chan acceptResponse),
closeCh: make(chan int), closeCh: make(chan int),
doneCh: make(chan int), doneCh: make(chan int),
} }
// The first connection is swallowed on Windows 7 & 8, so synthesize it. // The first connection is swallowed on Windows 7 & 8, so synthesize it.
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
path16, err := windows.UTF16PtrFromString(path) path16, err := windows.UTF16PtrFromString(path)
if err == nil { if err == nil {
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
@ -425,6 +430,13 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) {
return l, nil return l, nil
} }
var defaultListener ListenConfig
// Listen calls ListenConfig.Listen using an empty configuration.
func Listen(path string) (net.Listener, error) {
return defaultListener.Listen(path)
}
func connectPipe(p *file) error { func connectPipe(p *file) error {
c, err := p.prepareIo() c, err := p.prepareIo()
if err != nil { if err != nil {

View File

@ -1,12 +1,12 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows //go:build windows
// +build windows
/* SPDX-License-Identifier: MIT package namedpipe_test
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package winpipe_test
import ( import (
"bufio" "bufio"
@ -22,7 +22,7 @@ import (
"time" "time"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/winpipe" "golang.zx2c4.com/wireguard/ipc/namedpipe"
) )
func randomPipePath() string { func randomPipePath() string {
@ -30,7 +30,7 @@ func randomPipePath() string {
if err != nil { if err != nil {
panic(err) panic(err)
} }
return `\\.\PIPE\go-winpipe-test-` + guid.String() return `\\.\PIPE\go-namedpipe-test-` + guid.String()
} }
func TestPingPong(t *testing.T) { func TestPingPong(t *testing.T) {
@ -39,7 +39,7 @@ func TestPingPong(t *testing.T) {
pong = 24 pong = 24
) )
pipePath := randomPipePath() pipePath := randomPipePath()
listener, err := winpipe.Listen(pipePath, nil) listener, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatalf("unable to listen on pipe: %v", err) t.Fatalf("unable to listen on pipe: %v", err)
} }
@ -64,11 +64,12 @@ func TestPingPong(t *testing.T) {
t.Fatalf("unable to write pong to pipe: %v", err) t.Fatalf("unable to write pong to pipe: %v", err)
} }
}() }()
client, err := winpipe.Dial(pipePath, nil, nil) client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatalf("unable to dial pipe: %v", err) t.Fatalf("unable to dial pipe: %v", err)
} }
defer client.Close() defer client.Close()
client.SetDeadline(time.Now().Add(time.Second * 5))
var data [1]byte var data [1]byte
data[0] = ping data[0] = ping
_, err = client.Write(data[:]) _, err = client.Write(data[:])
@ -85,7 +86,7 @@ func TestPingPong(t *testing.T) {
} }
func TestDialUnknownFailsImmediately(t *testing.T) { func TestDialUnknownFailsImmediately(t *testing.T) {
_, err := winpipe.Dial(randomPipePath(), nil, nil) _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
if !errors.Is(err, syscall.ENOENT) { if !errors.Is(err, syscall.ENOENT) {
t.Fatalf("expected ENOENT got %v", err) t.Fatalf("expected ENOENT got %v", err)
} }
@ -93,13 +94,15 @@ func TestDialUnknownFailsImmediately(t *testing.T) {
func TestDialListenerTimesOut(t *testing.T) { func TestDialListenerTimesOut(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
d := 10 * time.Millisecond pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
_, err = winpipe.Dial(pipePath, &d, nil) if err == nil {
pipe.Close()
}
if err != os.ErrDeadlineExceeded { if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
} }
@ -107,14 +110,17 @@ func TestDialListenerTimesOut(t *testing.T) {
func TestDialContextListenerTimesOut(t *testing.T) { func TestDialContextListenerTimesOut(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
d := 10 * time.Millisecond d := 10 * time.Millisecond
ctx, _ := context.WithTimeout(context.Background(), d) ctx, _ := context.WithTimeout(context.Background(), d)
_, err = winpipe.DialContext(ctx, pipePath, nil) pipe, err := namedpipe.DialContext(ctx, pipePath)
if err == nil {
pipe.Close()
}
if err != context.DeadlineExceeded { if err != context.DeadlineExceeded {
t.Fatalf("expected context.DeadlineExceeded, got %v", err) t.Fatalf("expected context.DeadlineExceeded, got %v", err)
} }
@ -123,14 +129,14 @@ func TestDialContextListenerTimesOut(t *testing.T) {
func TestDialListenerGetsCancelled(t *testing.T) { func TestDialListenerGetsCancelled(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ch := make(chan error)
defer l.Close() defer l.Close()
ch := make(chan error)
go func(ctx context.Context, ch chan error) { go func(ctx context.Context, ch chan error) {
_, err := winpipe.DialContext(ctx, pipePath, nil) _, err := namedpipe.DialContext(ctx, pipePath)
ch <- err ch <- err
}(ctx, ch) }(ctx, ch)
time.Sleep(time.Millisecond * 30) time.Sleep(time.Millisecond * 30)
@ -147,23 +153,28 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
} }
pipePath := randomPipePath() pipePath := randomPipePath()
sd, _ := windows.SecurityDescriptorFromString("D:") sd, _ := windows.SecurityDescriptorFromString("D:")
c := winpipe.ListenConfig{ l, err := (&namedpipe.ListenConfig{
SecurityDescriptor: sd, SecurityDescriptor: sd,
} }).Listen(pipePath)
l, err := winpipe.Listen(pipePath, &c)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
_, err = winpipe.Dial(pipePath, nil, nil) pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil {
pipe.Close()
}
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
} }
} }
func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) { func getConnection(cfg *namedpipe.ListenConfig) (client net.Conn, server net.Conn, err error) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, cfg) if cfg == nil {
cfg = &namedpipe.ListenConfig{}
}
l, err := cfg.Listen(pipePath)
if err != nil { if err != nil {
return return
} }
@ -179,7 +190,7 @@ func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn,
ch <- response{c, err} ch <- response{c, err}
}() }()
c, err := winpipe.Dial(pipePath, nil, nil) c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
return return
} }
@ -236,7 +247,7 @@ func server(l net.Listener, ch chan int) {
func TestFullListenDialReadWrite(t *testing.T) { func TestFullListenDialReadWrite(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -245,7 +256,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
ch := make(chan int) ch := make(chan int)
go server(l, ch) go server(l, ch)
c, err := winpipe.Dial(pipePath, nil, nil) c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -275,7 +286,7 @@ func TestFullListenDialReadWrite(t *testing.T) {
func TestCloseAbortsListen(t *testing.T) { func TestCloseAbortsListen(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -328,7 +339,7 @@ func TestCloseServerEOFClient(t *testing.T) {
} }
func TestCloseWriteEOF(t *testing.T) { func TestCloseWriteEOF(t *testing.T) {
cfg := &winpipe.ListenConfig{ cfg := &namedpipe.ListenConfig{
MessageMode: true, MessageMode: true,
} }
c, s, err := getConnection(cfg) c, s, err := getConnection(cfg)
@ -356,7 +367,7 @@ func TestCloseWriteEOF(t *testing.T) {
func TestAcceptAfterCloseFails(t *testing.T) { func TestAcceptAfterCloseFails(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -369,12 +380,15 @@ func TestAcceptAfterCloseFails(t *testing.T) {
func TestDialTimesOutByDefault(t *testing.T) { func TestDialTimesOutByDefault(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer l.Close() defer l.Close()
_, err = winpipe.Dial(pipePath, nil, nil) pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
if err == nil {
pipe.Close()
}
if err != os.ErrDeadlineExceeded { if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
} }
@ -382,7 +396,7 @@ func TestDialTimesOutByDefault(t *testing.T) {
func TestTimeoutPendingRead(t *testing.T) { func TestTimeoutPendingRead(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -400,7 +414,7 @@ func TestTimeoutPendingRead(t *testing.T) {
close(serverDone) close(serverDone)
}() }()
client, err := winpipe.Dial(pipePath, nil, nil) client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -430,7 +444,7 @@ func TestTimeoutPendingRead(t *testing.T) {
func TestTimeoutPendingWrite(t *testing.T) { func TestTimeoutPendingWrite(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -448,7 +462,7 @@ func TestTimeoutPendingWrite(t *testing.T) {
close(serverDone) close(serverDone)
}() }()
client, err := winpipe.Dial(pipePath, nil, nil) client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -480,13 +494,12 @@ type CloseWriter interface {
} }
func TestEchoWithMessaging(t *testing.T) { func TestEchoWithMessaging(t *testing.T) {
c := winpipe.ListenConfig{ pipePath := randomPipePath()
l, err := (&namedpipe.ListenConfig{
MessageMode: true, // Use message mode so that CloseWrite() is supported MessageMode: true, // Use message mode so that CloseWrite() is supported
InputBufferSize: 65536, // Use 64KB buffers to improve performance InputBufferSize: 65536, // Use 64KB buffers to improve performance
OutputBufferSize: 65536, OutputBufferSize: 65536,
} }).Listen(pipePath)
pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, &c)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -496,19 +509,21 @@ func TestEchoWithMessaging(t *testing.T) {
clientDone := make(chan bool) clientDone := make(chan bool)
go func() { go func() {
// server echo // server echo
conn, e := l.Accept() conn, err := l.Accept()
if e != nil { if err != nil {
t.Fatal(e) t.Fatal(err)
} }
defer conn.Close() defer conn.Close()
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
io.Copy(conn, conn) _, err = io.Copy(conn, conn)
if err != nil {
t.Fatal(err)
}
conn.(CloseWriter).CloseWrite() conn.(CloseWriter).CloseWrite()
close(listenerDone) close(listenerDone)
}() }()
timeout := 1 * time.Second client, err := namedpipe.DialTimeout(pipePath, time.Second)
client, err := winpipe.Dial(pipePath, &timeout, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -521,7 +536,7 @@ func TestEchoWithMessaging(t *testing.T) {
if e != nil { if e != nil {
t.Fatal(e) t.Fatal(e)
} }
if n != 2 { if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
t.Fatalf("expected 2 bytes, got %v", n) t.Fatalf("expected 2 bytes, got %v", n)
} }
close(clientDone) close(clientDone)
@ -545,7 +560,7 @@ func TestEchoWithMessaging(t *testing.T) {
func TestConnectRace(t *testing.T) { func TestConnectRace(t *testing.T) {
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, nil) l, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -565,7 +580,7 @@ func TestConnectRace(t *testing.T) {
}() }()
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
c, err := winpipe.Dial(pipePath, nil, nil) c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -580,7 +595,7 @@ func TestMessageReadMode(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
defer wg.Wait() defer wg.Wait()
pipePath := randomPipePath() pipePath := randomPipePath()
l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true}) l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -602,7 +617,7 @@ func TestMessageReadMode(t *testing.T) {
s.Close() s.Close()
}() }()
c, err := winpipe.Dial(pipePath, nil, nil) c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -643,13 +658,13 @@ func TestListenConnectRace(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
c, err := winpipe.Dial(pipePath, nil, nil) c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil { if err == nil {
c.Close() c.Close()
} }
wg.Done() wg.Done()
}() }()
s, err := winpipe.Listen(pipePath, nil) s, err := namedpipe.Listen(pipePath)
if err != nil { if err != nil {
t.Error(i, err) t.Error(i, err)
} else { } else {

View File

@ -9,8 +9,7 @@ import (
"net" "net"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/namedpipe"
"golang.zx2c4.com/wireguard/ipc/winpipe"
) )
// TODO: replace these with actual standard windows error numbers from the win package // TODO: replace these with actual standard windows error numbers from the win package
@ -61,10 +60,9 @@ func init() {
} }
func UAPIListen(name string) (net.Listener, error) { func UAPIListen(name string) (net.Listener, error) {
config := winpipe.ListenConfig{ listener, err := (&namedpipe.ListenConfig{
SecurityDescriptor: UAPISecurityDescriptor, SecurityDescriptor: UAPISecurityDescriptor,
} }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,128 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
)
func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
return &lazyDLL{Name: name, onLoad: onLoad}
}
func (d *lazyDLL) NewProc(name string) *lazyProc {
return &lazyProc{dll: d, Name: name}
}
type lazyProc struct {
Name string
mu sync.Mutex
dll *lazyDLL
addr uintptr
}
func (p *lazyProc) Find() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
if p.addr != 0 {
return nil
}
err := p.dll.Load()
if err != nil {
return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
}
addr, err := p.nameToAddr()
if err != nil {
return fmt.Errorf("Error getting %v address: %w", p.Name, err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
return nil
}
func (p *lazyProc) Addr() uintptr {
err := p.Find()
if err != nil {
panic(err)
}
return p.addr
}
type lazyDLL struct {
Name string
mu sync.Mutex
module windows.Handle
onLoad func(d *lazyDLL)
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != 0 {
return nil
}
const (
LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
)
module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return windows.GetProcAddress(p.dll.module, p.Name)
}
// Version returns the version of the Wintun DLL.
func Version() string {
if modwintun.Load() != nil {
return "unknown"
}
resInfo, err := windows.FindResource(modwintun.module, windows.ResourceID(1), windows.RT_VERSION)
if err != nil {
return "unknown"
}
data, err := windows.LoadResourceData(modwintun.module, resInfo)
if err != nil {
return "unknown"
}
var fixedInfo *windows.VS_FIXEDFILEINFO
fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo))
err = windows.VerQueryValue(unsafe.Pointer(&data[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen)
if err != nil {
return "unknown"
}
version := fmt.Sprintf("%d.%d", (fixedInfo.FileVersionMS>>16)&0xff, (fixedInfo.FileVersionMS>>0)&0xff)
if nextNibble := (fixedInfo.FileVersionLS >> 16) & 0xff; nextNibble != 0 {
version += fmt.Sprintf(".%d", nextNibble)
}
if nextNibble := (fixedInfo.FileVersionLS >> 0) & 0xff; nextNibble != 0 {
version += fmt.Sprintf(".%d", nextNibble)
}
return version
}

View File

@ -1,90 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type Session struct {
handle uintptr
}
const (
PacketSizeMax = 0xffff // Maximum packet size
RingCapacityMin = 0x20000 // Minimum ring capacity (128 kiB)
RingCapacityMax = 0x4000000 // Maximum ring capacity (64 MiB)
)
// Packet with data
type Packet struct {
Next *Packet // Pointer to next packet in queue
Size uint32 // Size of packet (max WINTUN_MAX_IP_PACKET_SIZE)
Data *[PacketSizeMax]byte // Pointer to layer 3 IPv4 or IPv6 packet
}
var (
procWintunAllocateSendPacket = modwintun.NewProc("WintunAllocateSendPacket")
procWintunEndSession = modwintun.NewProc("WintunEndSession")
procWintunGetReadWaitEvent = modwintun.NewProc("WintunGetReadWaitEvent")
procWintunReceivePacket = modwintun.NewProc("WintunReceivePacket")
procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket")
procWintunSendPacket = modwintun.NewProc("WintunSendPacket")
procWintunStartSession = modwintun.NewProc("WintunStartSession")
)
func (wintun *Adapter) StartSession(capacity uint32) (session Session, err error) {
r0, _, e1 := syscall.Syscall(procWintunStartSession.Addr(), 2, uintptr(wintun.handle), uintptr(capacity), 0)
if r0 == 0 {
err = e1
} else {
session = Session{r0}
}
return
}
func (session Session) End() {
syscall.Syscall(procWintunEndSession.Addr(), 1, session.handle, 0, 0)
session.handle = 0
}
func (session Session) ReadWaitEvent() (handle windows.Handle) {
r0, _, _ := syscall.Syscall(procWintunGetReadWaitEvent.Addr(), 1, session.handle, 0, 0)
handle = windows.Handle(r0)
return
}
func (session Session) ReceivePacket() (packet []byte, err error) {
var packetSize uint32
r0, _, e1 := syscall.Syscall(procWintunReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0)
if r0 == 0 {
err = e1
return
}
packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize)
return
}
func (session Session) ReleaseReceivePacket(packet []byte) {
syscall.Syscall(procWintunReleaseReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
}
func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) {
r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket.Addr(), 2, session.handle, uintptr(packetSize), 0)
if r0 == 0 {
err = e1
return
}
packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize)
return
}
func (session Session) SendPacket(packet []byte) {
syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
}

View File

@ -1,150 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"log"
"runtime"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type loggerLevel int
const (
logInfo loggerLevel = iota
logWarn
logErr
)
const AdapterNameMax = 128
type Adapter struct {
handle uintptr
}
var (
modwintun = newLazyDLL("wintun.dll", setupLogger)
procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter")
procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter")
procWintunCloseAdapter = modwintun.NewProc("WintunCloseAdapter")
procWintunDeleteDriver = modwintun.NewProc("WintunDeleteDriver")
procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID")
procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
)
type TimestampedWriter interface {
WriteWithTimestamp(p []byte, ts int64) (n int, err error)
}
func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int {
if tw, ok := log.Default().Writer().(TimestampedWriter); ok {
tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100)
} else {
log.Println(windows.UTF16PtrToString(msg))
}
return 0
}
func setupLogger(dll *lazyDLL) {
var callback uintptr
if runtime.GOARCH == "386" {
callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int {
return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
})
} else if runtime.GOARCH == "arm" {
callback = windows.NewCallback(func(level loggerLevel, _, timestampLow, timestampHigh uint32, msg *uint16) int {
return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
})
} else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" {
callback = windows.NewCallback(logMessage)
}
syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, callback, 0, 0)
}
func closeAdapter(wintun *Adapter) {
syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
}
// CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter.
// tunnelType represents the type of adapter and should be "Wintun". requestedGUID is
// the GUID of the created network adapter, which then influences NLA generation
// deterministically. If it is set to nil, the GUID is chosen by the system at random,
// and hence a new NLA entry is created for each new adapter.
func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil {
return
}
var tunnelType16 *uint16
tunnelType16, err = windows.UTF16PtrFromString(tunnelType)
if err != nil {
return
}
r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID)))
if r0 == 0 {
err = e1
return
}
wintun = &Adapter{handle: r0}
runtime.SetFinalizer(wintun, closeAdapter)
return
}
// OpenAdapter opens an existing Wintun adapter by name.
func OpenAdapter(name string) (wintun *Adapter, err error) {
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil {
return
}
r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0)
if r0 == 0 {
err = e1
return
}
wintun = &Adapter{handle: r0}
runtime.SetFinalizer(wintun, closeAdapter)
return
}
// Close closes a Wintun adapter.
func (wintun *Adapter) Close() (err error) {
runtime.SetFinalizer(wintun, nil)
r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
if r1 == 0 {
err = e1
}
return
}
// Uninstall removes the driver from the system if no drivers are currently in use.
func Uninstall() (err error) {
r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0)
if r1 == 0 {
err = e1
}
return
}
// RunningVersion returns the version of the loaded driver.
func RunningVersion() (version uint32, err error) {
r0, _, e1 := syscall.Syscall(procWintunGetRunningDriverVersion.Addr(), 0, 0, 0, 0)
version = uint32(r0)
if version == 0 {
err = e1
}
return
}
// LUID returns the LUID of the adapter.
func (wintun *Adapter) LUID() (luid uint64) {
syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0)
return
}