diff --git a/cmd/otpgateway/main.go b/cmd/otpgateway/main.go index 3c87666..248fedf 100644 --- a/cmd/otpgateway/main.go +++ b/cmd/otpgateway/main.go @@ -58,6 +58,11 @@ func main() { ko.UnmarshalWithConf("store.redis", &rc, koanf.UnmarshalConf{Tag: "json"}) app.store = redis.New(rc) + // Check if the Redis server is available by sending a Ping. + if err := app.store.Ping(); err != nil { + log.Fatalf("failed to connect to redis: %v", err) + } + // Compile static templates. tpl, err := stuffbin.ParseTemplatesGlob(nil, app.fs, "/static/*.html") if err != nil { diff --git a/internal/store/redis/redis.go b/internal/store/redis/redis.go index 5fe4105..8276f0f 100644 --- a/internal/store/redis/redis.go +++ b/internal/store/redis/redis.go @@ -1,21 +1,26 @@ package redis import ( + "context" "encoding/json" "fmt" "time" - redigo "github.com/gomodule/redigo/redis" "github.com/knadh/otpgateway/v3/internal/store" "github.com/knadh/otpgateway/v3/pkg/models" + "github.com/redis/go-redis/v9" ) -// Redis implements a Redis Store. +// Redis implements a Redis Store. type Redis struct { - pool *redigo.Pool - conf Conf + client *redis.Client + conf Conf } +var ( + ctx = context.Background() +) + // Conf contains Redis configuration fields. type Conf struct { Host string `json:"host"` @@ -27,7 +32,6 @@ type Conf struct { MaxIdle int `json:"max_idle"` Timeout time.Duration `json:"timeout"` KeyPrefix string `json:"key_prefix"` - // If this is set, 'check' and 'close' events will be PUBLISHed to // to this Redis key (Redis PubSub). PublishKey string `json:"publish_key"` @@ -45,45 +49,33 @@ func New(c Conf) *Redis { if c.KeyPrefix == "" { c.KeyPrefix = "OTP" } - pool := &redigo.Pool{ - Wait: true, - MaxActive: c.MaxActive, - MaxIdle: c.MaxIdle, - Dial: func() (redigo.Conn, error) { - c, err := redigo.Dial( - "tcp", - fmt.Sprintf("%s:%d", c.Host, c.Port), - redigo.DialPassword(c.Password), - redigo.DialConnectTimeout(c.Timeout), - redigo.DialReadTimeout(c.Timeout), - redigo.DialWriteTimeout(c.Timeout), - redigo.DialDatabase(c.DB), - ) - - return c, err - }, - } + + client := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", c.Host, c.Port), + Username: c.Username, + Password: c.Password, + DB: c.DB, + DialTimeout: c.Timeout, + WriteTimeout: c.Timeout, + ReadTimeout: c.Timeout, + }) + return &Redis{ - conf: c, - pool: pool, + conf: c, + client: client, } } // Ping checks if Redis server is reachable func (r *Redis) Ping() error { - c := r.pool.Get() - defer c.Close() - _, err := c.Do("PING") // Test redis connection - return err + return r.client.Ping(ctx).Err() } // Check checks the attempt count and TTL duration against an ID. // Passing count=true increments the attempt counter. func (r *Redis) Check(namespace, id string, counter bool) (models.OTP, error) { - c := r.pool.Get() - defer c.Close() - - out, err := r.get(namespace, id, c) + // Retrieve the OTP information. + out, err := r.get(namespace, id) if err != nil { return out, err } @@ -91,23 +83,22 @@ func (r *Redis) Check(namespace, id string, counter bool) (models.OTP, error) { return out, err } - // Increment attempts. + // Define the key.... key := r.makeKey(namespace, id) - r.begin(c) - c.Send("HINCRBY", key, "attempts", 1) - c.Send("TTL", key) - resp, err := r.end(c) + + // Increment attempts and get TTL. + pipe := r.client.TxPipeline() + attempts := pipe.HIncrBy(ctx, key, "attempts", 1) + ttl := pipe.TTL(ctx, key) + _, err = pipe.Exec(ctx) if err != nil { return out, err } - attempts, _ := redigo.Int(resp[0], nil) - out.Attempts = attempts - - ttl, _ := redigo.Int64(resp[1], nil) - out.TTL = time.Duration(ttl) * time.Second + out.Attempts = int(attempts.Val()) + out.TTL = ttl.Val() - // Publish? + // If there's a configured PublishKey, publish the event. if r.conf.PublishKey != "" { b, _ := json.Marshal(out) e, _ := json.Marshal(event{ @@ -116,63 +107,67 @@ func (r *Redis) Check(namespace, id string, counter bool) (models.OTP, error) { ID: id, Data: json.RawMessage(b), }) - _, _ = c.Do("PUBLISH", r.conf.PublishKey, e) + err := r.client.Publish(ctx, r.conf.PublishKey, e).Err() + if err != nil { + return out, err + } } - return out, err + return out, nil } -// Set sets an OTP in the store. func (r *Redis) Set(namespace, id string, otp models.OTP) (models.OTP, error) { - c := r.pool.Get() - defer c.Close() - // Set the OTP value. - var ( - key = r.makeKey(namespace, id) - exp = otp.TTL.Nanoseconds() / int64(time.Millisecond) - ) - - r.begin(c) - c.Send("HMSET", key, - "otp", otp.OTP, - "to", otp.To, - "channel_description", otp.ChannelDesc, - "address_description", otp.AddressDesc, - "extra", string(otp.Extra), - "provider", otp.Provider, - "closed", false, - "max_attempts", otp.MaxAttempts) - c.Send("HINCRBY", key, "attempts", 1) - c.Send("PEXPIRE", key, exp) - - // Flush the commands and get their responses. - // [1] is the number of attempts. - // [3] is the TTL. - resp, err := r.end(c) + key := r.makeKey(namespace, id) + exp := otp.TTL.Milliseconds() + + // Create a transaction to execute commands atomically. + txf := func(tx *redis.Tx) error { + _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.HMSet(ctx, key, + "otp", otp.OTP, + "to", otp.To, + "channel_description", otp.ChannelDesc, + "address_description", otp.AddressDesc, + "extra", string(otp.Extra), + "provider", otp.Provider, + "closed", false, + "max_attempts", otp.MaxAttempts) + + pipe.HIncrBy(ctx, key, "attempts", 1) + pipe.PExpire(ctx, key, time.Duration(exp)*time.Millisecond) + return nil + }) + return err + } + + // Watch the key for changes. If the key is modified externally between + // the time of watch and the transaction execution, the transaction will be aborted. + err := r.client.Watch(ctx, txf, key) if err != nil { return otp, err } - attempts, err := redigo.Int(resp[1], nil) + + // Retrieve the updated attempts count to update the OTP struct. + attempts, err := r.client.HGet(ctx, key, "attempts").Int() if err != nil { return otp, err } + otp.Attempts = attempts otp.TTLSeconds = otp.TTL.Seconds() otp.Namespace = namespace otp.ID = id + return otp, nil } // SetAddress sets (updates) the address on an existing OTP. func (r *Redis) SetAddress(namespace, id, address string) error { - c := r.pool.Get() - defer c.Close() - // Set the OTP value. - var key = r.makeKey(namespace, id) + key := r.makeKey(namespace, id) - if _, err := c.Do("HSET", key, "to", address); err != nil { + if err := r.client.HSet(ctx, key, "to", address).Err(); err != nil { return err } @@ -182,10 +177,10 @@ func (r *Redis) SetAddress(namespace, id, address string) error { // Close closes an OTP and marks it as done (verified). // After this, the OTP has to expire after a TTL or be deleted. func (r *Redis) Close(namespace, id string) error { - c := r.pool.Get() - defer c.Close() - - _, err := c.Do("HSET", r.makeKey(namespace, id), "closed", true) + // Set the OTP as closed. + if err := r.client.HSet(ctx, r.makeKey(namespace, id), "closed", true).Err(); err != nil { + return err + } // Publish? if r.conf.PublishKey != "" { @@ -195,36 +190,37 @@ func (r *Redis) Close(namespace, id string) error { ID: id, Data: json.RawMessage([]byte(`null`)), }) - _, _ = c.Do("PUBLISH", r.conf.PublishKey, e) + if err := r.client.Publish(ctx, r.conf.PublishKey, e).Err(); err != nil { + return err + } } - return err + return nil } // Delete deletes the OTP saved against a given ID. func (r *Redis) Delete(namespace, id string) error { - c := r.pool.Get() - defer c.Close() - - _, err := c.Do("DEL", r.makeKey(namespace, id)) - return err + if err := r.client.Del(ctx, r.makeKey(namespace, id)).Err(); err != nil { + return err + } + return nil } -// get begins a transaction. -func (r *Redis) get(namespace, id string, c redigo.Conn) (models.OTP, error) { - var ( - key = r.makeKey(namespace, id) - out = models.OTP{ - Namespace: namespace, - ID: id, - } - ) +// makeKey makes the Redis key for the OTP. +func (r *Redis) makeKey(namespace, id string) string { + return fmt.Sprintf("%s:%s:%s", r.conf.KeyPrefix, namespace, id) +} - resp, err := redigo.Values(c.Do("HGETALL", key)) - if err != nil { - return out, err +// get retrieves the OTP information from Redis based on the namespace and ID. +func (r *Redis) get(namespace, id string) (models.OTP, error) { + key := r.makeKey(namespace, id) + out := models.OTP{ + Namespace: namespace, + ID: id, } - if err := redigo.ScanStruct(resp, &out); err != nil { + + // Retrieve all fields of the hash. + if err := r.client.HGetAll(ctx, key).Scan(&out); err != nil { return out, err } @@ -233,35 +229,13 @@ func (r *Redis) get(namespace, id string, c redigo.Conn) (models.OTP, error) { return out, store.ErrNotExist } - ttl, err := redigo.Int64(c.Do("TTL", key)) + // Retrieve TTL. + ttl, err := r.client.TTL(ctx, key).Result() if err != nil { return out, err } - out.TTL = time.Duration(ttl) * time.Second - out.TTLSeconds = out.TTL.Seconds() + out.TTL = ttl + out.TTLSeconds = ttl.Seconds() return out, nil } - -// begin begins a transaction. -func (r *Redis) begin(c redigo.Conn) error { - return c.Send("MULTI") -} - -// end begins a transaction. -func (r *Redis) end(c redigo.Conn) ([]interface{}, error) { - rep, err := redigo.Values(c.Do("EXEC")) - - // Check if there are any errors. - for _, r := range rep { - if v, ok := r.(redigo.Error); ok { - return rep, v - } - } - return rep, err -} - -// makeKey makes the Redis key for the OTP. -func (r *Redis) makeKey(namespace, id string) string { - return fmt.Sprintf("%s:%s:%s", r.conf.KeyPrefix, namespace, id) -} diff --git a/internal/store/redis/redis_test.go b/internal/store/redis/redis_test.go index a054a5b..4a84b75 100644 --- a/internal/store/redis/redis_test.go +++ b/internal/store/redis/redis_test.go @@ -10,6 +10,7 @@ import ( "github.com/knadh/otpgateway/v3/internal/store" "github.com/knadh/otpgateway/v3/pkg/models" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -43,63 +44,77 @@ func init() { }) } -func reset(t *testing.T) { +func setup(t *testing.T) *Redis { rdis.FlushDB() _, err := rStore.Set(mockOTP.Namespace, mockOTP.ID, mockOTP) - assert.Equal(t, nil, err, "error setting OTP") + require.NoError(t, err, "Failed to set up test OTP") + + t.Cleanup(func() { + rdis.FlushDB() + }) + + return rStore } func TestStoreSet(t *testing.T) { - rdis.FlushDB() + rStore := setup(t) + resp, err := rStore.Set(mockOTP.Namespace, mockOTP.ID, mockOTP) - assert.Equal(t, nil, err, "error setting OTP") + assert.NoError(t, err, "Error setting OTP") cmp := mockOTP // Override dynamic values. cmp.Attempts = resp.Attempts cmp.TTL = resp.TTL cmp.TTLSeconds = resp.TTLSeconds - assert.Equal(t, cmp, resp, "OTP doesn't match") + assert.Equal(t, cmp, resp, "Returned OTP doesn't match expected OTP") } func TestStoreCheck(t *testing.T) { - reset(t) + rStore := setup(t) - // Don't increment. - o, _ := rStore.Check(mockOTP.Namespace, mockOTP.ID, false) - assert.Equal(t, 1, o.Attempts, "attempts incorrectly incremented") + t.Run("no increment", func(t *testing.T) { + o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, false) + assert.NoError(t, err, "Error checking OTP without increment") + assert.Equal(t, 1, o.Attempts, "Unexpected attempt count") + }) - // Increment. - o, _ = rStore.Check(mockOTP.Namespace, mockOTP.ID, true) - assert.Equal(t, 2, o.Attempts, "attempts didn't increment") + t.Run("with increment", func(t *testing.T) { + o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, true) + assert.NoError(t, err, "Error checking OTP with increment") + assert.Equal(t, 2, o.Attempts, "Unexpected attempt count after first increment") - o, _ = rStore.Check(mockOTP.Namespace, mockOTP.ID, true) - assert.Equal(t, 3, o.Attempts, "attempts didn't increment") + o, err = rStore.Check(mockOTP.Namespace, mockOTP.ID, true) + assert.NoError(t, err, "Error checking OTP with second increment") + assert.Equal(t, 3, o.Attempts, "Unexpected attempt count after second increment") + }) } func TestStoreTTL(t *testing.T) { - reset(t) + rStore := setup(t) - // Check if the OTP has expired. o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, false) - assert.Equal(t, nil, err, "error checking OTP") - assert.Equal(t, mockOTP.TTL, o.TTL, "TTL doesn't match") + assert.NoError(t, err, "Error checking OTP") + assert.Equal(t, mockOTP.TTL, o.TTL, "Returned OTP TTL doesn't match expected TTL") } + func TestStoreClose(t *testing.T) { - reset(t) + rStore := setup(t) err := rStore.Close(mockOTP.Namespace, mockOTP.ID) - assert.Equal(t, nil, err, "error closing OTP") + assert.NoError(t, err, "Error closing OTP") o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, false) - assert.Equal(t, true, o.Closed, "OTP didn't close") + assert.NoError(t, err, "Error checking closed OTP") + assert.True(t, o.Closed, "OTP should be closed but isn't") } + func TestStoreDelete(t *testing.T) { - reset(t) + rStore := setup(t) err := rStore.Delete(mockOTP.Namespace, mockOTP.ID) - assert.Equal(t, nil, err, "error deleting OTP") + assert.NoError(t, err, "Error deleting OTP") _, err = rStore.Check(mockOTP.Namespace, mockOTP.ID, false) - assert.Equal(t, store.ErrNotExist, err, "OTP wasn't deleted") + assert.Equal(t, store.ErrNotExist, err, "OTP should not exist but it does") }