diff --git a/website/api/nus/auth/login.ts b/website/api/nus/auth/login.ts index fc6a7211e1..cef11b0bc2 100644 --- a/website/api/nus/auth/login.ts +++ b/website/api/nus/auth/login.ts @@ -1,4 +1,4 @@ -import { authenticate } from '../../../src/serverless/nus-auth'; +import { authenticate, isCallbackUrlValid } from '../../../src/serverless/nus-auth'; import { createRouteHandler, defaultFallback, @@ -9,6 +9,7 @@ import { const errors = { noRelayState: 'ERR_NO_RELAY_STATE', + invalidRelayState: 'ERR_INVALID_RELAY_STATE', }; const handlePost: Handler = async (req, res) => { @@ -18,6 +19,9 @@ const handlePost: Handler = async (req, res) => { throw new Error(errors.noRelayState); } + if (!isCallbackUrlValid(relayState)) { + throw new Error(errors.invalidRelayState); + } const userURL = new URL(relayState); userURL.searchParams.append('token', token); @@ -27,6 +31,10 @@ const handlePost: Handler = async (req, res) => { res.json({ message: 'Relay state not found in request', }); + } else if (err.message === errors.invalidRelayState) { + res.json({ + message: 'Invalid relay state given. URL must be from a valid domain.', + }); } else { throw err; } diff --git a/website/api/nus/auth/sso.ts b/website/api/nus/auth/sso.ts index 0fed0c7c10..a2d5158e01 100644 --- a/website/api/nus/auth/sso.ts +++ b/website/api/nus/auth/sso.ts @@ -1,4 +1,4 @@ -import { createLoginURL } from '../../../src/serverless/nus-auth'; +import { createLoginURL, isCallbackUrlValid } from '../../../src/serverless/nus-auth'; import { createRouteHandler, defaultFallback, @@ -9,6 +9,7 @@ import { const errors = { noCallbackUrl: 'ERR_NO_REFERER', + invalidCallbackUrl: 'ERR_INVALID_REFERER', }; function getCallbackUrl(callback: string | string[] | undefined) { @@ -24,12 +25,20 @@ const handleGet: Handler = async (req, res) => { throw new Error(errors.noCallbackUrl); } + if (!isCallbackUrlValid(callback)) { + throw new Error(errors.invalidCallbackUrl); + } + res.send(createLoginURL(callback)); } catch (err) { if (err.message === errors.noCallbackUrl) { res.json({ message: 'Request needs a referer', }); + } else if (err.message === errors.invalidCallbackUrl) { + res.json({ + message: 'Invalid referer given. URL must be from a valid domain.', + }); } else { throw err; } diff --git a/website/api/nus/auth/user.ts b/website/api/nus/auth/user.ts new file mode 100644 index 0000000000..bd6834c74b --- /dev/null +++ b/website/api/nus/auth/user.ts @@ -0,0 +1,19 @@ +import { authenticate } from '../../../src/serverless/nus-auth'; +import { + createRouteHandler, + defaultFallback, + defaultRescue, + Handler, + MethodHandlers, +} from '../../../src/serverless/handler'; + +const handleGet: Handler = async (req, res) => { + const { user } = await authenticate(req); + res.json(user); +}; + +const methodHandlers: MethodHandlers = { + GET: handleGet, +}; + +export default createRouteHandler(methodHandlers, defaultFallback, defaultRescue(true)); diff --git a/website/src/serverless/nus-auth.ts b/website/src/serverless/nus-auth.ts index b75f4ee065..d03c6622b3 100644 --- a/website/src/serverless/nus-auth.ts +++ b/website/src/serverless/nus-auth.ts @@ -15,6 +15,9 @@ const errors = { noTokenSupplied: 'ERR_NO_TOKEN_SUPPLIED', }; +// Domains allowed as callback URLs +const allowedDomains = ['nusmods.com', 'nuscourses.com', 'modsn.us', 'localhost']; + export type User = { accountName: string; upn: string; @@ -45,6 +48,28 @@ export const createLoginURL = (relayState = '') => { return ssoLoginURL.toString(); }; +export const isCallbackUrlValid = (callbackUrl: string): boolean => { + try { + const url = new URL(callbackUrl); + + const validMatch = allowedDomains.some( + (allowedDomain) => + url.hostname.endsWith(`.${allowedDomain}`) || url.hostname === allowedDomain, + ); + + if (!validMatch) { + // eslint-disable-next-line no-console + console.error('Invalid callback URL given by user:', callbackUrl); + } + + return validMatch; + } catch (error) { + // eslint-disable-next-line no-console + console.error('Invalid callback URL:', error); + return false; + } +}; + export const authenticate = async (req: Request) => { const tokenProvided = req.headers.authorization || (req.body && req.body.SAMLResponse); if (!tokenProvided) {