diff --git a/drivers/store/redis/store.go b/drivers/store/redis/store.go index e8e45c8..97cd778 100644 --- a/drivers/store/redis/store.go +++ b/drivers/store/redis/store.go @@ -3,6 +3,8 @@ package redis import ( "context" "fmt" + "strings" + "sync" "time" libredis "github.com/go-redis/redis/v8" @@ -12,6 +14,32 @@ import ( "github.com/ulule/limiter/v3/drivers/store/common" ) +const ( + luaIncrScript = ` +local key = KEYS[1] +local count = tonumber(ARGV[1]) +local ttl = tonumber(ARGV[2]) +local ret = redis.call("incrby", key, ARGV[1]) +if ret == count then + if ttl > 0 then + redis.call("pexpire", key, ARGV[2]) + end + return {ret, ttl} +end +ttl = redis.call("pttl", key) +return {ret, ttl} +` + luaPeekScript = ` +local key = KEYS[1] +local v = redis.call("get", key) +if v == false then + return {0, 0} +end +local ttl = redis.call("pttl", key) +return {tonumber(v), ttl} +` +) + // Client is an interface thats allows to use a redis cluster or a redis single client seamlessly. type Client interface { Get(ctx context.Context, key string) *libredis.StringCmd @@ -19,17 +47,26 @@ type Client interface { Watch(ctx context.Context, handler func(*libredis.Tx) error, keys ...string) error Del(ctx context.Context, keys ...string) *libredis.IntCmd SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *libredis.BoolCmd - Eval(ctx context.Context, script string, keys []string, args ...interface{}) *libredis.Cmd + EvalSha(ctx context.Context, sha string, keys []string, args ...interface{}) *libredis.Cmd + ScriptLoad(ctx context.Context, script string) *libredis.StringCmd } // Store is the redis store. type Store struct { // Prefix used for the key. Prefix string + // deprecated, this option make no sense when all operations were atomic // MaxRetry is the maximum number of retry under race conditions. MaxRetry int // client used to communicate with redis server. client Client + // luaIncrSHA is the SHA of increase and expire key script + luaIncrSHA string + // luaPeekSHA is the SHA of peek and expire key script + luaPeekSHA string + // hasLuaScriptLoaded was used to check whether the lua script was loaded or not + hasLuaScriptLoaded bool + mu sync.Mutex } // NewStore returns an instance of redis store with defaults. @@ -44,271 +81,125 @@ func NewStore(client Client) (limiter.Store, error) { // NewStoreWithOptions returns an instance of redis store with options. func NewStoreWithOptions(client Client, options limiter.StoreOptions) (limiter.Store, error) { store := &Store{ - client: client, - Prefix: options.Prefix, - MaxRetry: options.MaxRetry, + client: client, + Prefix: options.Prefix, + MaxRetry: options.MaxRetry, + hasLuaScriptLoaded: false, } if store.MaxRetry <= 0 { store.MaxRetry = 1 } - + if err := store.preloadLuaScripts(context.Background()); err != nil { + return nil, err + } return store, nil } -// Get returns the limit for given identifier. -func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { - key = fmt.Sprintf("%s:%s", store.Prefix, key) - now := time.Now() - - lctx := limiter.Context{} - onWatch := func(rtx *libredis.Tx) error { - - created, err := store.doSetValue(ctx, rtx, key, rate.Period) - if err != nil { - return err - } - - if created { - expiration := now.Add(rate.Period) - lctx = common.GetContextFromState(now, rate, expiration, 1) - return nil - } - - count, ttl, err := store.doUpdateValue(ctx, rtx, key, rate.Period) - if err != nil { - return err - } - - expiration := now.Add(rate.Period) - if ttl > 0 { - expiration = now.Add(ttl) - } - - lctx = common.GetContextFromState(now, rate, expiration, count) +// preloadLuaScripts would preload the incr and peek lua script +func (store *Store) preloadLuaScripts(ctx context.Context) error { + store.mu.Lock() + defer store.mu.Unlock() + if store.hasLuaScriptLoaded { return nil } + incrLuaSHA, err := store.client.ScriptLoad(ctx, luaIncrScript).Result() + if err != nil { + return errors.Wrap(err, "failed to load incr lua script") + } + peekLuaSHA, err := store.client.ScriptLoad(ctx, luaPeekScript).Result() + if err != nil { + return errors.Wrap(err, "failed to load peek lua script") + } + store.luaIncrSHA = incrLuaSHA + store.luaPeekSHA = peekLuaSHA + store.hasLuaScriptLoaded = true + return nil +} - err := store.client.Watch(ctx, onWatch, key) +// Get returns the limit for given identifier. +func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + key = fmt.Sprintf("%s:%s", store.Prefix, key) + cmd := store.evalSHA(ctx, store.luaIncrSHA, []string{key}, 1, rate.Period.Milliseconds()) + count, ttl, err := parseCountAndTTL(cmd) if err != nil { - err = errors.Wrapf(err, "limiter: cannot get value for %s", key) return limiter.Context{}, err } - - return lctx, nil + now := time.Now() + expiration := now.Add(rate.Period) + if ttl > 0 { + expiration = now.Add(time.Duration(ttl) * time.Millisecond) + } + return common.GetContextFromState(now, rate, expiration, count), nil } // Peek returns the limit for given identifier, without modification on current values. func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { key = fmt.Sprintf("%s:%s", store.Prefix, key) - now := time.Now() - - lctx := limiter.Context{} - onWatch := func(rtx *libredis.Tx) error { - count, ttl, err := store.doPeekValue(ctx, rtx, key) - if err != nil { - return err - } - - expiration := now.Add(rate.Period) - if ttl > 0 { - expiration = now.Add(ttl) - } - - lctx = common.GetContextFromState(now, rate, expiration, count) - return nil - } - - err := store.client.Watch(ctx, onWatch, key) + cmd := store.evalSHA(ctx, store.luaPeekSHA, []string{key}) + count, ttl, err := parseCountAndTTL(cmd) if err != nil { - err = errors.Wrapf(err, "limiter: cannot peek value for %s", key) return limiter.Context{}, err } - - return lctx, nil + now := time.Now() + expiration := now.Add(rate.Period) + if ttl > 0 { + expiration = now.Add(time.Duration(ttl) * time.Millisecond) + } + return common.GetContextFromState(now, rate, expiration, count), nil } // Reset returns the limit for given identifier which is set to zero. func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { key = fmt.Sprintf("%s:%s", store.Prefix, key) - now := time.Now() - - lctx := limiter.Context{} - onWatch := func(rtx *libredis.Tx) error { - - err := store.doResetValue(ctx, rtx, key) - if err != nil { - return err - } - - count := int64(0) - expiration := now.Add(rate.Period) - - lctx = common.GetContextFromState(now, rate, expiration, count) - return nil - } - - err := store.client.Watch(ctx, onWatch, key) - if err != nil { - err = errors.Wrapf(err, "limiter: cannot reset value for %s", key) + if _, err := store.client.Del(ctx, key).Result(); err != nil { return limiter.Context{}, err } - - return lctx, nil + count := int64(0) + now := time.Now() + expiration := now.Add(rate.Period) + return common.GetContextFromState(now, rate, expiration, count), nil } -// doPeekValue will execute peekValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached. -func (store *Store) doPeekValue(ctx context.Context, rtx *libredis.Tx, key string) (int64, time.Duration, error) { - for i := 0; i < store.MaxRetry; i++ { - count, ttl, err := peekValue(ctx, rtx, key) - if err == nil { - return count, ttl, nil +// evalSHA eval the redis lua sha and load the script if missing +func (store *Store) evalSHA(ctx context.Context, sha string, keys []string, args ...interface{}) *libredis.Cmd { + cmd := store.client.EvalSha(ctx, sha, keys, args...) + if err := cmd.Err(); err != nil { + if !isLuaScriptGone(err) { + return cmd } - } - return 0, 0, errors.New("retry limit exceeded") -} - -// peekValue will retrieve the counter and its expiration for given key. -func peekValue(ctx context.Context, rtx *libredis.Tx, key string) (int64, time.Duration, error) { - pipe := rtx.TxPipeline() - value := pipe.Get(ctx, key) - expire := pipe.PTTL(ctx, key) - - _, err := pipe.Exec(ctx) - if err != nil && err != libredis.Nil { - return 0, 0, err - } - - count, err := value.Int64() - if err != nil && err != libredis.Nil { - return 0, 0, err - } - - ttl, err := expire.Result() - if err != nil { - return 0, 0, err - } - - return count, ttl, nil -} - -// doSetValue will execute setValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached. -func (store *Store) doSetValue(ctx context.Context, rtx *libredis.Tx, - key string, expiration time.Duration) (bool, error) { - - for i := 0; i < store.MaxRetry; i++ { - created, err := setValue(ctx, rtx, key, expiration) - if err == nil { - return created, nil + store.mu.Lock() + store.hasLuaScriptLoaded = false + store.mu.Unlock() + if err := store.preloadLuaScripts(ctx); err != nil { + cmd = libredis.NewCmd(ctx) + cmd.SetErr(err) + return cmd } + cmd = store.client.EvalSha(ctx, sha, keys) } - return false, errors.New("retry limit exceeded") -} - -// setValue will try to initialize a new counter if given key doesn't exists. -func setValue(ctx context.Context, rtx *libredis.Tx, key string, expiration time.Duration) (bool, error) { - value := rtx.SetNX(ctx, key, 1, expiration) - - created, err := value.Result() - if err != nil { - return false, err - } - - return created, nil + return cmd } -// doUpdateValue will execute setValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached. -func (store *Store) doUpdateValue(ctx context.Context, rtx *libredis.Tx, key string, - expiration time.Duration) (int64, time.Duration, error) { - - for i := 0; i < store.MaxRetry; i++ { - count, ttl, err := updateValue(ctx, rtx, key, expiration) - if err == nil { - return count, ttl, nil - } - - // If ttl is negative and there is an error, do not retry an update. - if ttl < 0 { - return 0, 0, err - } - } - return 0, 0, errors.New("retry limit exceeded") +// isLuaScriptGone check whether the error was no script or no +func isLuaScriptGone(err error) bool { + return strings.HasPrefix(err.Error(), "NOSCRIPT") } -// updateValue will try to increment the counter identified by given key. -func updateValue(ctx context.Context, rtx *libredis.Tx, - key string, expiration time.Duration) (int64, time.Duration, error) { - - pipe := rtx.TxPipeline() - value := pipe.Incr(ctx, key) - expire := pipe.PTTL(ctx, key) - - _, err := pipe.Exec(ctx) +// parseCountAndTTL parse count and ttl from lua script output +func parseCountAndTTL(cmd *libredis.Cmd) (int64, int64, error) { + ret, err := cmd.Result() if err != nil { return 0, 0, err } - - count, err := value.Result() - if err != nil { - return 0, 0, err - } - - ttl, err := expire.Result() - if err != nil { - return 0, 0, err + if fields, ok := ret.([]interface{}); !ok || len(fields) != 2 { + return 0, 0, errors.New("two elements in array was expected") } - - // If ttl is less than zero, we have to define key expiration. - // The PTTL command returns -2 if the key does not exist, and -1 if the key exists, but there is no expiry set. - // We shouldn't try to set an expiry on a key that doesn't exist. - if isExpirationRequired(ttl) { - expire := rtx.Expire(ctx, key, expiration) - - ok, err := expire.Result() - if err != nil { - return count, ttl, err - } - - if !ok { - return count, ttl, errors.New("cannot configure timeout on key") - } + fields := ret.([]interface{}) + count, ok1 := fields[0].(int64) + ttl, ok2 := fields[1].(int64) + if !ok1 || !ok2 { + return 0, 0, errors.New("type of the count and ttl should be number") } - return count, ttl, nil } - -// doResetValue will execute resetValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached. -func (store *Store) doResetValue(ctx context.Context, rtx *libredis.Tx, key string) error { - for i := 0; i < store.MaxRetry; i++ { - err := resetValue(ctx, rtx, key) - if err == nil { - return nil - } - } - return errors.New("retry limit exceeded") -} - -// resetValue will try to reset the counter identified by given key. -func resetValue(ctx context.Context, rtx *libredis.Tx, key string) error { - deletion := rtx.Del(ctx, key) - - _, err := deletion.Result() - if err != nil { - return err - } - - return nil -} - -// isExpirationRequired returns if we should set an expiration on a key, using (error) result from PTTL command. -// The error code is -2 if the key does not exist, and -1 if the key exists. -// Usually, it should be returned in nanosecond, but some users have reported that it could be in millisecond as well. -// Better safe than sorry: we handle both. -func isExpirationRequired(ttl time.Duration) bool { - switch ttl { - case -1 * time.Nanosecond, -1 * time.Millisecond: - return true - default: - return false - } -} diff --git a/drivers/store/redis/store_test.go b/drivers/store/redis/store_test.go index 87cff5e..8a615dd 100644 --- a/drivers/store/redis/store_test.go +++ b/drivers/store/redis/store_test.go @@ -108,3 +108,26 @@ func newRedisClient() (*libredis.Client, error) { client := libredis.NewClient(opt) return client, nil } + +func BenchmarkGet(b *testing.B) { + is := require.New(b) + client, err := newRedisClient() + is.NoError(err) + is.NotNil(client) + store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{ + Prefix: "limiter:redis:benchmark", + MaxRetry: 3, + }) + is.NoError(err) + is.NotNil(store) + limiter := limiter.New(store, limiter.Rate{ + Limit: 100000, + Period: 10 * time.Second, + }) + + for i := 0; i < b.N; i++ { + lctx, err := limiter.Get(context.TODO(), "foo") + is.NoError(err) + is.NotZero(lctx) + } +}