diff --git a/go.mod b/go.mod index d84237976..77212e251 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/PuerkitoBio/rehttp v1.4.0 github.com/adjust/rmq/v5 v5.2.0 github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d - github.com/bsm/redislock v0.9.3 github.com/caarlos0/env/v10 v10.0.0 github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d github.com/dave/jennifer v1.4.1 diff --git a/go.sum b/go.sum index c8bfd3ba9..178f1e098 100644 --- a/go.sum +++ b/go.sum @@ -55,8 +55,6 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= -github.com/bsm/redislock v0.9.3 h1:osmvugkXGiLDEhzUPdM0EUtKpTEgLLuli4Ky2Z4vx38= -github.com/bsm/redislock v0.9.3/go.mod h1:Epf7AJLiSFwLCiZcfi6pWFO/8eAYrYpQXFxEDPoDeAk= github.com/caarlos0/env/v10 v10.0.0 h1:yIHUBZGsyqCnpTkbjk8asUlx6RFhhEs+h7TOBdgdzXA= github.com/caarlos0/env/v10 v10.0.0/go.mod h1:ZfulV76NvVPw3tm591U4SwL3Xx9ldzBP9aGxzeN7G18= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= diff --git a/maintenance/failover/failover.go b/maintenance/failover/failover.go index f57e342d5..f3b6185c6 100644 --- a/maintenance/failover/failover.go +++ b/maintenance/failover/failover.go @@ -5,6 +5,7 @@ package failover import ( "context" "fmt" + "github.com/pace/bricks/pkg/lock/redislock" "net/http" "strings" "sync" @@ -12,7 +13,6 @@ import ( "github.com/rs/zerolog" - "github.com/bsm/redislock" "github.com/pace/bricks/backend/k8sapi" "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/health" @@ -52,6 +52,7 @@ type ActivePassive struct { clusterName string timeToFailover time.Duration locker *redislock.Client + redisClient *redis.Client // access to the kubernetes api k8sClient *k8sapi.Client @@ -77,6 +78,7 @@ func NewActivePassive(clusterName string, timeToFailover time.Duration, redisCli clusterName: clusterName, timeToFailover: timeToFailover, locker: redislock.New(redisClient), + redisClient: redisClient, k8sClient: k8sClient, } health.SetCustomReadinessCheck(activePassive.Handler) @@ -92,6 +94,8 @@ func (a *ActivePassive) Run(ctx context.Context) error { defer errors.HandleWithCtx(ctx, "activepassive failover handler") lockName := "activepassive:lock:" + a.clusterName + token := "activepassive:token:" + a.clusterName + logger := log.Ctx(ctx).With().Str("failover", lockName).Logger() ctx = logger.WithContext(ctx) @@ -124,6 +128,7 @@ func (a *ActivePassive) Run(ctx context.Context) error { if a.getState() == ACTIVE { err := lock.Refresh(ctx, a.timeToFailover, &redislock.Options{ RetryStrategy: redislock.LimitRetry(redislock.LinearBackoff(a.timeToFailover/3), 3), + Token: token, }) if err != nil { logger.Info().Err(err).Msg("failed to refresh the lock; becoming undefined...") @@ -136,6 +141,7 @@ func (a *ActivePassive) Run(ctx context.Context) error { lock, err = a.locker.Obtain(ctx, lockName, a.timeToFailover, &redislock.Options{ RetryStrategy: redislock.LimitRetry(redislock.LinearBackoff(a.timeToFailover/3), 3), + Token: token, }) if err != nil { if a.getState() != PASSIVE { @@ -149,6 +155,31 @@ func (a *ActivePassive) Run(ctx context.Context) error { logger.Debug().Msg("lock acquired; becoming active...") a.becomeActive(ctx) + logger.Debug().Msg("check if lock exists") + + // Verify that key exists, then, retrieve the value + keyExists, err := a.redisClient.Exists(ctx, lockName).Result() + if err != nil { + logger.Error().Err(err).Msgf("Stefan: Failed to check that lock/key '%v' exists", lockName) + + continue + } + + if keyExists == 0 { + logger.Info().Msgf("Stefan: Lock/Key '%s' does not exist", lockName) + + continue + } + + lockValue, err := a.redisClient.Get(ctx, lockName).Result() + if err != nil { + logger.Error().Err(err).Msg("Error getting key value") + + continue + } + + logger.Info().Msgf("Stefan: Key value is: %v", lockValue) + // Check TTL of the newly acquired lock ttl, err := safeGetTTL(ctx, lock, logger) if err != nil { diff --git a/pkg/lock/redis/lock.go b/pkg/lock/redis/lock.go index 6fbe55978..d6ddabce3 100644 --- a/pkg/lock/redis/lock.go +++ b/pkg/lock/redis/lock.go @@ -7,13 +7,13 @@ import ( "context" "errors" "fmt" + "github.com/pace/bricks/pkg/lock/redislock" "sync" "time" redisbackend "github.com/pace/bricks/backend/redis" pberrors "github.com/pace/bricks/maintenance/errors" - "github.com/bsm/redislock" "github.com/redis/go-redis/v9" "github.com/rs/zerolog/log" ) diff --git a/pkg/lock/redislock/redislock.go b/pkg/lock/redislock/redislock.go new file mode 100644 index 000000000..8e7dbf021 --- /dev/null +++ b/pkg/lock/redislock/redislock.go @@ -0,0 +1,311 @@ +package redislock + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "github.com/pace/bricks/maintenance/log" + "io" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9" +) + +var ( + luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`) + luaRelease = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) else return 0 end`) + luaPTTL = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pttl", KEYS[1]) else return -3 end`) +) + +var ( + // ErrNotObtained is returned when a lock cannot be obtained. + ErrNotObtained = errors.New("redislock: not obtained") + + // ErrLockNotHeld is returned when trying to release an inactive lock. + ErrLockNotHeld = errors.New("redislock: lock not held") +) + +// RedisClient is a minimal client interface. +type RedisClient interface { + redis.Scripter + SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.BoolCmd +} + +// Client wraps a redis client. +type Client struct { + client RedisClient + tmp []byte + tmpMu sync.Mutex +} + +// New creates a new Client instance with a custom namespace. +func New(client RedisClient) *Client { + return &Client{client: client} +} + +// Obtain tries to obtain a new lock using a key with the given TTL. +// May return ErrNotObtained if not successful. +func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt *Options) (*Lock, error) { + token := opt.getToken() + + // Create a random token + if token == "" { + var err error + if token, err = c.randomToken(); err != nil { + return nil, err + } + } + + value := token + opt.getMetadata() + retry := opt.getRetryStrategy() + + // make sure we don't retry forever + if _, ok := ctx.Deadline(); !ok { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, time.Now().Add(ttl)) + defer cancel() + } + + var ticker *time.Ticker + for { + log.Ctx(ctx).Debug().Msgf("Stefan: Obtain(): Trying to obtain the lock: key=%v, value=%v, ttl=%v", key, value, ttl) + ok, err := c.obtain(ctx, key, value, ttl) + if err != nil { + return nil, err + } else if ok { + return &Lock{Client: c, key: key, value: value, tokenLen: len(token)}, nil + } + + backoff := retry.NextBackoff() + if backoff < 1 { + return nil, ErrNotObtained + } + + if ticker == nil { + ticker = time.NewTicker(backoff) + defer ticker.Stop() + } else { + ticker.Reset(backoff) + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + } + } +} + +func (c *Client) obtain(ctx context.Context, key, value string, ttl time.Duration) (bool, error) { + return c.client.SetNX(ctx, key, value, ttl).Result() +} + +func (c *Client) randomToken() (string, error) { + c.tmpMu.Lock() + defer c.tmpMu.Unlock() + + if len(c.tmp) == 0 { + c.tmp = make([]byte, 16) + } + + if _, err := io.ReadFull(rand.Reader, c.tmp); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(c.tmp), nil +} + +// -------------------------------------------------------------------- + +// Lock represents an obtained, distributed lock. +type Lock struct { + *Client + key string + value string + tokenLen int +} + +// Obtain is a short-cut for New(...).Obtain(...). +func Obtain(ctx context.Context, client RedisClient, key string, ttl time.Duration, opt *Options) (*Lock, error) { + return New(client).Obtain(ctx, key, ttl, opt) +} + +// Key returns the redis key used by the lock. +func (l *Lock) Key() string { + return l.key +} + +// Token returns the token value set by the lock. +func (l *Lock) Token() string { + return l.value[:l.tokenLen] +} + +// Metadata returns the metadata of the lock. +func (l *Lock) Metadata() string { + return l.value[l.tokenLen:] +} + +// TTL returns the remaining time-to-live. Returns 0 if the lock has expired. +func (l *Lock) TTL(ctx context.Context) (time.Duration, error) { + log.Ctx(ctx).Debug().Msgf("Stefan: redislock.TTL(): key=%v, value=%v", l.key, l.value) + res, err := luaPTTL.Run(ctx, l.client, []string{l.key}, l.value).Result() + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("Stefan: redislock.TTL(): luaPTTL returned an error") + } + log.Ctx(ctx).Debug().Msgf("Stefan: redislock.TTL(): res=%v", res) + if errors.Is(err, redis.Nil) { + return 0, nil + } else if err != nil { + return 0, err + } + + if num := res.(int64); num > 0 { + return time.Duration(num) * time.Millisecond, nil + } + return 0, nil +} + +// Refresh extends the lock with a new TTL. +// May return ErrNotObtained if refresh is unsuccessful. +func (l *Lock) Refresh(ctx context.Context, ttl time.Duration, opt *Options) error { + ttlVal := strconv.FormatInt(int64(ttl/time.Millisecond), 10) + status, err := luaRefresh.Run(ctx, l.client, []string{l.key}, l.value, ttlVal).Result() + if err != nil { + return err + } else if status == int64(1) { + return nil + } + return ErrNotObtained +} + +// Release manually releases the lock. +// May return ErrLockNotHeld. +func (l *Lock) Release(ctx context.Context) error { + if l == nil { + return ErrLockNotHeld + } + + res, err := luaRelease.Run(ctx, l.client, []string{l.key}, l.value).Result() + if errors.Is(err, redis.Nil) { + return ErrLockNotHeld + } else if err != nil { + return err + } + + if i, ok := res.(int64); !ok || i != 1 { + return ErrLockNotHeld + } + return nil +} + +// -------------------------------------------------------------------- + +// Options describe the options for the lock +type Options struct { + // RetryStrategy allows to customise the lock retry strategy. + // Default: do not retry + RetryStrategy RetryStrategy + + // Metadata string. + Metadata string + + // Token is a unique value that is used to identify the lock. By default, a random tokens are generated. Use this + // option to provide a custom token instead. + Token string +} + +func (o *Options) getMetadata() string { + if o != nil { + return o.Metadata + } + return "" +} + +func (o *Options) getToken() string { + if o != nil { + return o.Token + } + return "" +} + +func (o *Options) getRetryStrategy() RetryStrategy { + if o != nil && o.RetryStrategy != nil { + return o.RetryStrategy + } + return NoRetry() +} + +// -------------------------------------------------------------------- + +// RetryStrategy allows to customise the lock retry strategy. +type RetryStrategy interface { + // NextBackoff returns the next backoff duration. + NextBackoff() time.Duration +} + +type linearBackoff time.Duration + +// LinearBackoff allows retries regularly with customized intervals +func LinearBackoff(backoff time.Duration) RetryStrategy { + return linearBackoff(backoff) +} + +// NoRetry acquire the lock only once. +func NoRetry() RetryStrategy { + return linearBackoff(0) +} + +func (r linearBackoff) NextBackoff() time.Duration { + return time.Duration(r) +} + +type limitedRetry struct { + s RetryStrategy + cnt int64 + max int64 +} + +// LimitRetry limits the number of retries to max attempts. +func LimitRetry(s RetryStrategy, max int) RetryStrategy { + return &limitedRetry{s: s, max: int64(max)} +} + +func (r *limitedRetry) NextBackoff() time.Duration { + if atomic.LoadInt64(&r.cnt) >= r.max { + return 0 + } + atomic.AddInt64(&r.cnt, 1) + return r.s.NextBackoff() +} + +type exponentialBackoff struct { + cnt uint64 + + min, max time.Duration +} + +// ExponentialBackoff strategy is an optimization strategy with a retry time of 2**n milliseconds (n means number of times). +// You can set a minimum and maximum value, the recommended minimum value is not less than 16ms. +func ExponentialBackoff(min, max time.Duration) RetryStrategy { + return &exponentialBackoff{min: min, max: max} +} + +func (r *exponentialBackoff) NextBackoff() time.Duration { + cnt := atomic.AddUint64(&r.cnt, 1) + + ms := 2 << 25 + if cnt < 25 { + ms = 2 << cnt + } + + if d := time.Duration(ms) * time.Millisecond; d < r.min { + return r.min + } else if r.max != 0 && d > r.max { + return r.max + } else { + return d + } +}