From be24b11a7029e2dbf41d95d7cf37417e801bcf6d Mon Sep 17 00:00:00 2001 From: thom Date: Mon, 7 Oct 2024 12:30:54 -0700 Subject: [PATCH] fix: refactor trpc auth guard name + add two more --- src/backend/routers/admin.ts | 4 +- src/backend/routers/case_manager.ts | 22 ++++----- src/backend/routers/file.ts | 12 ++--- src/backend/routers/iep.ts | 36 +++++++------- src/backend/routers/para.ts | 10 ++-- src/backend/routers/public.ts | 4 +- src/backend/routers/student.ts | 14 +++--- src/backend/routers/user.ts | 6 +-- src/backend/trpc.ts | 76 +++++++++++++++++++++++++++-- 9 files changed, 126 insertions(+), 58 deletions(-) diff --git a/src/backend/routers/admin.ts b/src/backend/routers/admin.ts index 5c5099a8..80a2eb98 100644 --- a/src/backend/routers/admin.ts +++ b/src/backend/routers/admin.ts @@ -1,8 +1,8 @@ import { sql } from "kysely"; -import { adminProcedure, router } from "../trpc"; +import { hasAdmin, router } from "../trpc"; export const admin = router({ - getPostgresInfo: adminProcedure.query(async (req) => { + getPostgresInfo: hasAdmin.query(async (req) => { const result = await sql<{ version: string }>`SELECT version()`.execute( req.ctx.db ); diff --git a/src/backend/routers/case_manager.ts b/src/backend/routers/case_manager.ts index aaa5d91a..7d7091c4 100644 --- a/src/backend/routers/case_manager.ts +++ b/src/backend/routers/case_manager.ts @@ -1,5 +1,5 @@ import { z } from "zod"; -import { authenticatedProcedure, router } from "../trpc"; +import { hasAuthenticated, router } from "../trpc"; import { createPara, assignParaToCaseManager, @@ -10,7 +10,7 @@ export const case_manager = router({ /** * Get all students assigned to the current user */ - getMyStudents: authenticatedProcedure.query(async (req) => { + getMyStudents: hasAuthenticated.query(async (req) => { const { userId } = req.ctx.auth; const result = await req.ctx.db @@ -22,7 +22,7 @@ export const case_manager = router({ return result; }), - getMyStudentsAndIepInfo: authenticatedProcedure.query(async (req) => { + getMyStudentsAndIepInfo: hasAuthenticated.query(async (req) => { const { userId } = req.ctx.auth; const studentData = await req.ctx.db @@ -50,7 +50,7 @@ export const case_manager = router({ * it doesn't already exist. Throws an error if the student is already * assigned to another CM. */ - addStudent: authenticatedProcedure + addStudent: hasAuthenticated .input( z.object({ first_name: z.string(), @@ -72,7 +72,7 @@ export const case_manager = router({ /** * Edits the given student in the CM's roster. Throws an error if the student was not found in the db. */ - editStudent: authenticatedProcedure + editStudent: hasAuthenticated .input( z.object({ student_id: z.string(), @@ -115,7 +115,7 @@ export const case_manager = router({ /** * Removes the case manager associated with this student. */ - removeStudent: authenticatedProcedure + removeStudent: hasAuthenticated .input( z.object({ student_id: z.string(), @@ -131,7 +131,7 @@ export const case_manager = router({ .execute(); }), - getMyParas: authenticatedProcedure.query(async (req) => { + getMyParas: hasAuthenticated.query(async (req) => { const { userId } = req.ctx.auth; const result = await req.ctx.db @@ -152,7 +152,7 @@ export const case_manager = router({ * Handles creation of para and assignment to user, attempts to send * email but does not await email success */ - addStaff: authenticatedProcedure + addStaff: hasAuthenticated .input( z.object({ first_name: z.string(), @@ -180,7 +180,7 @@ export const case_manager = router({ /** * Deprecated: use addStaff instead */ - addPara: authenticatedProcedure + addPara: hasAuthenticated .input( z.object({ para_id: z.string(), @@ -195,7 +195,7 @@ export const case_manager = router({ return; }), - editPara: authenticatedProcedure + editPara: hasAuthenticated .input( z.object({ para_id: z.string(), @@ -236,7 +236,7 @@ export const case_manager = router({ .executeTakeFirstOrThrow(); }), - removePara: authenticatedProcedure + removePara: hasAuthenticated .input( z.object({ para_id: z.string(), diff --git a/src/backend/routers/file.ts b/src/backend/routers/file.ts index cd99aed8..1206a89e 100644 --- a/src/backend/routers/file.ts +++ b/src/backend/routers/file.ts @@ -5,12 +5,12 @@ import { GetObjectCommand, } from "@aws-sdk/client-s3"; import { getSignedUrl } from "@aws-sdk/s3-request-presigner"; -import { authenticatedProcedure, router } from "../trpc"; +import { hasAuthenticated, router } from "../trpc"; import { randomUUID } from "crypto"; import { deleteFile } from "../lib/files"; export const file = router({ - getMyFiles: authenticatedProcedure.query(async (req) => { + getMyFiles: hasAuthenticated.query(async (req) => { return req.ctx.db .selectFrom("file") .selectAll() @@ -18,7 +18,7 @@ export const file = router({ .execute(); }), - getPresignedUrlForFileDownload: authenticatedProcedure + getPresignedUrlForFileDownload: hasAuthenticated .input( z.object({ file_id: z.string().uuid(), @@ -50,7 +50,7 @@ export const file = router({ }; }), - getPresignedUrlForFileUpload: authenticatedProcedure + getPresignedUrlForFileUpload: hasAuthenticated .input( z.object({ type: z.string(), @@ -71,7 +71,7 @@ export const file = router({ return { url, key }; }), - finishFileUpload: authenticatedProcedure + finishFileUpload: hasAuthenticated .input( z.object({ filename: z.string(), @@ -99,7 +99,7 @@ export const file = router({ return file; }), - deleteFile: authenticatedProcedure + deleteFile: hasAuthenticated .input( z.object({ file_id: z.string().uuid(), diff --git a/src/backend/routers/iep.ts b/src/backend/routers/iep.ts index 0a7ae50a..2f5a716a 100644 --- a/src/backend/routers/iep.ts +++ b/src/backend/routers/iep.ts @@ -1,5 +1,5 @@ import { z } from "zod"; -import { authenticatedProcedure, router } from "../trpc"; +import { hasAuthenticated, router } from "../trpc"; import { jsonArrayFrom } from "kysely/helpers/postgres"; import { deleteFile } from "../lib/files"; import { substituteTransactionOnContext } from "../lib/utils/context"; @@ -7,7 +7,7 @@ import { TRPCError } from "@trpc/server"; // TODO: define .output() schemas for all procedures export const iep = router({ - addGoal: authenticatedProcedure + addGoal: hasAuthenticated .input( z.object({ iep_id: z.string(), @@ -31,7 +31,7 @@ export const iep = router({ return result; }), - editGoal: authenticatedProcedure + editGoal: hasAuthenticated .input( z.object({ goal_id: z.string(), @@ -70,7 +70,7 @@ export const iep = router({ return result; }), - addSubgoal: authenticatedProcedure + addSubgoal: hasAuthenticated .input( z.object({ // current_level not included, should be calculated as trial data is collected @@ -123,7 +123,7 @@ export const iep = router({ return result; }), - addTask: authenticatedProcedure + addTask: hasAuthenticated .input( z.object({ subgoal_id: z.string(), @@ -148,7 +148,7 @@ export const iep = router({ return result; }), - assignTaskToParas: authenticatedProcedure + assignTaskToParas: hasAuthenticated .input( z.object({ subgoal_id: z.string().uuid(), @@ -175,7 +175,7 @@ export const iep = router({ return result; }), //Temporary function to easily assign tasks to self for testing - tempAddTaskToSelf: authenticatedProcedure + tempAddTaskToSelf: hasAuthenticated .input( z.object({ subgoal_id: z.string(), @@ -217,7 +217,7 @@ export const iep = router({ return result; }), - addTrialData: authenticatedProcedure + addTrialData: hasAuthenticated .input( z.object({ task_id: z.string(), @@ -246,7 +246,7 @@ export const iep = router({ return result; }), - updateTrialData: authenticatedProcedure + updateTrialData: hasAuthenticated .input( z.object({ trial_data_id: z.string(), @@ -271,7 +271,7 @@ export const iep = router({ .execute(); }), - getGoals: authenticatedProcedure + getGoals: hasAuthenticated .input( z.object({ iep_id: z.string(), @@ -289,7 +289,7 @@ export const iep = router({ return result; }), - getGoal: authenticatedProcedure + getGoal: hasAuthenticated .input( z.object({ goal_id: z.string(), @@ -307,7 +307,7 @@ export const iep = router({ return result; }), - getSubgoals: authenticatedProcedure + getSubgoals: hasAuthenticated .input( z.object({ goal_id: z.string(), @@ -325,7 +325,7 @@ export const iep = router({ return result; }), - getSubgoal: authenticatedProcedure + getSubgoal: hasAuthenticated .input( z.object({ subgoal_id: z.string(), @@ -342,7 +342,7 @@ export const iep = router({ return result; }), - getSubgoalsByAssignee: authenticatedProcedure + getSubgoalsByAssignee: hasAuthenticated .input( z.object({ assignee_id: z.string(), @@ -361,7 +361,7 @@ export const iep = router({ return result; }), - getSubgoalAndTrialData: authenticatedProcedure + getSubgoalAndTrialData: hasAuthenticated .input( z.object({ task_id: z.string(), @@ -424,7 +424,7 @@ export const iep = router({ return result; }), - markAsSeen: authenticatedProcedure + markAsSeen: hasAuthenticated .input( z.object({ task_id: z.string(), @@ -442,7 +442,7 @@ export const iep = router({ .execute(); }), - attachFileToTrialData: authenticatedProcedure + attachFileToTrialData: hasAuthenticated .input( z.object({ trial_data_id: z.string(), @@ -461,7 +461,7 @@ export const iep = router({ .execute(); }), - removeFileFromTrialDataAndDelete: authenticatedProcedure + removeFileFromTrialDataAndDelete: hasAuthenticated .input( z.object({ trial_data_id: z.string(), diff --git a/src/backend/routers/para.ts b/src/backend/routers/para.ts index 2a5ac7e0..44acccee 100644 --- a/src/backend/routers/para.ts +++ b/src/backend/routers/para.ts @@ -1,9 +1,9 @@ import { z } from "zod"; -import { authenticatedProcedure, router } from "../trpc"; +import { hasAuthenticated, router } from "../trpc"; import { createPara } from "../lib/db_helpers/case_manager"; export const para = router({ - getParaById: authenticatedProcedure + getParaById: hasAuthenticated .input(z.object({ user_id: z.string().uuid() })) .query(async (req) => { const { user_id } = req.input; @@ -17,7 +17,7 @@ export const para = router({ return result; }), - getParaByEmail: authenticatedProcedure + getParaByEmail: hasAuthenticated .input(z.object({ email: z.string() })) .query(async (req) => { const { email } = req.input; @@ -34,7 +34,7 @@ export const para = router({ /** * Deprecated: use case_manager.addStaff instead */ - createPara: authenticatedProcedure + createPara: hasAuthenticated .input( z.object({ first_name: z.string(), @@ -61,7 +61,7 @@ export const para = router({ // TODO elsewhere: add "email_verified_at" timestamp when para first signs in with their email address (entered into db by cm) }), - getMyTasks: authenticatedProcedure.query(async (req) => { + getMyTasks: hasAuthenticated.query(async (req) => { const { userId } = req.ctx.auth; const result = await req.ctx.db diff --git a/src/backend/routers/public.ts b/src/backend/routers/public.ts index ae5de6b7..d7ab270d 100644 --- a/src/backend/routers/public.ts +++ b/src/backend/routers/public.ts @@ -1,7 +1,7 @@ -import { publicProcedure, router } from "../trpc"; +import { noAuth, router } from "../trpc"; export const publicRouter = router({ - healthCheck: publicProcedure.query(() => { + healthCheck: noAuth.query(() => { return "Ok"; }), }); diff --git a/src/backend/routers/student.ts b/src/backend/routers/student.ts index 263a09d4..9170475d 100644 --- a/src/backend/routers/student.ts +++ b/src/backend/routers/student.ts @@ -1,9 +1,9 @@ import { z } from "zod"; -import { authenticatedProcedure, router } from "../trpc"; +import { hasAuthenticated, router } from "../trpc"; // TODO: define .output() schemas for all procedures export const student = router({ - getStudentById: authenticatedProcedure + getStudentById: hasAuthenticated .input(z.object({ student_id: z.string().uuid() })) .query(async (req) => { const { student_id } = req.input; @@ -17,7 +17,7 @@ export const student = router({ return result; }), - getStudentByTaskId: authenticatedProcedure + getStudentByTaskId: hasAuthenticated .input(z.object({ task_id: z.string().uuid() })) .query(async (req) => { const { task_id } = req.input; @@ -38,7 +38,7 @@ export const student = router({ /** * Adds a new IEP for the given student. */ - addIep: authenticatedProcedure + addIep: hasAuthenticated .input( z.object({ student_id: z.string(), @@ -67,7 +67,7 @@ export const student = router({ /** * Adds a new IEP for the given student. */ - editIep: authenticatedProcedure + editIep: hasAuthenticated .input( z.object({ student_id: z.string(), @@ -104,7 +104,7 @@ export const student = router({ /** * Returns all the IEPs associated with the given student. */ - getIeps: authenticatedProcedure + getIeps: hasAuthenticated .input( z.object({ student_id: z.string(), @@ -130,7 +130,7 @@ export const student = router({ * per the MVP that there will only be one IEP per student, * but this should be revisited after the MVP. */ - getActiveStudentIep: authenticatedProcedure + getActiveStudentIep: hasAuthenticated .input( z.object({ student_id: z.string().uuid(), diff --git a/src/backend/routers/user.ts b/src/backend/routers/user.ts index aa031aeb..0c0bfdb3 100644 --- a/src/backend/routers/user.ts +++ b/src/backend/routers/user.ts @@ -1,7 +1,7 @@ -import { authenticatedProcedure, router } from "../trpc"; +import { hasAuthenticated, router } from "../trpc"; export const user = router({ - getMe: authenticatedProcedure.query(async (req) => { + getMe: hasAuthenticated.query(async (req) => { const { userId } = req.ctx.auth; const user = await req.ctx.db @@ -23,7 +23,7 @@ export const user = router({ /** * @returns Whether the current user is a case manager */ - isCaseManager: authenticatedProcedure.query(async (req) => { + isCaseManager: hasAuthenticated.query(async (req) => { const { userId } = req.ctx.auth; const result = await req.ctx.db diff --git a/src/backend/trpc.ts b/src/backend/trpc.ts index 58391158..a149497d 100644 --- a/src/backend/trpc.ts +++ b/src/backend/trpc.ts @@ -2,6 +2,28 @@ import { TRPCError, initTRPC } from "@trpc/server"; import { createContext } from "./context"; import superjson from "superjson"; +// Role-based access control type +type RoleLevel = { + user: 0; + para: 1; + case_manager: 2; + admin: 3; +}; + +const ROLE_LEVELS: RoleLevel = { + user: 0, + para: 1, + case_manager: 2, + admin: 3, +}; + +type Role = keyof RoleLevel; + +// Function to compare roles +function hasMinimumRole(userRole: Role, requiredRole: Role): boolean { + return ROLE_LEVELS[userRole] >= ROLE_LEVELS[requiredRole]; +} + // initialize tRPC exactly once per application: export const t = initTRPC.context().create({ // SuperJSON allows us to transparently use, e.g., standard Date/Map/Sets @@ -22,8 +44,43 @@ const isAuthenticated = t.middleware(({ next, ctx }) => { }); }); +const atLeastPara = t.middleware(({ next, ctx }) => { + if ( + ctx.auth.type !== "session" || + !hasMinimumRole(ctx.auth.role as Role, "para") + ) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + + return next({ + ctx: { + ...ctx, + auth: ctx.auth, + }, + }); +}); + +const atLeastCaseManager = t.middleware(({ next, ctx }) => { + if ( + ctx.auth.type !== "session" || + !hasMinimumRole(ctx.auth.role as Role, "case_manager") + ) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + + return next({ + ctx: { + ...ctx, + auth: ctx.auth, + }, + }); +}); + const isAdmin = t.middleware(({ next, ctx }) => { - if (ctx.auth.type !== "session" || ctx.auth.role !== "admin") { + if ( + ctx.auth.type !== "session" || + !hasMinimumRole(ctx.auth.role as Role, "admin") + ) { throw new TRPCError({ code: "UNAUTHORIZED" }); } @@ -37,6 +94,17 @@ const isAdmin = t.middleware(({ next, ctx }) => { // Define and export the tRPC router export const router = t.router; -export const publicProcedure = t.procedure; -export const authenticatedProcedure = t.procedure.use(isAuthenticated); -export const adminProcedure = t.procedure.use(isAuthenticated).use(isAdmin); + +// Define and export the tRPC procedures that can be used as auth guards inside routes +export const noAuth = t.procedure; // Can be used for public routes +export const hasAuthenticated = t.procedure // for routes that require authentication only, no specific role + .use(isAuthenticated); +export const hasPara = t.procedure // for routes that require at least para role + .use(isAuthenticated) + .use(atLeastPara); +export const hasCaseManager = t.procedure // for routes that require at least case manager role + .use(isAuthenticated) + .use(atLeastCaseManager); +export const hasAdmin = t.procedure // for routes that require admin role + .use(isAuthenticated) + .use(isAdmin);