diff --git a/README.md b/README.md index cf45648..47e348b 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,8 @@ * [Example](#example) * [With a Redis driver](#with-a-redis-driver) * [With a memory driver](#with-a-memory-driver) +* [Group-Based Rate Limiting (New Feature)](#group-based-rate-limiting-new-feature) + * [Parameters](#parameters) * [Options](#options) * [Responses](#responses) * [License](#license) @@ -118,6 +120,47 @@ app.listen( ``` +## Group-Based Rate Limiting (New Feature) + +You can apply rate limits by group using the new `groupRateLimit` middleware: + +```js +const Koa = require('koa'); +const { groupRateLimit } = require('koa-ratelimit'); + +const app = new Koa(); + +app.use( + groupRateLimit({ + routeMap: { + user: { + duration: 60000, + max: 10, + id: (ctx) => ctx.ip + }, + admin: { + duration: 60000, + max: 100, + id: (ctx) => ctx.ip + } + }, + groupBy: (ctx) => (ctx.path.startsWith('/admin') ? 'admin' : 'user') + }) +); + +app.use((ctx) => { + ctx.body = 'Hello, rate limited world!'; +}); +``` + +### Parameters + +* `routeMap`: Object mapping group names to individual `ratelimit` configurations. +* `groupBy(ctx)`: Function that returns a group name for the current request. + +This enables route-aware rate limiting logic for better control across app segments. + + ## Options * `driver` memory or redis \[redis] diff --git a/group-limiter.js b/group-limiter.js new file mode 100644 index 0000000..4d943f0 --- /dev/null +++ b/group-limiter.js @@ -0,0 +1,48 @@ +const ratelimit = require('./index'); + +/** + * Create group-based rate limiter middleware. + * + * @param {Object} options + * @param {Object} options.routeMap - Map of groupName -> rateLimitOptions + * @param {Function} options.groupBy - Function(ctx) => groupName + */ +module.exports = function (options) { + const { routeMap = {}, groupBy } = options; + + if (typeof groupBy !== 'function') { + throw new TypeError('groupBy must be a function'); + } + + // Assign shared db per group if using memory driver + const sharedDbMap = {}; + + for (const [groupName, config] of Object.entries(routeMap)) { + if (config.driver === 'memory' || !config.driver) { + sharedDbMap[groupName] = config.db || new Map(); + } + } + + return async function (ctx, next) { + try { + const group = groupBy(ctx); + const baseOpts = routeMap[group]; + + if (!baseOpts) { + return await next(); // skip rate limiting if group not found + } + + const groupOpts = { + driver: 'memory', + ...baseOpts, + db: sharedDbMap[group] // reuse persistent in-memory store + }; + + const rateLimitMiddleware = ratelimit(groupOpts); + return await rateLimitMiddleware(ctx, next); + } catch (err) { + console.error('groupLimiter error:', err); + ctx.throw(500, 'Group limiter internal error'); + } + }; +}; diff --git a/index.js b/index.js index 3d11178..0176069 100644 --- a/index.js +++ b/index.js @@ -126,3 +126,5 @@ module.exports = function ratelimit(opts = {}) { } }; }; + +module.exports.groupRateLimit = require('./group-limiter'); diff --git a/test/group-limiter.test.js b/test/group-limiter.test.js new file mode 100644 index 0000000..bf88641 --- /dev/null +++ b/test/group-limiter.test.js @@ -0,0 +1,55 @@ +const Koa = require('koa'); +const request = require('supertest'); +const groupRateLimit = require('../group-limiter'); + +// const sleep = (ms) => new Promise((res) => setTimeout(res, ms)); + +describe('groupRateLimit middleware', () => { + let app; + + beforeEach(() => { + app = new Koa(); + }); + + it('applies rate limit per group', async () => { + const middleware = groupRateLimit({ + routeMap: { + groupA: { duration: 1000, max: 1 }, + groupB: { duration: 1000, max: 2 } + }, + groupBy: (ctx) => (ctx.path === '/a' ? 'groupA' : 'groupB') + }); + + app.use(middleware); + app.use((ctx) => { + ctx.body = 'ok'; + }); + + const server = app.callback(); + + // Group A - should block 2nd request + await request(server).get('/a').expect(200); + await request(server).get('/a').expect(429); + + // Group B - should allow 2 requests + await request(server).get('/b').expect(200); + await request(server).get('/b').expect(200); + await request(server).get('/b').expect(429); + }); + + it('calls next() when no group matches', async () => { + const middleware = groupRateLimit({ + routeMap: { + groupA: { duration: 1000, max: 1 } + }, + groupBy: () => 'nonexistent' + }); + + app.use(middleware); + app.use((ctx) => { + ctx.body = 'ok'; + }); + + await request(app.callback()).get('/x').expect(200); + }); +});