diff --git a/internal/api/api.go b/internal/api/api.go index 49b810696..14659f091 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -37,6 +37,9 @@ type API struct { // overrideTime can be used to override the clock used by handlers. Should only be used in tests! overrideTime func() time.Time + + emailRateLimiter *SimpleRateLimiter + smsRateLimiter *SimpleRateLimiter } func (a *API) Now() time.Time { @@ -70,6 +73,11 @@ func (a *API) deprecationNotices() { func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string) *API { api := &API{config: globalConfig, db: db, version: version} + now := time.Now() + + api.emailRateLimiter = NewSimpleRateLimiter(now, globalConfig.RateLimitEmailSent.DefaultOverTime(time.Hour)) + api.smsRateLimiter = NewSimpleRateLimiter(now, globalConfig.RateLimitSmsSent.DefaultOverTime(time.Hour)) + if api.config.Password.HIBP.Enabled { httpClient := &http.Client{ // all HIBP API requests should finish quickly to avoid @@ -133,9 +141,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Get("/authorize", api.ExternalProviderRedirect) - sharedLimiter := api.limitEmailOrPhoneSentHandler() - r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite) - r.With(sharedLimiter).With(api.verifyCaptcha).Route("/signup", func(r *router) { + r.With(api.requireAdminCredentials).Post("/invite", api.Invite) + r.With(api.verifyCaptcha).Route("/signup", func(r *router) { // rate limit per hour limitAnonymousSignIns := tollbooth.NewLimiter(api.config.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{ DefaultExpirationTTL: time.Hour, @@ -164,10 +171,6 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne if _, err := api.limitHandler(limitSignups)(w, r); err != nil { return err } - // apply shared rate limiting on email / phone - if _, err := sharedLimiter(w, r); err != nil { - return err - } return api.Signup(w, r) }) }) @@ -176,28 +179,28 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ DefaultExpirationTTL: time.Hour, }).SetBurst(30), - )).With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover) + )).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover) r.With(api.limitHandler( // Allow requests at the specified rate per 5 minutes tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ DefaultExpirationTTL: time.Hour, }).SetBurst(30), - )).With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend) + )).With(api.verifyCaptcha).Post("/resend", api.Resend) r.With(api.limitHandler( // Allow requests at the specified rate per 5 minutes tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ DefaultExpirationTTL: time.Hour, }).SetBurst(30), - )).With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink) + )).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink) r.With(api.limitHandler( // Allow requests at the specified rate per 5 minutes tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ DefaultExpirationTTL: time.Hour, }).SetBurst(30), - )).With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp) + )).With(api.verifyCaptcha).Post("/otp", api.Otp) r.With(api.limitHandler( // Allow requests at the specified rate per 5 minutes. @@ -229,7 +232,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ DefaultExpirationTTL: time.Hour, }).SetBurst(30), - )).With(sharedLimiter).Put("/", api.UserUpdate) + )).Put("/", api.UserUpdate) r.Route("/identities", func(r *router) { r.Use(api.requireManualLinkingEnabled) diff --git a/internal/api/context.go b/internal/api/context.go index ff01e7120..bf566ad3e 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -32,7 +32,6 @@ const ( ssoProviderKey = contextKey("sso_provider") externalHostKey = contextKey("external_host") flowStateKey = contextKey("flow_state_id") - sharedLimiterKey = contextKey("shared_limiter") ) // withToken adds the JWT token to the context. @@ -248,15 +247,3 @@ type SharedLimiter struct { EmailLimiter *limiter.Limiter PhoneLimiter *limiter.Limiter } - -func withLimiter(ctx context.Context, limiter *SharedLimiter) context.Context { - return context.WithValue(ctx, sharedLimiterKey, limiter) -} - -func getLimiter(ctx context.Context) *SharedLimiter { - obj := ctx.Value(sharedLimiterKey) - if obj == nil { - return nil - } - return obj.(*SharedLimiter) -} diff --git a/internal/api/mail.go b/internal/api/mail.go index c82214918..280ea0e22 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -6,7 +6,6 @@ import ( "strings" "time" - "github.com/didip/tollbooth/v5" "github.com/supabase/auth/internal/hooks" mail "github.com/supabase/auth/internal/mailer" "go.opentelemetry.io/otel/attribute" @@ -595,16 +594,13 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, referrerURL := utilities.GetReferrer(r, config) externalURL := getExternalHost(ctx) - // apply rate limiting before the email is sent out - if limiter := getLimiter(ctx); limiter != nil { - if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil { - emailRateLimitCounter.Add( - ctx, - 1, - metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))), - ) - return EmailRateLimitExceeded - } + if ok := a.emailRateLimiter.Increment(1); !ok { + emailRateLimitCounter.Add( + ctx, + 1, + metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))), + ) + return EmailRateLimitExceeded } if config.Hook.SendEmail.Enabled { diff --git a/internal/api/middleware.go b/internal/api/middleware.go index d2070e3ac..7103c29a6 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -76,36 +76,6 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { } } -func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { - // limit per hour - smsFreq := a.config.RateLimitSmsSent / (60 * 60) - - emailLimiter := a.config.RateLimitEmailSent.DivideIfDefaultDuration(60 * 60).CreateLimiter().SetBurst(int(a.config.RateLimitEmailSent.Events)).SetMethods([]string{"PUT", "POST"}) - - phoneLimiter := tollbooth.NewLimiter(smsFreq, &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(int(a.config.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"}) - - return func(w http.ResponseWriter, req *http.Request) (context.Context, error) { - c := req.Context() - config := a.config - shouldRateLimitEmail := config.External.Email.Enabled && !config.Mailer.Autoconfirm - shouldRateLimitPhone := config.External.Phone.Enabled && !config.Sms.Autoconfirm - - if shouldRateLimitEmail || shouldRateLimitPhone { - if req.Method == "PUT" || req.Method == "POST" { - // store rate limiter in request context - c = withLimiter(c, &SharedLimiter{ - EmailLimiter: emailLimiter, - PhoneLimiter: phoneLimiter, - }) - } - } - - return c, nil - } -} - func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) { t, err := a.extractBearerToken(req) if err != nil || t == "" { diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index 068f0283e..ced249f6c 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -185,52 +185,6 @@ func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() { } } -func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() { - // Set up rate limit config for this test - ts.Config.RateLimitEmailSent = conf.Rate{Events: 5} - ts.Config.RateLimitSmsSent = 5 - ts.Config.External.Phone.Enabled = true - - cases := []struct { - desc string - expectedErrorMsg string - requestBody map[string]interface{} - }{ - { - desc: "Email rate limit exceeded", - expectedErrorMsg: "429: Email rate limit exceeded", - requestBody: map[string]interface{}{ - "email": "test@example.com", - }, - }, - { - desc: "SMS rate limit exceeded", - expectedErrorMsg: "429: SMS rate limit exceeded", - requestBody: map[string]interface{}{ - "phone": "+1233456789", - }, - }, - } - - limiter := ts.API.limitEmailOrPhoneSentHandler() - for _, c := range cases { - ts.Run(c.desc, func() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.requestBody)) - req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - - ctx, err := limiter(w, req) - require.NoError(ts.T(), err) - - // check that shared limiter is set in the request context - sharedLimiter := getLimiter(ctx) - require.NotNil(ts.T(), sharedLimiter) - }) - } -} - func (ts *MiddlewareTestSuite) TestIsValidExternalHost() { cases := []struct { desc string @@ -387,131 +341,3 @@ func (ts *MiddlewareTestSuite) TestLimitHandler() { ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req) require.Equal(ts.T(), http.StatusTooManyRequests, w.Code) } - -func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() { - // setup config for shared limiter and ip-based limiter to work - ts.Config.RateLimitHeader = "X-Rate-Limit" - ts.Config.External.Email.Enabled = true - ts.Config.External.Phone.Enabled = true - ts.Config.Mailer.Autoconfirm = false - ts.Config.Sms.Autoconfirm = false - - ipBasedLimiter := func(max float64) *limiter.Limiter { - return tollbooth.NewLimiter(max, &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }) - } - - okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - limiter := getLimiter(r.Context()) - if limiter != nil { - var requestBody struct { - Email string `json:"email"` - Phone string `json:"phone"` - } - err := retrieveRequestParams(r, &requestBody) - require.NoError(ts.T(), err) - - if requestBody.Email != "" { - if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil { - sendJSON(w, http.StatusTooManyRequests, HTTPError{ - HTTPStatus: http.StatusTooManyRequests, - ErrorCode: ErrorCodeOverEmailSendRateLimit, - Message: "Email rate limit exceeded", - }) - } - } - if requestBody.Phone != "" { - if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"phone_functions"}); err != nil { - sendJSON(w, http.StatusTooManyRequests, HTTPError{ - HTTPStatus: http.StatusTooManyRequests, - ErrorCode: ErrorCodeOverSMSSendRateLimit, - Message: "SMS rate limit exceeded", - }) - } - } - } - w.WriteHeader(http.StatusOK) - }) - - cases := []struct { - desc string - sharedLimiterConfig *conf.GlobalConfiguration - ipBasedLimiterConfig float64 - body map[string]interface{} - expectedErrorCode string - }{ - { - desc: "Exceed ip-based rate limit before shared limiter", - sharedLimiterConfig: &conf.GlobalConfiguration{ - RateLimitEmailSent: conf.Rate{Events: 10}, - RateLimitSmsSent: 10, - }, - ipBasedLimiterConfig: 1, - body: map[string]interface{}{ - "email": "foo@example.com", - }, - expectedErrorCode: ErrorCodeOverRequestRateLimit, - }, - { - desc: "Exceed email shared limiter", - sharedLimiterConfig: &conf.GlobalConfiguration{ - RateLimitEmailSent: conf.Rate{Events: 1}, - RateLimitSmsSent: 1, - }, - ipBasedLimiterConfig: 10, - body: map[string]interface{}{ - "email": "foo@example.com", - }, - expectedErrorCode: ErrorCodeOverEmailSendRateLimit, - }, - { - desc: "Exceed sms shared limiter", - sharedLimiterConfig: &conf.GlobalConfiguration{ - RateLimitEmailSent: conf.Rate{Events: 1}, - RateLimitSmsSent: 1, - }, - ipBasedLimiterConfig: 10, - body: map[string]interface{}{ - "phone": "123456789", - }, - expectedErrorCode: ErrorCodeOverSMSSendRateLimit, - }, - } - - for _, c := range cases { - ts.Run(c.desc, func() { - ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent - ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent - lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig)) - sharedLimiter := ts.API.limitEmailOrPhoneSentHandler() - - // get the minimum amount to reach the threshold just before the rate limit is exceeded - threshold := min(c.sharedLimiterConfig.RateLimitEmailSent.Events, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig) - for i := 0; i < int(threshold); i++ { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) - req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") - - w := httptest.NewRecorder() - lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - } - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) - req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") - - // check if the rate limit is exceeded with the expected error code - w := httptest.NewRecorder() - lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusTooManyRequests, w.Code) - - var data map[string]interface{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - require.Equal(ts.T(), c.expectedErrorCode, data["error_code"]) - }) - } -} diff --git a/internal/api/phone.go b/internal/api/phone.go index ce11c5a3f..683a429d4 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -8,7 +8,6 @@ import ( "text/template" "time" - "github.com/didip/tollbooth/v5" "github.com/supabase/auth/internal/hooks" "github.com/pkg/errors" @@ -45,7 +44,6 @@ func formatPhoneNumber(phone string) string { // sendPhoneConfirmation sends an otp to the user's phone number func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) { - ctx := r.Context() config := a.config var token *string @@ -88,13 +86,10 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use // not using test OTPs if otp == "" { - // apply rate limiting before the sms is sent out - limiter := getLimiter(ctx) - if limiter != nil { - if err := tollbooth.LimitByKeys(limiter.PhoneLimiter, []string{"phone_functions"}); err != nil { - return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") - } + if ok := a.smsRateLimiter.Increment(1); !ok { + return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") } + otp, err = crypto.GenerateOtp(config.Sms.OtpLength) if err != nil { return "", internalServerError("error generating otp").WithInternalError(err) diff --git a/internal/api/ratelimits.go b/internal/api/ratelimits.go new file mode 100644 index 000000000..77904a492 --- /dev/null +++ b/internal/api/ratelimits.go @@ -0,0 +1,78 @@ +package api + +import ( + "fmt" + "sync/atomic" + "time" + + "github.com/supabase/auth/internal/conf" +) + +// SimpleRateLimiter holds a rate limiter that implements a token-bucket +// algorithm. Rate.OverTime is the duration at which the bucket is filled, and +// Rate.Events is the number of tokens in the bucket. +// +// Internally it uses an atomically increasing counter that resets to 0 on +// every OverTime tick. +// +// You should always use NewSimpleRateLimiter to create a new one. +type SimpleRateLimiter struct { + Rate conf.Rate + + ticker *time.Ticker + counter uint64 +} + +// NewSimpleRateLimiter creates a new rate limiter starting at the specified +// time and with the specified Rate. +// +// Initially the bucket is filled with a proprotion of the Rate.Events +// depending on how close to the Rate.OverTime tick it has been crated. This is +// one way of making sure that server restarts do not give out a too big of a +// rate limit, as the counter is reset. +func NewSimpleRateLimiter(now time.Time, rate conf.Rate) *SimpleRateLimiter { + r := &SimpleRateLimiter{ + Rate: rate, + } + + counterStartedAt := now.Truncate(rate.OverTime) + counterResetsAt := counterStartedAt.Add(rate.OverTime) + + proRate := float64(counterStartedAt.Sub(now).Milliseconds()) / float64(rate.OverTime.Milliseconds()) + + r.counter = uint64(rate.Events * proRate) + r.ticker = time.NewTicker(counterResetsAt.Sub(now)) + + go r.fillBucket() + + return r +} + +func (r *SimpleRateLimiter) Increment(events uint64) bool { + fmt.Printf("@@@@@@@@@@@@@@@@@@@@@@@ %d %f\n", r.counter, r.Rate.Events) + return atomic.AddUint64(&r.counter, events) < uint64(r.Rate.Events) +} + +func (r *SimpleRateLimiter) fillBucket() { + if _, ok := <-r.ticker.C; !ok { + return + } + + // reset ticker to start ticking at the OverTime rate, as it was + // initially set up to tick at the next aligned OverTime event + r.ticker.Reset(r.Rate.OverTime) + + // reset counter + atomic.StoreUint64(&r.counter, 0) + + // then keep resetting at regular OverTime intervals + for range r.ticker.C { + atomic.StoreUint64(&r.counter, 0) + } +} + +func (r *SimpleRateLimiter) Close() { + if r.ticker != nil { + r.ticker.Stop() + } +} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index d330150e1..cb7aca153 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -242,7 +242,7 @@ type GlobalConfiguration struct { SMTP SMTPConfiguration RateLimitHeader string `split_words:"true"` RateLimitEmailSent Rate `split_words:"true" default:"30"` - RateLimitSmsSent float64 `split_words:"true" default:"30"` + RateLimitSmsSent Rate `split_words:"true" default:"30"` RateLimitVerify float64 `split_words:"true" default:"30"` RateLimitTokenRefresh float64 `split_words:"true" default:"150"` RateLimitSso float64 `split_words:"true" default:"30"` diff --git a/internal/conf/rate.go b/internal/conf/rate.go index 37076cbcc..c33d160ff 100644 --- a/internal/conf/rate.go +++ b/internal/conf/rate.go @@ -5,9 +5,6 @@ import ( "strconv" "strings" "time" - - "github.com/didip/tollbooth/v5" - "github.com/didip/tollbooth/v5/limiter" ) type Rate struct { @@ -23,28 +20,19 @@ func (r *Rate) EventsPerSecond() float64 { return r.Events / r.OverTime.Seconds() } -func (r *Rate) DivideIfDefaultDuration(div float64) *Rate { - if r.OverTime == time.Duration(0) { - return &Rate{ - Events: r.Events / div, +// DefaultOverTime sets the OverTime field to overTime if it is 0. +func (r *Rate) DefaultOverTime(overTime time.Duration) Rate { + if r.OverTime == 0 { + return Rate{ + Events: r.Events, + OverTime: time.Hour, } } - return r -} - -func (r *Rate) CreateLimiter() *limiter.Limiter { - overTime := r.OverTime - if int64(overTime) == 0 { - // if r.OverTime is not specified, i.e. the configuration specified just a single float64 number, the - overTime = time.Hour - } - - return tollbooth.NewLimiter(r.EventsPerSecond(), &limiter.ExpirableOptions{ - DefaultExpirationTTL: overTime, - }) + return *r } +// Decode is used by envconfig to parse the env-config string to a Rate value. func (r *Rate) Decode(value string) error { if f, err := strconv.ParseFloat(value, 64); err == nil { r.Events = f @@ -56,9 +44,9 @@ func (r *Rate) Decode(value string) error { return fmt.Errorf("rate: value does not match rate syntax %q", value) } - f, err := strconv.ParseFloat(parts[0], 64) + e, err := strconv.ParseUint(parts[0], 10, 52) // 52 because the uint needs to fit in a float64 if err != nil { - return fmt.Errorf("rate: events part of rate value %q failed to parse as float64: %w", value, err) + return fmt.Errorf("rate: events part of rate value %q failed to parse as uint64: %w", value, err) } d, err := time.ParseDuration(parts[1]) @@ -66,7 +54,7 @@ func (r *Rate) Decode(value string) error { return fmt.Errorf("rate: over-time part of rate value %q failed to parse as duration: %w", value, err) } - r.Events = f + r.Events = float64(e) r.OverTime = d return nil @@ -77,5 +65,5 @@ func (r *Rate) String() string { return fmt.Sprintf("%f", r.Events) } - return fmt.Sprintf("%f/%s", r.Events, r.OverTime.String()) + return fmt.Sprintf("%d/%s", uint64(r.Events), r.OverTime.String()) } diff --git a/internal/conf/rate_test.go b/internal/conf/rate_test.go index 43cb86bf0..2c6cf052a 100644 --- a/internal/conf/rate_test.go +++ b/internal/conf/rate_test.go @@ -15,12 +15,15 @@ func TestRateDecode(t *testing.T) { require.Equal(t, r, Rate{Events: 123.0, OverTime: 0}) r = Rate{} - require.NoError(t, r.Decode("123.0/24h")) + require.NoError(t, r.Decode("123/24h")) require.Equal(t, r, Rate{Events: 123.0, OverTime: 24 * time.Hour}) r = Rate{} require.Error(t, r.Decode("not a number")) + r = Rate{} + require.Error(t, r.Decode("123.0/24h")) // events are integers only + r = Rate{} require.Error(t, r.Decode("123/456/789"))