mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 01:05:15 +01:00
tun: windows: protect reads from closing
The code previously used the old errors channel for checking, rather than the simpler boolean, which caused issues on shutdown, since the errors channel was meaningless. However, looking at this exposed a more basic problem: Close() and all the other functions that check the closed boolean can race. So protect with a basic RW lock, to ensure that Close() waits for all pending operations to complete. Reported-by: Joshua Sjoding <joshua.sjoding@scjalliance.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
8246d251ea
commit
097af6e135
@ -37,8 +37,8 @@ type NativeTun struct {
|
|||||||
wt *wintun.Adapter
|
wt *wintun.Adapter
|
||||||
handle windows.Handle
|
handle windows.Handle
|
||||||
close bool
|
close bool
|
||||||
|
closing sync.RWMutex
|
||||||
events chan Event
|
events chan Event
|
||||||
errors chan error
|
|
||||||
forcedMTU int
|
forcedMTU int
|
||||||
rate rateJuggler
|
rate rateJuggler
|
||||||
session wintun.Session
|
session wintun.Session
|
||||||
@ -97,7 +97,6 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
|
|||||||
wt: wt,
|
wt: wt,
|
||||||
handle: windows.InvalidHandle,
|
handle: windows.InvalidHandle,
|
||||||
events: make(chan Event, 10),
|
events: make(chan Event, 10),
|
||||||
errors: make(chan error, 1),
|
|
||||||
forcedMTU: forcedMTU,
|
forcedMTU: forcedMTU,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,6 +111,11 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Name() (string, error) {
|
func (tun *NativeTun) Name() (string, error) {
|
||||||
|
tun.closing.RLock()
|
||||||
|
defer tun.closing.RUnlock()
|
||||||
|
if tun.close {
|
||||||
|
return "", os.ErrClosed
|
||||||
|
}
|
||||||
return tun.wt.Name()
|
return tun.wt.Name()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -126,6 +130,8 @@ func (tun *NativeTun) Events() chan Event {
|
|||||||
func (tun *NativeTun) Close() error {
|
func (tun *NativeTun) Close() error {
|
||||||
var err error
|
var err error
|
||||||
tun.closeOnce.Do(func() {
|
tun.closeOnce.Do(func() {
|
||||||
|
tun.closing.Lock()
|
||||||
|
defer tun.closing.Unlock()
|
||||||
tun.close = true
|
tun.close = true
|
||||||
tun.session.End()
|
tun.session.End()
|
||||||
if tun.wt != nil {
|
if tun.wt != nil {
|
||||||
@ -148,11 +154,11 @@ func (tun *NativeTun) ForceMTU(mtu int) {
|
|||||||
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
|
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
|
||||||
|
|
||||||
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||||
|
tun.closing.RLock()
|
||||||
|
defer tun.closing.RUnlock()
|
||||||
retry:
|
retry:
|
||||||
select {
|
if tun.close {
|
||||||
case err := <-tun.errors:
|
return 0, os.ErrClosed
|
||||||
return 0, err
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
start := nanotime()
|
start := nanotime()
|
||||||
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
|
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
|
||||||
@ -189,6 +195,8 @@ func (tun *NativeTun) Flush() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||||
|
tun.closing.RLock()
|
||||||
|
defer tun.closing.RUnlock()
|
||||||
if tun.close {
|
if tun.close {
|
||||||
return 0, os.ErrClosed
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
@ -213,6 +221,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
|||||||
|
|
||||||
// LUID returns Windows interface instance ID.
|
// LUID returns Windows interface instance ID.
|
||||||
func (tun *NativeTun) LUID() uint64 {
|
func (tun *NativeTun) LUID() uint64 {
|
||||||
|
tun.closing.RLock()
|
||||||
|
defer tun.closing.RUnlock()
|
||||||
|
if tun.close {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
return tun.wt.LUID()
|
return tun.wt.LUID()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user