diff --git a/drivers/middleware/fasthttp/middleware_test.go b/drivers/middleware/fasthttp/middleware_test.go index ab821f8..5172dbc 100644 --- a/drivers/middleware/fasthttp/middleware_test.go +++ b/drivers/middleware/fasthttp/middleware_test.go @@ -133,6 +133,49 @@ func TestFasthttpMiddleware(t *testing.T) { is.NoError(err) is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10)) } + + // + // Test ExcludedKey + // + + store = memory.NewStore() + is.NotZero(store) + + counter = int64(0) + keyGetterHandler := func(c *libfasthttp.RequestCtx) string { + v := atomic.AddInt64(&counter, 1) + return strconv.FormatInt(v%2, 10) + } + excludedKeyHandler := func(key string) bool { + return key == "1" + } + + middleware = fasthttp.NewMiddleware(limiter.New(store, rate), + fasthttp.WithKeyGetter(keyGetterHandler), fasthttp.WithExcludedKey(excludedKeyHandler)) + is.NotZero(middleware) + + requestHandler = func(ctx *libfasthttp.RequestCtx) { + switch string(ctx.Path()) { + case "/": + ctx.SetStatusCode(libfasthttp.StatusOK) + ctx.SetBodyString("hello") + } + } + + success = 20 + for i := int64(1); i <= clients; i++ { + resp := libfasthttp.AcquireResponse() + req := libfasthttp.AcquireRequest() + req.Header.SetHost("localhost:8081") + req.Header.SetRequestURI("/") + err := serve(middleware.Handle(requestHandler), req, resp) + is.NoError(err) + if i <= success || i%2 == 1 { + is.Equal(libfasthttp.StatusOK, resp.StatusCode(), strconv.FormatInt(i, 10)) + } else { + is.Equal(libfasthttp.StatusTooManyRequests, resp.StatusCode(), strconv.FormatInt(i, 10)) + } + } } func serve(handler libfasthttp.RequestHandler, req *libfasthttp.Request, res *libfasthttp.Response) error {