Skip to content

Commit

Permalink
♻️ Rate Limit middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
lukevella committed Feb 8, 2025
1 parent 24558c2 commit dbeb3b3
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 42 deletions.
6 changes: 4 additions & 2 deletions apps/web/src/trpc/routers/polls.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import { getEmailClient } from "@/utils/emails";

import { getTimeZoneAbbreviation } from "../../utils/date";
import {
createRateLimitMiddleware,
possiblyPublicProcedure,
privateProcedure,
proProcedure,
publicProcedure,
rateLimitMiddleware,
requireUserMiddleware,
router,
} from "../trpc";
Expand Down Expand Up @@ -130,7 +130,7 @@ export const polls = router({

// START LEGACY ROUTES
create: possiblyPublicProcedure
.use(rateLimitMiddleware)
.use(createRateLimitMiddleware(20, "1 h"))
.use(requireUserMiddleware)
.input(
z.object({
Expand Down Expand Up @@ -233,6 +233,7 @@ export const polls = router({
return { id: poll.id };
}),
update: possiblyPublicProcedure
.use(createRateLimitMiddleware(60, "1 h"))
.input(
z.object({
urlId: z.string(),
Expand Down Expand Up @@ -305,6 +306,7 @@ export const polls = router({
});
}),
delete: possiblyPublicProcedure
.use(createRateLimitMiddleware(30, "1 h"))
.input(
z.object({
urlId: z.string(),
Expand Down
38 changes: 21 additions & 17 deletions apps/web/src/trpc/routers/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import { createToken } from "@/utils/session";
import { getSubscriptionStatus } from "@/utils/subscription";

import {
createRateLimitMiddleware,
privateProcedure,
publicProcedure,
rateLimitMiddleware,
router,
} from "../trpc";

Expand Down Expand Up @@ -53,20 +53,22 @@ export const user = router({
},
});
}),
delete: privateProcedure.mutation(async ({ ctx }) => {
if (ctx.user.isGuest) {
throw new TRPCError({
code: "BAD_REQUEST",
message: "Guest users cannot be deleted",
});
}
delete: privateProcedure
.use(createRateLimitMiddleware(5, "1 h"))
.mutation(async ({ ctx }) => {
if (ctx.user.isGuest) {
throw new TRPCError({
code: "BAD_REQUEST",
message: "Guest users cannot be deleted",
});
}

await prisma.user.delete({
where: {
id: ctx.user.id,
},
});
}),
await prisma.user.delete({
where: {
id: ctx.user.id,
},
});
}),
subscription: publicProcedure.query(
async ({ ctx }): Promise<{ legacy?: boolean; active: boolean }> => {
if (!ctx.user || ctx.user.isGuest) {
Expand All @@ -80,6 +82,7 @@ export const user = router({
},
),
changeName: privateProcedure
.use(createRateLimitMiddleware(20, "1 h"))
.input(
z.object({
name: z.string().min(1).max(100),
Expand All @@ -96,6 +99,7 @@ export const user = router({
});
}),
updatePreferences: privateProcedure
.use(createRateLimitMiddleware(30, "1 h"))
.input(
z.object({
locale: z.string().optional(),
Expand All @@ -122,7 +126,7 @@ export const user = router({
return { success: true };
}),
requestEmailChange: privateProcedure
.use(rateLimitMiddleware)
.use(createRateLimitMiddleware(10, "1 h"))
.input(z.object({ email: z.string().email() }))
.mutation(async ({ input, ctx }) => {
const currentUser = await prisma.user.findUnique({
Expand Down Expand Up @@ -174,7 +178,7 @@ export const user = router({
return { success: true as const };
}),
getAvatarUploadUrl: privateProcedure
.use(rateLimitMiddleware)
.use(createRateLimitMiddleware(20, "1 h"))
.input(
z.object({
fileType: z.enum(["image/jpeg", "image/png"]),
Expand Down Expand Up @@ -220,7 +224,7 @@ export const user = router({
}),
updateAvatar: privateProcedure
.input(z.object({ imageKey: z.string().max(255) }))
.use(rateLimitMiddleware)
.use(createRateLimitMiddleware(10, "1 h"))
.mutation(async ({ ctx, input }) => {
const userId = ctx.user.id;
const oldImageKey = ctx.user.image;
Expand Down
53 changes: 30 additions & 23 deletions apps/web/src/trpc/trpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,33 +89,40 @@ export const proProcedure = privateProcedure.use(async ({ ctx, next }) => {
return next();
});

export const rateLimitMiddleware = middleware(async ({ ctx, next }) => {
if (!process.env.KV_REST_API_URL) {
return next();
}

const ratelimit = new Ratelimit({
redis: kv,
limiter: Ratelimit.slidingWindow(5, "1 m"),
});
export const createRateLimitMiddleware = (
requests: number,
duration: "1 m" | "1 h",
) => {
return middleware(async ({ ctx, next }) => {
if (!process.env.KV_REST_API_URL) {
return next();
}

if (!ctx.ip) {
throw new TRPCError({
code: "INTERNAL_SERVER_ERROR",
message: "Failed to get client IP",
if (!ctx.ip) {
throw new TRPCError({
code: "INTERNAL_SERVER_ERROR",
message: "Failed to get client IP",
});
}
const ratelimit = new Ratelimit({
redis: kv,
limiter: Ratelimit.slidingWindow(requests, duration),
});
}

const res = await ratelimit.limit(ctx.ip);
const res = await ratelimit.limit(ctx.ip);

if (!res.success) {
throw new TRPCError({
code: "TOO_MANY_REQUESTS",
message: "Too many requests",
});
}
if (!res.success) {
throw new TRPCError({
code: "TOO_MANY_REQUESTS",
message: "Too many requests",
});
}

return next();
});
return next();
});
};

// Usage example:
export const rateLimitMiddleware = createRateLimitMiddleware(5, "1 m");

export const mergeRouters = t.mergeRouters;

0 comments on commit dbeb3b3

Please sign in to comment.