From 92f869c1e211b79e50d362bf9a1b5d04e964b108 Mon Sep 17 00:00:00 2001 From: Chris Stockton Date: Tue, 24 Sep 2024 11:29:21 -0700 Subject: [PATCH] fix: few fixes to the existing email rate limiter changes by Stojan Summary of changes: - I replaced the existing rate limitier made by Stojan with the `x/rate/limit` package from golang.org while trying to preserve the same behavior. - Fixed the tests that are failing with a small change in the helper function `setupAPIForTestWithCallback`. - Updated the call sites using limiters (mail.go, phone.go) - Added some basic test cases along with an example to help visualize rate limits. Some small notes: - Setting the "Burst" value a little higher could be a consideration if the default of 1 is too restrictive. Adding Burst to conf.Rate for better control of the Burst is another option. - Using a value such as 100/24h is equivelant in functionality to the expression 1/14m, though slightly less clear. If the intent is to not limit the rate, but impose a _quota_ of 100 per 24 hours we may want to add some additional changes. --- go.mod | 2 +- go.sum | 2 + internal/api/api.go | 12 ++- internal/api/api_test.go | 5 +- internal/api/mail.go | 2 +- internal/api/phone.go | 2 +- internal/api/ratelimits.go | 86 ++++----------------- internal/api/ratelimits_test.go | 128 ++++++++++++++++++++++++++++++++ 8 files changed, 158 insertions(+), 81 deletions(-) create mode 100644 internal/api/ratelimits_test.go diff --git a/go.mod b/go.mod index 3e4af443d..7d4e6229e 100644 --- a/go.mod +++ b/go.mod @@ -146,7 +146,7 @@ require ( golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect golang.org/x/text v0.16.0 // indirect - golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect + golang.org/x/time v0.6.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/grpc v1.63.2 // indirect google.golang.org/protobuf v1.33.0 // indirect diff --git a/go.sum b/go.sum index b30c1baaf..97d1828dd 100644 --- a/go.sum +++ b/go.sum @@ -490,6 +490,8 @@ golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.0.0-20160926182426-711ca1cb8763/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20220411224347-583f2d630306 h1:+gHMid33q6pen7kv9xvT+JRinntgeXO2AeZVd0AWD3w= golang.org/x/time v0.0.0-20220411224347-583f2d630306/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= diff --git a/internal/api/api.go b/internal/api/api.go index 14659f091..275d000f8 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -17,6 +17,7 @@ import ( "github.com/supabase/auth/internal/storage" "github.com/supabase/auth/internal/utilities" "github.com/supabase/hibp" + "golang.org/x/time/rate" ) const ( @@ -38,8 +39,8 @@ type API struct { // overrideTime can be used to override the clock used by handlers. Should only be used in tests! overrideTime func() time.Time - emailRateLimiter *SimpleRateLimiter - smsRateLimiter *SimpleRateLimiter + emailRateLimiter *rate.Limiter + smsRateLimiter *rate.Limiter } func (a *API) Now() time.Time { @@ -72,11 +73,8 @@ func (a *API) deprecationNotices() { // NewAPIWithVersion creates a new REST API using the specified version func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string) *API { api := &API{config: globalConfig, db: db, version: version} - - now := time.Now() - - api.emailRateLimiter = NewSimpleRateLimiter(now, globalConfig.RateLimitEmailSent.DefaultOverTime(time.Hour)) - api.smsRateLimiter = NewSimpleRateLimiter(now, globalConfig.RateLimitSmsSent.DefaultOverTime(time.Hour)) + api.emailRateLimiter = newRateLimiter(globalConfig.RateLimitEmailSent) + api.smsRateLimiter = newRateLimiter(globalConfig.RateLimitSmsSent) if api.config.Password.HIBP.Enabled { httpClient := &http.Client{ diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 87639a09c..8c7a26c28 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -45,7 +45,10 @@ func setupAPIForTestWithCallback(cb func(*conf.GlobalConfiguration, *storage.Con cb(nil, conn) } - return NewAPIWithVersion(config, conn, apiTestVersion), config, nil + a := NewAPIWithVersion(config, conn, apiTestVersion) + a.smsRateLimiter = newUnlimitedLimiter() + a.emailRateLimiter = newUnlimitedLimiter() + return a, config, nil } func TestEmailEnabledByDefault(t *testing.T) { diff --git a/internal/api/mail.go b/internal/api/mail.go index 280ea0e22..3fe7cdd05 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -594,7 +594,7 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, referrerURL := utilities.GetReferrer(r, config) externalURL := getExternalHost(ctx) - if ok := a.emailRateLimiter.Increment(1); !ok { + if ok := a.emailRateLimiter.Allow(); !ok { emailRateLimitCounter.Add( ctx, 1, diff --git a/internal/api/phone.go b/internal/api/phone.go index 683a429d4..2a1659440 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -86,7 +86,7 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use // not using test OTPs if otp == "" { - if ok := a.smsRateLimiter.Increment(1); !ok { + if ok := a.smsRateLimiter.Allow(); !ok { return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") } diff --git a/internal/api/ratelimits.go b/internal/api/ratelimits.go index 77904a492..fa1747e7d 100644 --- a/internal/api/ratelimits.go +++ b/internal/api/ratelimits.go @@ -1,78 +1,24 @@ package api import ( - "fmt" - "sync/atomic" - "time" - "github.com/supabase/auth/internal/conf" + "golang.org/x/time/rate" ) -// SimpleRateLimiter holds a rate limiter that implements a token-bucket -// algorithm. Rate.OverTime is the duration at which the bucket is filled, and -// Rate.Events is the number of tokens in the bucket. -// -// Internally it uses an atomically increasing counter that resets to 0 on -// every OverTime tick. -// -// You should always use NewSimpleRateLimiter to create a new one. -type SimpleRateLimiter struct { - Rate conf.Rate - - ticker *time.Ticker - counter uint64 -} - -// NewSimpleRateLimiter creates a new rate limiter starting at the specified -// time and with the specified Rate. +// newRateLimiter returns a rate limiter configured using the given conf.Rate. // -// Initially the bucket is filled with a proprotion of the Rate.Events -// depending on how close to the Rate.OverTime tick it has been crated. This is -// one way of making sure that server restarts do not give out a too big of a -// rate limit, as the counter is reset. -func NewSimpleRateLimiter(now time.Time, rate conf.Rate) *SimpleRateLimiter { - r := &SimpleRateLimiter{ - Rate: rate, - } - - counterStartedAt := now.Truncate(rate.OverTime) - counterResetsAt := counterStartedAt.Add(rate.OverTime) - - proRate := float64(counterStartedAt.Sub(now).Milliseconds()) / float64(rate.OverTime.Milliseconds()) - - r.counter = uint64(rate.Events * proRate) - r.ticker = time.NewTicker(counterResetsAt.Sub(now)) - - go r.fillBucket() - - return r -} - -func (r *SimpleRateLimiter) Increment(events uint64) bool { - fmt.Printf("@@@@@@@@@@@@@@@@@@@@@@@ %d %f\n", r.counter, r.Rate.Events) - return atomic.AddUint64(&r.counter, events) < uint64(r.Rate.Events) -} - -func (r *SimpleRateLimiter) fillBucket() { - if _, ok := <-r.ticker.C; !ok { - return - } - - // reset ticker to start ticking at the OverTime rate, as it was - // initially set up to tick at the next aligned OverTime event - r.ticker.Reset(r.Rate.OverTime) - - // reset counter - atomic.StoreUint64(&r.counter, 0) - - // then keep resetting at regular OverTime intervals - for range r.ticker.C { - atomic.StoreUint64(&r.counter, 0) - } -} - -func (r *SimpleRateLimiter) Close() { - if r.ticker != nil { - r.ticker.Stop() - } +// The returned *rate.Limiter will be configured with a token bucket containing +// a single token, which will fill up at a rate of r. For example to allow 100 +// events every 24 hours. This will fill a token bucket approximately once every +// 864 seconds (14.4 minutes). See Example_newRateLimiter for a visualization. +func newRateLimiter(r conf.Rate) *rate.Limiter { + // The rate limiter deals in events per second. + eps := r.EventsPerSecond() + const burst = 1 + + // NewLimiter will have an initial token bucket of size `burst`. It will + // be refilled at a rate of `eps` indefinitely. Note that the expression + // 100 / 24h is roughly equivelant to the expression 1 / 15m. The 100 is + // a rate, not a quota. + return rate.NewLimiter(rate.Limit(eps), burst) } diff --git a/internal/api/ratelimits_test.go b/internal/api/ratelimits_test.go new file mode 100644 index 000000000..d125872bf --- /dev/null +++ b/internal/api/ratelimits_test.go @@ -0,0 +1,128 @@ +package api + +import ( + "fmt" + "testing" + "time" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/time/rate" +) + +func newUnlimitedLimiter() *rate.Limiter { + return rate.NewLimiter(rate.Inf, 0) +} + +func Example_newRateLimiter() { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + { + cfg := conf.Rate{Events: 100, OverTime: time.Hour * 24} + rl := newRateLimiter(cfg) + cur := now + for i := 0; i < 61; i++ { + fmt.Printf("%-5v @ %v\n", rl.AllowN(cur, 1), cur) + cur = cur.Add(time.Minute) + } + } + + // Output: + // true @ 2024-09-24 10:00:00 +0000 UTC + // false @ 2024-09-24 10:01:00 +0000 UTC + // false @ 2024-09-24 10:02:00 +0000 UTC + // false @ 2024-09-24 10:03:00 +0000 UTC + // false @ 2024-09-24 10:04:00 +0000 UTC + // false @ 2024-09-24 10:05:00 +0000 UTC + // false @ 2024-09-24 10:06:00 +0000 UTC + // false @ 2024-09-24 10:07:00 +0000 UTC + // false @ 2024-09-24 10:08:00 +0000 UTC + // false @ 2024-09-24 10:09:00 +0000 UTC + // false @ 2024-09-24 10:10:00 +0000 UTC + // false @ 2024-09-24 10:11:00 +0000 UTC + // false @ 2024-09-24 10:12:00 +0000 UTC + // false @ 2024-09-24 10:13:00 +0000 UTC + // false @ 2024-09-24 10:14:00 +0000 UTC + // true @ 2024-09-24 10:15:00 +0000 UTC + // false @ 2024-09-24 10:16:00 +0000 UTC + // false @ 2024-09-24 10:17:00 +0000 UTC + // false @ 2024-09-24 10:18:00 +0000 UTC + // false @ 2024-09-24 10:19:00 +0000 UTC + // false @ 2024-09-24 10:20:00 +0000 UTC + // false @ 2024-09-24 10:21:00 +0000 UTC + // false @ 2024-09-24 10:22:00 +0000 UTC + // false @ 2024-09-24 10:23:00 +0000 UTC + // false @ 2024-09-24 10:24:00 +0000 UTC + // false @ 2024-09-24 10:25:00 +0000 UTC + // false @ 2024-09-24 10:26:00 +0000 UTC + // false @ 2024-09-24 10:27:00 +0000 UTC + // false @ 2024-09-24 10:28:00 +0000 UTC + // false @ 2024-09-24 10:29:00 +0000 UTC + // true @ 2024-09-24 10:30:00 +0000 UTC + // false @ 2024-09-24 10:31:00 +0000 UTC + // false @ 2024-09-24 10:32:00 +0000 UTC + // false @ 2024-09-24 10:33:00 +0000 UTC + // false @ 2024-09-24 10:34:00 +0000 UTC + // false @ 2024-09-24 10:35:00 +0000 UTC + // false @ 2024-09-24 10:36:00 +0000 UTC + // false @ 2024-09-24 10:37:00 +0000 UTC + // false @ 2024-09-24 10:38:00 +0000 UTC + // false @ 2024-09-24 10:39:00 +0000 UTC + // false @ 2024-09-24 10:40:00 +0000 UTC + // false @ 2024-09-24 10:41:00 +0000 UTC + // false @ 2024-09-24 10:42:00 +0000 UTC + // false @ 2024-09-24 10:43:00 +0000 UTC + // false @ 2024-09-24 10:44:00 +0000 UTC + // true @ 2024-09-24 10:45:00 +0000 UTC + // false @ 2024-09-24 10:46:00 +0000 UTC + // false @ 2024-09-24 10:47:00 +0000 UTC + // false @ 2024-09-24 10:48:00 +0000 UTC + // false @ 2024-09-24 10:49:00 +0000 UTC + // false @ 2024-09-24 10:50:00 +0000 UTC + // false @ 2024-09-24 10:51:00 +0000 UTC + // false @ 2024-09-24 10:52:00 +0000 UTC + // false @ 2024-09-24 10:53:00 +0000 UTC + // false @ 2024-09-24 10:54:00 +0000 UTC + // false @ 2024-09-24 10:55:00 +0000 UTC + // false @ 2024-09-24 10:56:00 +0000 UTC + // false @ 2024-09-24 10:57:00 +0000 UTC + // false @ 2024-09-24 10:58:00 +0000 UTC + // false @ 2024-09-24 10:59:00 +0000 UTC + // true @ 2024-09-24 11:00:00 +0000 UTC + +} + +func TestNewRateLimiter(t *testing.T) { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + + type event struct { + ok bool + at time.Time + } + cases := []struct { + cfg conf.Rate + now time.Time + evts []event + }{ + { + cfg: conf.Rate{Events: 100, OverTime: time.Hour * 24}, + now: now, + evts: []event{ + {true, now}, + {false, now.Add(time.Minute)}, + {false, now.Add(time.Minute)}, + {false, now.Add(time.Minute * 14)}, + {true, now.Add(time.Minute * 15)}, + {false, now.Add(time.Minute * 16)}, + {false, now.Add(time.Minute * 17)}, + {true, now.Add(time.Minute * 30)}, + }, + }, + } + for _, tc := range cases { + rl := newRateLimiter(tc.cfg) + for _, evt := range tc.evts { + if exp, got := evt.ok, rl.AllowN(evt.at, 1); exp != got { + t.Fatalf("exp AllowN(%v, 1) to be %v; got %v", evt.at, exp, got) + } + } + } +}