diff --git a/lib/get-aws-sts-token.js b/lib/get-aws-sts-token.js index 018670e..0686506 100644 --- a/lib/get-aws-sts-token.js +++ b/lib/get-aws-sts-token.js @@ -7,22 +7,23 @@ const CACHE_EXPIRY = process.env.AWS_STS_SESSION_RESET_EXPIRY || (EXPIRY - 600); async function getAwsAuthToken( logger, createHash, retrieveHash, - {accessKeyId, secretAccessKey, region, roleArn}) { + {speech_credential_sid, accessKeyId, secretAccessKey, region, roleArn}) { logger = logger || noopLogger; try { - const key = makeAwsKey(roleArn || accessKeyId); + // if incase instance profile is used, speech_credential_sid will be used as key to lookup cache + const key = makeAwsKey(roleArn || accessKeyId || speech_credential_sid); const obj = await retrieveHash(key); if (obj) return {...obj, servedFromCache: true}; - + /* access token not found in cache, so generate it using STS */ let data; + let expiry = CACHE_EXPIRY; if (roleArn) { const stsClient = new STSClient({ region }); const roleToAssume = { RoleArn: roleArn, RoleSessionName: 'Jambonz_Speech', DurationSeconds: EXPIRY}; const command = new AssumeRoleCommand(roleToAssume); data = await stsClient.send(command); - } else { - /* access token not found in cache, so generate it using STS */ + } else if (accessKeyId) { const stsClient = new STSClient({ region, credentials: { @@ -32,6 +33,26 @@ async function getAwsAuthToken( }); const command = new GetSessionTokenCommand({DurationSeconds: EXPIRY}); data = await stsClient.send(command); + } else { + // instance profile is used. + const stsClient = new STSClient({ region }); + const cred = await stsClient.config.credentials(); + // method in the AWS SDK automatically fetches credentials using the default credential + // provider chain. If the credentials come from an instance profile or an environment + // variable, their expiration is controlled by AWS and not explicitly by our code. + if (cred && cred.expiration) { + const currentTime = new Date(); + const expiryTime = new Date(cred.expiration); + const remainingTimeInSeconds = Math.round((expiryTime - currentTime) / 1000); + expiry = remainingTimeInSeconds; + } + data = { + Credentials: { + AccessKeyId: cred.accessKeyId, + SecretAccessKey: cred.secretAccessKey, + SessionToken: cred.sessionToken + } + }; } const credentials = { @@ -40,9 +61,11 @@ async function getAwsAuthToken( sessionToken: data.Credentials.SessionToken, securityToken: data.Credentials.SessionToken }; - - createHash(key, credentials, CACHE_EXPIRY) - .catch((err) => logger.error(err, `Error saving hash for key ${key}`)); + // Only cache if expiry is good + if (expiry > 0) { + createHash(key, credentials, expiry) + .catch((err) => logger.error(err, `Error saving hash for key ${key}`)); + } return {...credentials, servedFromCache: false}; } catch (err) {