initial import
This commit is contained in:
109
core/limit/periodlimit.go
Normal file
109
core/limit/periodlimit.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package limit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"zero/core/stores/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
|
||||
periodScript = `local limit = tonumber(ARGV[1])
|
||||
local window = tonumber(ARGV[2])
|
||||
local current = redis.call("INCRBY", KEYS[1], 1)
|
||||
if current == 1 then
|
||||
redis.call("expire", KEYS[1], window)
|
||||
return 1
|
||||
elseif current < limit then
|
||||
return 1
|
||||
elseif current == limit then
|
||||
return 2
|
||||
else
|
||||
return 0
|
||||
end`
|
||||
zoneDiff = 3600 * 8 // GMT+8 for our services
|
||||
)
|
||||
|
||||
const (
|
||||
Unknown = iota
|
||||
Allowed
|
||||
HitQuota
|
||||
OverQuota
|
||||
|
||||
internalOverQuota = 0
|
||||
internalAllowed = 1
|
||||
internalHitQuota = 2
|
||||
)
|
||||
|
||||
var ErrUnknownCode = errors.New("unknown status code")
|
||||
|
||||
type (
|
||||
LimitOption func(l *PeriodLimit)
|
||||
|
||||
PeriodLimit struct {
|
||||
period int
|
||||
quota int
|
||||
limitStore *redis.Redis
|
||||
keyPrefix string
|
||||
align bool
|
||||
}
|
||||
)
|
||||
|
||||
func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string,
|
||||
opts ...LimitOption) *PeriodLimit {
|
||||
limiter := &PeriodLimit{
|
||||
period: period,
|
||||
quota: quota,
|
||||
limitStore: limitStore,
|
||||
keyPrefix: keyPrefix,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(limiter)
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (h *PeriodLimit) Take(key string) (int, error) {
|
||||
resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{
|
||||
strconv.Itoa(h.quota),
|
||||
strconv.Itoa(h.calcExpireSeconds()),
|
||||
})
|
||||
if err != nil {
|
||||
return Unknown, err
|
||||
}
|
||||
|
||||
code, ok := resp.(int64)
|
||||
if !ok {
|
||||
return Unknown, ErrUnknownCode
|
||||
}
|
||||
|
||||
switch code {
|
||||
case internalOverQuota:
|
||||
return OverQuota, nil
|
||||
case internalAllowed:
|
||||
return Allowed, nil
|
||||
case internalHitQuota:
|
||||
return HitQuota, nil
|
||||
default:
|
||||
return Unknown, ErrUnknownCode
|
||||
}
|
||||
}
|
||||
|
||||
func (h *PeriodLimit) calcExpireSeconds() int {
|
||||
if h.align {
|
||||
unix := time.Now().Unix() + zoneDiff
|
||||
return h.period - int(unix%int64(h.period))
|
||||
} else {
|
||||
return h.period
|
||||
}
|
||||
}
|
||||
|
||||
func Align() LimitOption {
|
||||
return func(l *PeriodLimit) {
|
||||
l.align = true
|
||||
}
|
||||
}
|
||||
68
core/limit/periodlimit_test.go
Normal file
68
core/limit/periodlimit_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package limit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"zero/core/stores/redis"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPeriodLimit_Take(t *testing.T) {
|
||||
testPeriodLimit(t)
|
||||
}
|
||||
|
||||
func TestPeriodLimit_TakeWithAlign(t *testing.T) {
|
||||
testPeriodLimit(t, Align())
|
||||
}
|
||||
|
||||
func TestPeriodLimit_RedisUnavailable(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
|
||||
const (
|
||||
seconds = 1
|
||||
total = 100
|
||||
quota = 5
|
||||
)
|
||||
l := NewPeriodLimit(seconds, quota, redis.NewRedis(s.Addr(), redis.NodeType), "periodlimit")
|
||||
s.Close()
|
||||
val, err := l.Take("first")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, 0, val)
|
||||
}
|
||||
|
||||
func testPeriodLimit(t *testing.T, opts ...LimitOption) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
defer s.Close()
|
||||
|
||||
const (
|
||||
seconds = 1
|
||||
total = 100
|
||||
quota = 5
|
||||
)
|
||||
l := NewPeriodLimit(seconds, quota, redis.NewRedis(s.Addr(), redis.NodeType), "periodlimit", opts...)
|
||||
var allowed, hitQuota, overQuota int
|
||||
for i := 0; i < total; i++ {
|
||||
val, err := l.Take("first")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
switch val {
|
||||
case Allowed:
|
||||
allowed++
|
||||
case HitQuota:
|
||||
hitQuota++
|
||||
case OverQuota:
|
||||
overQuota++
|
||||
default:
|
||||
t.Error("unknown status")
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, quota-1, allowed)
|
||||
assert.Equal(t, 1, hitQuota)
|
||||
assert.Equal(t, total-quota, overQuota)
|
||||
}
|
||||
166
core/limit/tokenlimit.go
Normal file
166
core/limit/tokenlimit.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package limit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/stores/redis"
|
||||
|
||||
xrate "golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const (
|
||||
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
|
||||
// KEYS[1] as tokens_key
|
||||
// KEYS[2] as timestamp_key
|
||||
script = `local rate = tonumber(ARGV[1])
|
||||
local capacity = tonumber(ARGV[2])
|
||||
local now = tonumber(ARGV[3])
|
||||
local requested = tonumber(ARGV[4])
|
||||
local fill_time = capacity/rate
|
||||
local ttl = math.floor(fill_time*2)
|
||||
local last_tokens = tonumber(redis.call("get", KEYS[1]))
|
||||
if last_tokens == nil then
|
||||
last_tokens = capacity
|
||||
end
|
||||
|
||||
local last_refreshed = tonumber(redis.call("get", KEYS[2]))
|
||||
if last_refreshed == nil then
|
||||
last_refreshed = 0
|
||||
end
|
||||
|
||||
local delta = math.max(0, now-last_refreshed)
|
||||
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
|
||||
local allowed = filled_tokens >= requested
|
||||
local new_tokens = filled_tokens
|
||||
if allowed then
|
||||
new_tokens = filled_tokens - requested
|
||||
end
|
||||
|
||||
redis.call("setex", KEYS[1], ttl, new_tokens)
|
||||
redis.call("setex", KEYS[2], ttl, now)
|
||||
|
||||
return allowed`
|
||||
tokenFormat = "{%s}.tokens"
|
||||
timestampFormat = "{%s}.ts"
|
||||
pingInterval = time.Millisecond * 100
|
||||
)
|
||||
|
||||
// A TokenLimiter controls how frequently events are allowed to happen with in one second.
|
||||
type TokenLimiter struct {
|
||||
rate int
|
||||
burst int
|
||||
store *redis.Redis
|
||||
tokenKey string
|
||||
timestampKey string
|
||||
rescueLock sync.Mutex
|
||||
redisAlive uint32
|
||||
rescueLimiter *xrate.Limiter
|
||||
monitorStarted bool
|
||||
}
|
||||
|
||||
// NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits
|
||||
// bursts of at most burst tokens.
|
||||
func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter {
|
||||
tokenKey := fmt.Sprintf(tokenFormat, key)
|
||||
timestampKey := fmt.Sprintf(timestampFormat, key)
|
||||
|
||||
return &TokenLimiter{
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
store: store,
|
||||
tokenKey: tokenKey,
|
||||
timestampKey: timestampKey,
|
||||
redisAlive: 1,
|
||||
rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow is shorthand for AllowN(time.Now(), 1).
|
||||
func (lim *TokenLimiter) Allow() bool {
|
||||
return lim.AllowN(time.Now(), 1)
|
||||
}
|
||||
|
||||
// AllowN reports whether n events may happen at time now.
|
||||
// Use this method if you intend to drop / skip events that exceed the rate rate.
|
||||
// Otherwise use Reserve or Wait.
|
||||
func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
|
||||
return lim.reserveN(now, n)
|
||||
}
|
||||
|
||||
func (lim *TokenLimiter) reserveN(now time.Time, n int) bool {
|
||||
if atomic.LoadUint32(&lim.redisAlive) == 0 {
|
||||
return lim.rescueLimiter.AllowN(now, n)
|
||||
}
|
||||
|
||||
resp, err := lim.store.Eval(
|
||||
script,
|
||||
[]string{
|
||||
lim.tokenKey,
|
||||
lim.timestampKey,
|
||||
},
|
||||
[]string{
|
||||
strconv.Itoa(lim.rate),
|
||||
strconv.Itoa(lim.burst),
|
||||
strconv.FormatInt(now.Unix(), 10),
|
||||
strconv.Itoa(n),
|
||||
})
|
||||
// redis allowed == false
|
||||
// Lua boolean false -> r Nil bulk reply
|
||||
if err == redis.Nil {
|
||||
return false
|
||||
} else if err != nil {
|
||||
logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
|
||||
lim.startMonitor()
|
||||
return lim.rescueLimiter.AllowN(now, n)
|
||||
}
|
||||
|
||||
code, ok := resp.(int64)
|
||||
if !ok {
|
||||
logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp)
|
||||
lim.startMonitor()
|
||||
return lim.rescueLimiter.AllowN(now, n)
|
||||
}
|
||||
|
||||
// redis allowed == true
|
||||
// Lua boolean true -> r integer reply with value of 1
|
||||
return code == 1
|
||||
}
|
||||
|
||||
func (lim *TokenLimiter) startMonitor() {
|
||||
lim.rescueLock.Lock()
|
||||
defer lim.rescueLock.Unlock()
|
||||
|
||||
if lim.monitorStarted {
|
||||
return
|
||||
}
|
||||
|
||||
lim.monitorStarted = true
|
||||
atomic.StoreUint32(&lim.redisAlive, 0)
|
||||
|
||||
go lim.waitForRedis()
|
||||
}
|
||||
|
||||
func (lim *TokenLimiter) waitForRedis() {
|
||||
ticker := time.NewTicker(pingInterval)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
lim.rescueLock.Lock()
|
||||
lim.monitorStarted = false
|
||||
lim.rescueLock.Unlock()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if lim.store.Ping() {
|
||||
atomic.StoreUint32(&lim.redisAlive, 1)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
88
core/limit/tokenlimit_test.go
Normal file
88
core/limit/tokenlimit_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package limit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/stores/redis"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logx.Disable()
|
||||
}
|
||||
|
||||
func TestTokenLimit_Rescue(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
|
||||
const (
|
||||
total = 100
|
||||
rate = 5
|
||||
burst = 10
|
||||
)
|
||||
l := NewTokenLimiter(rate, burst, redis.NewRedis(s.Addr(), redis.NodeType), "tokenlimit")
|
||||
s.Close()
|
||||
|
||||
var allowed int
|
||||
for i := 0; i < total; i++ {
|
||||
time.Sleep(time.Second / time.Duration(total))
|
||||
if i == total>>1 {
|
||||
assert.Nil(t, s.Restart())
|
||||
}
|
||||
if l.Allow() {
|
||||
allowed++
|
||||
}
|
||||
|
||||
// make sure start monitor more than once doesn't matter
|
||||
l.startMonitor()
|
||||
}
|
||||
|
||||
assert.True(t, allowed >= burst+rate)
|
||||
}
|
||||
|
||||
func TestTokenLimit_Take(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
defer s.Close()
|
||||
|
||||
const (
|
||||
total = 100
|
||||
rate = 5
|
||||
burst = 10
|
||||
)
|
||||
l := NewTokenLimiter(rate, burst, redis.NewRedis(s.Addr(), redis.NodeType), "tokenlimit")
|
||||
var allowed int
|
||||
for i := 0; i < total; i++ {
|
||||
time.Sleep(time.Second / time.Duration(total))
|
||||
if l.Allow() {
|
||||
allowed++
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, allowed >= burst+rate)
|
||||
}
|
||||
|
||||
func TestTokenLimit_TakeBurst(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
defer s.Close()
|
||||
|
||||
const (
|
||||
total = 100
|
||||
rate = 5
|
||||
burst = 10
|
||||
)
|
||||
l := NewTokenLimiter(rate, burst, redis.NewRedis(s.Addr(), redis.NodeType), "tokenlimit")
|
||||
var allowed int
|
||||
for i := 0; i < total; i++ {
|
||||
if l.Allow() {
|
||||
allowed++
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, allowed >= burst)
|
||||
}
|
||||
Reference in New Issue
Block a user