diff --git a/drivers/middleware/fasthttp/middleware.go b/drivers/middleware/fasthttp/middleware.go index e80204e..156941e 100644 --- a/drivers/middleware/fasthttp/middleware.go +++ b/drivers/middleware/fasthttp/middleware.go @@ -12,6 +12,7 @@ type Middleware struct { OnError ErrorHandler OnLimitReached LimitReachedHandler KeyGetter KeyGetter + ExcludedKey func(string) bool } // NewMiddleware return a new instance of a fasthttp middleware. @@ -21,6 +22,7 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware { OnError: DefaultErrorHandler, OnLimitReached: DefaultLimitReachedHandler, KeyGetter: DefaultKeyGetter, + ExcludedKey: nil, } for _, option := range options { @@ -34,6 +36,11 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware { func (middleware *Middleware) Handle(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { key := middleware.KeyGetter(ctx) + if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) { + next(ctx) + return + } + context, err := middleware.Limiter.Get(ctx, key) if err != nil { middleware.OnError(ctx, err) diff --git a/drivers/middleware/fasthttp/middleware_test.go b/drivers/middleware/fasthttp/middleware_test.go index ab821f8..bb1fff6 100644 --- a/drivers/middleware/fasthttp/middleware_test.go +++ b/drivers/middleware/fasthttp/middleware_test.go @@ -16,6 +16,7 @@ import ( "github.com/ulule/limiter/v3/drivers/store/memory" ) +// nolint: gocyclo func TestFasthttpMiddleware(t *testing.T) { is := require.New(t) @@ -133,6 +134,49 @@ func TestFasthttpMiddleware(t *testing.T) { is.NoError(err) is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10)) } + + // + // Test ExcludedKey + // + + store = memory.NewStore() + is.NotZero(store) + + counter = int64(0) + keyGetterHandler := func(c *libfasthttp.RequestCtx) string { + v := atomic.AddInt64(&counter, 1) + return strconv.FormatInt(v%2, 10) + } + excludedKeyHandler := func(key string) bool { + return key == "1" + } + + middleware = fasthttp.NewMiddleware(limiter.New(store, rate), + fasthttp.WithKeyGetter(keyGetterHandler), fasthttp.WithExcludedKey(excludedKeyHandler)) + is.NotZero(middleware) + + requestHandler = func(ctx *libfasthttp.RequestCtx) { + switch string(ctx.Path()) { + case "/": + ctx.SetStatusCode(libfasthttp.StatusOK) + ctx.SetBodyString("hello") + } + } + + success = 20 + for i := int64(1); i <= clients; i++ { + resp := libfasthttp.AcquireResponse() + req := libfasthttp.AcquireRequest() + req.Header.SetHost("localhost:8081") + req.Header.SetRequestURI("/") + err := serve(middleware.Handle(requestHandler), req, resp) + is.NoError(err) + if i <= success || i%2 == 1 { + is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10)) + } else { + is.Equal(libfasthttp.StatusTooManyRequests, resp.StatusCode(), strconv.FormatInt(i, 10)) + } + } } func serve(handler libfasthttp.RequestHandler, req *libfasthttp.Request, res *libfasthttp.Response) error { diff --git a/drivers/middleware/fasthttp/options.go b/drivers/middleware/fasthttp/options.go index eb319f7..65382f2 100644 --- a/drivers/middleware/fasthttp/options.go +++ b/drivers/middleware/fasthttp/options.go @@ -61,3 +61,10 @@ func WithKeyGetter(KeyGetter KeyGetter) Option { func DefaultKeyGetter(ctx *fasthttp.RequestCtx) string { return ctx.RemoteIP().String() } + +// WithExcludedKey will configure the Middleware to ignore key(s) using the given function. +func WithExcludedKey(handler func(string) bool) Option { + return option(func(middleware *Middleware) { + middleware.ExcludedKey = handler + }) +} diff --git a/drivers/middleware/gin/middleware.go b/drivers/middleware/gin/middleware.go index fafb1a2..23bad41 100644 --- a/drivers/middleware/gin/middleware.go +++ b/drivers/middleware/gin/middleware.go @@ -24,6 +24,7 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) gin.HandlerFunc OnError: DefaultErrorHandler, OnLimitReached: DefaultLimitReachedHandler, KeyGetter: DefaultKeyGetter, + ExcludedKey: nil, } for _, option := range options { diff --git a/drivers/middleware/gin/options.go b/drivers/middleware/gin/options.go index 944157f..604c6bc 100644 --- a/drivers/middleware/gin/options.go +++ b/drivers/middleware/gin/options.go @@ -51,9 +51,9 @@ func DefaultLimitReachedHandler(c *gin.Context) { type KeyGetter func(c *gin.Context) string // WithKeyGetter will configure the Middleware to use the given KeyGetter. -func WithKeyGetter(KeyGetter KeyGetter) Option { +func WithKeyGetter(handler KeyGetter) Option { return option(func(middleware *Middleware) { - middleware.KeyGetter = KeyGetter + middleware.KeyGetter = handler }) } @@ -63,9 +63,9 @@ func DefaultKeyGetter(c *gin.Context) string { return c.ClientIP() } -// WithExcludedKey will configure the Middleware to use the given function. -func WithExcludedKey(fn func(string) bool) Option { +// WithExcludedKey will configure the Middleware to ignore key(s) using the given function. +func WithExcludedKey(handler func(string) bool) Option { return option(func(middleware *Middleware) { - middleware.ExcludedKey = fn + middleware.ExcludedKey = handler }) } diff --git a/drivers/middleware/stdlib/middleware.go b/drivers/middleware/stdlib/middleware.go index 0377cac..fa21d61 100644 --- a/drivers/middleware/stdlib/middleware.go +++ b/drivers/middleware/stdlib/middleware.go @@ -9,10 +9,10 @@ import ( // Middleware is the middleware for basic http.Handler. type Middleware struct { - Limiter *limiter.Limiter - OnError ErrorHandler - OnLimitReached LimitReachedHandler - TrustForwardHeader bool + Limiter *limiter.Limiter + OnError ErrorHandler + OnLimitReached LimitReachedHandler + ExcludedKey func(string) bool } // NewMiddleware return a new instance of a basic HTTP middleware. @@ -21,6 +21,7 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware { Limiter: limiter, OnError: DefaultErrorHandler, OnLimitReached: DefaultLimitReachedHandler, + ExcludedKey: nil, } for _, option := range options { @@ -33,7 +34,13 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware { // Handler handles a HTTP request. func (middleware *Middleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - context, err := middleware.Limiter.Get(r.Context(), middleware.Limiter.GetIPKey(r)) + key := middleware.Limiter.GetIPKey(r) + if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) { + h.ServeHTTP(w, r) + return + } + + context, err := middleware.Limiter.Get(r.Context(), key) if err != nil { middleware.OnError(w, r, err) return diff --git a/drivers/middleware/stdlib/options.go b/drivers/middleware/stdlib/options.go index fd99eaa..fd724cd 100644 --- a/drivers/middleware/stdlib/options.go +++ b/drivers/middleware/stdlib/options.go @@ -44,3 +44,10 @@ func WithLimitReachedHandler(handler LimitReachedHandler) Option { func DefaultLimitReachedHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Limit exceeded", http.StatusTooManyRequests) } + +// WithExcludedKey will configure the Middleware to ignore key(s) using the given function. +func WithExcludedKey(handler func(string) bool) Option { + return option(func(middleware *Middleware) { + middleware.ExcludedKey = handler + }) +}