Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/app.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { Logger, Module } from '@nestjs/common';
import { ConfigModule, ConfigService } from '@nestjs/config';
import { SequelizeModule } from '@nestjs/sequelize';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { seconds, ThrottlerModule, ThrottlerGuard } from '@nestjs/throttler';
import { LoggerModule } from 'nestjs-pino';
import { FileModule } from './modules/file/file.module';
import { TrashModule } from './modules/trash/trash.module';
Expand Down Expand Up @@ -33,6 +32,8 @@ import { getClientIdFromHeaders } from './common/decorators/client.decorator';
import { CustomThrottlerGuard } from './guards/throttler.guard';
import { AuthGuard } from './modules/auth/auth.guard';
import { CustomThrottlerModule } from './guards/throttler.module';
import { CustomEndpointThrottleGuard } from './guards/custom-endpoint-throttle.guard';
import { CacheManagerModule } from './modules/cache-manager/cache-manager.module';

@Module({
imports: [
Expand Down Expand Up @@ -147,6 +148,7 @@ import { CustomThrottlerModule } from './guards/throttler.module';
PlanModule,
WorkspacesModule,
GatewayModule,
CacheManagerModule,
],
controllers: [],
providers: [
Expand All @@ -156,11 +158,15 @@ import { CustomThrottlerModule } from './guards/throttler.module';
},
{
provide: APP_GUARD,
useClass: AuthGuard,
useClass: CustomThrottlerGuard,
},
{
provide: APP_GUARD,
useClass: CustomThrottlerGuard,
useClass: CustomEndpointThrottleGuard
},
{
provide: APP_GUARD,
useClass: AuthGuard,
},
],
})
Expand Down
63 changes: 63 additions & 0 deletions src/guards/custom-endpoint-throttle.guard.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import { CustomEndpointThrottleGuard } from './custom-endpoint-throttle.guard';
import { CacheManagerService } from '../modules/cache-manager/cache-manager.service';
import { ThrottlerException } from '@nestjs/throttler';
import jwt from 'jsonwebtoken';

describe('CustomThrottleGuard', () => {
let guard: CustomEndpointThrottleGuard;
Expand Down Expand Up @@ -38,6 +39,7 @@
const request: any = {
route: { path: route },
user: { uuid: 'user-1' },
headers: {},
ip: '1.2.3.4',
};
(cacheService.increment as jest.Mock).mockResolvedValue({
Expand All @@ -63,6 +65,7 @@
const request: any = {
route: { path: route },
user: { uuid: 'user-2' },
headers: {},
ip: '2.2.2.2',
};
(cacheService.increment as jest.Mock).mockResolvedValue({
Expand All @@ -82,6 +85,64 @@
});
});

describe('When request.user is undefined but authorization header is present', () => {
const route = '/some-route';

it('When under limit then it throttles using the decoded token identity', async () => {
const policy = { ttl: 60, limit: 5 };
(reflector.get as jest.Mock).mockReturnValue(policy);

const token = jwt.sign({ uuid: 'token-user-1' }, 'secret');

const request: any = {
route: { path: route },
headers: { authorization: `Bearer ${token}` },
ip: '5.5.5.5',
};
(cacheService.increment as jest.Mock).mockResolvedValue({
totalHits: 1,
timeToExpire: 5000,
});
const context = tsjest.createMock<ExecutionContext>();
(context as any).switchToHttp = () => ({ getRequest: () => request });

const result = await guard.canActivate(context);

expect(result).toBe(true);
expect(cacheService.increment).toHaveBeenCalledWith(
`${route}:policy0:cet:uid:token-user-1`,
60,
);
});

it('When over the limit then the request is throttled', async () => {
const policy = { ttl: 60, limit: 1 };
(reflector.get as jest.Mock).mockReturnValue(policy);

const token = jwt.sign({ uuid: 'token-user-2' }, 'secret');

const request: any = {
route: { path: route },
headers: { authorization: `Bearer ${token}` },
ip: '5.5.5.5',
};
(cacheService.increment as jest.Mock).mockResolvedValue({
totalHits: 2,
timeToExpire: 1000,
});
const context = tsjest.createMock<ExecutionContext>();
(context as any).switchToHttp = () => ({ getRequest: () => request });

await expect(guard.canActivate(context)).rejects.toBeInstanceOf(
ThrottlerException,
);
expect(cacheService.increment).toHaveBeenCalledWith(
`${route}:policy0:cet:uid:token-user-2`,
60,
);
});
});

describe('Applying multiple policies', () => {
const route = '/login';

Expand All @@ -94,6 +155,7 @@
const request: any = {
route: { path: route },
user: null,
headers: {},
ip: '9.9.9.9',
};

Expand Down Expand Up @@ -132,6 +194,7 @@
const request: any = {
route: { path: route },
user: null,
headers: {},
ip: '11.11.11.11',
};

Expand Down
28 changes: 18 additions & 10 deletions src/guards/custom-endpoint-throttle.guard.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import {
CanActivate,
ExecutionContext,
Injectable,
Inject,
} from '@nestjs/common';
import { CanActivate, ExecutionContext, Injectable } from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import { ThrottlerException } from '@nestjs/throttler';
import jwt from 'jsonwebtoken';
import { CacheManagerService } from '../modules/cache-manager/cache-manager.service';
import {
CUSTOM_ENDPOINT_THROTTLE_KEY,
Expand All @@ -19,6 +15,19 @@ export class CustomEndpointThrottleGuard implements CanActivate {
private readonly cacheService: CacheManagerService,
) {}

private decodeAuthIfPresent(request: any): string | null {
if (!request.headers.authorization) {
return null;
}
try {
const token = request.headers.authorization.split(' ')[1];
const decoded: any = jwt.decode(token);
return decoded?.uuid || decoded?.payload?.uuid;
} catch {
return null;
}
}

async canActivate(context: ExecutionContext): Promise<boolean> {
const raw = this.reflector.get<any>(
CUSTOM_ENDPOINT_THROTTLE_KEY,
Expand Down Expand Up @@ -49,7 +58,6 @@ export class CustomEndpointThrottleGuard implements CanActivate {
}

const request = context.switchToHttp().getRequest();
const user = request.user;

let ip = request.headers['cf-connecting-ip'];
if (Array.isArray(ip)) {
Expand All @@ -59,9 +67,9 @@ export class CustomEndpointThrottleGuard implements CanActivate {
ip = request.ip;
}

const identifierBase = user?.uuid
? `cet:uid:${user.uuid}`
: `cet:ip:${ip}`;
const userId = request.user?.uuid || this.decodeAuthIfPresent(request);

const identifierBase = userId ? `cet:uid:${userId}` : `cet:ip:${ip}`;
const route = request.route?.path ?? request.originalUrl ?? 'unknown';

// Apply all policies. If any policy is violated, throw.
Expand Down
21 changes: 20 additions & 1 deletion src/guards/throttler.guard.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { Injectable } from '@nestjs/common';
import { ThrottlerGuard as BaseThrottlerGuard } from '@nestjs/throttler';
import jwt from 'jsonwebtoken'

@Injectable()
export class ThrottlerGuard extends BaseThrottlerGuard {
protected async getTracker(req: Record<string, any>): Promise<string> {
Expand All @@ -15,8 +17,25 @@ export class ThrottlerGuard extends BaseThrottlerGuard {

@Injectable()
export class CustomThrottlerGuard extends ThrottlerGuard {
private decodeAuthIfPresent(request: any): string | null {
if (!request.headers.authorization) {
return null;
}
try {
const token = request.headers.authorization.split(' ')[1];
const decoded: any = jwt.decode(token);
return decoded?.uuid || decoded?.payload?.uuid;
} catch {
return null;
}
}

protected async getTracker(req: any): Promise<string> {
const userId = req.user?.uuid;
let userId = req.user?.uuid;

if (!userId) {
userId = this.decodeAuthIfPresent(req);
}

if (userId) {
return `rl:${userId}`;
Expand Down
Loading