1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

tun: windows: set event before waiting

In 097af6e ("tun: windows: protect reads from closing") we made sure no
functions are running when End() is called, to avoid a UaF. But we still
need to kick that event somehow, so that Read() is allowed to exit, in
order to release the lock. So this commit calls SetEvent, while moving
the closing boolean to be atomic so it can be modified without locks,
and then moves to a WaitGroup for the RCU-like pattern.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-05-07 09:26:24 +02:00
parent db733ccd65
commit 69b39db0b4

View File

@ -40,10 +40,10 @@ type NativeTun struct {
session wintun.Session session wintun.Session
readWait windows.Handle readWait windows.Handle
events chan Event events chan Event
closing sync.RWMutex running sync.WaitGroup
closeOnce sync.Once closeOnce sync.Once
close int32
forcedMTU int forcedMTU int
close bool
} }
var WintunPool, _ = wintun.MakePool("WireGuard") var WintunPool, _ = wintun.MakePool("WireGuard")
@ -111,9 +111,9 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
} }
func (tun *NativeTun) Name() (string, error) { func (tun *NativeTun) Name() (string, error) {
tun.closing.RLock() tun.running.Add(1)
defer tun.closing.RUnlock() defer tun.running.Done()
if tun.close { if atomic.LoadInt32(&tun.close) == 1 {
return "", os.ErrClosed return "", os.ErrClosed
} }
return tun.wt.Name() return tun.wt.Name()
@ -130,9 +130,9 @@ func (tun *NativeTun) Events() chan Event {
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err error var err error
tun.closeOnce.Do(func() { tun.closeOnce.Do(func() {
tun.closing.Lock() atomic.StoreInt32(&tun.close, 1)
defer tun.closing.Unlock() windows.SetEvent(tun.readWait)
tun.close = true tun.running.Wait()
tun.session.End() tun.session.End()
if tun.wt != nil { if tun.wt != nil {
_, err = tun.wt.Delete(false) _, err = tun.wt.Delete(false)
@ -158,16 +158,16 @@ func (tun *NativeTun) ForceMTU(mtu int) {
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
tun.closing.RLock() tun.running.Add(1)
defer tun.closing.RUnlock() defer tun.running.Done()
retry: retry:
if tun.close { if atomic.LoadInt32(&tun.close) == 1 {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
start := nanotime() start := nanotime()
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
for { for {
if tun.close { if atomic.LoadInt32(&tun.close) == 1 {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
packet, err := tun.session.ReceivePacket() packet, err := tun.session.ReceivePacket()
@ -199,9 +199,9 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
tun.closing.RLock() tun.running.Add(1)
defer tun.closing.RUnlock() defer tun.running.Done()
if tun.close { if atomic.LoadInt32(&tun.close) == 1 {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
@ -225,9 +225,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
// LUID returns Windows interface instance ID. // LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 { func (tun *NativeTun) LUID() uint64 {
tun.closing.RLock() tun.running.Add(1)
defer tun.closing.RUnlock() defer tun.running.Done()
if tun.close { if atomic.LoadInt32(&tun.close) == 1 {
return 0 return 0
} }
return tun.wt.LUID() return tun.wt.LUID()