From 9bb10deeac1159b55c1008486b2bcfd8f3799252 Mon Sep 17 00:00:00 2001 From: Pedram Bakh <56321501+PedramBakh@users.noreply.github.com> Date: Wed, 22 Nov 2023 13:10:29 +0100 Subject: [PATCH] Enchance argument parsing to support flexible command execution with and without --cmd flag --- carbontracker/cli.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/carbontracker/cli.py b/carbontracker/cli.py index e1ae05f..f69dc87 100644 --- a/carbontracker/cli.py +++ b/carbontracker/cli.py @@ -1,20 +1,18 @@ import argparse -import shlex import subprocess -from carbontracker.tracker import CarbonTracker import ast - +from carbontracker.tracker import CarbonTracker +import shlex def main(): + # Create the parser for the script's own arguments parser = argparse.ArgumentParser(description="CarbonTracker CLI") - - # Accept a list of arguments - parser.add_argument("command", type=str, nargs='+', - help="Command and arguments to execute. E.g., 'python myscript.py arg1 arg2'") parser.add_argument("--log_dir", type=str, help="Log directory", default="./logs") parser.add_argument("--api_keys", type=str, help="API keys in a dictionary-like format, e.g., " "'{\"electricitymaps\": \"YOUR_KEY\"}'", default=None) - args = parser.parse_args() + + # Parse known arguments and remaining arguments + args, remaining_args = parser.parse_known_args() # Parse the API keys string into a dictionary api_keys = ast.literal_eval(args.api_keys) if args.api_keys else None @@ -22,16 +20,22 @@ def main(): tracker = CarbonTracker(epochs=1, log_dir=args.log_dir, epochs_before_pred=0, api_keys=api_keys) tracker.epoch_start() - # Execute the provided command with its arguments - try: - subprocess.run(args.command, check=True) - except subprocess.CalledProcessError: - print(f"Error executing command: {' '.join(map(shlex.quote, args.command))}") - # Handle errors or exceptions if needed + # Handle the command + if remaining_args: + if '--cmd' in remaining_args: + cmd_index = remaining_args.index('--cmd') + command_args = remaining_args[cmd_index + 1:] + else: + command_args = remaining_args + + # Execute the command + try: + subprocess.run(command_args, check=True) + except subprocess.CalledProcessError: + print(f"Error executing command: {' '.join(map(shlex.quote, command_args))}") tracker.epoch_end() tracker.stop() - if __name__ == "__main__": main()