Skip to content

Commit

Permalink
Merge pull request #57 from NathanBaulch/master
Browse files Browse the repository at this point in the history
Simplify recently added context support
  • Loading branch information
hjr265 authored Sep 20, 2020
2 parents cb454cf + c3d6845 commit 6c7240c
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 71 deletions.
8 changes: 4 additions & 4 deletions mutex.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (m *Mutex) valid(ctx context.Context, pool redis.Pool) (bool, error) {
return false, err
}
defer conn.Close()
reply, err := conn.Get(ctx, m.name)
reply, err := conn.Get(m.name)
if err != nil {
return false, err
}
Expand All @@ -144,7 +144,7 @@ func (m *Mutex) acquire(ctx context.Context, pool redis.Pool, value string) (boo
return false, err
}
defer conn.Close()
reply, err := conn.SetNX(ctx, m.name, value, m.expiry)
reply, err := conn.SetNX(m.name, value, m.expiry)
if err != nil {
return false, err
}
Expand All @@ -165,7 +165,7 @@ func (m *Mutex) release(ctx context.Context, pool redis.Pool, value string) (boo
return false, err
}
defer conn.Close()
status, err := conn.Eval(ctx, deleteScript, m.name, value)
status, err := conn.Eval(deleteScript, m.name, value)
if err != nil {
return false, err
}
Expand All @@ -186,7 +186,7 @@ func (m *Mutex) touch(ctx context.Context, pool redis.Pool, value string, expiry
return false, err
}
defer conn.Close()
status, err := conn.Eval(ctx, touchScript, m.name, value, expiry)
status, err := conn.Eval(touchScript, m.name, value, expiry)
if err != nil {
return false, err
}
Expand Down
13 changes: 6 additions & 7 deletions mutex_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package redsync

import (
"context"
"strconv"
"testing"
"time"
Expand Down Expand Up @@ -127,11 +126,11 @@ func TestValid(t *testing.T) {
func getPoolValues(pools []redis.Pool, name string) []string {
values := make([]string, len(pools))
for i, pool := range pools {
conn, err := pool.Get(context.TODO())
conn, err := pool.Get(nil)
if err != nil {
panic(err)
}
value, err := conn.Get(context.TODO(), name)
value, err := conn.Get(name)
if err != nil {
panic(err)
}
Expand All @@ -144,11 +143,11 @@ func getPoolValues(pools []redis.Pool, name string) []string {
func getPoolExpiries(pools []redis.Pool, name string) []int {
expiries := make([]int, len(pools))
for i, pool := range pools {
conn, err := pool.Get(context.TODO())
conn, err := pool.Get(nil)
if err != nil {
panic(err)
}
expiry, err := conn.PTTL(context.TODO(), name)
expiry, err := conn.PTTL(name)
if err != nil {
panic(err)
}
Expand All @@ -165,11 +164,11 @@ func clogPools(pools []redis.Pool, mask int, mutex *Mutex) int {
n++
continue
}
conn, err := pool.Get(context.TODO())
conn, err := pool.Get(nil)
if err != nil {
panic(err)
}
_, err = conn.Set(context.TODO(), mutex.name, "foobar")
_, err = conn.Set(mutex.name, "foobar")
if err != nil {
panic(err)
}
Expand Down
36 changes: 16 additions & 20 deletions redis/goredis/goredis.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ type pool struct {
}

func (p *pool) Get(ctx context.Context) (redsyncredis.Conn, error) {
return &conn{p.delegate}, nil
c := p.delegate
if ctx != nil {
c = c.WithContext(ctx)
}
return &conn{c}, nil
}

func NewPool(delegate *redis.Client) redsyncredis.Pool {
Expand All @@ -25,27 +29,27 @@ type conn struct {
delegate *redis.Client
}

func (c *conn) Get(ctx context.Context, name string) (string, error) {
value, err := c.client(ctx).Get(name).Result()
func (c *conn) Get(name string) (string, error) {
value, err := c.delegate.Get(name).Result()
return value, noErrNil(err)
}

func (c *conn) Set(ctx context.Context, name string, value string) (bool, error) {
reply, err := c.client(ctx).Set(name, value, 0).Result()
func (c *conn) Set(name string, value string) (bool, error) {
reply, err := c.delegate.Set(name, value, 0).Result()
return reply == "OK", noErrNil(err)
}

func (c *conn) SetNX(ctx context.Context, name string, value string, expiry time.Duration) (bool, error) {
ok, err := c.client(ctx).SetNX(name, value, expiry).Result()
func (c *conn) SetNX(name string, value string, expiry time.Duration) (bool, error) {
ok, err := c.delegate.SetNX(name, value, expiry).Result()
return ok, noErrNil(err)
}

func (c *conn) PTTL(ctx context.Context, name string) (time.Duration, error) {
expiry, err := c.client(ctx).PTTL(name).Result()
func (c *conn) PTTL(name string) (time.Duration, error) {
expiry, err := c.delegate.PTTL(name).Result()
return expiry, noErrNil(err)
}

func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) {
func (c *conn) Eval(script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) {
keys := make([]string, script.KeyCount)
args := keysAndArgs

Expand All @@ -57,10 +61,9 @@ func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArg
args = keysAndArgs[script.KeyCount:]
}

cli := c.client(ctx)
v, err := cli.EvalSha(script.Hash, keys, args...).Result()
v, err := c.delegate.EvalSha(script.Hash, keys, args...).Result()
if err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") {
v, err = cli.Eval(script.Src, keys, args...).Result()
v, err = c.delegate.Eval(script.Src, keys, args...).Result()
}
return v, noErrNil(err)
}
Expand All @@ -70,13 +73,6 @@ func (c *conn) Close() error {
return nil
}

func (c *conn) client(ctx context.Context) *redis.Client {
if ctx != nil {
return c.delegate.WithContext(ctx)
}
return c.delegate
}

func noErrNil(err error) error {
if err == redis.Nil {
return nil
Expand Down
29 changes: 16 additions & 13 deletions redis/goredis/v7/goredis.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ type pool struct {
}

func (p *pool) Get(ctx context.Context) (redsyncredis.Conn, error) {
return &conn{p.delegate}, nil
c := p.delegate
if ctx != nil {
c = c.WithContext(ctx)
}
return &conn{c}, nil
}

func NewPool(delegate *redis.Client) redsyncredis.Pool {
Expand All @@ -25,25 +29,25 @@ type conn struct {
delegate *redis.Client
}

func (c *conn) Get(ctx context.Context, name string) (string, error) {
value, err := c.client(ctx).Get(name).Result()
func (c *conn) Get(name string) (string, error) {
value, err := c.delegate.Get(name).Result()
return value, noErrNil(err)
}

func (c *conn) Set(ctx context.Context, name string, value string) (bool, error) {
reply, err := c.client(ctx).Set(name, value, 0).Result()
func (c *conn) Set(name string, value string) (bool, error) {
reply, err := c.delegate.Set(name, value, 0).Result()
return reply == "OK", err
}

func (c *conn) SetNX(ctx context.Context, name string, value string, expiry time.Duration) (bool, error) {
return c.client(ctx).SetNX(name, value, expiry).Result()
func (c *conn) SetNX(name string, value string, expiry time.Duration) (bool, error) {
return c.delegate.SetNX(name, value, expiry).Result()
}

func (c *conn) PTTL(ctx context.Context, name string) (time.Duration, error) {
return c.client(ctx).PTTL(name).Result()
func (c *conn) PTTL(name string) (time.Duration, error) {
return c.delegate.PTTL(name).Result()
}

func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) {
func (c *conn) Eval(script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) {
keys := make([]string, script.KeyCount)
args := keysAndArgs

Expand All @@ -54,10 +58,9 @@ func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArg
args = keysAndArgs[script.KeyCount:]
}

cli := c.client(ctx)
v, err := cli.EvalSha(script.Hash, keys, args...).Result()
v, err := c.delegate.EvalSha(script.Hash, keys, args...).Result()
if err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") {
v, err = cli.Eval(script.Src, keys, args...).Result()
v, err = c.delegate.Eval(script.Src, keys, args...).Result()
}
return v, noErrNil(err)
}
Expand Down
29 changes: 16 additions & 13 deletions redis/goredis/v8/goredis.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ type pool struct {
}

func (p *pool) Get(ctx context.Context) (redsyncredis.Conn, error) {
return &conn{p.delegate}, nil
if ctx == nil {
ctx = p.delegate.Context()
}
return &conn{p.delegate, ctx}, nil
}

func NewPool(delegate *redis.Client) redsyncredis.Pool {
Expand All @@ -23,27 +26,28 @@ func NewPool(delegate *redis.Client) redsyncredis.Pool {

type conn struct {
delegate *redis.Client
ctx context.Context
}

func (c *conn) Get(ctx context.Context, name string) (string, error) {
value, err := c.delegate.Get(c._context(ctx), name).Result()
func (c *conn) Get(name string) (string, error) {
value, err := c.delegate.Get(c.ctx, name).Result()
return value, noErrNil(err)
}

func (c *conn) Set(ctx context.Context, name string, value string) (bool, error) {
reply, err := c.delegate.Set(c._context(ctx), name, value, 0).Result()
func (c *conn) Set(name string, value string) (bool, error) {
reply, err := c.delegate.Set(c.ctx, name, value, 0).Result()
return reply == "OK", err
}

func (c *conn) SetNX(ctx context.Context, name string, value string, expiry time.Duration) (bool, error) {
return c.delegate.SetNX(c._context(ctx), name, value, expiry).Result()
func (c *conn) SetNX(name string, value string, expiry time.Duration) (bool, error) {
return c.delegate.SetNX(c.ctx, name, value, expiry).Result()
}

func (c *conn) PTTL(ctx context.Context, name string) (time.Duration, error) {
return c.delegate.PTTL(c._context(ctx), name).Result()
func (c *conn) PTTL(name string) (time.Duration, error) {
return c.delegate.PTTL(c.ctx, name).Result()
}

func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) {
func (c *conn) Eval(script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) {
keys := make([]string, script.KeyCount)
args := keysAndArgs

Expand All @@ -54,10 +58,9 @@ func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArg
args = keysAndArgs[script.KeyCount:]
}

ctx = c._context(ctx)
v, err := c.delegate.EvalSha(ctx, script.Hash, keys, args...).Result()
v, err := c.delegate.EvalSha(c.ctx, script.Hash, keys, args...).Result()
if err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") {
v, err = c.delegate.Eval(ctx, script.Src, keys, args...).Result()
v, err = c.delegate.Eval(c.ctx, script.Src, keys, args...).Result()
}
return v, noErrNil(err)
}
Expand Down
21 changes: 12 additions & 9 deletions redis/redigo/redigo.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ type pool struct {
}

func (p *pool) Get(ctx context.Context) (redsyncredis.Conn, error) {
c, err := p.delegate.GetContext(ctx)
if err != nil {
return nil, err
if ctx != nil {
c, err := p.delegate.GetContext(ctx)
if err != nil {
return nil, err
}
return &conn{c}, nil
}
return &conn{c}, nil
return &conn{p.delegate.Get()}, nil
}

func NewPool(delegate *redis.Pool) redsyncredis.Pool {
Expand All @@ -29,27 +32,27 @@ type conn struct {
delegate redis.Conn
}

func (c *conn) Get(_ context.Context, name string) (string, error) {
func (c *conn) Get(name string) (string, error) {
value, err := redis.String(c.delegate.Do("GET", name))
return value, noErrNil(err)
}

func (c *conn) Set(_ context.Context, name string, value string) (bool, error) {
func (c *conn) Set(name string, value string) (bool, error) {
reply, err := redis.String(c.delegate.Do("SET", name, value))
return reply == "OK", noErrNil(err)
}

func (c *conn) SetNX(_ context.Context, name string, value string, expiry time.Duration) (bool, error) {
func (c *conn) SetNX(name string, value string, expiry time.Duration) (bool, error) {
reply, err := redis.String(c.delegate.Do("SET", name, value, "NX", "PX", int(expiry/time.Millisecond)))
return reply == "OK", noErrNil(err)
}

func (c *conn) PTTL(_ context.Context, name string) (time.Duration, error) {
func (c *conn) PTTL(name string) (time.Duration, error) {
expiry, err := redis.Int64(c.delegate.Do("PTTL", name))
return time.Duration(expiry) * time.Millisecond, noErrNil(err)
}

func (c *conn) Eval(_ context.Context, script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) {
func (c *conn) Eval(script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) {
v, err := c.delegate.Do("EVALSHA", args(script, script.Hash, keysAndArgs)...)
if e, ok := err.(redis.Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") {
v, err = c.delegate.Do("EVAL", args(script, script.Src, keysAndArgs)...)
Expand Down
10 changes: 5 additions & 5 deletions redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ type Pool interface {
}

type Conn interface {
Get(ctx context.Context, name string) (string, error)
Set(ctx context.Context, name string, value string) (bool, error)
SetNX(ctx context.Context, name string, value string, expiry time.Duration) (bool, error)
Eval(ctx context.Context, script *Script, keysAndArgs ...interface{}) (interface{}, error)
PTTL(ctx context.Context, name string) (time.Duration, error)
Get(name string) (string, error)
Set(name string, value string) (bool, error)
SetNX(name string, value string, expiry time.Duration) (bool, error)
Eval(script *Script, keysAndArgs ...interface{}) (interface{}, error)
PTTL(name string) (time.Duration, error)
Close() error
}

Expand Down

0 comments on commit 6c7240c

Please sign in to comment.