Skip to content

Commit

Permalink
feat: preserve rate limiters in memory across configuration reloads (#…
Browse files Browse the repository at this point in the history
…1792)

The goal with this change is to preserve the rate limiters across reloads
with the least invasive changes possible. To do this I first added a
LimiterOptions structure to hold the current set of rate limiters. I
then identified each call of tollbooth.New and moved the construction of
the limiter without modifications to options.go. I assigned each limiter
a distinct field to be referenced during the API object creation.

Next I needed to add an optional parameter to the NewAPI methods so I
could store the new LimiterOptions onto the API object for reference
during route construction. To do this without breaking all existing
calls to NewAPI I used the options pattern. This makes the method accept
a parametric set of values implementing a common interface, so future
problems of similar nature can also be added as options.

I then replaced all the local anonymous limiter.Limiters made inline
within the API route construction with each corresponding newly added
field within the LimiterOptions struct.
  • Loading branch information
cstockton authored Oct 10, 2024
1 parent 7f006b6 commit 0a3968b
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 101 deletions.
8 changes: 6 additions & 2 deletions cmd/serve_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ func serve(ctx context.Context) {
addr := net.JoinHostPort(config.API.Host, config.API.Port)
logrus.Infof("GoTrue API started on: %s", addr)

a := api.NewAPIWithVersion(config, db, utilities.Version)
opts := []api.Option{
api.NewLimiterOptions(config),
}
a := api.NewAPIWithVersion(config, db, utilities.Version, opts...)
ah := reloader.NewAtomicHandler(a)

baseCtx, baseCancel := context.WithCancel(context.Background())
Expand All @@ -74,7 +77,8 @@ func serve(ctx context.Context) {

fn := func(latestCfg *conf.GlobalConfiguration) {
log.Info("reloading api with new configuration")
latestAPI := api.NewAPIWithVersion(latestCfg, db, utilities.Version)
latestAPI := api.NewAPIWithVersion(
latestCfg, db, utilities.Version, opts...)
ah.Store(latestAPI)
}

Expand Down
121 changes: 40 additions & 81 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"regexp"
"time"

"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
"github.com/rs/cors"
"github.com/sebest/xff"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -37,6 +35,8 @@ type API struct {

// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
overrideTime func() time.Time

limiterOpts *LimiterOptions
}

func (a *API) Now() time.Time {
Expand All @@ -48,8 +48,8 @@ func (a *API) Now() time.Time {
}

// NewAPI instantiates a new REST API
func NewAPI(globalConfig *conf.GlobalConfiguration, db *storage.Connection) *API {
return NewAPIWithVersion(globalConfig, db, defaultVersion)
func NewAPI(globalConfig *conf.GlobalConfiguration, db *storage.Connection, opt ...Option) *API {
return NewAPIWithVersion(globalConfig, db, defaultVersion, opt...)
}

func (a *API) deprecationNotices() {
Expand All @@ -67,9 +67,15 @@ func (a *API) deprecationNotices() {
}

// NewAPIWithVersion creates a new REST API using the specified version
func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string) *API {
func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string, opt ...Option) *API {
api := &API{config: globalConfig, db: db, version: version}

for _, o := range opt {
o.apply(api)
}
if api.limiterOpts == nil {
api.limiterOpts = NewLimiterOptions(globalConfig)
}
if api.config.Password.HIBP.Enabled {
httpClient := &http.Client{
// all HIBP API requests should finish quickly to avoid
Expand Down Expand Up @@ -134,18 +140,12 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.Get("/authorize", api.ExternalProviderRedirect)

sharedLimiter := api.limitEmailOrPhoneSentHandler()
sharedLimiter := api.limitEmailOrPhoneSentHandler(api.limiterOpts)
r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite)
r.With(sharedLimiter).With(api.verifyCaptcha).Route("/signup", func(r *router) {
// rate limit per hour
limitAnonymousSignIns := tollbooth.NewLimiter(api.config.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(api.config.RateLimitAnonymousUsers)).SetMethods([]string{"POST"})

limitSignups := tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

limitAnonymousSignIns := api.limiterOpts.AnonymousSignIns
limitSignups := api.limiterOpts.Signups
r.Post("/", func(w http.ResponseWriter, r *http.Request) error {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
Expand All @@ -172,47 +172,22 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
return api.Signup(w, r)
})
})
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.RateLimitTokenRefresh/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(api.verifyCaptcha).Post("/token", api.Token)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.RateLimitVerify/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).Route("/verify", func(r *router) {
r.With(api.limitHandler(api.limiterOpts.Recover)).
With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)

r.With(api.limitHandler(api.limiterOpts.Resend)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend)

r.With(api.limitHandler(api.limiterOpts.MagicLink)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)

r.With(api.limitHandler(api.limiterOpts.Otp)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp)

r.With(api.limitHandler(api.limiterOpts.Token)).
With(api.verifyCaptcha).Post("/token", api.Token)

r.With(api.limitHandler(api.limiterOpts.Verify)).Route("/verify", func(r *router) {
r.Get("/", api.Verify)
r.Post("/", api.Verify)
})
Expand All @@ -225,12 +200,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.With(api.requireAuthentication).Route("/user", func(r *router) {
r.Get("/", api.UserGet)
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).Put("/", api.UserUpdate)
r.With(api.limitHandler(api.limiterOpts.User)).
With(sharedLimiter).Put("/", api.UserUpdate)

r.Route("/identities", func(r *router) {
r.Use(api.requireManualLinkingEnabled)
Expand All @@ -245,37 +216,25 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
r.Route("/{factor_id}", func(r *router) {
r.Use(api.loadFactor)

r.With(api.limitHandler(
tollbooth.NewLimiter(api.config.MFA.RateLimitChallengeAndVerify/60, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30))).Post("/verify", api.VerifyFactor)
r.With(api.limitHandler(
tollbooth.NewLimiter(api.config.MFA.RateLimitChallengeAndVerify/60, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30))).Post("/challenge", api.ChallengeFactor)
r.With(api.limitHandler(api.limiterOpts.FactorVerify)).
Post("/verify", api.VerifyFactor)
r.With(api.limitHandler(api.limiterOpts.FactorChallenge)).
Post("/challenge", api.ChallengeFactor)
r.Delete("/", api.UnenrollFactor)

})
})

r.Route("/sso", func(r *router) {
r.Use(api.requireSAMLEnabled)
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.RateLimitSso/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(api.verifyCaptcha).Post("/", api.SingleSignOn)
r.With(api.limitHandler(api.limiterOpts.SSO)).
With(api.verifyCaptcha).Post("/", api.SingleSignOn)

r.Route("/saml", func(r *router) {
r.Get("/metadata", api.SAMLMetadata)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.SAML.RateLimitAssertion/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).Post("/acs", api.SamlAcs)
r.With(api.limitHandler(api.limiterOpts.SAMLAssertion)).
Post("/acs", api.SamlAcs)
})
})

Expand Down
3 changes: 2 additions & 1 deletion internal/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func setupAPIForTestWithCallback(cb func(*conf.GlobalConfiguration, *storage.Con
cb(nil, conn)
}

return NewAPIWithVersion(config, conn, apiTestVersion), config, nil
limiterOpts := NewLimiterOptions(config)
return NewAPIWithVersion(config, conn, apiTestVersion, limiterOpts), config, nil
}

func TestEmailEnabledByDefault(t *testing.T) {
Expand Down
18 changes: 3 additions & 15 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,7 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
}
}

func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {
// limit per hour
emailFreq := a.config.RateLimitEmailSent / (60 * 60)
smsFreq := a.config.RateLimitSmsSent / (60 * 60)

emailLimiter := tollbooth.NewLimiter(emailFreq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(a.config.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"})

phoneLimiter := tollbooth.NewLimiter(smsFreq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(a.config.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"})

func (a *API) limitEmailOrPhoneSentHandler(limiterOptions *LimiterOptions) middlewareHandler {
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
c := req.Context()
config := a.config
Expand All @@ -100,8 +88,8 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {
if req.Method == "PUT" || req.Method == "POST" {
// store rate limiter in request context
c = withLimiter(c, &SharedLimiter{
EmailLimiter: emailLimiter,
PhoneLimiter: phoneLimiter,
EmailLimiter: limiterOptions.Email,
PhoneLimiter: limiterOptions.Phone,
})
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() {
},
}

limiter := ts.API.limitEmailOrPhoneSentHandler()
limiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config))
for _, c := range cases {
ts.Run(c.desc, func() {
var buffer bytes.Buffer
Expand Down Expand Up @@ -484,7 +484,7 @@ func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() {
ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent
ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent
lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig))
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler()
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config))

// get the minimum amount to reach the threshold just before the rate limit is exceeded
threshold := min(c.sharedLimiterConfig.RateLimitEmailSent, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig)
Expand Down
107 changes: 107 additions & 0 deletions internal/api/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package api

import (
"time"

"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
"github.com/supabase/auth/internal/conf"
)

type Option interface {
apply(*API)
}

type LimiterOptions struct {
Email *limiter.Limiter
Phone *limiter.Limiter
Signups *limiter.Limiter
AnonymousSignIns *limiter.Limiter
Recover *limiter.Limiter
Resend *limiter.Limiter
MagicLink *limiter.Limiter
Otp *limiter.Limiter
Token *limiter.Limiter
Verify *limiter.Limiter
User *limiter.Limiter
FactorVerify *limiter.Limiter
FactorChallenge *limiter.Limiter
SSO *limiter.Limiter
SAMLAssertion *limiter.Limiter
}

func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo }

func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions {
o := &LimiterOptions{}

o.Email = tollbooth.NewLimiter(gc.RateLimitEmailSent/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"})

o.Phone = tollbooth.NewLimiter(gc.RateLimitSmsSent/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"})

o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitAnonymousUsers)).SetMethods([]string{"POST"})

o.Token = tollbooth.NewLimiter(gc.RateLimitTokenRefresh/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.Verify = tollbooth.NewLimiter(gc.RateLimitVerify/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.User = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.FactorVerify = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60,
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30)

o.FactorChallenge = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60,
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30)

o.SSO = tollbooth.NewLimiter(gc.RateLimitSso/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.SAMLAssertion = tollbooth.NewLimiter(gc.SAML.RateLimitAssertion/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.Signups = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

// These all use the OTP limit per 5 min with 1hour ttl and burst of 30.
o.Recover = newLimiterPer5mOver1h(gc.RateLimitOtp)
o.Resend = newLimiterPer5mOver1h(gc.RateLimitOtp)
o.MagicLink = newLimiterPer5mOver1h(gc.RateLimitOtp)
o.Otp = newLimiterPer5mOver1h(gc.RateLimitOtp)
return o
}

func newLimiterPer5mOver1h(rate float64) *limiter.Limiter {
freq := rate / (60 * 5)
lim := tollbooth.NewLimiter(freq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)
return lim
}

0 comments on commit 0a3968b

Please sign in to comment.