From 4ebff3597ee61c8d91bd03530e3c14868ec37658 Mon Sep 17 00:00:00 2001 From: feyishola Date: Wed, 30 Apr 2025 14:08:12 +0100 Subject: [PATCH] implemented the abuse prevention middleware and test files --- apps/backend/src/app.module.ts | 2 + apps/backend/src/config/throttler.config.ts | 19 ++ .../entities/rate-limit-violation.entity.ts | 25 +++ .../middleware/rate-limit.middleware.spec.ts | 169 ++++++++++++++++++ .../src/middleware/rate-limit.middleware.ts | 146 +++++++++++++-- apps/backend/test/rate-limit.e2e-spec.ts | 140 +++++++++++++++ 6 files changed, 488 insertions(+), 13 deletions(-) create mode 100644 apps/backend/src/config/throttler.config.ts create mode 100644 apps/backend/src/entities/rate-limit-violation.entity.ts create mode 100644 apps/backend/src/middleware/rate-limit.middleware.spec.ts create mode 100644 apps/backend/test/rate-limit.e2e-spec.ts diff --git a/apps/backend/src/app.module.ts b/apps/backend/src/app.module.ts index 1760162..6ffc88b 100644 --- a/apps/backend/src/app.module.ts +++ b/apps/backend/src/app.module.ts @@ -18,6 +18,8 @@ import { RateLimitMiddleware } from './middleware/rate-limit.middleware'; import { MailModule } from './mail/mail.module'; import { NewsModule } from './news/news.module'; import { TasksModule } from './tasks/tasks.module'; +import { RateLimitViolation } from './entities/rate-limit-violation.entity'; +import { getThrottlerConfig } from './config/throttler.config'; const ENV = process.env.NODE_ENV || 'development'; console.log('Current environment:', ENV); diff --git a/apps/backend/src/config/throttler.config.ts b/apps/backend/src/config/throttler.config.ts new file mode 100644 index 0000000..0ccb5b5 --- /dev/null +++ b/apps/backend/src/config/throttler.config.ts @@ -0,0 +1,19 @@ +import { ThrottlerModule } from '@nestjs/throttler'; +import { ThrottlerStorageRedisService } from 'nestjs-throttler-storage-redis'; +import { ConfigService } from '@nestjs/config'; + +export const getThrottlerConfig = (configService: ConfigService) => { + return ThrottlerModule.forRootAsync({ + inject: [ConfigService], + useFactory: (config: ConfigService) => ({ + ttl: config.get('THROTTLE_TTL', 60), // Default 60 seconds window + limit: config.get('THROTTLE_LIMIT', 10), // Default 10 requests per window + storage: new ThrottlerStorageRedisService({ + host: config.get('REDIS_HOST', 'localhost'), + port: config.get('REDIS_PORT', 6379), + password: config.get('REDIS_PASSWORD', ''), + keyPrefix: 'throttle:', + }), + }), + }); +}; \ No newline at end of file diff --git a/apps/backend/src/entities/rate-limit-violation.entity.ts b/apps/backend/src/entities/rate-limit-violation.entity.ts new file mode 100644 index 0000000..324c5f8 --- /dev/null +++ b/apps/backend/src/entities/rate-limit-violation.entity.ts @@ -0,0 +1,25 @@ +import { Entity, Column, PrimaryColumn } from 'typeorm'; + +@Entity('rate_limit_violations') +export class RateLimitViolation { + @PrimaryColumn('uuid') + id: string; + + @Column({ name: 'ip_address', nullable: false }) + ipAddress: string; + + @Column({ name: 'wallet_address', nullable: true }) + walletAddress: string | null; + + @Column({ nullable: false }) + endpoint: string; + + @Column({ nullable: false }) + method: string; + + @Column({ type: 'timestamp', nullable: false }) + timestamp: Date; + + @Column({ name: 'violated_rule', type: 'json', nullable: false }) + violatedRule: { limit: number; window: string }; +} \ No newline at end of file diff --git a/apps/backend/src/middleware/rate-limit.middleware.spec.ts b/apps/backend/src/middleware/rate-limit.middleware.spec.ts new file mode 100644 index 0000000..6cfd769 --- /dev/null +++ b/apps/backend/src/middleware/rate-limit.middleware.spec.ts @@ -0,0 +1,169 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { ThrottlerGuard, ThrottlerException } from '@nestjs/throttler'; +import { getRepositoryToken } from '@nestjs/typeorm'; +import { RateLimitMiddleware } from './rate-limit.middleware'; +import { RateLimitViolation } from '../entities/rate-limit-violation.entity'; +import { createMock } from '@golevelup/ts-jest'; +import { Request, Response } from 'express'; + +describe('RateLimitMiddleware', () => { + let middleware: RateLimitMiddleware; + let throttlerGuard: ThrottlerGuard; + let repo: any; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + providers: [ + RateLimitMiddleware, + { + provide: ThrottlerGuard, + useValue: { + handleRequest: jest.fn(), + }, + }, + { + provide: getRepositoryToken(RateLimitViolation), + useValue: { + save: jest.fn(), + }, + }, + ], + }).compile(); + + middleware = module.get(RateLimitMiddleware); + throttlerGuard = module.get(ThrottlerGuard); + repo = module.get(getRepositoryToken(RateLimitViolation)); + }); + + it('should be defined', () => { + expect(middleware).toBeDefined(); + }); + + describe('when handling a non-protected route', () => { + it('should call next without throttling', async () => { + const req = createMock({ + path: '/unprotected', + method: 'GET', + }); + const res = createMock(); + const next = jest.fn(); + + await middleware.use(req, res, next); + + expect(next).toHaveBeenCalledTimes(1); + expect(throttlerGuard.handleRequest).not.toHaveBeenCalled(); + }); + }); + + describe('when handling a protected route', () => { + it('should apply throttling and continue if limit not exceeded', async () => { + const req = createMock({ + path: '/signal', + method: 'POST', + ip: '192.168.1.1', + headers: {}, + }); + const res = createMock(); + const next = jest.fn(); + + (throttlerGuard.handleRequest as jest.Mock).mockResolvedValueOnce(true); + + await middleware.use(req, res, next); + + expect(throttlerGuard.handleRequest).toHaveBeenCalled(); + expect(next).toHaveBeenCalledTimes(1); + }); + + it('should log abuse and return 429 if limit exceeded', async () => { + const req = createMock({ + path: '/vote', + method: 'POST', + ip: '192.168.1.2', + headers: { + 'x-wallet-address': '0x123456', + }, + }); + const res = createMock({ + status: jest.fn().mockReturnThis(), + json: jest.fn(), + }); + const next = jest.fn(); + + (throttlerGuard.handleRequest as jest.Mock).mockRejectedValueOnce( + new ThrottlerException() + ); + + await middleware.use(req, res, next); + + expect(throttlerGuard.handleRequest).toHaveBeenCalled(); + expect(next).not.toHaveBeenCalled(); + expect(res.status).toHaveBeenCalledWith(429); + expect(res.json).toHaveBeenCalled(); + expect(repo.save).toHaveBeenCalledWith( + expect.objectContaining({ + ipAddress: '192.168.1.2', + walletAddress: '0x123456', + endpoint: '/vote', + method: 'POST', + }) + ); + }); + + it('should prioritize wallet address over IP when both are available', async () => { + const req = createMock({ + path: '/webhook', + method: 'POST', + ip: '192.168.1.3', + headers: { + 'x-wallet-address': '0xABCDEF', + }, + }); + const res = createMock(); + const next = jest.fn(); + + (throttlerGuard.handleRequest as jest.Mock).mockResolvedValueOnce(true); + + await middleware.use(req, res, next); + + expect(req.throttlerKey).toBe('wallet:0xABCDEF'); + expect(next).toHaveBeenCalledTimes(1); + }); + + it('should use IP address when wallet is not available', async () => { + const req = createMock({ + path: '/signal', + method: 'POST', + ip: '192.168.1.4', + headers: {}, + }); + const res = createMock(); + const next = jest.fn(); + + (throttlerGuard.handleRequest as jest.Mock).mockResolvedValueOnce(true); + + await middleware.use(req, res, next); + + expect(req.throttlerKey).toBe('ip:192.168.1.4'); + expect(next).toHaveBeenCalledTimes(1); + }); + }); + + describe('error handling', () => { + it('should pass through non-throttler errors', async () => { + const req = createMock({ + path: '/signal', + method: 'POST', + ip: '192.168.1.5', + }); + const res = createMock(); + const next = jest.fn(); + const testError = new Error('Test error'); + + (throttlerGuard.handleRequest as jest.Mock).mockRejectedValueOnce(testError); + + await middleware.use(req, res, next); + + expect(next).toHaveBeenCalledWith(testError); + }); + }); +}); \ No newline at end of file diff --git a/apps/backend/src/middleware/rate-limit.middleware.ts b/apps/backend/src/middleware/rate-limit.middleware.ts index b4f2f68..affd8ed 100644 --- a/apps/backend/src/middleware/rate-limit.middleware.ts +++ b/apps/backend/src/middleware/rate-limit.middleware.ts @@ -1,19 +1,139 @@ -import { Injectable, type NestMiddleware } from '@nestjs/common'; +import { Injectable, NestMiddleware, Logger } from '@nestjs/common'; import { Request, Response, NextFunction } from 'express'; -import rateLimit from 'express-rate-limit'; +import { ThrottlerGuard, ThrottlerException } from '@nestjs/throttler'; +import { InjectRepository } from '@nestjs/typeorm'; +import { Repository } from 'typeorm'; +import { RateLimitViolation } from '../entities/rate-limit-violation.entity'; +import { v4 as uuidv4 } from 'uuid'; @Injectable() export class RateLimitMiddleware implements NestMiddleware { - private rateLimiter = rateLimit({ - windowMs: 15 * 60 * 1000, // 15 minutes in milliseconds - max: 100, // Limit each IP to 100 requests per windowMs - standardHeaders: true, // Return rate limit info in the `RateLimit-*` headers - legacyHeaders: false, // Disable the `X-RateLimit-*` headers - message: { message: 'Too many requests', statusCode: 429 }, - skipSuccessfulRequests: false, // Don't skip successful requests - }); - - use(req: Request, res: Response, next: NextFunction) { - this.rateLimiter(req, res, next); + private readonly logger = new Logger(RateLimitMiddleware.name); + private readonly protectedRoutes = [ + { path: '/signal', method: 'POST' }, + { path: '/vote', method: 'POST' }, + { path: '/webhook', method: 'POST' }, + ]; + + // Default rate limit settings + private readonly defaultRateLimit = { + limit: 10, + ttl: 60, // 60 seconds window + }; + + constructor( + private readonly throttlerGuard: ThrottlerGuard, + @InjectRepository(RateLimitViolation) + private rateLimitViolationRepo: Repository, + ) {} + + async use(req: Request, res: Response, next: NextFunction) { + // Check if route is protected + const isProtectedRoute = this.protectedRoutes.some( + (route) => route.path === req.path && route.method === req.method, + ); + + if (!isProtectedRoute) { + return next(); + } + + // Get IP address + const ipAddress = req.ip || req.socket.remoteAddress || 'unknown'; + + // Get wallet address if available + const walletAddress = req.headers['x-wallet-address'] as string || null; + + // Create tracking key - prioritize wallet if available, otherwise use IP + const trackingKey = walletAddress ? `wallet:${walletAddress}` : `ip:${ipAddress}`; + + try { + // Apply throttling using NestJS throttler + // We're utilizing the ThrottlerGuard but calling it manually + req.throttlerKey = trackingKey; + + // Apply rate limiting + const throttlerContext = { + getHandler: () => this.use, + getClass: () => RateLimitMiddleware, + }; + + await this.throttlerGuard.handleRequest( + { ...req, params: {}, query: {}, body: {} }, + res, + throttlerContext, + ); + + return next(); + } catch (error) { + if (error instanceof ThrottlerException) { + // Log abuse attempt + await this.logAbuseAttempt({ + ipAddress, + walletAddress, + endpoint: req.path, + method: req.method, + violatedRule: { + limit: this.defaultRateLimit.limit, + window: `${this.defaultRateLimit.ttl}s`, + }, + }); + + // Return standard throttling response + res.status(429).json({ + statusCode: 429, + message: 'Too Many Requests', + error: 'Rate limit exceeded', + }); + } else { + // Pass through other errors + next(error); + } + } + } + + private async logAbuseAttempt(data: { + ipAddress: string; + walletAddress: string | null; + endpoint: string; + method: string; + violatedRule: { limit: number; window: string }; + }) { + try { + const violation = new RateLimitViolation(); + violation.id = uuidv4(); + violation.ipAddress = data.ipAddress; + violation.walletAddress = data.walletAddress; + violation.endpoint = data.endpoint; + violation.method = data.method; + violation.timestamp = new Date(); + violation.violatedRule = data.violatedRule; + + await this.rateLimitViolationRepo.save(violation); + this.logger.warn(`Rate limit violation recorded: ${JSON.stringify(data)}`); + } catch (error) { + this.logger.error(`Failed to log rate limit violation: ${error.message}`, error.stack); + } } } + + + +// import { Injectable, type NestMiddleware } from '@nestjs/common'; +// import { Request, Response, NextFunction } from 'express'; +// import rateLimit from 'express-rate-limit'; + +// @Injectable() +// export class RateLimitMiddleware implements NestMiddleware { +// private rateLimiter = rateLimit({ +// windowMs: 15 * 60 * 1000, // 15 minutes in milliseconds +// max: 100, // Limit each IP to 100 requests per windowMs +// standardHeaders: true, // Return rate limit info in the `RateLimit-*` headers +// legacyHeaders: false, // Disable the `X-RateLimit-*` headers +// message: { message: 'Too many requests', statusCode: 429 }, +// skipSuccessfulRequests: false, // Don't skip successful requests +// }); + +// use(req: Request, res: Response, next: NextFunction) { +// this.rateLimiter(req, res, next); +// } +// } diff --git a/apps/backend/test/rate-limit.e2e-spec.ts b/apps/backend/test/rate-limit.e2e-spec.ts new file mode 100644 index 0000000..7bfe1b6 --- /dev/null +++ b/apps/backend/test/rate-limit.e2e-spec.ts @@ -0,0 +1,140 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { INestApplication } from '@nestjs/common'; +import * as request from 'supertest'; +import { AppModule } from '../src/app.module'; +import { TypeOrmModule } from '@nestjs/typeorm'; +import { ConfigModule, ConfigService } from '@nestjs/config'; +import { RateLimitViolation } from '../src/entities/rate-limit-violation.entity'; +import { getThrottlerConfig } from '../src/config/throttler.config'; +import { ThrottlerModule } from '@nestjs/throttler'; + +describe('Rate Limiting (e2e)', () => { + let app: INestApplication; + + beforeAll(async () => { + // Create a test module with in-memory database + const moduleFixture: TestingModule = await Test.createTestingModule({ + imports: [ + ConfigModule.forRoot({ + isGlobal: true, + // Override with test config + load: [() => ({ + THROTTLE_TTL: 1, + THROTTLE_LIMIT: 2, // Very restrictive limit for testing + })], + }), + TypeOrmModule.forRootAsync({ + useFactory: () => ({ + type: 'sqlite', + database: ':memory:', + entities: [RateLimitViolation], + synchronize: true, + }), + }), + TypeOrmModule.forFeature([RateLimitViolation]), + ThrottlerModule.forRoot({ + ttl: 1, + limit: 2, + }), + AppModule, + ], + }).compile(); + + app = moduleFixture.createNestApplication(); + await app.init(); + }); + + afterAll(async () => { + await app.close(); + }); + + describe('Rate limiting on protected routes', () => { + const testProtectedRoutes = [ + { path: '/signal', method: 'post' }, + { path: '/vote', method: 'post' }, + { path: '/webhook', method: 'post' }, + ]; + + testProtectedRoutes.forEach(route => { + it(`should allow requests within rate limit for ${route.method.toUpperCase()} ${route.path}`, async () => { + // First request - should succeed + await request(app.getHttpServer()) + [route.method](route.path) + .expect(res => { + // This is a test so we might get different status codes depending on the route + // implementation but it should not be 429 + expect(res.status).not.toBe(429); + }); + + // Second request - should still succeed + await request(app.getHttpServer()) + [route.method](route.path) + .expect(res => { + expect(res.status).not.toBe(429); + }); + }); + + it(`should block requests exceeding rate limit for ${route.method.toUpperCase()} ${route.path}`, async () => { + // First two requests + await request(app.getHttpServer())[route.method](route.path); + await request(app.getHttpServer())[route.method](route.path); + + // Third request in quick succession - should be rate limited + await request(app.getHttpServer()) + [route.method](route.path) + .expect(429); + + // Wait for ttl to expire (2 seconds to be safe) + await new Promise(resolve => setTimeout(resolve, 2000)); + + // After waiting, should work again + await request(app.getHttpServer()) + [route.method](route.path) + .expect(res => { + expect(res.status).not.toBe(429); + }); + }); + + it(`should track wallet address when provided for ${route.method.toUpperCase()} ${route.path}`, async () => { + const wallet1 = '0xWallet1'; + const wallet2 = '0xWallet2'; + + // First wallet can make 2 requests + await request(app.getHttpServer()) + [route.method](route.path) + .set('x-wallet-address', wallet1); + + await request(app.getHttpServer()) + [route.method](route.path) + .set('x-wallet-address', wallet1); + + // Third request from first wallet should be blocked + await request(app.getHttpServer()) + [route.method](route.path) + .set('x-wallet-address', wallet1) + .expect(429); + + // Second wallet should still be allowed + await request(app.getHttpServer()) + [route.method](route.path) + .set('x-wallet-address', wallet2) + .expect(res => { + expect(res.status).not.toBe(429); + }); + }); + }); + }); + + describe('Unprotected routes', () => { + it('should not apply rate limiting to unprotected route', async () => { + // Make many requests to an unprotected route - none should be rate limited + for (let i = 0; i < 10; i++) { + await request(app.getHttpServer()) + .get('/health') + .expect(res => { + expect(res.status).not.toBe(429); + }); + } + }); + }); +}); \ No newline at end of file