diff --git a/README.md b/README.md index 2963044..1d8974f 100644 --- a/README.md +++ b/README.md @@ -78,36 +78,64 @@ r.Use(httprate.Limit( )) ``` -### Send specific response for rate limited requests +### Rate limit by request payload +```go +// Rate-limiter for login endpoint. +loginRateLimiter := httprate.NewRateLimiter(5, time.Minute) + +r.Post("/login", func(w http.ResponseWriter, r *http.Request) { + var payload struct { + Username string `json:"username"` + Password string `json:"password"` + } + err := json.NewDecoder(r.Body).Decode(&payload) + if err != nil || payload.Username == "" || payload.Password == "" { + w.WriteHeader(400) + return + } + + // Rate-limit login at 5 req/min. + if loginRateLimiter.OnLimit(w, r, payload.Username) { + return + } + + w.Write([]byte("login at 5 req/min\n")) +}) +``` + +### Send specific response for rate-limited requests + +The default response is `HTTP 429` with `Too Many Requests` body. You can override it with: ```go r.Use(httprate.Limit( 10, time.Minute, httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, `{"error": "Rate limited. Please slow down."}`, http.StatusTooManyRequests) + http.Error(w, `{"error": "Rate-limited. Please, slow down."}`, http.StatusTooManyRequests) }), )) ``` -### Send specific response for backend errors +### Send specific response on errors + +An error can be returned by: +- A custom key function provided by `httprate.WithKeyFunc(customKeyFn)` +- A custom backend provided by `httprateredis.WithRedisLimitCounter(customBackend)` + - The default local in-memory counter is guaranteed not return any errors + - Backends that fall-back to the local in-memory counter (e.g. [httprate-redis](https://github.com/go-chi/httprate-redis)) can choose not to return any errors either ```go r.Use(httprate.Limit( 10, time.Minute, httprate.WithErrorHandler(func(w http.ResponseWriter, r *http.Request, err error) { - // NOTE: The local in-memory counter is guaranteed not return any errors. - // Other backends may return errors, depending on whether they have - // in-memory fallback mechanism implemented in case of network errors. - http.Error(w, fmt.Sprintf(`{"error": %q}`, err), http.StatusPreconditionRequired) }), httprate.WithLimitCounter(customBackend), )) ``` - ### Send custom response headers ```go diff --git a/_example/main.go b/_example/main.go index 70ebb8c..cf69e0a 100644 --- a/_example/main.go +++ b/_example/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/json" "log" "net/http" "time" @@ -15,52 +16,59 @@ func main() { r := chi.NewRouter() r.Use(middleware.Logger) + // Rate-limit all routes at 1000 req/min by IP address. + r.Use(httprate.LimitByIP(1000, time.Minute)) + r.Route("/admin", func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Note: this is a mock middleware to set a userID on the request context + // Note: This is a mock middleware to set a userID on the request context next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "userID", "123"))) }) }) - // Here we set a specific rate limit by ip address and userID + // Rate-limit admin routes at 10 req/s by userID. r.Use(httprate.Limit( - 10, - time.Minute, - httprate.WithKeyFuncs(httprate.KeyByIP, func(r *http.Request) (string, error) { - token := r.Context().Value("userID").(string) + 10, time.Second, + httprate.WithKeyFuncs(func(r *http.Request) (string, error) { + token, _ := r.Context().Value("userID").(string) return token, nil }), - httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { - // We can send custom responses for the rate limited requests, e.g. a JSON message - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) - w.Write([]byte(`{"error": "Too many requests"}`)) - }), )) r.Get("/", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("10 req/min\n")) + w.Write([]byte("admin at 10 req/s\n")) }) }) - r.Group(func(r chi.Router) { - // Here we set another rate limit (3 req/min) for a group of handlers. - // - // Note: in practice you don't need to have so many layered rate-limiters, - // but the example here is to illustrate how to control the machinery. - r.Use(httprate.LimitByIP(3, time.Minute)) + // Rate-limiter for login endpoint. + loginRateLimiter := httprate.NewRateLimiter(5, time.Minute) - r.Get("/", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("3 req/min\n")) - }) + r.Post("/login", func(w http.ResponseWriter, r *http.Request) { + var payload struct { + Username string `json:"username"` + Password string `json:"password"` + } + err := json.NewDecoder(r.Body).Decode(&payload) + if err != nil || payload.Username == "" || payload.Password == "" { + w.WriteHeader(400) + return + } + + // Rate-limit login at 5 req/min. + if loginRateLimiter.OnLimit(w, r, payload.Username) { + return + } + + w.Write([]byte("login at 5 req/min\n")) }) log.Printf("Serving at localhost:3333") log.Println() log.Printf("Try running:") - log.Printf("curl -v http://localhost:3333") - log.Printf("curl -v http://localhost:3333/admin") + log.Printf(`curl -v http://localhost:3333?[0-1000]`) + log.Printf(`curl -v http://localhost:3333/admin?[1-12]`) + log.Printf(`curl -v http://localhost:3333/login\?[1-8] --data '{"username":"alice","password":"***"}'`) http.ListenAndServe(":3333", r) } diff --git a/limiter.go b/limiter.go index c97f086..bf4023f 100644 --- a/limiter.go +++ b/limiter.go @@ -66,6 +66,56 @@ type rateLimiter struct { mu sync.Mutex } +// OnLimit checks the rate limit for the given key. If the limit is reached, it returns true +// and automatically sends HTTP response. The caller should halt further request processing. +// If the limit is not reached, it increments the request count and returns false, allowing +// the request to proceed. +func (l *rateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool { + currentWindow := time.Now().UTC().Truncate(l.windowLength) + ctx := r.Context() + + limit := l.requestLimit + if val := getRequestLimit(ctx); val > 0 { + limit = val + } + setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit)) + setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix())) + + l.mu.Lock() + _, rateFloat, err := l.calculateRate(key, limit) + if err != nil { + l.mu.Unlock() + l.onError(w, r, err) + return true + } + rate := int(math.Round(rateFloat)) + + increment := getIncrement(r.Context()) + if increment > 1 { + setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment)) + } + + if rate+increment > limit { + setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate)) + + l.mu.Unlock() + setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585 + l.onRateLimited(w, r) + return true + } + + err = l.limitCounter.IncrementBy(key, currentWindow, increment) + if err != nil { + l.mu.Unlock() + l.onError(w, r, err) + return true + } + l.mu.Unlock() + + setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate-increment)) + return false +} + func (l *rateLimiter) Counter() LimitCounter { return l.limitCounter } @@ -82,49 +132,10 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler { return } - currentWindow := time.Now().UTC().Truncate(l.windowLength) - ctx := r.Context() - - limit := l.requestLimit - if val := getRequestLimit(ctx); val > 0 { - limit = val - } - setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit)) - setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix())) - - l.mu.Lock() - _, rateFloat, err := l.calculateRate(key, limit) - if err != nil { - l.mu.Unlock() - l.onError(w, r, err) - return - } - rate := int(math.Round(rateFloat)) - - increment := getIncrement(r.Context()) - if increment > 1 { - setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment)) - } - - if rate+increment > limit { - setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate)) - - l.mu.Unlock() - setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585 - l.onRateLimited(w, r) + if l.OnLimit(w, r, key) { return } - err = l.limitCounter.IncrementBy(key, currentWindow, increment) - if err != nil { - l.mu.Unlock() - l.onError(w, r, err) - return - } - l.mu.Unlock() - - setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate-increment)) - next.ServeHTTP(w, r) }) } diff --git a/limiter_test.go b/limiter_test.go index bcbb938..689074a 100644 --- a/limiter_test.go +++ b/limiter_test.go @@ -3,6 +3,7 @@ package httprate_test import ( "bytes" "context" + "encoding/json" "io" "net/http" "net/http/httptest" @@ -437,3 +438,59 @@ func TestOverrideRequestLimit(t *testing.T) { } } } + +func TestRateLimitPayload(t *testing.T) { + loginRateLimiter := httprate.NewRateLimiter(5, time.Minute) + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var payload struct { + Username string `json:"username"` + Password string `json:"password"` + } + err := json.NewDecoder(r.Body).Decode(&payload) + if err != nil || payload.Username == "" || payload.Password == "" { + w.WriteHeader(400) + return + } + + // Rate-limit login at 5 req/min. + if loginRateLimiter.OnLimit(w, r, payload.Username) { + return + } + + w.Write([]byte("login at 5 req/min\n")) + }) + + responses := []struct { + StatusCode int + Body string + }{ + {StatusCode: 200, Body: "login at 5 req/min"}, + {StatusCode: 200, Body: "login at 5 req/min"}, + {StatusCode: 200, Body: "login at 5 req/min"}, + {StatusCode: 200, Body: "login at 5 req/min"}, + {StatusCode: 200, Body: "login at 5 req/min"}, + {StatusCode: 429, Body: "Too Many Requests"}, + {StatusCode: 429, Body: "Too Many Requests"}, + {StatusCode: 429, Body: "Too Many Requests"}, + } + for i, response := range responses { + req, err := http.NewRequest("GET", "/", strings.NewReader(`{"username":"alice","password":"***"}`)) + if err != nil { + t.Errorf("failed = %v", err) + } + + recorder := httptest.NewRecorder() + h.ServeHTTP(recorder, req) + result := recorder.Result() + if respStatus := result.StatusCode; respStatus != response.StatusCode { + t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respStatus, response.StatusCode) + } + body, _ := io.ReadAll(result.Body) + respBody := strings.TrimSuffix(string(body), "\n") + + if string(respBody) != response.Body { + t.Errorf("resp.Body(%v) = %q, want %q", i, respBody, response.Body) + } + } +}