2019-01-02 01:55:51 +01:00
|
|
|
/* SPDX-License-Identifier: MIT
|
2018-05-03 15:04:00 +02:00
|
|
|
*
|
2021-01-28 17:52:15 +01:00
|
|
|
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
2018-05-03 15:04:00 +02:00
|
|
|
*/
|
2017-07-12 23:11:49 +02:00
|
|
|
|
2018-05-03 15:04:00 +02:00
|
|
|
package ratelimiter
|
2017-07-11 18:48:29 +02:00
|
|
|
|
|
|
|
import (
|
2021-11-05 01:52:54 +01:00
|
|
|
"net/netip"
|
2017-07-11 18:48:29 +02:00
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
2018-02-11 23:01:55 +01:00
|
|
|
packetsPerSecond = 20
|
|
|
|
packetsBurstable = 5
|
|
|
|
garbageCollectTime = time.Second
|
|
|
|
packetCost = 1000000000 / packetsPerSecond
|
|
|
|
maxTokens = packetCost * packetsBurstable
|
2017-07-11 18:48:29 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
type RatelimiterEntry struct {
|
2019-12-09 00:22:31 +01:00
|
|
|
mu sync.Mutex
|
2017-07-11 18:48:29 +02:00
|
|
|
lastTime time.Time
|
|
|
|
tokens int64
|
|
|
|
}
|
|
|
|
|
|
|
|
type Ratelimiter struct {
|
2019-12-09 00:22:31 +01:00
|
|
|
mu sync.RWMutex
|
|
|
|
timeNow func() time.Time
|
|
|
|
|
|
|
|
stopReset chan struct{} // send to reset, close to stop
|
2021-11-05 01:52:54 +01:00
|
|
|
table map[netip.Addr]*RatelimiterEntry
|
2018-02-11 22:53:39 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (rate *Ratelimiter) Close() {
|
2019-12-09 00:22:31 +01:00
|
|
|
rate.mu.Lock()
|
|
|
|
defer rate.mu.Unlock()
|
2018-02-11 22:53:39 +01:00
|
|
|
|
2018-05-21 03:18:56 +02:00
|
|
|
if rate.stopReset != nil {
|
|
|
|
close(rate.stopReset)
|
2018-02-11 22:53:39 +01:00
|
|
|
}
|
2017-07-11 18:48:29 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func (rate *Ratelimiter) Init() {
|
2019-12-09 00:22:31 +01:00
|
|
|
rate.mu.Lock()
|
|
|
|
defer rate.mu.Unlock()
|
2018-02-11 22:53:39 +01:00
|
|
|
|
2019-12-09 00:22:31 +01:00
|
|
|
if rate.timeNow == nil {
|
|
|
|
rate.timeNow = time.Now
|
|
|
|
}
|
2018-02-11 23:01:55 +01:00
|
|
|
|
2019-12-09 00:22:31 +01:00
|
|
|
// stop any ongoing garbage collection routine
|
2018-05-21 03:18:56 +02:00
|
|
|
if rate.stopReset != nil {
|
|
|
|
close(rate.stopReset)
|
2018-02-11 22:53:39 +01:00
|
|
|
}
|
|
|
|
|
2018-05-21 03:18:56 +02:00
|
|
|
rate.stopReset = make(chan struct{})
|
2021-11-05 01:52:54 +01:00
|
|
|
rate.table = make(map[netip.Addr]*RatelimiterEntry)
|
2018-02-11 22:53:39 +01:00
|
|
|
|
2019-12-09 00:22:31 +01:00
|
|
|
stopReset := rate.stopReset // store in case Init is called again.
|
2018-02-11 23:01:55 +01:00
|
|
|
|
2019-12-09 00:22:31 +01:00
|
|
|
// Start garbage collection routine.
|
2018-02-11 22:53:39 +01:00
|
|
|
go func() {
|
2018-05-13 18:42:06 +02:00
|
|
|
ticker := time.NewTicker(time.Second)
|
2018-05-21 03:18:56 +02:00
|
|
|
ticker.Stop()
|
2018-02-11 22:53:39 +01:00
|
|
|
for {
|
|
|
|
select {
|
2019-12-09 00:22:31 +01:00
|
|
|
case _, ok := <-stopReset:
|
2018-05-13 18:42:06 +02:00
|
|
|
ticker.Stop()
|
2019-12-09 00:22:31 +01:00
|
|
|
if !ok {
|
2018-05-21 03:18:56 +02:00
|
|
|
return
|
|
|
|
}
|
2019-12-09 00:22:31 +01:00
|
|
|
ticker = time.NewTicker(time.Second)
|
2018-05-13 18:42:06 +02:00
|
|
|
case <-ticker.C:
|
2019-12-09 00:22:31 +01:00
|
|
|
if rate.cleanup() {
|
|
|
|
ticker.Stop()
|
|
|
|
}
|
2018-02-11 22:53:39 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}()
|
2017-07-11 18:48:29 +02:00
|
|
|
}
|
|
|
|
|
2019-12-09 00:22:31 +01:00
|
|
|
func (rate *Ratelimiter) cleanup() (empty bool) {
|
|
|
|
rate.mu.Lock()
|
|
|
|
defer rate.mu.Unlock()
|
|
|
|
|
2021-11-05 01:52:54 +01:00
|
|
|
for key, entry := range rate.table {
|
2019-12-09 00:22:31 +01:00
|
|
|
entry.mu.Lock()
|
|
|
|
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
2021-11-05 01:52:54 +01:00
|
|
|
delete(rate.table, key)
|
2019-12-09 00:22:31 +01:00
|
|
|
}
|
|
|
|
entry.mu.Unlock()
|
|
|
|
}
|
|
|
|
|
2021-11-05 01:52:54 +01:00
|
|
|
return len(rate.table) == 0
|
2019-12-09 00:22:31 +01:00
|
|
|
}
|
|
|
|
|
2021-11-05 01:52:54 +01:00
|
|
|
func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
|
2017-07-11 18:48:29 +02:00
|
|
|
var entry *RatelimiterEntry
|
|
|
|
// lookup entry
|
2019-12-09 00:22:31 +01:00
|
|
|
rate.mu.RLock()
|
2021-11-05 01:52:54 +01:00
|
|
|
entry = rate.table[ip]
|
2019-12-09 00:22:31 +01:00
|
|
|
rate.mu.RUnlock()
|
2017-07-11 18:48:29 +02:00
|
|
|
|
|
|
|
// make new entry if not found
|
|
|
|
if entry == nil {
|
|
|
|
entry = new(RatelimiterEntry)
|
2018-02-11 23:01:55 +01:00
|
|
|
entry.tokens = maxTokens - packetCost
|
2019-12-09 00:22:31 +01:00
|
|
|
entry.lastTime = rate.timeNow()
|
|
|
|
rate.mu.Lock()
|
2021-11-05 01:52:54 +01:00
|
|
|
rate.table[ip] = entry
|
|
|
|
if len(rate.table) == 1 {
|
|
|
|
rate.stopReset <- struct{}{}
|
2017-07-11 18:48:29 +02:00
|
|
|
}
|
2019-12-09 00:22:31 +01:00
|
|
|
rate.mu.Unlock()
|
2017-07-11 18:48:29 +02:00
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
|
|
|
// add tokens to entry
|
2019-12-09 00:22:31 +01:00
|
|
|
entry.mu.Lock()
|
|
|
|
now := rate.timeNow()
|
2017-07-11 18:48:29 +02:00
|
|
|
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
|
|
|
entry.lastTime = now
|
2018-02-11 23:01:55 +01:00
|
|
|
if entry.tokens > maxTokens {
|
|
|
|
entry.tokens = maxTokens
|
2017-07-11 18:48:29 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
// subtract cost of packet
|
2018-02-11 23:01:55 +01:00
|
|
|
if entry.tokens > packetCost {
|
|
|
|
entry.tokens -= packetCost
|
2019-12-09 00:22:31 +01:00
|
|
|
entry.mu.Unlock()
|
2017-07-11 18:48:29 +02:00
|
|
|
return true
|
|
|
|
}
|
2019-12-09 00:22:31 +01:00
|
|
|
entry.mu.Unlock()
|
2017-07-11 18:48:29 +02:00
|
|
|
return false
|
|
|
|
}
|