1
0
mirror of https://git.zx2c4.com/wireguard-go synced 2024-11-15 01:05:15 +01:00

tun/netstack: simplify read timeout on ping socket

I'm not 100% sure this is correct, but it certainly is a lot simpler.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2022-02-02 23:30:31 +01:00
parent b9669b734e
commit 3b95c81cc1

View File

@ -17,7 +17,6 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"golang.zx2c4.com/go118/netip" "golang.zx2c4.com/go118/netip"
@ -294,13 +293,11 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
} }
type PingConn struct { type PingConn struct {
laddr PingAddr laddr PingAddr
raddr PingAddr raddr PingAddr
wq waiter.Queue wq waiter.Queue
ep tcpip.Endpoint ep tcpip.Endpoint
mu sync.RWMutex deadline *time.Timer
deadline time.Time
deadlineBreaker chan struct{}
} }
type PingAddr struct{ addr netip.Addr } type PingAddr struct{ addr netip.Addr }
@ -348,9 +345,10 @@ func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
} }
pc := &PingConn{ pc := &PingConn{
laddr: PingAddr{laddr}, laddr: PingAddr{laddr},
deadlineBreaker: make(chan struct{}, 1), deadline: time.NewTimer(time.Hour << 10),
} }
pc.deadline.Stop()
ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq) ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
if tcpipErr != nil { if tcpipErr != nil {
@ -408,7 +406,7 @@ func (pc *PingConn) RemoteAddr() net.Addr {
} }
func (pc *PingConn) Close() error { func (pc *PingConn) Close() error {
close(pc.deadlineBreaker) pc.deadline.Reset(0)
pc.ep.Close() pc.ep.Close()
return nil return nil
} }
@ -454,33 +452,10 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
pc.wq.EventRegister(&e, waiter.EventIn) pc.wq.EventRegister(&e, waiter.EventIn)
defer pc.wq.EventUnregister(&e) defer pc.wq.EventUnregister(&e)
ready := false select {
case <-pc.deadline.C:
for !ready { return 0, nil, os.ErrDeadlineExceeded
pc.mu.RLock() case <-notifyCh:
deadlineBreaker := pc.deadlineBreaker
deadline := pc.deadline
pc.mu.RUnlock()
if deadline.IsZero() {
select {
case <-deadlineBreaker:
case <-notifyCh:
ready = true
}
} else {
t := time.NewTimer(deadline.Sub(time.Now()))
defer t.Stop()
select {
case <-t.C:
return 0, nil, os.ErrDeadlineExceeded
case <-deadlineBreaker:
case <-notifyCh:
ready = true
}
}
} }
w := tcpip.SliceWriter(p) w := tcpip.SliceWriter(p)
@ -508,11 +483,7 @@ func (pc *PingConn) SetDeadline(t time.Time) error {
} }
func (pc *PingConn) SetReadDeadline(t time.Time) error { func (pc *PingConn) SetReadDeadline(t time.Time) error {
pc.mu.Lock() pc.deadline.Reset(t.Sub(time.Now()))
defer pc.mu.Unlock()
close(pc.deadlineBreaker)
pc.deadlineBreaker = make(chan struct{}, 1)
pc.deadline = t
return nil return nil
} }