Skip to content

Commit

Permalink
Refactor OTP attempts and add separate counters for generate/validate.
Browse files Browse the repository at this point in the history
  • Loading branch information
shridarpatil authored Jun 7, 2024
1 parent 3e5e12c commit 604d9c9
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 24 deletions.
32 changes: 24 additions & 8 deletions cmd/otpgateway/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func handleSetOTP(w http.ResponseWriter, r *http.Request) {
addressDesc = r.FormValue("address_description")
rawTTL = r.FormValue("ttl")
rawMaxAttempts = r.FormValue("max_attempts")
rawMaxGenerate = r.FormValue("max_generate")
extra = []byte(r.FormValue("extra"))
to = r.FormValue("to")
otpVal = r.FormValue("otp")
Expand Down Expand Up @@ -160,6 +161,16 @@ func handleSetOTP(w http.ResponseWriter, r *http.Request) {
maxAttempts = v
}

maxGenerate := app.constants.OtpMaxGenerate
if rawMaxGenerate != "" {
v, err := strconv.Atoi(rawMaxGenerate)
if err != nil || v < 1 {
sendErrorResponse(w, "Invalid `max_generate` value.", http.StatusBadRequest, nil)
return
}
maxGenerate = v
}

// If there's extra data, make sure it's JSON.
if len(extra) > 0 {
var tmp interface{}
Expand Down Expand Up @@ -194,7 +205,7 @@ func handleSetOTP(w http.ResponseWriter, r *http.Request) {
}

// Check if the OTP attempts have exceeded the quota.
otp, err := app.store.Check(namespace, id, false)
otp, err := app.store.Check(namespace, id, store.CounterNil)
if err != nil && err != store.ErrNotExist {
app.lo.Error("error checking OTP status", "error", err)
sendErrorResponse(w, "Error checking OTP status.", http.StatusBadRequest, nil)
Expand Down Expand Up @@ -224,6 +235,7 @@ func handleSetOTP(w http.ResponseWriter, r *http.Request) {
Provider: provider,
TTL: ttl,
MaxAttempts: maxAttempts,
MaxGenerate: maxGenerate,
})
if err != nil {
app.lo.Error("error setting OTP", "error", err)
Expand Down Expand Up @@ -258,7 +270,7 @@ func handleCheckOTPStatus(w http.ResponseWriter, r *http.Request) {
}

// Check the OTP status.
out, err := app.store.Check(namespace, id, false)
out, err := app.store.Check(namespace, id, store.CounterNil)
if err != nil {
if err == store.ErrNotExist {
sendErrorResponse(w, err.Error(), http.StatusBadRequest, nil)
Expand Down Expand Up @@ -335,10 +347,10 @@ func handleOTPView(w http.ResponseWriter, r *http.Request) {

if action == "" {
// Render the view without incrementing attempts.
out, otpErr = app.store.Check(namespace, id, false)
out, otpErr = app.store.Check(namespace, id, store.CounterNil)
} else if action == actResend {
// Fetch the OTP for resending.
out, otpErr = app.store.Check(namespace, id, true)
out, otpErr = app.store.Check(namespace, id, store.CounterGenerate)
} else {
// Validate the attempt.
out, otpErr = verifyOTP(namespace, id, otp, false, app)
Expand Down Expand Up @@ -425,7 +437,7 @@ func handleGetOTPClosed(w http.ResponseWriter, r *http.Request) {
id = chi.URLParam(r, "id")
)

out, err := app.store.Check(namespace, id, false)
out, err := app.store.Check(namespace, id, store.CounterNil)
if err != nil {
if err == store.ErrNotExist {
sendErrorResponse(w, "Session expired.", http.StatusBadRequest, nil)
Expand All @@ -451,7 +463,7 @@ func handleAddressView(w http.ResponseWriter, r *http.Request) {
to = r.FormValue("to")
)

out, err := app.store.Check(namespace, id, false)
out, err := app.store.Check(namespace, id, store.CounterNil)
if err != nil {
if err == store.ErrNotExist {
app.tpl.ExecuteTemplate(w, "message", webviewTpl{App: app.constants,
Expand Down Expand Up @@ -520,7 +532,7 @@ func handleAddressView(w http.ResponseWriter, r *http.Request) {
// verifyOTP validates an OTP against user input.
func verifyOTP(namespace, id, otp string, deleteOnVerify bool, app *App) (models.OTP, error) {
// Check the OTP.
out, err := app.store.Check(namespace, id, true)
out, err := app.store.Check(namespace, id, store.CounterAttempts)
if err != nil {
if err != store.ErrNotExist {
app.lo.Error("error checking OTP", "error", err)
Expand Down Expand Up @@ -599,7 +611,11 @@ func generateRandomString(totalLen int, chars string) (string, error) {

// isLocked tells if an OTP is locked after exceeding attempts.
func isLocked(otp models.OTP) bool {
if otp.Attempts >= otp.MaxAttempts {
if otp.Attempts > otp.MaxAttempts {
return true
}

if otp.Generate > otp.MaxGenerate {
return true
}
return false
Expand Down
3 changes: 3 additions & 0 deletions cmd/otpgateway/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func init() {
constants: constants{
OtpTTL: 10 * time.Second,
OtpMaxAttempts: 10,
OtpMaxGenerate: 10,
},
store: redis.New(redis.Conf{
Host: rd.Host(),
Expand Down Expand Up @@ -248,6 +249,7 @@ func TestCheckOTPAttempts(t *testing.T) {
p.Set("id", dummyOTPID)
p.Set("otp", dummyOTP)
p.Set("max_attempts", "5")
p.Set("max_generate", "5")
p.Set("to", dummyToAddress)
p.Set("provider", dummyProvider)

Expand Down Expand Up @@ -287,6 +289,7 @@ func TestDeleteOnOTPCheck(t *testing.T) {
p.Set("id", dummyOTPID)
p.Set("otp", dummyOTP)
p.Set("max_attempts", "5")
p.Set("max_generate", "5")
p.Set("to", dummyToAddress)
p.Set("provider", dummyProvider)

Expand Down
1 change: 1 addition & 0 deletions cmd/otpgateway/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func initLogger(debug bool) logf.Logger {
type constants struct {
OtpTTL time.Duration
OtpMaxAttempts int
OtpMaxGenerate int

// Exported to templates.
RootURL string
Expand Down
1 change: 1 addition & 0 deletions cmd/otpgateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func main() {
constants: constants{
OtpTTL: ko.MustDuration("app.otp_ttl") * time.Second,
OtpMaxAttempts: ko.MustInt("app.otp_max_attempts"),
OtpMaxGenerate: ko.MustInt("app.otp_max_generate"),
RootURL: strings.TrimRight(ko.String("app.root_url"), "/"),
LogoURL: ko.String("app.logo_url"),
FaviconURL: ko.String("app.favicon_url"),
Expand Down
32 changes: 24 additions & 8 deletions internal/store/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ func (r *Redis) Ping() error {
}

// Check checks the attempt count and TTL duration against an ID.
// Passing count=true increments the attempt counter.
func (r *Redis) Check(namespace, id string, counter bool) (models.OTP, error) {
// Passing counterKey increments the attempt counter.
func (r *Redis) Check(namespace, id string, counterKey string) (models.OTP, error) {
// Retrieve the OTP information.
out, err := r.get(namespace, id)
if err != nil {
return out, err
}
if !counter {
if counterKey == store.CounterNil {
return out, err
}

Expand All @@ -88,14 +88,22 @@ func (r *Redis) Check(namespace, id string, counter bool) (models.OTP, error) {

// Increment attempts and get TTL.
pipe := r.client.TxPipeline()
attempts := pipe.HIncrBy(ctx, key, "attempts", 1)
attempts := pipe.HIncrBy(ctx, key, counterKey, 1)
ttl := pipe.TTL(ctx, key)
_, err = pipe.Exec(ctx)
if err != nil {
return out, err
}

out.Attempts = int(attempts.Val())
switch counterKey {
case store.CounterAttempts:
out.Attempts = int(attempts.Val())
case store.CounterGenerate:
out.Generate = int(attempts.Val())
default:
return out, store.ErrNotExist
}
// out.Attempts = int(attempts.Val())
out.TTL = ttl.Val()

// If there's a configured PublishKey, publish the event.
Expand Down Expand Up @@ -132,9 +140,11 @@ func (r *Redis) Set(namespace, id string, otp models.OTP) (models.OTP, error) {
"extra", string(otp.Extra),
"provider", otp.Provider,
"closed", false,
"max_attempts", otp.MaxAttempts)
"max_attempts", otp.MaxAttempts,
"max_generate", otp.MaxGenerate)

pipe.HIncrBy(ctx, key, "attempts", 1)
pipe.HIncrBy(ctx, key, store.CounterAttempts, 1)
pipe.HIncrBy(ctx, key, store.CounterGenerate, 1)
pipe.PExpire(ctx, key, time.Duration(exp)*time.Millisecond)
return nil
})
Expand All @@ -149,12 +159,18 @@ func (r *Redis) Set(namespace, id string, otp models.OTP) (models.OTP, error) {
}

// Retrieve the updated attempts count to update the OTP struct.
attempts, err := r.client.HGet(ctx, key, "attempts").Int()
generate, err := r.client.HGet(ctx, key, store.CounterGenerate).Int()
if err != nil {
return otp, err
}

attempts, err := r.client.HGet(ctx, key, store.CounterAttempts).Int()
if err != nil {
return otp, err
}

otp.Attempts = attempts
otp.Generate = generate
otp.TTLSeconds = otp.TTL.Seconds()
otp.Namespace = namespace
otp.ID = id
Expand Down
21 changes: 15 additions & 6 deletions internal/store/redis/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func TestStoreSet(t *testing.T) {
cmp := mockOTP
// Override dynamic values.
cmp.Attempts = resp.Attempts
cmp.Generate = resp.Generate
cmp.TTL = resp.TTL
cmp.TTLSeconds = resp.TTLSeconds
assert.Equal(t, cmp, resp, "Returned OTP doesn't match expected OTP")
Expand All @@ -74,26 +75,34 @@ func TestStoreCheck(t *testing.T) {
rStore := setup(t)

t.Run("no increment", func(t *testing.T) {
o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, false)
o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, store.CounterNil)
assert.NoError(t, err, "Error checking OTP without increment")
assert.Equal(t, 1, o.Attempts, "Unexpected attempt count")
})

t.Run("with increment", func(t *testing.T) {
o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, true)
o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, store.CounterAttempts)
assert.NoError(t, err, "Error checking OTP with increment")
assert.Equal(t, 2, o.Attempts, "Unexpected attempt count after first increment")

o, err = rStore.Check(mockOTP.Namespace, mockOTP.ID, true)
o, err = rStore.Check(mockOTP.Namespace, mockOTP.ID, store.CounterAttempts)
assert.NoError(t, err, "Error checking OTP with second increment")
assert.Equal(t, 3, o.Attempts, "Unexpected attempt count after second increment")

o, err = rStore.Check(mockOTP.Namespace, mockOTP.ID, store.CounterGenerate)
assert.NoError(t, err, "Error checking generate OTP with increment")
assert.Equal(t, 2, o.Generate, "Unexpected generate count after first increment")

o, err = rStore.Check(mockOTP.Namespace, mockOTP.ID, store.CounterGenerate)
assert.NoError(t, err, "Error checking generate OTP with second increment")
assert.Equal(t, 3, o.Generate, "Unexpected generate count after second increment")
})
}

func TestStoreTTL(t *testing.T) {
rStore := setup(t)

o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, false)
o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, store.CounterNil)
assert.NoError(t, err, "Error checking OTP")
assert.Equal(t, mockOTP.TTL, o.TTL, "Returned OTP TTL doesn't match expected TTL")
}
Expand All @@ -104,7 +113,7 @@ func TestStoreClose(t *testing.T) {
err := rStore.Close(mockOTP.Namespace, mockOTP.ID)
assert.NoError(t, err, "Error closing OTP")

o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, false)
o, err := rStore.Check(mockOTP.Namespace, mockOTP.ID, store.CounterNil)
assert.NoError(t, err, "Error checking closed OTP")
assert.True(t, o.Closed, "OTP should be closed but isn't")
}
Expand All @@ -115,6 +124,6 @@ func TestStoreDelete(t *testing.T) {
err := rStore.Delete(mockOTP.Namespace, mockOTP.ID)
assert.NoError(t, err, "Error deleting OTP")

_, err = rStore.Check(mockOTP.Namespace, mockOTP.ID, false)
_, err = rStore.Check(mockOTP.Namespace, mockOTP.ID, store.CounterNil)
assert.Equal(t, store.ErrNotExist, err, "OTP should not exist but it does")
}
8 changes: 7 additions & 1 deletion internal/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import (
// does not exist.
var ErrNotExist = errors.New("the OTP does not exist")

const (
CounterAttempts = "attempts"
CounterGenerate = "generate"
CounterNil = ""
)

// Store represents a storage backend where OTP data is stored.
type Store interface {
// Set sets an OTP against an ID. Every Set() increments the attempts
Expand All @@ -21,7 +27,7 @@ type Store interface {

// Check checks the attempt count and TTL duration against an ID.
// Passing counter=true increments the attempt counter.
Check(namespace, id string, counter bool) (models.OTP, error)
Check(namespace, id string, counterKey string) (models.OTP, error)

// Close closes an OTP and marks it as done (verified).
// After this, the OTP has to expire after a TTL or be deleted.
Expand Down
2 changes: 2 additions & 0 deletions pkg/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ type OTP struct {
OTP string `redis:"otp" json:"otp"`
MaxAttempts int `redis:"max_attempts" json:"max_attempts"`
Attempts int `redis:"attempts" json:"attempts"`
Generate int `redis:"generate" json:"generate"`
MaxGenerate int `redis:"max_generate" json:"max_generate"`
Closed bool `redis:"closed" json:"closed"`
TTL time.Duration `redis:"-" json:"-"`
TTLSeconds float64 `redis:"-" json:"ttl"`
Expand Down
1 change: 0 additions & 1 deletion static/otp.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ <h1>{{ .ChannelName }} verification</h1>
{{ if .Message }}
<p class="error">{{ .Message }}</p>
{{ end }}

<div class="stats">
<span class="attempts">
<span class="pulse">{{ .OTP.Attempts }}</span> / {{ .OTP.MaxAttempts }} attempts
Expand Down

0 comments on commit 604d9c9

Please sign in to comment.