Skip to content

Commit

Permalink
Merge pull request #93 from ulule/feat-excluded-key
Browse files Browse the repository at this point in the history
Add excluded keys for stdlib and fasthttp middlewares
  • Loading branch information
novln authored Apr 7, 2020
2 parents c3ad534 + 8e8b502 commit cd35126
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 10 deletions.
7 changes: 7 additions & 0 deletions drivers/middleware/fasthttp/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type Middleware struct {
OnError ErrorHandler
OnLimitReached LimitReachedHandler
KeyGetter KeyGetter
ExcludedKey func(string) bool
}

// NewMiddleware return a new instance of a fasthttp middleware.
Expand All @@ -21,6 +22,7 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware {
OnError: DefaultErrorHandler,
OnLimitReached: DefaultLimitReachedHandler,
KeyGetter: DefaultKeyGetter,
ExcludedKey: nil,
}

for _, option := range options {
Expand All @@ -34,6 +36,11 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware {
func (middleware *Middleware) Handle(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
key := middleware.KeyGetter(ctx)
if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) {
next(ctx)
return
}

context, err := middleware.Limiter.Get(ctx, key)
if err != nil {
middleware.OnError(ctx, err)
Expand Down
44 changes: 44 additions & 0 deletions drivers/middleware/fasthttp/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/ulule/limiter/v3/drivers/store/memory"
)

// nolint: gocyclo
func TestFasthttpMiddleware(t *testing.T) {
is := require.New(t)

Expand Down Expand Up @@ -133,6 +134,49 @@ func TestFasthttpMiddleware(t *testing.T) {
is.NoError(err)
is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10))
}

//
// Test ExcludedKey
//

store = memory.NewStore()
is.NotZero(store)

counter = int64(0)
keyGetterHandler := func(c *libfasthttp.RequestCtx) string {
v := atomic.AddInt64(&counter, 1)
return strconv.FormatInt(v%2, 10)
}
excludedKeyHandler := func(key string) bool {
return key == "1"
}

middleware = fasthttp.NewMiddleware(limiter.New(store, rate),
fasthttp.WithKeyGetter(keyGetterHandler), fasthttp.WithExcludedKey(excludedKeyHandler))
is.NotZero(middleware)

requestHandler = func(ctx *libfasthttp.RequestCtx) {
switch string(ctx.Path()) {
case "/":
ctx.SetStatusCode(libfasthttp.StatusOK)
ctx.SetBodyString("hello")
}
}

success = 20
for i := int64(1); i <= clients; i++ {
resp := libfasthttp.AcquireResponse()
req := libfasthttp.AcquireRequest()
req.Header.SetHost("localhost:8081")
req.Header.SetRequestURI("/")
err := serve(middleware.Handle(requestHandler), req, resp)
is.NoError(err)
if i <= success || i%2 == 1 {
is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10))
} else {
is.Equal(libfasthttp.StatusTooManyRequests, resp.StatusCode(), strconv.FormatInt(i, 10))
}
}
}

func serve(handler libfasthttp.RequestHandler, req *libfasthttp.Request, res *libfasthttp.Response) error {
Expand Down
7 changes: 7 additions & 0 deletions drivers/middleware/fasthttp/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,10 @@ func WithKeyGetter(KeyGetter KeyGetter) Option {
func DefaultKeyGetter(ctx *fasthttp.RequestCtx) string {
return ctx.RemoteIP().String()
}

// WithExcludedKey will configure the Middleware to ignore key(s) using the given function.
func WithExcludedKey(handler func(string) bool) Option {
return option(func(middleware *Middleware) {
middleware.ExcludedKey = handler
})
}
1 change: 1 addition & 0 deletions drivers/middleware/gin/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) gin.HandlerFunc
OnError: DefaultErrorHandler,
OnLimitReached: DefaultLimitReachedHandler,
KeyGetter: DefaultKeyGetter,
ExcludedKey: nil,
}

for _, option := range options {
Expand Down
10 changes: 5 additions & 5 deletions drivers/middleware/gin/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ func DefaultLimitReachedHandler(c *gin.Context) {
type KeyGetter func(c *gin.Context) string

// WithKeyGetter will configure the Middleware to use the given KeyGetter.
func WithKeyGetter(KeyGetter KeyGetter) Option {
func WithKeyGetter(handler KeyGetter) Option {
return option(func(middleware *Middleware) {
middleware.KeyGetter = KeyGetter
middleware.KeyGetter = handler
})
}

Expand All @@ -63,9 +63,9 @@ func DefaultKeyGetter(c *gin.Context) string {
return c.ClientIP()
}

// WithExcludedKey will configure the Middleware to use the given function.
func WithExcludedKey(fn func(string) bool) Option {
// WithExcludedKey will configure the Middleware to ignore key(s) using the given function.
func WithExcludedKey(handler func(string) bool) Option {
return option(func(middleware *Middleware) {
middleware.ExcludedKey = fn
middleware.ExcludedKey = handler
})
}
17 changes: 12 additions & 5 deletions drivers/middleware/stdlib/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (

// Middleware is the middleware for basic http.Handler.
type Middleware struct {
Limiter *limiter.Limiter
OnError ErrorHandler
OnLimitReached LimitReachedHandler
TrustForwardHeader bool
Limiter *limiter.Limiter
OnError ErrorHandler
OnLimitReached LimitReachedHandler
ExcludedKey func(string) bool
}

// NewMiddleware return a new instance of a basic HTTP middleware.
Expand All @@ -21,6 +21,7 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware {
Limiter: limiter,
OnError: DefaultErrorHandler,
OnLimitReached: DefaultLimitReachedHandler,
ExcludedKey: nil,
}

for _, option := range options {
Expand All @@ -33,7 +34,13 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware {
// Handler handles a HTTP request.
func (middleware *Middleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
context, err := middleware.Limiter.Get(r.Context(), middleware.Limiter.GetIPKey(r))
key := middleware.Limiter.GetIPKey(r)
if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) {
h.ServeHTTP(w, r)
return
}

context, err := middleware.Limiter.Get(r.Context(), key)
if err != nil {
middleware.OnError(w, r, err)
return
Expand Down
7 changes: 7 additions & 0 deletions drivers/middleware/stdlib/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,10 @@ func WithLimitReachedHandler(handler LimitReachedHandler) Option {
func DefaultLimitReachedHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Limit exceeded", http.StatusTooManyRequests)
}

// WithExcludedKey will configure the Middleware to ignore key(s) using the given function.
func WithExcludedKey(handler func(string) bool) Option {
return option(func(middleware *Middleware) {
middleware.ExcludedKey = handler
})
}

0 comments on commit cd35126

Please sign in to comment.