mirror of
https://git.zx2c4.com/wireguard-go
synced 2024-11-15 09:15:14 +01:00
9cd8909df2
The existing test would occasionally flake out with: --- FAIL: TestRatelimiter (0.12s) ratelimiter_test.go:99: Test failed for 127.0.0.1 , on: 7 ( not having refilled enough ) expected: false got: true FAIL FAIL golang.zx2c4.com/wireguard/ratelimiter 0.171s The fake clock also means the tests run much faster, so testing this package with -count=1000 now takes < 100ms. While here, several style cleanups. The most significant one is unembeding the sync.Mutex fields in the rate limiter objects. Embedded as they were, the lock methods were accessible outside the ratelimiter package. As they aren't needed externally, keep them internal to make them easier to reason about. Passes `go test -race -count=10000 ./ratelimiter` Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
173 lines
3.3 KiB
Go
173 lines
3.3 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package ratelimiter
|
|
|
|
import (
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
packetsPerSecond = 20
|
|
packetsBurstable = 5
|
|
garbageCollectTime = time.Second
|
|
packetCost = 1000000000 / packetsPerSecond
|
|
maxTokens = packetCost * packetsBurstable
|
|
)
|
|
|
|
type RatelimiterEntry struct {
|
|
mu sync.Mutex
|
|
lastTime time.Time
|
|
tokens int64
|
|
}
|
|
|
|
type Ratelimiter struct {
|
|
mu sync.RWMutex
|
|
timeNow func() time.Time
|
|
|
|
stopReset chan struct{} // send to reset, close to stop
|
|
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
|
|
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
|
|
}
|
|
|
|
func (rate *Ratelimiter) Close() {
|
|
rate.mu.Lock()
|
|
defer rate.mu.Unlock()
|
|
|
|
if rate.stopReset != nil {
|
|
close(rate.stopReset)
|
|
}
|
|
}
|
|
|
|
func (rate *Ratelimiter) Init() {
|
|
rate.mu.Lock()
|
|
defer rate.mu.Unlock()
|
|
|
|
if rate.timeNow == nil {
|
|
rate.timeNow = time.Now
|
|
}
|
|
|
|
// stop any ongoing garbage collection routine
|
|
if rate.stopReset != nil {
|
|
close(rate.stopReset)
|
|
}
|
|
|
|
rate.stopReset = make(chan struct{})
|
|
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
|
|
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
|
|
|
|
stopReset := rate.stopReset // store in case Init is called again.
|
|
|
|
// Start garbage collection routine.
|
|
go func() {
|
|
ticker := time.NewTicker(time.Second)
|
|
ticker.Stop()
|
|
for {
|
|
select {
|
|
case _, ok := <-stopReset:
|
|
ticker.Stop()
|
|
if !ok {
|
|
return
|
|
}
|
|
ticker = time.NewTicker(time.Second)
|
|
case <-ticker.C:
|
|
if rate.cleanup() {
|
|
ticker.Stop()
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (rate *Ratelimiter) cleanup() (empty bool) {
|
|
rate.mu.Lock()
|
|
defer rate.mu.Unlock()
|
|
|
|
for key, entry := range rate.tableIPv4 {
|
|
entry.mu.Lock()
|
|
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
|
delete(rate.tableIPv4, key)
|
|
}
|
|
entry.mu.Unlock()
|
|
}
|
|
|
|
for key, entry := range rate.tableIPv6 {
|
|
entry.mu.Lock()
|
|
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
|
delete(rate.tableIPv6, key)
|
|
}
|
|
entry.mu.Unlock()
|
|
}
|
|
|
|
return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
|
|
}
|
|
|
|
func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
|
var entry *RatelimiterEntry
|
|
var keyIPv4 [net.IPv4len]byte
|
|
var keyIPv6 [net.IPv6len]byte
|
|
|
|
// lookup entry
|
|
|
|
IPv4 := ip.To4()
|
|
IPv6 := ip.To16()
|
|
|
|
rate.mu.RLock()
|
|
|
|
if IPv4 != nil {
|
|
copy(keyIPv4[:], IPv4)
|
|
entry = rate.tableIPv4[keyIPv4]
|
|
} else {
|
|
copy(keyIPv6[:], IPv6)
|
|
entry = rate.tableIPv6[keyIPv6]
|
|
}
|
|
|
|
rate.mu.RUnlock()
|
|
|
|
// make new entry if not found
|
|
|
|
if entry == nil {
|
|
entry = new(RatelimiterEntry)
|
|
entry.tokens = maxTokens - packetCost
|
|
entry.lastTime = rate.timeNow()
|
|
rate.mu.Lock()
|
|
if IPv4 != nil {
|
|
rate.tableIPv4[keyIPv4] = entry
|
|
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
|
|
rate.stopReset <- struct{}{}
|
|
}
|
|
} else {
|
|
rate.tableIPv6[keyIPv6] = entry
|
|
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
|
|
rate.stopReset <- struct{}{}
|
|
}
|
|
}
|
|
rate.mu.Unlock()
|
|
return true
|
|
}
|
|
|
|
// add tokens to entry
|
|
|
|
entry.mu.Lock()
|
|
now := rate.timeNow()
|
|
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
|
entry.lastTime = now
|
|
if entry.tokens > maxTokens {
|
|
entry.tokens = maxTokens
|
|
}
|
|
|
|
// subtract cost of packet
|
|
|
|
if entry.tokens > packetCost {
|
|
entry.tokens -= packetCost
|
|
entry.mu.Unlock()
|
|
return true
|
|
}
|
|
entry.mu.Unlock()
|
|
return false
|
|
}
|