Skip to content

Commit

Permalink
chore(cb2-14762): implement microsoft key caching to combat rate limi…
Browse files Browse the repository at this point in the history
…ting (#109)

* chore: debug auth flow

* chore: cache ms keys

* chore: cache ms keys

* chore: implement logs behind debug flag

* chore: add test for checking fetch keys only called once

* chore: add extra config to logger
  • Loading branch information
matthew2564 authored Oct 21, 2024
1 parent 90ef922 commit da37cb5
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 3 deletions.
29 changes: 29 additions & 0 deletions src/common/Logger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,32 @@ export const writeLogMessage = (event: APIGatewayTokenAuthorizerEvent, log: ILog
}
return log;
};

export enum LogLevel {
DEBUG = "DEBUG",
INFO = "INFO",
WARN = "WARN",
ERROR = "ERROR",
}

export const envLogger = (level: LogLevel, ...messages: string[]) => {
if (process.env.DEBUG === "true" || process.env.DEBUG === "log") {
switch (level) {
case LogLevel.DEBUG:
console.debug(messages);
break;
case LogLevel.INFO:
console.info(messages);
break;
case LogLevel.WARN:
console.warn(messages);
break;
case LogLevel.ERROR:
console.error(messages);
break;
default:
console.log(messages);
return;
}
}
};
14 changes: 13 additions & 1 deletion src/functions/authorizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { generatePolicy as generateFunctionalPolicy } from "./functionalPolicyFa
import { getValidJwt } from "../services/tokens";
import { JWT_MESSAGE } from "../models/enums";
import { ILogEvent } from "../models/ILogEvent";
import { writeLogMessage } from "../common/Logger";
import { envLogger, LogLevel, writeLogMessage } from "../common/Logger";
import newPolicyDocument from "./newPolicyDocument";
import { Jwt, JwtPayload } from "jsonwebtoken";

Expand All @@ -20,25 +20,35 @@ import { Jwt, JwtPayload } from "jsonwebtoken";
export const authorizer = async (event: APIGatewayTokenAuthorizerEvent, context: Context): Promise<APIGatewayAuthorizerResult> => {
const logEvent: ILogEvent = {};

envLogger(LogLevel.DEBUG, "Invoked authoriser");

if (!process.env.AZURE_TENANT_ID || !process.env.AZURE_CLIENT_ID) {
writeLogMessage(event, logEvent, JWT_MESSAGE.INVALID_ID_SETUP);
return unauthorisedPolicy();
}

envLogger(LogLevel.DEBUG, "AZURE_TENANT_ID and AZURE_CLIENT_ID are set");

try {
initialiseLogEvent(event);

envLogger(LogLevel.INFO, "Getting valid JWT");
const jwt = await getValidJwt(event.authorizationToken, logEvent, process.env.AZURE_TENANT_ID, process.env.AZURE_CLIENT_ID);

envLogger(LogLevel.INFO, "Generating role policy");
const policy = generateRolePolicy(jwt, logEvent) ?? generateFunctionalPolicy(jwt, logEvent);

if (policy !== undefined) {
envLogger(LogLevel.INFO, "Role policy generated");
return policy;
}

reportNoValidRoles(jwt, event, context, logEvent);
writeLogMessage(event, logEvent, JWT_MESSAGE.INVALID_ROLES);

return unauthorisedPolicy();
} catch (error: any) {
envLogger(LogLevel.ERROR, "Catch - Error occurred", error);
writeLogMessage(event, logEvent, error);
return unauthorisedPolicy();
}
Expand Down Expand Up @@ -67,6 +77,8 @@ const reportNoValidRoles = (jwt: Jwt, event: APIGatewayTokenAuthorizerEvent, con
* @param event
*/
const initialiseLogEvent = (event: APIGatewayTokenAuthorizerEvent): ILogEvent => {
envLogger(LogLevel.DEBUG, "Init log event");

return {
requestUrl: event.methodArn,
timeOfRequest: new Date().toISOString(),
Expand Down
18 changes: 17 additions & 1 deletion src/services/azure.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
import { KeyResponse } from "../models/KeyResponse";
import { envLogger, LogLevel } from "../common/Logger";

const cache: Map<string, Map<string, string>> = new Map();

export const getCertificateChain = async (tenantId: string, keyId: string): Promise<string> => {
const keys: Map<string, string> = await getKeys(tenantId);
const cacheKeys = cache.get(tenantId);

envLogger(LogLevel.DEBUG, `Cache ${cacheKeys ? "hit" : "not hit"}`);

const keys: Map<string, string> = cacheKeys ?? (await getKeys(tenantId));

envLogger(LogLevel.DEBUG, "Public keys read");

if (!cache.has(tenantId)) {
cache.set(tenantId, keys);
}

const certificateChain = keys.get(keyId);

Expand All @@ -25,9 +38,12 @@ const getKeys = async (tenantId: string): Promise<Map<string, string>> => {

map.set(keyId, certificateChain);
}

envLogger(LogLevel.DEBUG, "Key Map Created");
return map;
};

export const fetchKeys = (tenantId: string) => {
envLogger(LogLevel.DEBUG, `Fetching keys from https://login.microsoftonline.com/${tenantId}/discovery/keys`);
return fetch(`https://login.microsoftonline.com/${tenantId}/discovery/keys`);
};
3 changes: 3 additions & 0 deletions src/services/signature-check.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import * as JWT from "jsonwebtoken";
import { getCertificateChain } from "./azure";
import { envLogger, LogLevel } from "../common/Logger";

export const checkSignature = async (encodedToken: string, decodedToken: JWT.Jwt, tenantId: string, clientId: string): Promise<void> => {
// tid = tenant ID, kid = key ID
envLogger(LogLevel.DEBUG, "Getting cert chain");
const certificate = await getCertificateChain(tenantId, decodedToken.header.kid as string);

envLogger(LogLevel.INFO, "Verifying token");
JWT.verify(encodedToken, certificate, {
audience: clientId.split(","),
issuer: [`https://sts.windows.net/${tenantId}/`, `https://login.microsoftonline.com/${tenantId}/v2.0`],
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/services/azure.unitTest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ describe("getCertificateChain()", () => {
it("should throw an error if no key matches the given key ID", async (): Promise<void> => {
fetchSpy("somethingElse", "mySuperSecurePublicKey");

await expect(azure.getCertificateChain("tenantId", "keyToTheKingdom")).rejects.toThrow("no public key");
await expect(azure.getCertificateChain("tenantId", "otherKeyToTheKingdom")).rejects.toThrow("no public key");
});

// simulate multiple calls to the function
[1, 2, 3].forEach(() => {
it(`should call fetchKeys only once and then hit the cache`, async (): Promise<void> => {
const publicKey = "mySuperSecurePublicKey";
fetchSpy("keyToTheKingdom", publicKey);

await azure.getCertificateChain("tenantId", "keyToTheKingdom");
expect(azure.fetchKeys).toHaveBeenCalledTimes(1);
});
});
});

0 comments on commit da37cb5

Please sign in to comment.