Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions apps/backend/src/app.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ import { RateLimitMiddleware } from './middleware/rate-limit.middleware';
import { MailModule } from './mail/mail.module';
import { NewsModule } from './news/news.module';
import { TasksModule } from './tasks/tasks.module';
import { RateLimitViolation } from './entities/rate-limit-violation.entity';
import { getThrottlerConfig } from './config/throttler.config';

import { ForumReportModule } from './forum-report/forum-report.module';
import { AdminModule } from './admin/admin.module';


const ENV = process.env.NODE_ENV || 'development';
console.log('Current environment:', ENV);

Expand Down
19 changes: 19 additions & 0 deletions apps/backend/src/config/throttler.config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { ThrottlerModule } from '@nestjs/throttler';
import { ThrottlerStorageRedisService } from 'nestjs-throttler-storage-redis';
import { ConfigService } from '@nestjs/config';

export const getThrottlerConfig = (configService: ConfigService) => {
return ThrottlerModule.forRootAsync({
inject: [ConfigService],
useFactory: (config: ConfigService) => ({
ttl: config.get('THROTTLE_TTL', 60), // Default 60 seconds window
limit: config.get('THROTTLE_LIMIT', 10), // Default 10 requests per window
storage: new ThrottlerStorageRedisService({
host: config.get('REDIS_HOST', 'localhost'),
port: config.get('REDIS_PORT', 6379),
password: config.get('REDIS_PASSWORD', ''),
keyPrefix: 'throttle:',
}),
}),
});
};
25 changes: 25 additions & 0 deletions apps/backend/src/entities/rate-limit-violation.entity.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { Entity, Column, PrimaryColumn } from 'typeorm';

@Entity('rate_limit_violations')
export class RateLimitViolation {
@PrimaryColumn('uuid')
id: string;

@Column({ name: 'ip_address', nullable: false })
ipAddress: string;

@Column({ name: 'wallet_address', nullable: true })
walletAddress: string | null;

@Column({ nullable: false })
endpoint: string;

@Column({ nullable: false })
method: string;

@Column({ type: 'timestamp', nullable: false })
timestamp: Date;

@Column({ name: 'violated_rule', type: 'json', nullable: false })
violatedRule: { limit: number; window: string };
}
169 changes: 169 additions & 0 deletions apps/backend/src/middleware/rate-limit.middleware.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import { Test, TestingModule } from '@nestjs/testing';
import { ThrottlerGuard, ThrottlerException } from '@nestjs/throttler';
import { getRepositoryToken } from '@nestjs/typeorm';
import { RateLimitMiddleware } from './rate-limit.middleware';
import { RateLimitViolation } from '../entities/rate-limit-violation.entity';
import { createMock } from '@golevelup/ts-jest';
import { Request, Response } from 'express';

describe('RateLimitMiddleware', () => {
let middleware: RateLimitMiddleware;
let throttlerGuard: ThrottlerGuard;
let repo: any;

beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
RateLimitMiddleware,
{
provide: ThrottlerGuard,
useValue: {
handleRequest: jest.fn(),
},
},
{
provide: getRepositoryToken(RateLimitViolation),
useValue: {
save: jest.fn(),
},
},
],
}).compile();

middleware = module.get<RateLimitMiddleware>(RateLimitMiddleware);
throttlerGuard = module.get<ThrottlerGuard>(ThrottlerGuard);
repo = module.get(getRepositoryToken(RateLimitViolation));
});

it('should be defined', () => {
expect(middleware).toBeDefined();
});

describe('when handling a non-protected route', () => {
it('should call next without throttling', async () => {
const req = createMock<Request>({
path: '/unprotected',
method: 'GET',
});
const res = createMock<Response>();
const next = jest.fn();

await middleware.use(req, res, next);

expect(next).toHaveBeenCalledTimes(1);
expect(throttlerGuard.handleRequest).not.toHaveBeenCalled();
});
});

describe('when handling a protected route', () => {
it('should apply throttling and continue if limit not exceeded', async () => {
const req = createMock<Request>({
path: '/signal',
method: 'POST',
ip: '192.168.1.1',
headers: {},
});
const res = createMock<Response>();
const next = jest.fn();

(throttlerGuard.handleRequest as jest.Mock).mockResolvedValueOnce(true);

await middleware.use(req, res, next);

expect(throttlerGuard.handleRequest).toHaveBeenCalled();
expect(next).toHaveBeenCalledTimes(1);
});

it('should log abuse and return 429 if limit exceeded', async () => {
const req = createMock<Request>({
path: '/vote',
method: 'POST',
ip: '192.168.1.2',
headers: {
'x-wallet-address': '0x123456',
},
});
const res = createMock<Response>({
status: jest.fn().mockReturnThis(),
json: jest.fn(),
});
const next = jest.fn();

(throttlerGuard.handleRequest as jest.Mock).mockRejectedValueOnce(
new ThrottlerException()
);

await middleware.use(req, res, next);

expect(throttlerGuard.handleRequest).toHaveBeenCalled();
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(429);
expect(res.json).toHaveBeenCalled();
expect(repo.save).toHaveBeenCalledWith(
expect.objectContaining({
ipAddress: '192.168.1.2',
walletAddress: '0x123456',
endpoint: '/vote',
method: 'POST',
})
);
});

it('should prioritize wallet address over IP when both are available', async () => {
const req = createMock<Request>({
path: '/webhook',
method: 'POST',
ip: '192.168.1.3',
headers: {
'x-wallet-address': '0xABCDEF',
},
});
const res = createMock<Response>();
const next = jest.fn();

(throttlerGuard.handleRequest as jest.Mock).mockResolvedValueOnce(true);

await middleware.use(req, res, next);

expect(req.throttlerKey).toBe('wallet:0xABCDEF');
expect(next).toHaveBeenCalledTimes(1);
});

it('should use IP address when wallet is not available', async () => {
const req = createMock<Request>({
path: '/signal',
method: 'POST',
ip: '192.168.1.4',
headers: {},
});
const res = createMock<Response>();
const next = jest.fn();

(throttlerGuard.handleRequest as jest.Mock).mockResolvedValueOnce(true);

await middleware.use(req, res, next);

expect(req.throttlerKey).toBe('ip:192.168.1.4');
expect(next).toHaveBeenCalledTimes(1);
});
});

describe('error handling', () => {
it('should pass through non-throttler errors', async () => {
const req = createMock<Request>({
path: '/signal',
method: 'POST',
ip: '192.168.1.5',
});
const res = createMock<Response>();
const next = jest.fn();
const testError = new Error('Test error');

(throttlerGuard.handleRequest as jest.Mock).mockRejectedValueOnce(testError);

await middleware.use(req, res, next);

expect(next).toHaveBeenCalledWith(testError);
});
});
});
146 changes: 133 additions & 13 deletions apps/backend/src/middleware/rate-limit.middleware.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,139 @@
import { Injectable, type NestMiddleware } from '@nestjs/common';
import { Injectable, NestMiddleware, Logger } from '@nestjs/common';
import { Request, Response, NextFunction } from 'express';
import rateLimit from 'express-rate-limit';
import { ThrottlerGuard, ThrottlerException } from '@nestjs/throttler';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { RateLimitViolation } from '../entities/rate-limit-violation.entity';
import { v4 as uuidv4 } from 'uuid';

@Injectable()
export class RateLimitMiddleware implements NestMiddleware {
private rateLimiter = rateLimit({
windowMs: 15 * 60 * 1000, // 15 minutes in milliseconds
max: 100, // Limit each IP to 100 requests per windowMs
standardHeaders: true, // Return rate limit info in the `RateLimit-*` headers
legacyHeaders: false, // Disable the `X-RateLimit-*` headers
message: { message: 'Too many requests', statusCode: 429 },
skipSuccessfulRequests: false, // Don't skip successful requests
});

use(req: Request, res: Response, next: NextFunction) {
this.rateLimiter(req, res, next);
private readonly logger = new Logger(RateLimitMiddleware.name);
private readonly protectedRoutes = [
{ path: '/signal', method: 'POST' },
{ path: '/vote', method: 'POST' },
{ path: '/webhook', method: 'POST' },
];

// Default rate limit settings
private readonly defaultRateLimit = {
limit: 10,
ttl: 60, // 60 seconds window
};

constructor(
private readonly throttlerGuard: ThrottlerGuard,
@InjectRepository(RateLimitViolation)
private rateLimitViolationRepo: Repository<RateLimitViolation>,
) {}

async use(req: Request, res: Response, next: NextFunction) {
// Check if route is protected
const isProtectedRoute = this.protectedRoutes.some(
(route) => route.path === req.path && route.method === req.method,
);

if (!isProtectedRoute) {
return next();
}

// Get IP address
const ipAddress = req.ip || req.socket.remoteAddress || 'unknown';

// Get wallet address if available
const walletAddress = req.headers['x-wallet-address'] as string || null;

// Create tracking key - prioritize wallet if available, otherwise use IP
const trackingKey = walletAddress ? `wallet:${walletAddress}` : `ip:${ipAddress}`;

try {
// Apply throttling using NestJS throttler
// We're utilizing the ThrottlerGuard but calling it manually
req.throttlerKey = trackingKey;

// Apply rate limiting
const throttlerContext = {
getHandler: () => this.use,
getClass: () => RateLimitMiddleware,
};

await this.throttlerGuard.handleRequest(
{ ...req, params: {}, query: {}, body: {} },
res,
throttlerContext,
);

return next();
} catch (error) {
if (error instanceof ThrottlerException) {
// Log abuse attempt
await this.logAbuseAttempt({
ipAddress,
walletAddress,
endpoint: req.path,
method: req.method,
violatedRule: {
limit: this.defaultRateLimit.limit,
window: `${this.defaultRateLimit.ttl}s`,
},
});

// Return standard throttling response
res.status(429).json({
statusCode: 429,
message: 'Too Many Requests',
error: 'Rate limit exceeded',
});
} else {
// Pass through other errors
next(error);
}
}
}

private async logAbuseAttempt(data: {
ipAddress: string;
walletAddress: string | null;
endpoint: string;
method: string;
violatedRule: { limit: number; window: string };
}) {
try {
const violation = new RateLimitViolation();
violation.id = uuidv4();
violation.ipAddress = data.ipAddress;
violation.walletAddress = data.walletAddress;
violation.endpoint = data.endpoint;
violation.method = data.method;
violation.timestamp = new Date();
violation.violatedRule = data.violatedRule;

await this.rateLimitViolationRepo.save(violation);
this.logger.warn(`Rate limit violation recorded: ${JSON.stringify(data)}`);
} catch (error) {
this.logger.error(`Failed to log rate limit violation: ${error.message}`, error.stack);
}
}
}



// import { Injectable, type NestMiddleware } from '@nestjs/common';
// import { Request, Response, NextFunction } from 'express';
// import rateLimit from 'express-rate-limit';

// @Injectable()
// export class RateLimitMiddleware implements NestMiddleware {
// private rateLimiter = rateLimit({
// windowMs: 15 * 60 * 1000, // 15 minutes in milliseconds
// max: 100, // Limit each IP to 100 requests per windowMs
// standardHeaders: true, // Return rate limit info in the `RateLimit-*` headers
// legacyHeaders: false, // Disable the `X-RateLimit-*` headers
// message: { message: 'Too many requests', statusCode: 429 },
// skipSuccessfulRequests: false, // Don't skip successful requests
// });

// use(req: Request, res: Response, next: NextFunction) {
// this.rateLimiter(req, res, next);
// }
// }
Loading