From dbeb3b3b2c7d83220b802ae6b3cc48c81264c05b Mon Sep 17 00:00:00 2001 From: Luke Vella Date: Sat, 8 Feb 2025 15:09:11 +0700 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Rate=20Limit=20middleware?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/web/src/trpc/routers/polls.ts | 6 ++-- apps/web/src/trpc/routers/user.ts | 38 +++++++++++---------- apps/web/src/trpc/trpc.ts | 53 +++++++++++++++++------------- 3 files changed, 55 insertions(+), 42 deletions(-) diff --git a/apps/web/src/trpc/routers/polls.ts b/apps/web/src/trpc/routers/polls.ts index b634d0c5d15..f28b46f8e94 100644 --- a/apps/web/src/trpc/routers/polls.ts +++ b/apps/web/src/trpc/routers/polls.ts @@ -12,11 +12,11 @@ import { getEmailClient } from "@/utils/emails"; import { getTimeZoneAbbreviation } from "../../utils/date"; import { + createRateLimitMiddleware, possiblyPublicProcedure, privateProcedure, proProcedure, publicProcedure, - rateLimitMiddleware, requireUserMiddleware, router, } from "../trpc"; @@ -130,7 +130,7 @@ export const polls = router({ // START LEGACY ROUTES create: possiblyPublicProcedure - .use(rateLimitMiddleware) + .use(createRateLimitMiddleware(20, "1 h")) .use(requireUserMiddleware) .input( z.object({ @@ -233,6 +233,7 @@ export const polls = router({ return { id: poll.id }; }), update: possiblyPublicProcedure + .use(createRateLimitMiddleware(60, "1 h")) .input( z.object({ urlId: z.string(), @@ -305,6 +306,7 @@ export const polls = router({ }); }), delete: possiblyPublicProcedure + .use(createRateLimitMiddleware(30, "1 h")) .input( z.object({ urlId: z.string(), diff --git a/apps/web/src/trpc/routers/user.ts b/apps/web/src/trpc/routers/user.ts index e20974eb026..e2f9654ebea 100644 --- a/apps/web/src/trpc/routers/user.ts +++ b/apps/web/src/trpc/routers/user.ts @@ -12,9 +12,9 @@ import { createToken } from "@/utils/session"; import { getSubscriptionStatus } from "@/utils/subscription"; import { + createRateLimitMiddleware, privateProcedure, publicProcedure, - rateLimitMiddleware, router, } from "../trpc"; @@ -53,20 +53,22 @@ export const user = router({ }, }); }), - delete: privateProcedure.mutation(async ({ ctx }) => { - if (ctx.user.isGuest) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Guest users cannot be deleted", - }); - } + delete: privateProcedure + .use(createRateLimitMiddleware(5, "1 h")) + .mutation(async ({ ctx }) => { + if (ctx.user.isGuest) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Guest users cannot be deleted", + }); + } - await prisma.user.delete({ - where: { - id: ctx.user.id, - }, - }); - }), + await prisma.user.delete({ + where: { + id: ctx.user.id, + }, + }); + }), subscription: publicProcedure.query( async ({ ctx }): Promise<{ legacy?: boolean; active: boolean }> => { if (!ctx.user || ctx.user.isGuest) { @@ -80,6 +82,7 @@ export const user = router({ }, ), changeName: privateProcedure + .use(createRateLimitMiddleware(20, "1 h")) .input( z.object({ name: z.string().min(1).max(100), @@ -96,6 +99,7 @@ export const user = router({ }); }), updatePreferences: privateProcedure + .use(createRateLimitMiddleware(30, "1 h")) .input( z.object({ locale: z.string().optional(), @@ -122,7 +126,7 @@ export const user = router({ return { success: true }; }), requestEmailChange: privateProcedure - .use(rateLimitMiddleware) + .use(createRateLimitMiddleware(10, "1 h")) .input(z.object({ email: z.string().email() })) .mutation(async ({ input, ctx }) => { const currentUser = await prisma.user.findUnique({ @@ -174,7 +178,7 @@ export const user = router({ return { success: true as const }; }), getAvatarUploadUrl: privateProcedure - .use(rateLimitMiddleware) + .use(createRateLimitMiddleware(20, "1 h")) .input( z.object({ fileType: z.enum(["image/jpeg", "image/png"]), @@ -220,7 +224,7 @@ export const user = router({ }), updateAvatar: privateProcedure .input(z.object({ imageKey: z.string().max(255) })) - .use(rateLimitMiddleware) + .use(createRateLimitMiddleware(10, "1 h")) .mutation(async ({ ctx, input }) => { const userId = ctx.user.id; const oldImageKey = ctx.user.image; diff --git a/apps/web/src/trpc/trpc.ts b/apps/web/src/trpc/trpc.ts index 7ad2b4fe8ae..f94acc7b1b5 100644 --- a/apps/web/src/trpc/trpc.ts +++ b/apps/web/src/trpc/trpc.ts @@ -89,33 +89,40 @@ export const proProcedure = privateProcedure.use(async ({ ctx, next }) => { return next(); }); -export const rateLimitMiddleware = middleware(async ({ ctx, next }) => { - if (!process.env.KV_REST_API_URL) { - return next(); - } - - const ratelimit = new Ratelimit({ - redis: kv, - limiter: Ratelimit.slidingWindow(5, "1 m"), - }); +export const createRateLimitMiddleware = ( + requests: number, + duration: "1 m" | "1 h", +) => { + return middleware(async ({ ctx, next }) => { + if (!process.env.KV_REST_API_URL) { + return next(); + } - if (!ctx.ip) { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: "Failed to get client IP", + if (!ctx.ip) { + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: "Failed to get client IP", + }); + } + const ratelimit = new Ratelimit({ + redis: kv, + limiter: Ratelimit.slidingWindow(requests, duration), }); - } - const res = await ratelimit.limit(ctx.ip); + const res = await ratelimit.limit(ctx.ip); - if (!res.success) { - throw new TRPCError({ - code: "TOO_MANY_REQUESTS", - message: "Too many requests", - }); - } + if (!res.success) { + throw new TRPCError({ + code: "TOO_MANY_REQUESTS", + message: "Too many requests", + }); + } - return next(); -}); + return next(); + }); +}; + +// Usage example: +export const rateLimitMiddleware = createRateLimitMiddleware(5, "1 m"); export const mergeRouters = t.mergeRouters;