diff --git a/src/commands/ssh.ts b/src/commands/ssh.ts index 3fd0ed9..28e168b 100644 --- a/src/commands/ssh.ts +++ b/src/commands/ssh.ts @@ -27,20 +27,37 @@ import { getDoc, onSnapshot } from "firebase/firestore"; import { pick } from "lodash"; import yargs from "yargs"; +type SshCommandArgs = { + instance: string; + command?: string; + arguments: string[]; +}; + /** Maximum amount of time to wait after access is approved to wait for access * to be configured */ const GRANT_TIMEOUT_MILLIS = 60e3; export const sshCommand = (yargs: yargs.Argv) => - yargs.command<{ instance: string }>( - "ssh ", + yargs.command( + "ssh [command [arguments..]]", "SSH into a virtual machine", (yargs) => - yargs.positional("instance", { - type: "string", - demandOption: true, - }), + yargs + .positional("instance", { + type: "string", + demandOption: true, + }) + .positional("command", { + type: "string", + describe: "Pass command to the shell", + }) + .positional("arguments", { + describe: "Command arguments", + array: true, + string: true, + default: [] as string[], + }), guard(ssh) ); @@ -107,7 +124,7 @@ const waitForProvisioning = async

( * Supported SSH mechanisms: * - AWS EC2 via SSM with Okta SAML */ -const ssh = async (args: yargs.ArgumentsCamelCase<{ instance: string }>) => { +const ssh = async (args: yargs.ArgumentsCamelCase) => { // Prefix is required because the backend uses it to determine that this is an AWS request const authn = await authenticate(); await validateSshInstall(authn); @@ -127,5 +144,18 @@ const ssh = async (args: yargs.ArgumentsCamelCase<{ instance: string }>) => { const { id, isPreexisting } = response; if (!isPreexisting) print2("Waiting for access to be provisioned"); const requestData = await waitForProvisioning(authn, id); - await ssm(authn, { ...requestData, id }); + await ssm(authn, { + ...requestData, + id, + command: args.command + ? `${args.command} ${args.arguments + .map( + (argument) => + // escape all double quotes (") in commands such as `p0 ssh > echo 'hello; "world"'` because we + // need to encapsulate command arguments in double quotes as we pass them along to the remote shell + `"${argument.replace(/"/g, '\\"')}"` + ) + .join(" ")}`.trim() + : undefined, + }); }; diff --git a/src/plugins/aws/ssm.ts b/src/plugins/aws/ssm.ts index 3ae703a..bf01cb6 100644 --- a/src/plugins/aws/ssm.ts +++ b/src/plugins/aws/ssm.ts @@ -41,6 +41,7 @@ type SsmArgs = { requestId: string; documentName: string; credential: AwsCredentials; + command?: string; }; /** Checks if access has propagated through AWS to the SSM agent @@ -84,26 +85,36 @@ const accessPropagationGuard = ( }; }; +const createSsmCommand = (args: Omit) => { + const ssmCommand = [ + "aws", + "ssm", + "start-session", + "--region", + args.region, + "--target", + args.instance, + "--document-name", + args.documentName, + ]; + + if (args.command && args.command.trim()) { + ssmCommand.push("--parameters", `command='${args.command}'`); + } + + return ssmCommand; +}; + /** Starts an SSM session in the terminal by spawning `aws ssm` as a subprocess * * Requires `aws ssm` to be installed on the client machine. */ const spawnSsmNode = async ( - args: Pick, + args: Omit, options?: { attemptsRemaining?: number } ): Promise => new Promise((resolve, reject) => { - const ssmCommand = [ - "aws", - "ssm", - "start-session", - "--region", - args.region, - "--target", - args.instance, - "--document-name", - args.documentName, - ]; + const ssmCommand = createSsmCommand(args); const child = spawn("/usr/bin/env", ssmCommand, { env: { ...process.env, @@ -145,7 +156,7 @@ const spawnSsmNode = async ( /** Connect to an SSH backend using AWS Systems Manager (SSM) */ export const ssm = async ( authn: Authn, - request: Request & { id: string } + request: Request & { id: string; command?: string } ) => { const match = request.permission.spec.arn.match(INSTANCE_ARN_PATTERN); if (!match) throw "Did not receive a properly formatted instance identifier"; @@ -161,6 +172,7 @@ export const ssm = async ( documentName: request.generated.documentName, requestId: request.id, credential, + command: request.command, }; await spawnSsmNode(args); };