diff --git a/src/app.module.ts b/src/app.module.ts index 416a076a2..88d7cb0e3 100644 --- a/src/app.module.ts +++ b/src/app.module.ts @@ -2,7 +2,6 @@ import { Logger, Module } from '@nestjs/common'; import { ConfigModule, ConfigService } from '@nestjs/config'; import { SequelizeModule } from '@nestjs/sequelize'; import { EventEmitterModule } from '@nestjs/event-emitter'; -import { seconds, ThrottlerModule, ThrottlerGuard } from '@nestjs/throttler'; import { LoggerModule } from 'nestjs-pino'; import { FileModule } from './modules/file/file.module'; import { TrashModule } from './modules/trash/trash.module'; @@ -33,6 +32,8 @@ import { getClientIdFromHeaders } from './common/decorators/client.decorator'; import { CustomThrottlerGuard } from './guards/throttler.guard'; import { AuthGuard } from './modules/auth/auth.guard'; import { CustomThrottlerModule } from './guards/throttler.module'; +import { CustomEndpointThrottleGuard } from './guards/custom-endpoint-throttle.guard'; +import { CacheManagerModule } from './modules/cache-manager/cache-manager.module'; @Module({ imports: [ @@ -147,6 +148,7 @@ import { CustomThrottlerModule } from './guards/throttler.module'; PlanModule, WorkspacesModule, GatewayModule, + CacheManagerModule, ], controllers: [], providers: [ @@ -156,11 +158,15 @@ import { CustomThrottlerModule } from './guards/throttler.module'; }, { provide: APP_GUARD, - useClass: AuthGuard, + useClass: CustomThrottlerGuard, }, { provide: APP_GUARD, - useClass: CustomThrottlerGuard, + useClass: CustomEndpointThrottleGuard + }, + { + provide: APP_GUARD, + useClass: AuthGuard, }, ], }) diff --git a/src/guards/custom-endpoint-throttle.guard.spec.ts b/src/guards/custom-endpoint-throttle.guard.spec.ts index e1d0f6b15..955f22d5e 100644 --- a/src/guards/custom-endpoint-throttle.guard.spec.ts +++ b/src/guards/custom-endpoint-throttle.guard.spec.ts @@ -4,6 +4,7 @@ import { Reflector } from '@nestjs/core'; import { CustomEndpointThrottleGuard } from './custom-endpoint-throttle.guard'; import { CacheManagerService } from '../modules/cache-manager/cache-manager.service'; import { ThrottlerException } from '@nestjs/throttler'; +import jwt from 'jsonwebtoken'; describe('CustomThrottleGuard', () => { let guard: CustomEndpointThrottleGuard; @@ -38,6 +39,7 @@ describe('CustomThrottleGuard', () => { const request: any = { route: { path: route }, user: { uuid: 'user-1' }, + headers: {}, ip: '1.2.3.4', }; (cacheService.increment as jest.Mock).mockResolvedValue({ @@ -63,6 +65,7 @@ describe('CustomThrottleGuard', () => { const request: any = { route: { path: route }, user: { uuid: 'user-2' }, + headers: {}, ip: '2.2.2.2', }; (cacheService.increment as jest.Mock).mockResolvedValue({ @@ -82,6 +85,64 @@ describe('CustomThrottleGuard', () => { }); }); + describe('When request.user is undefined but authorization header is present', () => { + const route = '/some-route'; + + it('When under limit then it throttles using the decoded token identity', async () => { + const policy = { ttl: 60, limit: 5 }; + (reflector.get as jest.Mock).mockReturnValue(policy); + + const token = jwt.sign({ uuid: 'token-user-1' }, 'secret'); + + const request: any = { + route: { path: route }, + headers: { authorization: `Bearer ${token}` }, + ip: '5.5.5.5', + }; + (cacheService.increment as jest.Mock).mockResolvedValue({ + totalHits: 1, + timeToExpire: 5000, + }); + const context = tsjest.createMock(); + (context as any).switchToHttp = () => ({ getRequest: () => request }); + + const result = await guard.canActivate(context); + + expect(result).toBe(true); + expect(cacheService.increment).toHaveBeenCalledWith( + `${route}:policy0:cet:uid:token-user-1`, + 60, + ); + }); + + it('When over the limit then the request is throttled', async () => { + const policy = { ttl: 60, limit: 1 }; + (reflector.get as jest.Mock).mockReturnValue(policy); + + const token = jwt.sign({ uuid: 'token-user-2' }, 'secret'); + + const request: any = { + route: { path: route }, + headers: { authorization: `Bearer ${token}` }, + ip: '5.5.5.5', + }; + (cacheService.increment as jest.Mock).mockResolvedValue({ + totalHits: 2, + timeToExpire: 1000, + }); + const context = tsjest.createMock(); + (context as any).switchToHttp = () => ({ getRequest: () => request }); + + await expect(guard.canActivate(context)).rejects.toBeInstanceOf( + ThrottlerException, + ); + expect(cacheService.increment).toHaveBeenCalledWith( + `${route}:policy0:cet:uid:token-user-2`, + 60, + ); + }); + }); + describe('Applying multiple policies', () => { const route = '/login'; @@ -94,6 +155,7 @@ describe('CustomThrottleGuard', () => { const request: any = { route: { path: route }, user: null, + headers: {}, ip: '9.9.9.9', }; @@ -132,6 +194,7 @@ describe('CustomThrottleGuard', () => { const request: any = { route: { path: route }, user: null, + headers: {}, ip: '11.11.11.11', }; diff --git a/src/guards/custom-endpoint-throttle.guard.ts b/src/guards/custom-endpoint-throttle.guard.ts index 5e21b010c..eeefc1f9c 100644 --- a/src/guards/custom-endpoint-throttle.guard.ts +++ b/src/guards/custom-endpoint-throttle.guard.ts @@ -1,11 +1,7 @@ -import { - CanActivate, - ExecutionContext, - Injectable, - Inject, -} from '@nestjs/common'; +import { CanActivate, ExecutionContext, Injectable } from '@nestjs/common'; import { Reflector } from '@nestjs/core'; import { ThrottlerException } from '@nestjs/throttler'; +import jwt from 'jsonwebtoken'; import { CacheManagerService } from '../modules/cache-manager/cache-manager.service'; import { CUSTOM_ENDPOINT_THROTTLE_KEY, @@ -19,6 +15,19 @@ export class CustomEndpointThrottleGuard implements CanActivate { private readonly cacheService: CacheManagerService, ) {} + private decodeAuthIfPresent(request: any): string | null { + if (!request.headers.authorization) { + return null; + } + try { + const token = request.headers.authorization.split(' ')[1]; + const decoded: any = jwt.decode(token); + return decoded?.uuid || decoded?.payload?.uuid; + } catch { + return null; + } + } + async canActivate(context: ExecutionContext): Promise { const raw = this.reflector.get( CUSTOM_ENDPOINT_THROTTLE_KEY, @@ -49,7 +58,6 @@ export class CustomEndpointThrottleGuard implements CanActivate { } const request = context.switchToHttp().getRequest(); - const user = request.user; let ip = request.headers['cf-connecting-ip']; if (Array.isArray(ip)) { @@ -59,9 +67,9 @@ export class CustomEndpointThrottleGuard implements CanActivate { ip = request.ip; } - const identifierBase = user?.uuid - ? `cet:uid:${user.uuid}` - : `cet:ip:${ip}`; + const userId = request.user?.uuid || this.decodeAuthIfPresent(request); + + const identifierBase = userId ? `cet:uid:${userId}` : `cet:ip:${ip}`; const route = request.route?.path ?? request.originalUrl ?? 'unknown'; // Apply all policies. If any policy is violated, throw. diff --git a/src/guards/throttler.guard.ts b/src/guards/throttler.guard.ts index edee54d6f..8b8b8c00a 100644 --- a/src/guards/throttler.guard.ts +++ b/src/guards/throttler.guard.ts @@ -1,5 +1,7 @@ import { Injectable } from '@nestjs/common'; import { ThrottlerGuard as BaseThrottlerGuard } from '@nestjs/throttler'; +import jwt from 'jsonwebtoken' + @Injectable() export class ThrottlerGuard extends BaseThrottlerGuard { protected async getTracker(req: Record): Promise { @@ -15,8 +17,25 @@ export class ThrottlerGuard extends BaseThrottlerGuard { @Injectable() export class CustomThrottlerGuard extends ThrottlerGuard { + private decodeAuthIfPresent(request: any): string | null { + if (!request.headers.authorization) { + return null; + } + try { + const token = request.headers.authorization.split(' ')[1]; + const decoded: any = jwt.decode(token); + return decoded?.uuid || decoded?.payload?.uuid; + } catch { + return null; + } + } + protected async getTracker(req: any): Promise { - const userId = req.user?.uuid; + let userId = req.user?.uuid; + + if (!userId) { + userId = this.decodeAuthIfPresent(req); + } if (userId) { return `rl:${userId}`;