diff --git a/dags/examples/xpk_example_dag.py b/dags/examples/xpk_example_dag.py index ba0d3889f..e28cdfd8b 100644 --- a/dags/examples/xpk_example_dag.py +++ b/dags/examples/xpk_example_dag.py @@ -18,6 +18,7 @@ from airflow import models from dags.vm_resource import TpuVersion, Project, Zone, ClusterName, DockerImage from dags.examples.configs import xpk_example_config as config +from dags import test_owner from xlml.utils import name_format from airflow.utils.task_group import TaskGroup @@ -59,11 +60,13 @@ # The value of 'test_group_id': # 1) a task group name for those chained tests # 2) an ID to generate gcs folder path in format: - # "{gcs_bucket.BASE_OUTPUT_DIR}/{group_id}-{current_datetime}/" + # "{gcs_bucket.BASE_OUTPUT_DIR}/{gcs_subfolder}/{group_id}-{current_datetime}/" test_group_id = "chained_tests" + gcs_subfolder = f"{test_owner.Team.MULTIPOD.value}/maxtext" with TaskGroup(group_id=test_group_id) as group: shared_gcs_location = name_format.generate_gcs_folder_location( - test_group_id + gcs_subfolder, + test_group_id, ) chained_resnet_tpu_singleslice_v4_8 = config.get_flax_resnet_xpk_config( tpu_version=TpuVersion.V4, diff --git a/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py b/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py index 7a9bf4a59..43d979afd 100644 --- a/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py +++ b/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py @@ -18,7 +18,10 @@ from typing import Tuple from xlml.apis import gcp_config, metric_config, task, test_config import dags.vm_resource as resource -from dags import gcs_bucket, test_owner +from dags import test_owner + + +GCS_SUBFOLDER_PREFIX = test_owner.Team.PYTORCH_XLA.value class VERSION(enum.Enum): @@ -170,6 +173,7 @@ def get_torchbench_tpu_config( run_model_cmds=run_script_cmds, time_out_in_min=time_out_in_min, task_owner=test_owner.PEI_Z, + gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/torchbench", ) job_metric_config = metric_config.MetricConfig( @@ -330,6 +334,7 @@ def get_torchbench_gpu_config( run_model_cmds=run_script_cmds, time_out_in_min=time_out_in_min, task_owner=test_owner.PEI_Z, + gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/torchbench", ) job_metric_config = metric_config.MetricConfig( diff --git a/dags/solutions_team/configs/flax/solutionsteam_flax_latest_supported_config.py b/dags/solutions_team/configs/flax/solutionsteam_flax_latest_supported_config.py index f9c59ab6c..9fe404347 100644 --- a/dags/solutions_team/configs/flax/solutionsteam_flax_latest_supported_config.py +++ b/dags/solutions_team/configs/flax/solutionsteam_flax_latest_supported_config.py @@ -27,6 +27,7 @@ PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value RUN_DATE = datetime.now().strftime("%Y_%m_%d") +GCS_SUBFOLDER_PREFIX = test_owner.Team.SOLUTIONS_TEAM.value def get_flax_resnet_config( @@ -204,6 +205,7 @@ def get_flax_vit_conv_config( run_model_cmds=run_model_cmds, time_out_in_min=time_out_in_min, task_owner=test_owner.SHIVA_S, + gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/flax", ) job_metric_config = metric_config.MetricConfig( @@ -456,6 +458,7 @@ def get_flax_bart_conv_config( run_model_cmds=run_model_cmds, time_out_in_min=time_out_in_min, task_owner=test_owner.SHIVA_S, + gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/flax", ) job_metric_config = metric_config.MetricConfig( @@ -578,6 +581,7 @@ def get_flax_bert_conv_config( run_model_cmds=run_model_cmds, time_out_in_min=time_out_in_min, task_owner=test_owner.SHIVA_S, + gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/flax", ) job_metric_config = metric_config.MetricConfig( diff --git a/dags/test_owner.py b/dags/test_owner.py index 4910b9b8d..1f01adf64 100644 --- a/dags/test_owner.py +++ b/dags/test_owner.py @@ -14,6 +14,16 @@ """The file of test owners.""" +import enum + + +class Team(enum.Enum): + SOLUTIONS_TEAM = "solutions_team" + PYTORCH_XLA = "pytorch_xla" + MULTIPOD = "multipod" + MLCOMPASS = "mlcompass" + + # XLML - JAX/FLAX SHIVA_S = "Shiva S." diff --git a/dags/vm_resource.py b/dags/vm_resource.py index 66f4bf843..792a49f78 100644 --- a/dags/vm_resource.py +++ b/dags/vm_resource.py @@ -129,7 +129,6 @@ class ClusterName(enum.Enum): V5E_16_MULTISLICE_CLUSTER = "v5e-16-bodaborg" V5E_256_MULTISLICE_CLUSTER = "v5e-256-bodaborg" V5E_256_US_WEST_4_MULTISLICE_CLUSTER = "v5e-256-bodaborg-us-west4" - A3_CLUSTER = "maxtext-a3-20n" diff --git a/xlml/apis/task.py b/xlml/apis/task.py index 32a961af2..14b73cc84 100644 --- a/xlml/apis/task.py +++ b/xlml/apis/task.py @@ -184,7 +184,8 @@ def provision( ) ssh_keys = ssh.generate_ssh_keys() output_location = name_format.generate_gcs_folder_location( - self.task_test_config.benchmark_id + self.task_test_config.gcs_subfolder, + self.task_test_config.benchmark_id, ) queued_resource_op, queued_resource_name = tpu.create_queued_resource( @@ -395,7 +396,8 @@ def run_model( gcs_path = gcs_location else: gcs_path = name_format.generate_gcs_folder_location( - self.task_test_config.benchmark_id + self.task_test_config.gcs_subfolder, + self.task_test_config.benchmark_id, ) launch_workload = self.launch_workload(workload_id, gcs_path) wait_for_workload_completion = xpk.wait_for_workload_completion.override( @@ -545,7 +547,8 @@ def provision( gpu_name = gpu.generate_gpu_name() ssh_keys = ssh.generate_ssh_keys() gcs_location = name_format.generate_gcs_folder_location( - self.task_test_config.benchmark_id + self.task_test_config.gcs_subfolder, + self.task_test_config.benchmark_id, ) ip_address = gpu.create_resource( diff --git a/xlml/apis/test_config.py b/xlml/apis/test_config.py index f00cadfe9..8c3f6b676 100644 --- a/xlml/apis/test_config.py +++ b/xlml/apis/test_config.py @@ -130,6 +130,7 @@ class TestConfig(abc.ABC, Generic[A]): # TODO(wcromar): make this a timedelta time_out_in_min: Optional[int] = attrs.field(default=None, kw_only=True) task_owner: str = attrs.field(default='unowned', kw_only=True) + gcs_subfolder: str = attrs.field(default='unowned', kw_only=True) @property @abc.abstractmethod diff --git a/xlml/utils/name_format.py b/xlml/utils/name_format.py index 28344cf20..c25483bb1 100644 --- a/xlml/utils/name_format.py +++ b/xlml/utils/name_format.py @@ -48,13 +48,18 @@ def generate_tb_file_location(run_name: str, base_output_directory: str) -> str: @task -def generate_gcs_folder_location(benchmark_id: str) -> str: +def generate_gcs_folder_location(subfolder: str, benchmark_id: str) -> str: """Generates folder location in GCS. Args: + subfolder: Folder name/path for artifacts, such as 'solutions_team/flax' benchmark_id: Benchmark id of the test Returns: GCS folder name with location """ current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") - return f"{gcs_bucket.BASE_OUTPUT_DIR}/{benchmark_id}-{current_datetime}/" + return os.path.join( + gcs_bucket.BASE_OUTPUT_DIR, + subfolder, + f"{benchmark_id}-{current_datetime}/", + )