Skip to content

Commit

Permalink
feat: Switch to go-redis (#37)
Browse files Browse the repository at this point in the history
* feat: fatal if redis is unreachable
* feat: replace redigo with go-redis
  • Loading branch information
mr-karan authored Sep 20, 2023
1 parent 7574dc8 commit 5ea3383
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 152 deletions.
5 changes: 5 additions & 0 deletions cmd/otpgateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ func main() {
ko.UnmarshalWithConf("store.redis", &rc, koanf.UnmarshalConf{Tag: "json"})
app.store = redis.New(rc)

// Check if the Redis server is available by sending a Ping.
if err := app.store.Ping(); err != nil {
log.Fatalf("failed to connect to redis: %v", err)
}

// Compile static templates.
tpl, err := stuffbin.ParseTemplatesGlob(nil, app.fs, "/static/*.html")
if err != nil {
Expand Down
230 changes: 102 additions & 128 deletions internal/store/redis/redis.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
package redis

import (
"context"
"encoding/json"
"fmt"
"time"

redigo "github.com/gomodule/redigo/redis"
"github.com/knadh/otpgateway/v3/internal/store"
"github.com/knadh/otpgateway/v3/pkg/models"
"github.com/redis/go-redis/v9"
)

// Redis implements a Redis Store.
// Redis implements a Redis Store.
type Redis struct {
pool *redigo.Pool
conf Conf
client *redis.Client
conf Conf
}

var (
ctx = context.Background()
)

// Conf contains Redis configuration fields.
type Conf struct {
Host string `json:"host"`
Expand All @@ -27,7 +32,6 @@ type Conf struct {
MaxIdle int `json:"max_idle"`
Timeout time.Duration `json:"timeout"`
KeyPrefix string `json:"key_prefix"`

// If this is set, 'check' and 'close' events will be PUBLISHed to
// to this Redis key (Redis PubSub).
PublishKey string `json:"publish_key"`
Expand All @@ -45,69 +49,56 @@ func New(c Conf) *Redis {
if c.KeyPrefix == "" {
c.KeyPrefix = "OTP"
}
pool := &redigo.Pool{
Wait: true,
MaxActive: c.MaxActive,
MaxIdle: c.MaxIdle,
Dial: func() (redigo.Conn, error) {
c, err := redigo.Dial(
"tcp",
fmt.Sprintf("%s:%d", c.Host, c.Port),
redigo.DialPassword(c.Password),
redigo.DialConnectTimeout(c.Timeout),
redigo.DialReadTimeout(c.Timeout),
redigo.DialWriteTimeout(c.Timeout),
redigo.DialDatabase(c.DB),
)

return c, err
},
}

client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", c.Host, c.Port),
Username: c.Username,
Password: c.Password,
DB: c.DB,
DialTimeout: c.Timeout,
WriteTimeout: c.Timeout,
ReadTimeout: c.Timeout,
})

return &Redis{
conf: c,
pool: pool,
conf: c,
client: client,
}
}

// Ping checks if Redis server is reachable
func (r *Redis) Ping() error {
c := r.pool.Get()
defer c.Close()
_, err := c.Do("PING") // Test redis connection
return err
return r.client.Ping(ctx).Err()
}

// Check checks the attempt count and TTL duration against an ID.
// Passing count=true increments the attempt counter.
func (r *Redis) Check(namespace, id string, counter bool) (models.OTP, error) {
c := r.pool.Get()
defer c.Close()

out, err := r.get(namespace, id, c)
// Retrieve the OTP information.
out, err := r.get(namespace, id)
if err != nil {
return out, err
}
if !counter {
return out, err
}

// Increment attempts.
// Define the key....
key := r.makeKey(namespace, id)
r.begin(c)
c.Send("HINCRBY", key, "attempts", 1)
c.Send("TTL", key)
resp, err := r.end(c)

// Increment attempts and get TTL.
pipe := r.client.TxPipeline()
attempts := pipe.HIncrBy(ctx, key, "attempts", 1)
ttl := pipe.TTL(ctx, key)
_, err = pipe.Exec(ctx)
if err != nil {
return out, err
}

attempts, _ := redigo.Int(resp[0], nil)
out.Attempts = attempts

ttl, _ := redigo.Int64(resp[1], nil)
out.TTL = time.Duration(ttl) * time.Second
out.Attempts = int(attempts.Val())
out.TTL = ttl.Val()

// Publish?
// If there's a configured PublishKey, publish the event.
if r.conf.PublishKey != "" {
b, _ := json.Marshal(out)
e, _ := json.Marshal(event{
Expand All @@ -116,63 +107,67 @@ func (r *Redis) Check(namespace, id string, counter bool) (models.OTP, error) {
ID: id,
Data: json.RawMessage(b),
})
_, _ = c.Do("PUBLISH", r.conf.PublishKey, e)
err := r.client.Publish(ctx, r.conf.PublishKey, e).Err()
if err != nil {
return out, err
}
}

return out, err
return out, nil
}

// Set sets an OTP in the store.
func (r *Redis) Set(namespace, id string, otp models.OTP) (models.OTP, error) {
c := r.pool.Get()
defer c.Close()

// Set the OTP value.
var (
key = r.makeKey(namespace, id)
exp = otp.TTL.Nanoseconds() / int64(time.Millisecond)
)

r.begin(c)
c.Send("HMSET", key,
"otp", otp.OTP,
"to", otp.To,
"channel_description", otp.ChannelDesc,
"address_description", otp.AddressDesc,
"extra", string(otp.Extra),
"provider", otp.Provider,
"closed", false,
"max_attempts", otp.MaxAttempts)
c.Send("HINCRBY", key, "attempts", 1)
c.Send("PEXPIRE", key, exp)

// Flush the commands and get their responses.
// [1] is the number of attempts.
// [3] is the TTL.
resp, err := r.end(c)
key := r.makeKey(namespace, id)
exp := otp.TTL.Milliseconds()

// Create a transaction to execute commands atomically.
txf := func(tx *redis.Tx) error {
_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.HMSet(ctx, key,
"otp", otp.OTP,
"to", otp.To,
"channel_description", otp.ChannelDesc,
"address_description", otp.AddressDesc,
"extra", string(otp.Extra),
"provider", otp.Provider,
"closed", false,
"max_attempts", otp.MaxAttempts)

pipe.HIncrBy(ctx, key, "attempts", 1)
pipe.PExpire(ctx, key, time.Duration(exp)*time.Millisecond)
return nil
})
return err
}

// Watch the key for changes. If the key is modified externally between
// the time of watch and the transaction execution, the transaction will be aborted.
err := r.client.Watch(ctx, txf, key)
if err != nil {
return otp, err
}
attempts, err := redigo.Int(resp[1], nil)

// Retrieve the updated attempts count to update the OTP struct.
attempts, err := r.client.HGet(ctx, key, "attempts").Int()
if err != nil {
return otp, err
}

otp.Attempts = attempts
otp.TTLSeconds = otp.TTL.Seconds()
otp.Namespace = namespace
otp.ID = id

return otp, nil
}

// SetAddress sets (updates) the address on an existing OTP.
func (r *Redis) SetAddress(namespace, id, address string) error {
c := r.pool.Get()
defer c.Close()

// Set the OTP value.
var key = r.makeKey(namespace, id)
key := r.makeKey(namespace, id)

if _, err := c.Do("HSET", key, "to", address); err != nil {
if err := r.client.HSet(ctx, key, "to", address).Err(); err != nil {
return err
}

Expand All @@ -182,10 +177,10 @@ func (r *Redis) SetAddress(namespace, id, address string) error {
// Close closes an OTP and marks it as done (verified).
// After this, the OTP has to expire after a TTL or be deleted.
func (r *Redis) Close(namespace, id string) error {
c := r.pool.Get()
defer c.Close()

_, err := c.Do("HSET", r.makeKey(namespace, id), "closed", true)
// Set the OTP as closed.
if err := r.client.HSet(ctx, r.makeKey(namespace, id), "closed", true).Err(); err != nil {
return err
}

// Publish?
if r.conf.PublishKey != "" {
Expand All @@ -195,36 +190,37 @@ func (r *Redis) Close(namespace, id string) error {
ID: id,
Data: json.RawMessage([]byte(`null`)),
})
_, _ = c.Do("PUBLISH", r.conf.PublishKey, e)
if err := r.client.Publish(ctx, r.conf.PublishKey, e).Err(); err != nil {
return err
}
}

return err
return nil
}

// Delete deletes the OTP saved against a given ID.
func (r *Redis) Delete(namespace, id string) error {
c := r.pool.Get()
defer c.Close()

_, err := c.Do("DEL", r.makeKey(namespace, id))
return err
if err := r.client.Del(ctx, r.makeKey(namespace, id)).Err(); err != nil {
return err
}
return nil
}

// get begins a transaction.
func (r *Redis) get(namespace, id string, c redigo.Conn) (models.OTP, error) {
var (
key = r.makeKey(namespace, id)
out = models.OTP{
Namespace: namespace,
ID: id,
}
)
// makeKey makes the Redis key for the OTP.
func (r *Redis) makeKey(namespace, id string) string {
return fmt.Sprintf("%s:%s:%s", r.conf.KeyPrefix, namespace, id)
}

resp, err := redigo.Values(c.Do("HGETALL", key))
if err != nil {
return out, err
// get retrieves the OTP information from Redis based on the namespace and ID.
func (r *Redis) get(namespace, id string) (models.OTP, error) {
key := r.makeKey(namespace, id)
out := models.OTP{
Namespace: namespace,
ID: id,
}
if err := redigo.ScanStruct(resp, &out); err != nil {

// Retrieve all fields of the hash.
if err := r.client.HGetAll(ctx, key).Scan(&out); err != nil {
return out, err
}

Expand All @@ -233,35 +229,13 @@ func (r *Redis) get(namespace, id string, c redigo.Conn) (models.OTP, error) {
return out, store.ErrNotExist
}

ttl, err := redigo.Int64(c.Do("TTL", key))
// Retrieve TTL.
ttl, err := r.client.TTL(ctx, key).Result()
if err != nil {
return out, err
}

out.TTL = time.Duration(ttl) * time.Second
out.TTLSeconds = out.TTL.Seconds()
out.TTL = ttl
out.TTLSeconds = ttl.Seconds()
return out, nil
}

// begin begins a transaction.
func (r *Redis) begin(c redigo.Conn) error {
return c.Send("MULTI")
}

// end begins a transaction.
func (r *Redis) end(c redigo.Conn) ([]interface{}, error) {
rep, err := redigo.Values(c.Do("EXEC"))

// Check if there are any errors.
for _, r := range rep {
if v, ok := r.(redigo.Error); ok {
return rep, v
}
}
return rep, err
}

// makeKey makes the Redis key for the OTP.
func (r *Redis) makeKey(namespace, id string) string {
return fmt.Sprintf("%s:%s:%s", r.conf.KeyPrefix, namespace, id)
}
Loading

0 comments on commit 5ea3383

Please sign in to comment.