Skip to content

Commit

Permalink
feat: simplified, pro-rated email and sms rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Sep 11, 2024
1 parent 9ddd38c commit f05a4b7
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 274 deletions.
27 changes: 15 additions & 12 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ type API struct {

// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
overrideTime func() time.Time

emailRateLimiter *SimpleRateLimiter
smsRateLimiter *SimpleRateLimiter
}

func (a *API) Now() time.Time {
Expand Down Expand Up @@ -70,6 +73,11 @@ func (a *API) deprecationNotices() {
func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string) *API {
api := &API{config: globalConfig, db: db, version: version}

now := time.Now()

api.emailRateLimiter = NewSimpleRateLimiter(now, globalConfig.RateLimitEmailSent.DefaultOverTime(time.Hour))
api.smsRateLimiter = NewSimpleRateLimiter(now, globalConfig.RateLimitSmsSent.DefaultOverTime(time.Hour))

if api.config.Password.HIBP.Enabled {
httpClient := &http.Client{
// all HIBP API requests should finish quickly to avoid
Expand Down Expand Up @@ -133,9 +141,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.Get("/authorize", api.ExternalProviderRedirect)

sharedLimiter := api.limitEmailOrPhoneSentHandler()
r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite)
r.With(sharedLimiter).With(api.verifyCaptcha).Route("/signup", func(r *router) {
r.With(api.requireAdminCredentials).Post("/invite", api.Invite)
r.With(api.verifyCaptcha).Route("/signup", func(r *router) {
// rate limit per hour
limitAnonymousSignIns := tollbooth.NewLimiter(api.config.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
Expand Down Expand Up @@ -164,10 +171,6 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
if _, err := api.limitHandler(limitSignups)(w, r); err != nil {
return err
}
// apply shared rate limiting on email / phone
if _, err := sharedLimiter(w, r); err != nil {
return err
}
return api.Signup(w, r)
})
})
Expand All @@ -176,28 +179,28 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)
)).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend)
)).With(api.verifyCaptcha).Post("/resend", api.Resend)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)
)).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp)
)).With(api.verifyCaptcha).Post("/otp", api.Otp)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
Expand Down Expand Up @@ -229,7 +232,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).Put("/", api.UserUpdate)
)).Put("/", api.UserUpdate)

r.Route("/identities", func(r *router) {
r.Use(api.requireManualLinkingEnabled)
Expand Down
13 changes: 0 additions & 13 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ const (
ssoProviderKey = contextKey("sso_provider")
externalHostKey = contextKey("external_host")
flowStateKey = contextKey("flow_state_id")
sharedLimiterKey = contextKey("shared_limiter")
)

// withToken adds the JWT token to the context.
Expand Down Expand Up @@ -248,15 +247,3 @@ type SharedLimiter struct {
EmailLimiter *limiter.Limiter
PhoneLimiter *limiter.Limiter
}

func withLimiter(ctx context.Context, limiter *SharedLimiter) context.Context {
return context.WithValue(ctx, sharedLimiterKey, limiter)
}

func getLimiter(ctx context.Context) *SharedLimiter {
obj := ctx.Value(sharedLimiterKey)
if obj == nil {
return nil
}
return obj.(*SharedLimiter)
}
18 changes: 7 additions & 11 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"strings"
"time"

"github.com/didip/tollbooth/v5"
"github.com/supabase/auth/internal/hooks"
mail "github.com/supabase/auth/internal/mailer"
"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -595,16 +594,13 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User,
referrerURL := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)

// apply rate limiting before the email is sent out
if limiter := getLimiter(ctx); limiter != nil {
if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil {
emailRateLimitCounter.Add(
ctx,
1,
metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))),
)
return EmailRateLimitExceeded
}
if ok := a.emailRateLimiter.Increment(1); !ok {
emailRateLimitCounter.Add(
ctx,
1,
metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))),
)
return EmailRateLimitExceeded
}

if config.Hook.SendEmail.Enabled {
Expand Down
30 changes: 0 additions & 30 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,36 +76,6 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
}
}

func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {
// limit per hour
smsFreq := a.config.RateLimitSmsSent / (60 * 60)

emailLimiter := a.config.RateLimitEmailSent.DivideIfDefaultDuration(60 * 60).CreateLimiter().SetBurst(int(a.config.RateLimitEmailSent.Events)).SetMethods([]string{"PUT", "POST"})

phoneLimiter := tollbooth.NewLimiter(smsFreq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(a.config.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"})

return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
c := req.Context()
config := a.config
shouldRateLimitEmail := config.External.Email.Enabled && !config.Mailer.Autoconfirm
shouldRateLimitPhone := config.External.Phone.Enabled && !config.Sms.Autoconfirm

if shouldRateLimitEmail || shouldRateLimitPhone {
if req.Method == "PUT" || req.Method == "POST" {
// store rate limiter in request context
c = withLimiter(c, &SharedLimiter{
EmailLimiter: emailLimiter,
PhoneLimiter: phoneLimiter,
})
}
}

return c, nil
}
}

func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) {
t, err := a.extractBearerToken(req)
if err != nil || t == "" {
Expand Down
174 changes: 0 additions & 174 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,52 +185,6 @@ func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() {
}
}

func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() {
// Set up rate limit config for this test
ts.Config.RateLimitEmailSent = conf.Rate{Events: 5}
ts.Config.RateLimitSmsSent = 5
ts.Config.External.Phone.Enabled = true

cases := []struct {
desc string
expectedErrorMsg string
requestBody map[string]interface{}
}{
{
desc: "Email rate limit exceeded",
expectedErrorMsg: "429: Email rate limit exceeded",
requestBody: map[string]interface{}{
"email": "test@example.com",
},
},
{
desc: "SMS rate limit exceeded",
expectedErrorMsg: "429: SMS rate limit exceeded",
requestBody: map[string]interface{}{
"phone": "+1233456789",
},
},
}

limiter := ts.API.limitEmailOrPhoneSentHandler()
for _, c := range cases {
ts.Run(c.desc, func() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.requestBody))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

ctx, err := limiter(w, req)
require.NoError(ts.T(), err)

// check that shared limiter is set in the request context
sharedLimiter := getLimiter(ctx)
require.NotNil(ts.T(), sharedLimiter)
})
}
}

func (ts *MiddlewareTestSuite) TestIsValidExternalHost() {
cases := []struct {
desc string
Expand Down Expand Up @@ -387,131 +341,3 @@ func (ts *MiddlewareTestSuite) TestLimitHandler() {
ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusTooManyRequests, w.Code)
}

func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() {
// setup config for shared limiter and ip-based limiter to work
ts.Config.RateLimitHeader = "X-Rate-Limit"
ts.Config.External.Email.Enabled = true
ts.Config.External.Phone.Enabled = true
ts.Config.Mailer.Autoconfirm = false
ts.Config.Sms.Autoconfirm = false

ipBasedLimiter := func(max float64) *limiter.Limiter {
return tollbooth.NewLimiter(max, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
})
}

okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limiter := getLimiter(r.Context())
if limiter != nil {
var requestBody struct {
Email string `json:"email"`
Phone string `json:"phone"`
}
err := retrieveRequestParams(r, &requestBody)
require.NoError(ts.T(), err)

if requestBody.Email != "" {
if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil {
sendJSON(w, http.StatusTooManyRequests, HTTPError{
HTTPStatus: http.StatusTooManyRequests,
ErrorCode: ErrorCodeOverEmailSendRateLimit,
Message: "Email rate limit exceeded",
})
}
}
if requestBody.Phone != "" {
if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"phone_functions"}); err != nil {
sendJSON(w, http.StatusTooManyRequests, HTTPError{
HTTPStatus: http.StatusTooManyRequests,
ErrorCode: ErrorCodeOverSMSSendRateLimit,
Message: "SMS rate limit exceeded",
})
}
}
}
w.WriteHeader(http.StatusOK)
})

cases := []struct {
desc string
sharedLimiterConfig *conf.GlobalConfiguration
ipBasedLimiterConfig float64
body map[string]interface{}
expectedErrorCode string
}{
{
desc: "Exceed ip-based rate limit before shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: conf.Rate{Events: 10},
RateLimitSmsSent: 10,
},
ipBasedLimiterConfig: 1,
body: map[string]interface{}{
"email": "foo@example.com",
},
expectedErrorCode: ErrorCodeOverRequestRateLimit,
},
{
desc: "Exceed email shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: conf.Rate{Events: 1},
RateLimitSmsSent: 1,
},
ipBasedLimiterConfig: 10,
body: map[string]interface{}{
"email": "foo@example.com",
},
expectedErrorCode: ErrorCodeOverEmailSendRateLimit,
},
{
desc: "Exceed sms shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: conf.Rate{Events: 1},
RateLimitSmsSent: 1,
},
ipBasedLimiterConfig: 10,
body: map[string]interface{}{
"phone": "123456789",
},
expectedErrorCode: ErrorCodeOverSMSSendRateLimit,
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent
ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent
lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig))
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler()

// get the minimum amount to reach the threshold just before the rate limit is exceeded
threshold := min(c.sharedLimiterConfig.RateLimitEmailSent.Events, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig)
for i := 0; i < int(threshold); i++ {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")

w := httptest.NewRecorder()
lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
}

var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")

// check if the rate limit is exceeded with the expected error code
w := httptest.NewRecorder()
lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusTooManyRequests, w.Code)

var data map[string]interface{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
require.Equal(ts.T(), c.expectedErrorCode, data["error_code"])
})
}
}
11 changes: 3 additions & 8 deletions internal/api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"text/template"
"time"

"github.com/didip/tollbooth/v5"
"github.com/supabase/auth/internal/hooks"

"github.com/pkg/errors"
Expand Down Expand Up @@ -45,7 +44,6 @@ func formatPhoneNumber(phone string) string {

// sendPhoneConfirmation sends an otp to the user's phone number
func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) {
ctx := r.Context()
config := a.config

var token *string
Expand Down Expand Up @@ -88,13 +86,10 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use

// not using test OTPs
if otp == "" {
// apply rate limiting before the sms is sent out
limiter := getLimiter(ctx)
if limiter != nil {
if err := tollbooth.LimitByKeys(limiter.PhoneLimiter, []string{"phone_functions"}); err != nil {
return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded")
}
if ok := a.smsRateLimiter.Increment(1); !ok {
return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded")
}

otp, err = crypto.GenerateOtp(config.Sms.OtpLength)
if err != nil {
return "", internalServerError("error generating otp").WithInternalError(err)
Expand Down
Loading

0 comments on commit f05a4b7

Please sign in to comment.