Skip to content

Commit

Permalink
feat: rate limiter persistence
Browse files Browse the repository at this point in the history
The goal with this change is to persist the rate limiters across
reloads with the least invasive changes possible. To do this I
first added a LimiterOptions structure to hold the current set
of rate limiters. I then identified each call of tollbooth.New
and moved the construction of the limiter without modifications
to options.go. I assigned each limiter a distinct field to be
referenced during the API object creation.

Next I needed to add an optional parameter to the NewAPI methods
so I could store the new LimiterOptions onto the API object for
reference during route construction. To do this without breaking
all existing calls to NewAPI I used the options pattern. This
makes the method accept a parametric set of values implementing
a common interface, so future problems of similar nature can
also be added as options.

I then replaced all the local anonymous limiter.Limiters made
inline within the API route construction with each corresponding
newly added field within the LimiterOptions struct.
  • Loading branch information
Chris Stockton committed Oct 8, 2024
1 parent 56e3d33 commit 211df61
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 101 deletions.
8 changes: 6 additions & 2 deletions cmd/serve_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ func serve(ctx context.Context) {
addr := net.JoinHostPort(config.API.Host, config.API.Port)
logrus.Infof("GoTrue API started on: %s", addr)

a := api.NewAPIWithVersion(config, db, utilities.Version)
opts := []api.Option{
api.NewLimiterOptions(config),
}
a := api.NewAPIWithVersion(config, db, utilities.Version, opts...)
ah := reloader.NewAtomicHandler(a)

baseCtx, baseCancel := context.WithCancel(context.Background())
Expand All @@ -74,7 +77,8 @@ func serve(ctx context.Context) {

fn := func(latestCfg *conf.GlobalConfiguration) {
log.Info("reloading api with new configuration")
latestAPI := api.NewAPIWithVersion(latestCfg, db, utilities.Version)
latestAPI := api.NewAPIWithVersion(
latestCfg, db, utilities.Version, opts...)
ah.Store(latestAPI)
}

Expand Down
121 changes: 40 additions & 81 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"regexp"
"time"

"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
"github.com/rs/cors"
"github.com/sebest/xff"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -37,6 +35,8 @@ type API struct {

// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
overrideTime func() time.Time

limiterOpts *LimiterOptions
}

func (a *API) Now() time.Time {
Expand All @@ -48,8 +48,8 @@ func (a *API) Now() time.Time {
}

// NewAPI instantiates a new REST API
func NewAPI(globalConfig *conf.GlobalConfiguration, db *storage.Connection) *API {
return NewAPIWithVersion(globalConfig, db, defaultVersion)
func NewAPI(globalConfig *conf.GlobalConfiguration, db *storage.Connection, opt ...Option) *API {
return NewAPIWithVersion(globalConfig, db, defaultVersion, opt...)
}

func (a *API) deprecationNotices() {
Expand All @@ -67,9 +67,15 @@ func (a *API) deprecationNotices() {
}

// NewAPIWithVersion creates a new REST API using the specified version
func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string) *API {
func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string, opt ...Option) *API {
api := &API{config: globalConfig, db: db, version: version}

for _, o := range opt {
o.apply(api)
}
if api.limiterOpts == nil {
api.limiterOpts = NewLimiterOptions(globalConfig)
}
if api.config.Password.HIBP.Enabled {
httpClient := &http.Client{
// all HIBP API requests should finish quickly to avoid
Expand Down Expand Up @@ -134,18 +140,12 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.Get("/authorize", api.ExternalProviderRedirect)

sharedLimiter := api.limitEmailOrPhoneSentHandler()
sharedLimiter := api.limitEmailOrPhoneSentHandler(api.limiterOpts)
r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite)
r.With(sharedLimiter).With(api.verifyCaptcha).Route("/signup", func(r *router) {
// rate limit per hour
limitAnonymousSignIns := tollbooth.NewLimiter(api.config.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(api.config.RateLimitAnonymousUsers)).SetMethods([]string{"POST"})

limitSignups := tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

limitAnonymousSignIns := api.limiterOpts.AnonymousSignIns
limitSignups := api.limiterOpts.Signups
r.Post("/", func(w http.ResponseWriter, r *http.Request) error {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
Expand All @@ -172,47 +172,22 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
return api.Signup(w, r)
})
})
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.RateLimitTokenRefresh/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(api.verifyCaptcha).Post("/token", api.Token)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.RateLimitVerify/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).Route("/verify", func(r *router) {
r.With(api.limitHandler(api.limiterOpts.Recover)).
With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)

r.With(api.limitHandler(api.limiterOpts.Resend)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend)

r.With(api.limitHandler(api.limiterOpts.MagicLink)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)

r.With(api.limitHandler(api.limiterOpts.Otp)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp)

r.With(api.limitHandler(api.limiterOpts.Token)).
With(api.verifyCaptcha).Post("/token", api.Token)

r.With(api.limitHandler(api.limiterOpts.Verify)).Route("/verify", func(r *router) {
r.Get("/", api.Verify)
r.Post("/", api.Verify)
})
Expand All @@ -225,12 +200,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.With(api.requireAuthentication).Route("/user", func(r *router) {
r.Get("/", api.UserGet)
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).Put("/", api.UserUpdate)
r.With(api.limitHandler(api.limiterOpts.User)).
With(sharedLimiter).Put("/", api.UserUpdate)

r.Route("/identities", func(r *router) {
r.Use(api.requireManualLinkingEnabled)
Expand All @@ -245,37 +216,25 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
r.Route("/{factor_id}", func(r *router) {
r.Use(api.loadFactor)

r.With(api.limitHandler(
tollbooth.NewLimiter(api.config.MFA.RateLimitChallengeAndVerify/60, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30))).Post("/verify", api.VerifyFactor)
r.With(api.limitHandler(
tollbooth.NewLimiter(api.config.MFA.RateLimitChallengeAndVerify/60, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30))).Post("/challenge", api.ChallengeFactor)
r.With(api.limitHandler(api.limiterOpts.FactorVerify)).
Post("/verify", api.VerifyFactor)
r.With(api.limitHandler(api.limiterOpts.FactorChallenge)).
Post("/challenge", api.ChallengeFactor)
r.Delete("/", api.UnenrollFactor)

})
})

r.Route("/sso", func(r *router) {
r.Use(api.requireSAMLEnabled)
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.RateLimitSso/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(api.verifyCaptcha).Post("/", api.SingleSignOn)
r.With(api.limitHandler(api.limiterOpts.SSO)).
With(api.verifyCaptcha).Post("/", api.SingleSignOn)

r.Route("/saml", func(r *router) {
r.Get("/metadata", api.SAMLMetadata)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.SAML.RateLimitAssertion/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).Post("/acs", api.SamlAcs)
r.With(api.limitHandler(api.limiterOpts.SAMLAssertion)).
Post("/acs", api.SamlAcs)
})
})

Expand Down
3 changes: 2 additions & 1 deletion internal/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func setupAPIForTestWithCallback(cb func(*conf.GlobalConfiguration, *storage.Con
cb(nil, conn)
}

return NewAPIWithVersion(config, conn, apiTestVersion), config, nil
limiterOpts := NewLimiterOptions(config)
return NewAPIWithVersion(config, conn, apiTestVersion, limiterOpts), config, nil
}

func TestEmailEnabledByDefault(t *testing.T) {
Expand Down
18 changes: 3 additions & 15 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,7 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
}
}

func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {
// limit per hour
emailFreq := a.config.RateLimitEmailSent / (60 * 60)
smsFreq := a.config.RateLimitSmsSent / (60 * 60)

emailLimiter := tollbooth.NewLimiter(emailFreq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(a.config.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"})

phoneLimiter := tollbooth.NewLimiter(smsFreq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(a.config.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"})

func (a *API) limitEmailOrPhoneSentHandler(limiterOptions *LimiterOptions) middlewareHandler {
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
c := req.Context()
config := a.config
Expand All @@ -100,8 +88,8 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {
if req.Method == "PUT" || req.Method == "POST" {
// store rate limiter in request context
c = withLimiter(c, &SharedLimiter{
EmailLimiter: emailLimiter,
PhoneLimiter: phoneLimiter,
EmailLimiter: limiterOptions.Email,
PhoneLimiter: limiterOptions.Phone,
})
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() {
},
}

limiter := ts.API.limitEmailOrPhoneSentHandler()
limiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config))
for _, c := range cases {
ts.Run(c.desc, func() {
var buffer bytes.Buffer
Expand Down Expand Up @@ -484,7 +484,7 @@ func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() {
ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent
ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent
lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig))
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler()
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config))

// get the minimum amount to reach the threshold just before the rate limit is exceeded
threshold := min(c.sharedLimiterConfig.RateLimitEmailSent, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig)
Expand Down
111 changes: 111 additions & 0 deletions internal/api/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package api

import (
"time"

"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
"github.com/supabase/auth/internal/conf"
)

type Option interface {
apply(*API)
}

type optionFunc func(*API)

Check failure on line 15 in internal/api/options.go

View workflow job for this annotation

GitHub Actions / test (1.22.x)

type optionFunc is unused (U1000)

func (fn optionFunc) apply(a *API) { fn(a) }

Check failure on line 17 in internal/api/options.go

View workflow job for this annotation

GitHub Actions / test (1.22.x)

func optionFunc.apply is unused (U1000)

type LimiterOptions struct {
Email *limiter.Limiter
Phone *limiter.Limiter
Signups *limiter.Limiter
AnonymousSignIns *limiter.Limiter
Recover *limiter.Limiter
Resend *limiter.Limiter
MagicLink *limiter.Limiter
Otp *limiter.Limiter
Token *limiter.Limiter
Verify *limiter.Limiter
User *limiter.Limiter
FactorVerify *limiter.Limiter
FactorChallenge *limiter.Limiter
SSO *limiter.Limiter
SAMLAssertion *limiter.Limiter
}

func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo }

func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions {
o := &LimiterOptions{}

o.Email = tollbooth.NewLimiter(gc.RateLimitEmailSent/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"})

o.Phone = tollbooth.NewLimiter(gc.RateLimitSmsSent/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"})

o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitAnonymousUsers)).SetMethods([]string{"POST"})

o.Token = tollbooth.NewLimiter(gc.RateLimitTokenRefresh/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.Verify = tollbooth.NewLimiter(gc.RateLimitVerify/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.User = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.FactorVerify = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60,
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30)

o.FactorChallenge = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60,
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30)

o.SSO = tollbooth.NewLimiter(gc.RateLimitSso/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.SAMLAssertion = tollbooth.NewLimiter(gc.SAML.RateLimitAssertion/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.Signups = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

// These all use the OTP limit per 5 min with 1hour ttl and burst of 30.
o.Recover = newLimiterPer5mOver1h(gc.RateLimitOtp)
o.Resend = newLimiterPer5mOver1h(gc.RateLimitOtp)
o.MagicLink = newLimiterPer5mOver1h(gc.RateLimitOtp)
o.Otp = newLimiterPer5mOver1h(gc.RateLimitOtp)
return o
}

func newLimiterPer5mOver1h(rate float64) *limiter.Limiter {
freq := rate / (60 * 5)
lim := tollbooth.NewLimiter(freq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)
return lim
}

0 comments on commit 211df61

Please sign in to comment.