From 7d6f7b66029e07eb42da5015107a6414ae807610 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 4 Nov 2025 10:48:39 -0700 Subject: [PATCH 1/5] Cleaned up make_internal_command --- mason.py | 588 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 334 insertions(+), 254 deletions(-) diff --git a/mason.py b/mason.py index 447286ba56..1fa27a82b8 100644 --- a/mason.py +++ b/mason.py @@ -52,6 +52,47 @@ def parse_env_var(env_var_str: str) -> Dict[str, str]: return {"name": name, "value": value} +def get_clusters(beaker_client: beaker.Beaker = None) -> tuple[list[str], list[str], list[str]]: + """Get cluster lists from Beaker API or return defaults. + + Returns: + Tuple of (weka_clusters, gcp_clusters, interconnect_clusters) + """ + default_weka = ["ai2/jupiter", "ai2/saturn", "ai2/titan", "ai2/neptune", "ai2/ceres", "ai2/triton", "ai2/rhea"] + default_gcp = ["ai2/augusta"] + default_interconnect = ["ai2/jupiter", "ai2/ceres", "ai2/titan", "ai2/augusta"] + + if beaker_client is None: + return default_weka, default_gcp, default_interconnect + + weka_clusters = [] + gcp_clusters = [] + interconnect_clusters = [] + + for cluster in beaker_client.cluster.list(): + cluster_name = f"ai2/{cluster.name}" + has_interconnect = False + has_gcp = False + has_weka = False + + for tag in cluster.tags: + if tag.startswith("interconnect:"): + has_interconnect = True + if tag.startswith("provider:gcp"): + has_gcp = True + if tag.startswith("storage:weka"): + has_weka = True + + if has_interconnect: + interconnect_clusters.append(cluster_name) + if has_gcp: + gcp_clusters.append(cluster_name) + if has_weka: + weka_clusters.append(cluster_name) + + return weka_clusters, gcp_clusters, interconnect_clusters + + WEKA_CLUSTERS = ["ai2/jupiter", "ai2/saturn", "ai2/titan", "ai2/neptune", "ai2/ceres", "ai2/triton", "ai2/rhea"] GCP_CLUSTERS = ["ai2/augusta"] @@ -386,8 +427,285 @@ def get_datasets(beaker_datasets, cluster: List[str]): return res +def maybe_cache_dataset(command: List[str], args: argparse.Namespace) -> tuple[List[str], List[str], List[str]]: + """Cache datasets locally before running on beaker if auto-caching is enabled. + + Returns: + Tuple of (modified_command, dataset_cache_paths, dataset_config_hashes) + """ + + def find_list_idx(lst: List[str], item: str): + for i in range(len(lst)): + if item == lst[i]: + return i + return -1 + + def remove_arg_from_list(lst: List[str], item: str, remove_value: bool = False): + idx = find_list_idx(lst, item) + if idx != -1 and idx + 1 < len(lst): + if remove_value: + lst.pop(idx + 1) + lst.pop(idx) + + if not any("hf_entity" in c for c in command): + command.append("--hf_entity") + command.append("allenai") + if not any("wandb_entity" in c for c in command): + command.append("--wandb_entity") + command.append("ai2-llm") + + dataset_cache_paths = [] + dataset_config_hashes = [] + + if not args.no_auto_dataset_cache: + for file in OPEN_INSTRUCT_COMMANDS: + idx = find_list_idx(command, file) + if idx != -1: + caching_command = command.copy() + remove_arg_from_list(caching_command, "--with_tracking", False) + remove_arg_from_list(caching_command, "--checkpoint_state_freq", True) + remove_arg_from_list(caching_command, "--checkpoint_state_dir", True) + remove_arg_from_list(caching_command, "--gs_checkpoint_state_dir", True) + caching_command = "python " + " ".join(caching_command[idx:]) + " --cache_dataset_only" + console.log("📦📦📦 Running the caching command with `--cache_dataset_only`") + import subprocess + + process = subprocess.Popen( + caching_command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + stdout_data, stderr_data = [], [] + + streams = [process.stdout, process.stderr] + while True: + reads = select.select(streams, [], [])[0] + + done = True + for stream in reads: + line = stream.readline() + if line: + done = False + is_stdout = stream == process.stdout + print(line.rstrip(), file=sys.stdout if is_stdout else sys.stderr) + if is_stdout: + stdout_data.append(line) + else: + stderr_data.append(line) + + if done and process.poll() is not None: + break + + result = type( + "SubprocessResult", + (), + { + "returncode": process.returncode, + "stdout": "".join(stdout_data), + "stderr": "".join(stderr_data), + }, + ) + stdout = result.stdout + for line in stdout.splitlines(): + if "✅ Found cached dataset at" in line: + dataset_cache_path = line.split("✅ Found cached dataset at")[1].strip() + dataset_config_hash = dataset_cache_path.split("/")[-1] + console.log(f"📦 Found cached dataset at: {dataset_cache_path}") + console.log(f"📦 Found cached dataset config hash: {dataset_config_hash}") + dataset_cache_paths.append(dataset_cache_path) + dataset_config_hashes.append(dataset_config_hash) + stderr = result.stderr + return_code = result.returncode + if return_code != 0: + raise Exception(f"Error code {return_code} when creating cached dataset") + console.log("✅✅✅ Finished running the caching command") + + return command, dataset_cache_paths, dataset_config_hashes + + +def maybe_override_output_dir( + command: List[str], + args: argparse.Namespace, + whoami: str, + is_external_user: bool, + is_open_instruct_training: bool, + weka_clusters: List[str], +) -> List[str]: + """Override output_dir for Weka clusters to enable auto-evaluation. + + Returns: + Modified command list + """ + if any(c in weka_clusters for c in args.cluster): + if len(args.auto_output_dir_path) > 0: + need_to_override_output_dir = True + for idx, cmd in enumerate(command): + if cmd == "--output_dir": + if "/weka/" in command[idx + 1]: + need_to_override_output_dir = False + break + if need_to_override_output_dir and is_open_instruct_training and not is_external_user: + new_output_dir = f"{args.auto_output_dir_path}/{whoami}/" + console.log(f"🔍🔍🔍 Automatically overriding the `--output_dir` argument to be in `{new_output_dir}`") + command.append("--output_dir") + command.append(new_output_dir) + else: + no_eval_commands = [ + ["--try_launch_beaker_eval_jobs", "False"], + ["--try_launch_beaker_eval_jobs_on_weka", "False"], + ["--no_try_launch_beaker_eval_jobs"], + ["--no_try_launch_beaker_eval_jobs_on_weka"], + ] + no_eval_concat_commands = [" ".join(cmd) for cmd in no_eval_commands] + no_eval_concat_command_exists = any(cmd in command for cmd in no_eval_concat_commands) + if not no_eval_concat_command_exists: + raise ValueError( + "To auto-evaluation is turned on by default, to make sure it works, you must:\n" + "1. run mason with`--auto_output_dir_path /weka/...`, or\n" + "2. in the training command, disable auto-evaluation with `--no_try_launch_beaker_eval_jobs`, or\n" + "3. in the training command, use a `--output_dir` that starts with `/weka/`" + ) + + return command + + +def maybe_optimize_gcp_model_loading( + command: List[str], + args: argparse.Namespace, + dataset_cache_paths: List[str], + dataset_config_hashes: List[str], + gcp_clusters: List[str], +) -> List[str]: + """Optimize model loading for GCP clusters by uploading to GCS and downloading on compute nodes. + + Returns: + Modified command list with GCS download prefix + """ + from open_instruct.dataset_transformation import get_commit_hash + from open_instruct.utils import download_from_hf, gs_folder_exists, upload_to_gs_bucket + + if any(c in gcp_clusters for c in args.cluster): + model_name_or_path = None + for idx, cmd in enumerate(command): + if cmd == "--model_name_or_path": + model_name_or_path = command[idx + 1] + break + model_revision = "main" + for idx, cmd in enumerate(command): + if cmd == "--model_revision": + model_revision = command[idx + 1] + break + + commit_hash = get_commit_hash(model_name_or_path, model_revision, "config.json", "model") + if os.path.exists(model_name_or_path): + path = model_name_or_path + assert args.gs_model_name is not None, "for local models to upload to gs, you must set --gs_model_name" + model_name_or_path = args.gs_model_name + commit_hash = hashlib.md5(model_name_or_path.encode("utf-8")).hexdigest()[:8] + console.log( + f"Local model is already downloaded, using gs_model_name {model_name_or_path}, with hash of model path {commit_hash}" + ) + else: + download_from_hf(model_name_or_path, model_revision) + path = download_from_hf(model_name_or_path, model_revision) + gs_saved_path = f"gs://ai2-llm/post-training/deletable_cache_models/{model_name_or_path}/{commit_hash}" + gs_folder = gs_folder_exists(gs_saved_path) + if not gs_folder: + upload_to_gs_bucket(path, gs_saved_path) + + download_path = gs_saved_path.replace("gs://", "/gs/") + download_path_without_last_folder = download_path.rsplit("/", 1)[0] + gs_download_command = [ + "mkdir", + "-p", + download_path, + "&&", + "gsutil", + "-o", + "GSUtil:parallel_thread_count=1", + "-o", + "GSUtil:sliced_object_download_threshold=150", + "-m", + "cp", + "-r", + gs_saved_path, + download_path_without_last_folder, + "&&", + "ls", + download_path_without_last_folder, + "&&", + "ls", + download_path, + "&&", + ] + + command.append("--gs_bucket_path") + command.append("gs://ai2-llm/post-training/") + + for idx, cmd in enumerate(command): + if cmd == "--model_name_or_path": + command[idx + 1] = download_path + break + for idx, cmd in enumerate(command): + if cmd == "--model_revision": + command[idx + 1] = "main" + break + + if len(dataset_cache_paths) > 0: + for cidx, (dataset_cache_path, dataset_config_hash) in enumerate( + zip(dataset_cache_paths, dataset_config_hashes) + ): + gs_saved_path = f"gs://ai2-llm/post-training/deletable_cache_datasets/{dataset_cache_path}" + gs_folder = gs_folder_exists(gs_saved_path) + if not gs_folder: + upload_to_gs_bucket(dataset_cache_path, gs_saved_path) + dataset_cache_path_without_last_folder = dataset_cache_path.rsplit("/", 1)[0] + gs_download_command += [ + "mkdir", + "-p", + dataset_cache_path_without_last_folder, + "&&", + "gsutil", + "cp", + "-r", + gs_saved_path, + dataset_cache_path_without_last_folder, + "&&", + "ls", + dataset_cache_path_without_last_folder, + "&&", + "ls", + dataset_cache_path, + "&&", + ] + if cidx == 0: + command.append("--dataset_config_hash") + command.append(dataset_config_hash) + elif cidx == 1: + command.append("--dataset_config_eval_hash") + command.append(dataset_config_hash) + command = gs_download_command + command + + return command + + +def escape_strings(command: List[str]) -> List[str]: + """Escape JSON strings in command arguments by wrapping them in single quotes. + + Returns: + Modified command list with escaped JSON strings + """ + for idx in range(len(command)): + if "{" in command[idx]: + command[idx] = "'" + command[idx] + "'" + return command + + def make_internal_command(command: List[str], args: argparse.Namespace, whoami: str, is_external_user: bool) -> str: - # pass through WANDB_ENTITY and WANDB_PROJECT if "WANDB_ENTITY" in os.environ: command = [f"WANDB_ENTITY={os.environ['WANDB_ENTITY']}"] + command if "WANDB_PROJECT" in os.environ: @@ -395,118 +713,26 @@ def make_internal_command(command: List[str], args: argparse.Namespace, whoami: if "WANDB_TAGS" in os.environ: command = [f"WANDB_TAGS={os.environ['WANDB_TAGS']}"] + command - # escape the command (e.g., --stop_strings "") for i in range(len(command)): if " 0: need_to_override_checkpoint_state_dir = True default_checkpoint_state_freq = 200 @@ -529,160 +755,14 @@ def remove_arg_from_list(lst: List[str], item: str, remove_value: bool = False): command.append("--checkpoint_state_freq") command.append(str(default_checkpoint_state_freq)) - # For Weka clusters, we need to override the output_dir parameter to make auto-evaluation work - # If the output_dir is already set to a path in /weka/, we'll keep that path - # Otherwise, we'll set a default path in the user's directory on Weka - if any(c in WEKA_CLUSTERS for c in args.cluster): - if len(args.auto_output_dir_path) > 0: - need_to_override_output_dir = True - for idx, cmd in enumerate(command): - if cmd == "--output_dir": - if "/weka/" in command[idx + 1]: - need_to_override_output_dir = False - break - if need_to_override_output_dir and is_open_instruct_training and not is_external_user: - new_output_dir = f"{args.auto_output_dir_path}/{whoami}/" - console.log( - f"🔍🔍🔍 Automatically overriding the `--output_dir` argument to be in `{new_output_dir}`" - ) - command.append("--output_dir") - command.append(new_output_dir) - else: - no_eval_commands = [ - ["--try_launch_beaker_eval_jobs", "False"], - ["--try_launch_beaker_eval_jobs_on_weka", "False"], - ["--no_try_launch_beaker_eval_jobs"], - ["--no_try_launch_beaker_eval_jobs_on_weka"], - ] - no_eval_concat_commands = [" ".join(cmd) for cmd in no_eval_commands] - no_eval_concat_command_exists = any(cmd in command for cmd in no_eval_concat_commands) - if not no_eval_concat_command_exists: - raise ValueError( - "To auto-evaluation is turned on by default, to make sure it works, you must:\n" - "1. run mason with`--auto_output_dir_path /weka/...`, or\n" - "2. in the training command, disable auto-evaluation with `--no_try_launch_beaker_eval_jobs`, or\n" - "3. in the training command, use a `--output_dir` that starts with `/weka/`" - ) - - # For GCP clusters, since shared storage is slow, we optimize model loading by: - if any(c in GCP_CLUSTERS for c in args.cluster): - # 1. First downloading the model from HuggingFace to a local path - # 2. Uploading it to a Google Storage bucket (if not already there) - # 3. Then downloading it from the bucket to the compute node - # 4. Finally, replacing the original --model_name_or_path argument with the local path - model_name_or_path = None - for idx, cmd in enumerate(command): - if cmd == "--model_name_or_path": - model_name_or_path = command[idx + 1] - break - model_revision = "main" - for idx, cmd in enumerate(command): - if cmd == "--model_revision": - model_revision = command[idx + 1] - break - - commit_hash = get_commit_hash(model_name_or_path, model_revision, "config.json", "model") - if os.path.exists(model_name_or_path): - path = model_name_or_path - assert args.gs_model_name is not None, "for local models to upload to gs, you must set --gs_model_name" - model_name_or_path = args.gs_model_name - commit_hash = hashlib.md5(model_name_or_path.encode("utf-8")).hexdigest()[:8] - console.log( - f"Local model is already downloaded, using gs_model_name {model_name_or_path}, with hash of model path {commit_hash}" - ) - else: - download_from_hf(model_name_or_path, model_revision) # first download the model - path = download_from_hf(model_name_or_path, model_revision) # then get the path - gs_saved_path = f"gs://ai2-llm/post-training/deletable_cache_models/{model_name_or_path}/{commit_hash}" - gs_folder = gs_folder_exists( - gs_saved_path - ) # race condition exists, but it's fine since we are launching mason sequentially - if not gs_folder: - upload_to_gs_bucket(path, gs_saved_path) - - download_path = gs_saved_path.replace("gs://", "/gs/") - download_path_without_last_folder = download_path.rsplit("/", 1)[0] - gs_download_command = [ - "mkdir", - "-p", - download_path, - "&&", - "gsutil", - "-o", - "GSUtil:parallel_thread_count=1", - "-o", - "GSUtil:sliced_object_download_threshold=150", - "-m", - "cp", - "-r", - gs_saved_path, - download_path_without_last_folder, - "&&", - "ls", - download_path_without_last_folder, - "&&", - "ls", - download_path, - "&&", - ] - - command.append("--gs_bucket_path") - command.append("gs://ai2-llm/post-training/") + command = maybe_override_output_dir( + command, args, whoami, is_external_user, is_open_instruct_training, WEKA_CLUSTERS + ) + command = maybe_optimize_gcp_model_loading( + command, args, dataset_cache_paths, dataset_config_hashes, GCP_CLUSTERS + ) - # Replace the model_name_or_path with the downloaded path - for idx, cmd in enumerate(command): - if cmd == "--model_name_or_path": - command[idx + 1] = download_path - break - for idx, cmd in enumerate(command): - if cmd == "--model_revision": - command[idx + 1] = "main" - break - - # Save dataset to GCS - if len(dataset_cache_paths) > 0: - for cidx, (dataset_cache_path, dataset_config_hash) in enumerate( - zip(dataset_cache_paths, dataset_config_hashes) - ): - gs_saved_path = f"gs://ai2-llm/post-training/deletable_cache_datasets/{dataset_cache_path}" - gs_folder = gs_folder_exists( - gs_saved_path - ) # race condition exists, but it's fine since we are launching mason sequentially - if not gs_folder: - upload_to_gs_bucket(dataset_cache_path, gs_saved_path) - dataset_cache_path_without_last_folder = dataset_cache_path.rsplit("/", 1)[0] - gs_download_command += [ - "mkdir", - "-p", - dataset_cache_path_without_last_folder, - "&&", - "gsutil", - "cp", - "-r", - gs_saved_path, - dataset_cache_path_without_last_folder, - "&&", - "ls", - dataset_cache_path_without_last_folder, - "&&", - "ls", - dataset_cache_path, - "&&", - ] - if cidx == 0: - command.append("--dataset_config_hash") - command.append(dataset_config_hash) - elif cidx == 1: - command.append("--dataset_config_eval_hash") - command.append(dataset_config_hash) - command = gs_download_command + command - - # special logic to deal with escape like - # python mason.py ... -- python x.py --dataset_mixer '{"trl-internal-testing/sentiment-trl-style": 1.0}' - # we need to wrap the json string with single quote - for idx in range(len(command)): - if "{" in command[idx]: - command[idx] = "'" + command[idx] + "'" + command = escape_strings(command) full_command = command setup_commands = "" if not args.pure_docker_mode: From 7c94ab70aa049a975b19bcb2504afbae20710042 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 4 Nov 2025 11:03:52 -0700 Subject: [PATCH 2/5] Updated code --- mason.py | 89 ++++++++++++++++++++++++++++++++------------------------ 1 file changed, 51 insertions(+), 38 deletions(-) diff --git a/mason.py b/mason.py index 1fa27a82b8..0bc0197752 100644 --- a/mason.py +++ b/mason.py @@ -8,6 +8,7 @@ import string import sys import time +from dataclasses import dataclass from typing import Dict, List import beaker @@ -32,6 +33,13 @@ OPEN_INSTRUCT_RESUMABLES = ["open_instruct/grpo_fast.py"] +@dataclass +class ClusterConfig: + weka: list[str] + gcp: list[str] + interconnect: list[str] + + # ---------------------------------------------------------------------- # Mason logic def parse_beaker_dataset(dataset_str): @@ -52,25 +60,31 @@ def parse_env_var(env_var_str: str) -> Dict[str, str]: return {"name": name, "value": value} -def get_clusters(beaker_client: beaker.Beaker = None) -> tuple[list[str], list[str], list[str]]: - """Get cluster lists from Beaker API or return defaults. +def get_clusters(beaker_client: beaker.Beaker, selected_clusters: List[str]) -> ClusterConfig: + """Get cluster properties for the user's selected clusters from Beaker API. + + Args: + beaker_client: Beaker client instance + selected_clusters: List of cluster names the user wants to use Returns: - Tuple of (weka_clusters, gcp_clusters, interconnect_clusters) + ClusterConfig with lists of which selected clusters have weka/gcp/interconnect properties """ - default_weka = ["ai2/jupiter", "ai2/saturn", "ai2/titan", "ai2/neptune", "ai2/ceres", "ai2/triton", "ai2/rhea"] - default_gcp = ["ai2/augusta"] - default_interconnect = ["ai2/jupiter", "ai2/ceres", "ai2/titan", "ai2/augusta"] - if beaker_client is None: - return default_weka, default_gcp, default_interconnect + raise ValueError("You need access to Beaker to run mason.py") weka_clusters = [] gcp_clusters = [] interconnect_clusters = [] - for cluster in beaker_client.cluster.list(): - cluster_name = f"ai2/{cluster.name}" + for cluster_full_name in selected_clusters: + cluster_short_name = cluster_full_name.replace("ai2/", "") + try: + cluster = beaker_client.cluster.get(cluster_short_name) + except Exception: + console.log(f"Warning: Could not get info for cluster {cluster_full_name}, skipping property checks") + continue + has_interconnect = False has_gcp = False has_weka = False @@ -84,19 +98,14 @@ def get_clusters(beaker_client: beaker.Beaker = None) -> tuple[list[str], list[s has_weka = True if has_interconnect: - interconnect_clusters.append(cluster_name) + interconnect_clusters.append(cluster_full_name) if has_gcp: - gcp_clusters.append(cluster_name) + gcp_clusters.append(cluster_full_name) if has_weka: - weka_clusters.append(cluster_name) - - return weka_clusters, gcp_clusters, interconnect_clusters - + weka_clusters.append(cluster_full_name) -WEKA_CLUSTERS = ["ai2/jupiter", "ai2/saturn", "ai2/titan", "ai2/neptune", "ai2/ceres", "ai2/triton", "ai2/rhea"] -GCP_CLUSTERS = ["ai2/augusta"] + return ClusterConfig(weka=weka_clusters, gcp=gcp_clusters, interconnect=interconnect_clusters) -INTERCONNECT_CLUSTERS = ["ai2/jupiter", "ai2/ceres", "ai2/titan", "ai2/augusta"] # by default, we turn off vllm compile cache # torch compile caching seems consistently broken, but the actual compiling isn't. @@ -271,6 +280,7 @@ def get_env_vars( num_nodes: int, additional_env_vars: List[Dict[str, str]], additional_secrets: List[Dict[str, str]], + cluster_config: ClusterConfig, ): additional_env_var_names = {var["name"] for var in additional_env_vars} @@ -310,7 +320,7 @@ def get_env_vars( env_vars.extend([beaker.BeakerEnvVar(name="PATH", value=os.getenv("PATH"))]) # if all cluster is in weka, we mount the weka - if all(c in WEKA_CLUSTERS for c in cluster): + if all(c in cluster_config.weka for c in cluster): env_vars.extend( [ beaker.BeakerEnvVar(name="HF_HOME", value="/weka/oe-adapt-default/allennlp/.cache/huggingface"), @@ -333,7 +343,7 @@ def get_env_vars( ) # if all cluster is in gcp we add the following env - elif all(c in GCP_CLUSTERS for c in cluster): + elif all(c in cluster_config.gcp for c in cluster): env_vars.extend( [ beaker.BeakerEnvVar(name="HF_HOME", value="/filestore/.cache/huggingface"), @@ -399,11 +409,11 @@ def get_env_vars( return env_vars -def get_datasets(beaker_datasets, cluster: List[str]): +def get_datasets(beaker_datasets, cluster: List[str], cluster_config: ClusterConfig): """if pure docker mode we don't mount the NFS; so we can run it on jupiter2""" res = [] # if all cluster is in weka, we mount the weka - if all(c in WEKA_CLUSTERS for c in cluster): + if all(c in cluster_config.weka for c in cluster): res = [ beaker.BeakerDataMount( source=beaker.BeakerDataSource(weka="oe-adapt-default"), mount_path="/weka/oe-adapt-default" @@ -412,7 +422,7 @@ def get_datasets(beaker_datasets, cluster: List[str]): source=beaker.BeakerDataSource(weka="oe-training-default"), mount_path="/weka/oe-training-default" ), ] - elif all(c in GCP_CLUSTERS for c in cluster): + elif all(c in cluster_config.gcp for c in cluster): res = [ beaker.BeakerDataMount( source=beaker.BeakerDataSource(host_path="/mnt/filestore_1"), mount_path="/filestore" @@ -533,14 +543,14 @@ def maybe_override_output_dir( whoami: str, is_external_user: bool, is_open_instruct_training: bool, - weka_clusters: List[str], + cluster_config: ClusterConfig, ) -> List[str]: """Override output_dir for Weka clusters to enable auto-evaluation. Returns: Modified command list """ - if any(c in weka_clusters for c in args.cluster): + if any(c in cluster_config.weka for c in args.cluster): if len(args.auto_output_dir_path) > 0: need_to_override_output_dir = True for idx, cmd in enumerate(command): @@ -578,7 +588,7 @@ def maybe_optimize_gcp_model_loading( args: argparse.Namespace, dataset_cache_paths: List[str], dataset_config_hashes: List[str], - gcp_clusters: List[str], + cluster_config: ClusterConfig, ) -> List[str]: """Optimize model loading for GCP clusters by uploading to GCS and downloading on compute nodes. @@ -588,7 +598,7 @@ def maybe_optimize_gcp_model_loading( from open_instruct.dataset_transformation import get_commit_hash from open_instruct.utils import download_from_hf, gs_folder_exists, upload_to_gs_bucket - if any(c in gcp_clusters for c in args.cluster): + if any(c in cluster_config.gcp for c in args.cluster): model_name_or_path = None for idx, cmd in enumerate(command): if cmd == "--model_name_or_path": @@ -705,7 +715,7 @@ def escape_strings(command: List[str]) -> List[str]: return command -def make_internal_command(command: List[str], args: argparse.Namespace, whoami: str, is_external_user: bool) -> str: +def make_internal_command(command: List[str], args: argparse.Namespace, whoami: str, is_external_user: bool, cluster_config: ClusterConfig) -> str: if "WANDB_ENTITY" in os.environ: command = [f"WANDB_ENTITY={os.environ['WANDB_ENTITY']}"] + command if "WANDB_PROJECT" in os.environ: @@ -756,10 +766,10 @@ def find_list_idx(lst: List[str], item: str): command.append(str(default_checkpoint_state_freq)) command = maybe_override_output_dir( - command, args, whoami, is_external_user, is_open_instruct_training, WEKA_CLUSTERS + command, args, whoami, is_external_user, is_open_instruct_training, cluster_config ) command = maybe_optimize_gcp_model_loading( - command, args, dataset_cache_paths, dataset_config_hashes, GCP_CLUSTERS + command, args, dataset_cache_paths, dataset_config_hashes, cluster_config ) command = escape_strings(command) @@ -790,9 +800,9 @@ def find_list_idx(lst: List[str], item: str): return full_command -def make_task_spec(args, full_command: str, i: int, beaker_secrets: str, whoami: str, resumable: bool): +def make_task_spec(args, full_command: str, i: int, beaker_secrets: str, whoami: str, resumable: bool, cluster_config: ClusterConfig): # Add a check to ensure that the user is using the correct clusters for multi-node jobs - if args.num_nodes > 1 and not all(c in INTERCONNECT_CLUSTERS for c in args.cluster): + if args.num_nodes > 1 and not all(c in cluster_config.interconnect for c in args.cluster): confirmation = False while not confirmation: confirmation = input( @@ -802,11 +812,11 @@ def make_task_spec(args, full_command: str, i: int, beaker_secrets: str, whoami: confirmation = True elif confirmation == "n": raise ValueError( - f"Interconnect clusters are required for multi-node jobs; please only use the following clusters: {INTERCONNECT_CLUSTERS}" + f"Interconnect clusters are required for multi-node jobs; please only use the following clusters: {cluster_config.interconnect}" ) else: print("Invalid input. Please enter 'y' or 'n'.") - if args.image == "ai2/cuda11.8-cudnn8-dev-ubuntu20.04" and any(c in GCP_CLUSTERS for c in args.cluster): + if args.image == "ai2/cuda11.8-cudnn8-dev-ubuntu20.04" and any(c in cluster_config.gcp for c in args.cluster): raise ValueError("GCP clusters do not have the dev filesystem, please use a proper image") if args.hostname is not None: @@ -819,7 +829,7 @@ def make_task_spec(args, full_command: str, i: int, beaker_secrets: str, whoami: command=["/bin/bash", "-c"], arguments=[full_command], result=beaker.BeakerResultSpec(path="/output"), - datasets=get_datasets(args.beaker_datasets, args.cluster), + datasets=get_datasets(args.beaker_datasets, args.cluster, cluster_config), context=beaker.BeakerTaskContext( priority=beaker.BeakerJobPriority[args.priority], preemptible=args.preemptible ), @@ -833,6 +843,7 @@ def make_task_spec(args, full_command: str, i: int, beaker_secrets: str, whoami: args.num_nodes, args.env, args.secret, + cluster_config, ), resources=beaker.BeakerTaskResources(gpu_count=args.gpus, shared_memory=args.shared_memory), replicas=args.num_nodes, @@ -860,6 +871,7 @@ def main(): if is_external_user: whoami = "external_user" beaker_secrets = [] + beaker_client = None else: if args.workspace: beaker_client = beaker.Beaker.from_env(default_workspace=args.workspace) @@ -868,7 +880,8 @@ def main(): beaker_secrets = [secret.name for secret in beaker_client.secret.list()] whoami = beaker_client.user.get().name - full_commands = [make_internal_command(command, args, whoami, is_external_user) for command in commands] + cluster_config = get_clusters(beaker_client, args.cluster) + full_commands = [make_internal_command(command, args, whoami, is_external_user, cluster_config) for command in commands] if is_external_user: console.rule("[bold red]Non-Ai2 User Detected[/bold red]") console.print( @@ -888,7 +901,7 @@ def main(): experiment_spec = beaker.BeakerExperimentSpec( description=args.description, tasks=[ - make_task_spec(args, full_command, i, beaker_secrets, whoami, args.resumable) + make_task_spec(args, full_command, i, beaker_secrets, whoami, args.resumable, cluster_config) for i, full_command in enumerate(full_commands) ], budget=args.budget, From 4d140f607f5a915a3732ca3a46d1314232376406 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 4 Nov 2025 11:27:41 -0700 Subject: [PATCH 3/5] Cleaned up code --- mason.py | 94 +++++++++++++++++++++++--------------------------------- 1 file changed, 38 insertions(+), 56 deletions(-) diff --git a/mason.py b/mason.py index b040490574..e4ffd3a9aa 100644 --- a/mason.py +++ b/mason.py @@ -42,9 +42,11 @@ @dataclass class ClusterConfig: - weka: list[str] - gcp: list[str] - interconnect: list[str] + all_weka: bool + any_weka: bool + all_gcp: bool + any_gcp: bool + all_interconnect: bool # ---------------------------------------------------------------------- @@ -109,38 +111,25 @@ def get_clusters(beaker_client: beaker.Beaker, selected_clusters: List[str]) -> if beaker_client is None: raise ValueError("You need access to Beaker to run mason.py") - weka_clusters = [] - gcp_clusters = [] - interconnect_clusters = [] + num_weka = 0 + num_gcp = 0 + num_interconnect = 0 - for cluster_full_name in selected_clusters: - cluster_short_name = cluster_full_name.replace("ai2/", "") - try: - cluster = beaker_client.cluster.get(cluster_short_name) - except Exception: - console.log(f"Warning: Could not get info for cluster {cluster_full_name}, skipping property checks") - continue - - has_interconnect = False - has_gcp = False - has_weka = False - - for tag in cluster.tags: + for cluster_name in selected_clusters: + for tag in beaker_client.cluster.get(cluster_name).tags: if tag.startswith("interconnect:"): - has_interconnect = True + num_interconnect += 1 if tag.startswith("provider:gcp"): - has_gcp = True + num_gcp += 1 if tag.startswith("storage:weka"): - has_weka = True - - if has_interconnect: - interconnect_clusters.append(cluster_full_name) - if has_gcp: - gcp_clusters.append(cluster_full_name) - if has_weka: - weka_clusters.append(cluster_full_name) - - return ClusterConfig(weka=weka_clusters, gcp=gcp_clusters, interconnect=interconnect_clusters) + num_weka += 1 + return ClusterConfig( + all_weka=(num_weka == len(selected_clusters)), + any_weka=(num_weka > 0), + all_gcp=(num_gcp == len(selected_clusters)), + any_gcp=(num_gcp > 0), + all_interconnect=(num_interconnect == len(selected_clusters)), + ) # by default, we turn off vllm compile cache @@ -269,7 +258,6 @@ def _commands_include_resumable_target(cmds: List[List[str]]) -> bool: "--non_resumable is not set, but the command is not in OPEN_INSTRUCT_RESUMABLES, so the job will not be resumable" ) setattr(mason_args, "resumable", is_resumable) - return mason_args, commands @@ -356,7 +344,7 @@ def get_env_vars( env_vars.extend([beaker.BeakerEnvVar(name="PATH", value=os.getenv("PATH"))]) # if all cluster is in weka, we mount the weka - if all(c in cluster_config.weka for c in cluster): + if cluster_config.all_weka: env_vars.extend( [ beaker.BeakerEnvVar(name="HF_HOME", value="/weka/oe-adapt-default/allennlp/.cache/huggingface"), @@ -379,7 +367,7 @@ def get_env_vars( ) # if all cluster is in gcp we add the following env - elif all(c in cluster_config.gcp for c in cluster): + elif cluster_config.all_gcp: env_vars.extend( [ beaker.BeakerEnvVar(name="HF_HOME", value="/filestore/.cache/huggingface"), @@ -449,7 +437,7 @@ def get_datasets(beaker_datasets, cluster: List[str], cluster_config: ClusterCon """if pure docker mode we don't mount the NFS; so we can run it on jupiter2""" res = [] # if all cluster is in weka, we mount the weka - if all(c in cluster_config.weka for c in cluster): + if cluster_config.all_weka: res = [ beaker.BeakerDataMount( source=beaker.BeakerDataSource(weka="oe-adapt-default"), mount_path="/weka/oe-adapt-default" @@ -458,7 +446,7 @@ def get_datasets(beaker_datasets, cluster: List[str], cluster_config: ClusterCon source=beaker.BeakerDataSource(weka="oe-training-default"), mount_path="/weka/oe-training-default" ), ] - elif all(c in cluster_config.gcp for c in cluster): + elif cluster_config.all_gcp: res = [ beaker.BeakerDataMount( source=beaker.BeakerDataSource(host_path="/mnt/filestore_1"), mount_path="/filestore" @@ -517,12 +505,7 @@ def remove_arg_from_list(lst: List[str], item: str, remove_value: bool = False): import subprocess process = subprocess.Popen( - caching_command, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - bufsize=1, + caching_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1 ) stdout_data, stderr_data = [], [] @@ -549,11 +532,7 @@ def remove_arg_from_list(lst: List[str], item: str, remove_value: bool = False): result = type( "SubprocessResult", (), - { - "returncode": process.returncode, - "stdout": "".join(stdout_data), - "stderr": "".join(stderr_data), - }, + {"returncode": process.returncode, "stdout": "".join(stdout_data), "stderr": "".join(stderr_data)}, ) stdout = result.stdout for line in stdout.splitlines(): @@ -586,7 +565,7 @@ def maybe_override_output_dir( Returns: Modified command list """ - if any(c in cluster_config.weka for c in args.cluster): + if cluster_config.any_weka: if len(args.auto_output_dir_path) > 0: need_to_override_output_dir = True for idx, cmd in enumerate(command): @@ -634,7 +613,7 @@ def maybe_optimize_gcp_model_loading( from open_instruct.dataset_transformation import get_commit_hash from open_instruct.utils import download_from_hf, gs_folder_exists, upload_to_gs_bucket - if any(c in cluster_config.gcp for c in args.cluster): + if cluster_config.any_gcp: model_name_or_path = None for idx, cmd in enumerate(command): if cmd == "--model_name_or_path": @@ -751,7 +730,9 @@ def escape_strings(command: List[str]) -> List[str]: return command -def make_internal_command(command: List[str], args: argparse.Namespace, whoami: str, is_external_user: bool, cluster_config: ClusterConfig) -> str: +def make_internal_command( + command: List[str], args: argparse.Namespace, whoami: str, is_external_user: bool, cluster_config: ClusterConfig +) -> str: if "WANDB_ENTITY" in os.environ: command = [f"WANDB_ENTITY={os.environ['WANDB_ENTITY']}"] + command if "WANDB_PROJECT" in os.environ: @@ -914,9 +895,11 @@ def make_internal_command(command: List[str], args: argparse.Namespace, whoami: return full_command -def make_task_spec(args, full_command: str, i: int, beaker_secrets: str, whoami: str, resumable: bool, cluster_config: ClusterConfig): +def make_task_spec( + args, full_command: str, i: int, beaker_secrets: str, whoami: str, resumable: bool, cluster_config: ClusterConfig +): # Add a check to ensure that the user is using the correct clusters for multi-node jobs - if args.num_nodes > 1 and not all(c in cluster_config.interconnect for c in args.cluster): + if args.num_nodes > 1 and not cluster_config.all_interconnect: confirmation = False while not confirmation: confirmation = input( @@ -930,9 +913,6 @@ def make_task_spec(args, full_command: str, i: int, beaker_secrets: str, whoami: ) else: print("Invalid input. Please enter 'y' or 'n'.") - if args.image == "ai2/cuda11.8-cudnn8-dev-ubuntu20.04" and any(c in cluster_config.gcp for c in args.cluster): - raise ValueError("GCP clusters do not have the dev filesystem, please use a proper image") - if args.hostname is not None: constraints = beaker.BeakerConstraints(hostname=args.hostname) else: @@ -995,7 +975,9 @@ def main(): whoami = beaker_client.user.get().name cluster_config = get_clusters(beaker_client, args.cluster) - full_commands = [make_internal_command(command, args, whoami, is_external_user, cluster_config) for command in commands] + full_commands = [ + make_internal_command(command, args, whoami, is_external_user, cluster_config) for command in commands + ] if is_external_user: console.rule("[bold red]Non-Ai2 User Detected[/bold red]") console.print( From cfcb3ccaeeaa7161ff50d4e1ac683612327185b5 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 4 Nov 2025 11:29:47 -0700 Subject: [PATCH 4/5] Added tests --- test_mason.py | 163 +++++++++++++++++++++++++++++--------------------- 1 file changed, 94 insertions(+), 69 deletions(-) diff --git a/test_mason.py b/test_mason.py index 7babe2e845..abe800bdf2 100644 --- a/test_mason.py +++ b/test_mason.py @@ -5,79 +5,104 @@ import mason -class TestBuildCommandWithoutArgs(unittest.TestCase): - @parameterized.parameterized.expand([ - ( - "remove_arg_without_value", - ["python", "script.py", "--with_tracking", "--output", "out.txt"], - {"--with_tracking": False}, - ["python", "script.py", "--output", "out.txt"], - ), - ( - "remove_arg_with_value", - ["python", "script.py", "--checkpoint_state_dir", "/path/to/dir", "--output", "out.txt"], - {"--checkpoint_state_dir": True}, - ["python", "script.py", "--output", "out.txt"], - ), - ( - "remove_multiple_args", - ["python", "script.py", "--with_tracking", "--checkpoint_state_dir", "/path", "--output", "out.txt"], - {"--with_tracking": False, "--checkpoint_state_dir": True}, - ["python", "script.py", "--output", "out.txt"], - ), - ( - "arg_not_present", - ["python", "script.py", "--output", "out.txt"], - {"--nonexistent": True}, - ["python", "script.py", "--output", "out.txt"], - ), - ( - "empty_command", - [], - {"--with_tracking": False}, - [], - ), - ( - "empty_args_to_remove", - ["python", "script.py", "--output", "out.txt"], - {}, - ["python", "script.py", "--output", "out.txt"], - ), - ( - "remove_all_cache_excluded_args", - [ - "python", - "open_instruct/grpo_fast.py", - "--with_tracking", - "--checkpoint_state_freq", - "200", - "--checkpoint_state_dir", - "/weka/path", - "--gs_checkpoint_state_dir", - "gs://bucket", - "--output", - "out.txt", - ], - mason.CACHE_EXCLUDED_ARGS, - ["python", "open_instruct/grpo_fast.py", "--output", "out.txt"], - ), - ( - "arg_at_end_without_value", - ["python", "script.py", "--output", "out.txt", "--with_tracking"], - {"--with_tracking": False}, - ["python", "script.py", "--output", "out.txt"], - ), - ( - "arg_at_end_with_value", - ["python", "script.py", "--output", "out.txt", "--checkpoint_dir", "/path"], - {"--checkpoint_dir": True}, - ["python", "script.py", "--output", "out.txt"], - ), - ]) +class TestMason(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ( + "remove_arg_without_value", + ["python", "script.py", "--with_tracking", "--output", "out.txt"], + {"--with_tracking": False}, + ["python", "script.py", "--output", "out.txt"], + ), + ( + "remove_arg_with_value", + ["python", "script.py", "--checkpoint_state_dir", "/path/to/dir", "--output", "out.txt"], + {"--checkpoint_state_dir": True}, + ["python", "script.py", "--output", "out.txt"], + ), + ( + "remove_multiple_args", + ["python", "script.py", "--with_tracking", "--checkpoint_state_dir", "/path", "--output", "out.txt"], + {"--with_tracking": False, "--checkpoint_state_dir": True}, + ["python", "script.py", "--output", "out.txt"], + ), + ( + "arg_not_present", + ["python", "script.py", "--output", "out.txt"], + {"--nonexistent": True}, + ["python", "script.py", "--output", "out.txt"], + ), + ("empty_command", [], {"--with_tracking": False}, []), + ( + "empty_args_to_remove", + ["python", "script.py", "--output", "out.txt"], + {}, + ["python", "script.py", "--output", "out.txt"], + ), + ( + "remove_all_cache_excluded_args", + [ + "python", + "open_instruct/grpo_fast.py", + "--with_tracking", + "--checkpoint_state_freq", + "200", + "--checkpoint_state_dir", + "/weka/path", + "--gs_checkpoint_state_dir", + "gs://bucket", + "--output", + "out.txt", + ], + mason.CACHE_EXCLUDED_ARGS, + ["python", "open_instruct/grpo_fast.py", "--output", "out.txt"], + ), + ( + "arg_at_end_without_value", + ["python", "script.py", "--output", "out.txt", "--with_tracking"], + {"--with_tracking": False}, + ["python", "script.py", "--output", "out.txt"], + ), + ( + "arg_at_end_with_value", + ["python", "script.py", "--output", "out.txt", "--checkpoint_dir", "/path"], + {"--checkpoint_dir": True}, + ["python", "script.py", "--output", "out.txt"], + ), + ] + ) def test_build_command_without_args(self, name, command, args_to_remove, expected): result = mason.build_command_without_args(command, args_to_remove) self.assertEqual(result, expected) + @parameterized.parameterized.expand( + [ + ([], []), + (["python", "script.py"], ["python", "script.py"]), + ( + ["python", "script.py", "--arg", '{"key": "value"}'], + ["python", "script.py", "--arg", '\'{"key": "value"}\''], + ), + ( + ["python", "--dataset_mixer", '{"trl-internal-testing/sentiment-trl-style": 1.0}'], + ["python", "--dataset_mixer", "'{\"trl-internal-testing/sentiment-trl-style\": 1.0}'"], + ), + ( + ["python", "--arg1", '{"nested": {"key": "value"}}', "--arg2", "normal"], + ["python", "--arg1", '\'{"nested": {"key": "value"}}\'', "--arg2", "normal"], + ), + (["python", "--config", "{a:1,b:2}", "--flag"], ["python", "--config", "'{a:1,b:2}'", "--flag"]), + (["echo", "no braces here"], ["echo", "no braces here"]), + ( + ["python", "--json", '{"a":1}', "--json2", '{"b":2}'], + ["python", "--json", "'{\"a\":1}'", "--json2", "'{\"b\":2}'"], + ), + ] + ) + def test_escape_strings(self, input_command, expected_output): + result = mason.escape_strings(input_command.copy()) + self.assertEqual(result, expected_output) + if __name__ == "__main__": unittest.main() From 77df721094748922998df766546720c9497ca3f0 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 4 Nov 2025 11:40:04 -0700 Subject: [PATCH 5/5] Cleaned up code --- mason.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/mason.py b/mason.py index e4ffd3a9aa..b4d3332efc 100644 --- a/mason.py +++ b/mason.py @@ -98,7 +98,7 @@ def parse_env_var(env_var_str: str) -> Dict[str, str]: return {"name": name, "value": value} -def get_clusters(beaker_client: beaker.Beaker, selected_clusters: List[str]) -> ClusterConfig: +def get_clusters(beaker_client: beaker.Beaker | None, selected_clusters: List[str]) -> ClusterConfig: """Get cluster properties for the user's selected clusters from Beaker API. Args: @@ -109,7 +109,7 @@ def get_clusters(beaker_client: beaker.Beaker, selected_clusters: List[str]) -> ClusterConfig with lists of which selected clusters have weka/gcp/interconnect properties """ if beaker_client is None: - raise ValueError("You need access to Beaker to run mason.py") + return ClusterConfig(all_weka=False, any_weka=False, all_gcp=False, any_gcp=False, all_interconnect=False) num_weka = 0 num_gcp = 0 @@ -782,12 +782,7 @@ def make_internal_command( # Use Popen to get real-time output while also capturing it process = subprocess.Popen( - caching_command, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - bufsize=1, + caching_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1 ) stdout_data, stderr_data = [], [] @@ -816,11 +811,7 @@ def make_internal_command( result = type( "SubprocessResult", (), - { - "returncode": process.returncode, - "stdout": "".join(stdout_data), - "stderr": "".join(stderr_data), - }, + {"returncode": process.returncode, "stdout": "".join(stdout_data), "stderr": "".join(stderr_data)}, ) stdout = result.stdout # Extract the cached dataset path from stdout if it exists @@ -967,13 +958,9 @@ def main(): beaker_secrets = [] beaker_client = None else: - if args.workspace: - beaker_client = beaker.Beaker.from_env(default_workspace=args.workspace) - else: - beaker_client = beaker.Beaker.from_env() + beaker_client = beaker.Beaker.from_env(default_workspace=args.workspace) beaker_secrets = [secret.name for secret in beaker_client.secret.list()] whoami = beaker_client.user.get().name - cluster_config = get_clusters(beaker_client, args.cluster) full_commands = [ make_internal_command(command, args, whoami, is_external_user, cluster_config) for command in commands