diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index 772c45a..a6d0ea2 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -20,21 +20,23 @@ const ( ) type RatelimiterEntry struct { - sync.Mutex + mu sync.Mutex lastTime time.Time tokens int64 } type Ratelimiter struct { - sync.RWMutex - stopReset chan struct{} + mu sync.RWMutex + timeNow func() time.Time + + stopReset chan struct{} // send to reset, close to stop tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry } func (rate *Ratelimiter) Close() { - rate.Lock() - defer rate.Unlock() + rate.mu.Lock() + defer rate.mu.Unlock() if rate.stopReset != nil { close(rate.stopReset) @@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() { } func (rate *Ratelimiter) Init() { - rate.Lock() - defer rate.Unlock() + rate.mu.Lock() + defer rate.mu.Unlock() + + if rate.timeNow == nil { + rate.timeNow = time.Now + } // stop any ongoing garbage collection routine - if rate.stopReset != nil { close(rate.stopReset) } @@ -55,50 +60,52 @@ func (rate *Ratelimiter) Init() { rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) - // start garbage collection routine + stopReset := rate.stopReset // store in case Init is called again. + // Start garbage collection routine. go func() { ticker := time.NewTicker(time.Second) ticker.Stop() for { select { - case _, ok := <-rate.stopReset: + case _, ok := <-stopReset: ticker.Stop() - if ok { - ticker = time.NewTicker(time.Second) - } else { + if !ok { return } + ticker = time.NewTicker(time.Second) case <-ticker.C: - func() { - rate.Lock() - defer rate.Unlock() - - for key, entry := range rate.tableIPv4 { - entry.Lock() - if time.Since(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv4, key) - } - entry.Unlock() - } - - for key, entry := range rate.tableIPv6 { - entry.Lock() - if time.Since(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.Unlock() - } - - if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 { - ticker.Stop() - } - }() + if rate.cleanup() { + ticker.Stop() + } } } }() } +func (rate *Ratelimiter) cleanup() (empty bool) { + rate.mu.Lock() + defer rate.mu.Unlock() + + for key, entry := range rate.tableIPv4 { + entry.mu.Lock() + if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { + delete(rate.tableIPv4, key) + } + entry.mu.Unlock() + } + + for key, entry := range rate.tableIPv6 { + entry.mu.Lock() + if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { + delete(rate.tableIPv6, key) + } + entry.mu.Unlock() + } + + return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 +} + func (rate *Ratelimiter) Allow(ip net.IP) bool { var entry *RatelimiterEntry var keyIPv4 [net.IPv4len]byte @@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { IPv4 := ip.To4() IPv6 := ip.To16() - rate.RLock() + rate.mu.RLock() if IPv4 != nil { copy(keyIPv4[:], IPv4) @@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { entry = rate.tableIPv6[keyIPv6] } - rate.RUnlock() + rate.mu.RUnlock() // make new entry if not found if entry == nil { entry = new(RatelimiterEntry) entry.tokens = maxTokens - packetCost - entry.lastTime = time.Now() - rate.Lock() + entry.lastTime = rate.timeNow() + rate.mu.Lock() if IPv4 != nil { rate.tableIPv4[keyIPv4] = entry if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { @@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { rate.stopReset <- struct{}{} } } - rate.Unlock() + rate.mu.Unlock() return true } // add tokens to entry - entry.Lock() - now := time.Now() + entry.mu.Lock() + now := rate.timeNow() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.lastTime = now if entry.tokens > maxTokens { @@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { if entry.tokens > packetCost { entry.tokens -= packetCost - entry.Unlock() + entry.mu.Unlock() return true } - entry.Unlock() + entry.mu.Unlock() return false } diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go index 659bdfb..25d5d63 100644 --- a/ratelimiter/ratelimiter_test.go +++ b/ratelimiter/ratelimiter_test.go @@ -11,22 +11,21 @@ import ( "time" ) -type RatelimiterResult struct { +type result struct { allowed bool text string wait time.Duration } func TestRatelimiter(t *testing.T) { + var rate Ratelimiter + var expectedResults []result - var ratelimiter Ratelimiter - var expectedResults []RatelimiterResult - - Nano := func(nano int64) time.Duration { + nano := func(nano int64) time.Duration { return time.Nanosecond * time.Duration(nano) } - Add := func(res RatelimiterResult) { + add := func(res result) { expectedResults = append( expectedResults, res, @@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) { } for i := 0; i < packetsBurstable; i++ { - Add(RatelimiterResult{ + add(result{ allowed: true, text: "initial burst", }) } - Add(RatelimiterResult{ + add(result{ allowed: false, text: "after burst", }) - Add(RatelimiterResult{ + add(result{ allowed: true, - wait: Nano(time.Second.Nanoseconds() / packetsPerSecond), + wait: nano(time.Second.Nanoseconds() / packetsPerSecond), text: "filling tokens for single packet", }) - Add(RatelimiterResult{ + add(result{ allowed: false, text: "not having refilled enough", }) - Add(RatelimiterResult{ + add(result{ allowed: true, - wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)), + wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)), text: "filling tokens for two packet burst", }) - Add(RatelimiterResult{ + add(result{ allowed: true, text: "second packet in 2 packet burst", }) - Add(RatelimiterResult{ + add(result{ allowed: false, text: "packet following 2 packet burst", }) @@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) { net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), } - ratelimiter.Init() + now := time.Now() + rate.timeNow = func() time.Time { + return now + } + defer func() { + // Lock to avoid data race with cleanup goroutine from Init. + rate.mu.Lock() + defer rate.mu.Unlock() + + rate.timeNow = time.Now + }() + timeSleep := func(d time.Duration) { + now = now.Add(d + 1) + rate.cleanup() + } + + rate.Init() + defer rate.Close() for i, res := range expectedResults { - time.Sleep(res.wait) + timeSleep(res.wait) for _, ip := range ips { - allowed := ratelimiter.Allow(ip) + allowed := rate.Allow(ip) if allowed != res.allowed { - t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed) + t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed) } } }