Skip to content

Commit f7ddd94

Browse files
authored
Merge pull request #614 from TaloDev/fix-rate-limiting
Use rate-limiter-flexible for rate-limiting
2 parents cde88ef + ffdf915 commit f7ddd94

File tree

4 files changed

+43
-38
lines changed

4 files changed

+43
-38
lines changed

package-lock.json

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
"otplib": "^12.0.1",
9494
"qrcode": "^1.5.0",
9595
"qs": "^6.11.0",
96+
"rate-limiter-flexible": "^7.3.0",
9697
"stripe": "^18.0.0",
9798
"uuid": "^9.0.0",
9899
"ws": "^8.18.0",
Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,38 @@
11
import { Redis } from 'ioredis'
2+
import { RateLimiterRedis } from 'rate-limiter-flexible'
23

3-
const cache = new Map<string, { count: number, expires: number }>()
4+
const rateLimiters = new Map<string, RateLimiterRedis>()
45

5-
setInterval(() => {
6-
const now = Date.now()
7-
for (const [key, value] of cache.entries()) {
8-
if (now > value.expires) {
9-
cache.delete(key)
10-
}
6+
function getRateLimiter(redis: Redis, maxRequests: number, duration = 60): RateLimiterRedis {
7+
const limiterKey = `${maxRequests}_${duration}`
8+
if (!rateLimiters.has(limiterKey)) {
9+
rateLimiters.set(limiterKey, new RateLimiterRedis({
10+
storeClient: redis,
11+
keyPrefix: `rl_${maxRequests}_${duration}`,
12+
points: maxRequests,
13+
duration: duration,
14+
blockDuration: duration
15+
}))
1116
}
12-
}, 5000)
13-
14-
const script = `
15-
local current = redis.call('INCR', KEYS[1])
16-
if current == 1 then
17-
redis.call('EXPIRE', KEYS[1], ARGV[1])
18-
end
19-
return current
20-
`
17+
return rateLimiters.get(limiterKey)!
18+
}
2119

2220
export default async function checkRateLimitExceeded(
2321
redis: Redis,
2422
key: string,
25-
maxRequests: number
23+
maxRequests: number,
24+
duration = 60
2625
): Promise<boolean> {
27-
// Skip cache in test environment for predictable behavior
28-
if (process.env.NODE_ENV !== 'test') {
29-
const cached = cache.get(key)
30-
if (cached && Date.now() < cached.expires) {
31-
return cached.count > maxRequests
32-
}
33-
}
34-
35-
const current = await redis.eval(script, 1, key, 1) as number
26+
const rateLimiter = getRateLimiter(redis, maxRequests, duration)
3627

37-
// Only cache in production
38-
if (process.env.NODE_ENV !== 'test') {
39-
cache.set(key, {
40-
count: current,
41-
expires: Date.now() + 500
42-
})
28+
try {
29+
await rateLimiter.consume(key)
30+
return false
31+
} catch (err) {
32+
if (err && typeof err === 'object' && 'remainingPoints' in err) {
33+
return true
34+
}
35+
// re-throw actual errors
36+
throw err
4337
}
44-
45-
return current > maxRequests
4638
}

tests/socket/rateLimiting.test.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,15 @@ import { APIKeyScope } from '../../src/entities/api-key'
22
import createSocketIdentifyMessage from '../utils/createSocketIdentifyMessage'
33
import GameChannelFactory from '../fixtures/GameChannelFactory'
44
import createTestSocket from '../utils/createTestSocket'
5+
import * as checkRateLimitExceeded from '../../src/lib/errors/checkRateLimitExceeded'
56

67
describe('Socket rate limiting', () => {
8+
const checkRateLimitExceededMock = vi.spyOn(checkRateLimitExceeded, 'default')
9+
10+
afterEach(() => {
11+
checkRateLimitExceededMock.mockClear()
12+
})
13+
714
it('should return a rate limiting error', async () => {
815
const { identifyMessage, ticket, player } = await createSocketIdentifyMessage([
916
APIKeyScope.READ_PLAYERS,
@@ -14,11 +21,10 @@ describe('Socket rate limiting', () => {
1421
channel.members.add(player.aliases[0])
1522
await em.persistAndFlush(channel)
1623

17-
await createTestSocket(`/?ticket=${ticket}`, async (client, socket) => {
24+
await createTestSocket(`/?ticket=${ticket}`, async (client) => {
1825
await client.identify(identifyMessage)
1926

20-
const conn = socket.findConnections((conn) => conn.playerAliasId === player.aliases[0].id)[0]
21-
await redis.set(conn.rateLimitKey, 999)
27+
checkRateLimitExceededMock.mockResolvedValueOnce(true)
2228

2329
client.sendJson({
2430
req: 'v1.channels.message',
@@ -56,7 +62,7 @@ describe('Socket rate limiting', () => {
5662
const conn = socket.findConnections((conn) => conn.playerAliasId === player.aliases[0].id)[0]
5763
conn.rateLimitWarnings = 3
5864

59-
await redis.set(conn.rateLimitKey, 999)
65+
checkRateLimitExceededMock.mockResolvedValueOnce(true)
6066

6167
client.sendJson({
6268
req: 'v1.channels.message',

0 commit comments

Comments
 (0)