Skip to content

Commit

Permalink
Add subfolder to GCS location (#225)
Browse files Browse the repository at this point in the history
* Add subfolder to GCS location

* address example
  • Loading branch information
RissyRan authored Mar 29, 2024
1 parent 12aa063 commit 22fd87a
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 9 deletions.
7 changes: 5 additions & 2 deletions dags/examples/xpk_example_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from airflow import models
from dags.vm_resource import TpuVersion, Project, Zone, ClusterName, DockerImage
from dags.examples.configs import xpk_example_config as config
from dags import test_owner
from xlml.utils import name_format
from airflow.utils.task_group import TaskGroup

Expand Down Expand Up @@ -59,11 +60,13 @@
# The value of 'test_group_id':
# 1) a task group name for those chained tests
# 2) an ID to generate gcs folder path in format:
# "{gcs_bucket.BASE_OUTPUT_DIR}/{group_id}-{current_datetime}/"
# "{gcs_bucket.BASE_OUTPUT_DIR}/{gcs_subfolder}/{group_id}-{current_datetime}/"
test_group_id = "chained_tests"
gcs_subfolder = f"{test_owner.Team.MULTIPOD.value}/maxtext"
with TaskGroup(group_id=test_group_id) as group:
shared_gcs_location = name_format.generate_gcs_folder_location(
test_group_id
gcs_subfolder,
test_group_id,
)
chained_resnet_tpu_singleslice_v4_8 = config.get_flax_resnet_xpk_config(
tpu_version=TpuVersion.V4,
Expand Down
7 changes: 6 additions & 1 deletion dags/pytorch_xla/configs/pytorchxla_torchbench_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from typing import Tuple
from xlml.apis import gcp_config, metric_config, task, test_config
import dags.vm_resource as resource
from dags import gcs_bucket, test_owner
from dags import test_owner


GCS_SUBFOLDER_PREFIX = test_owner.Team.PYTORCH_XLA.value


class VERSION(enum.Enum):
Expand Down Expand Up @@ -170,6 +173,7 @@ def get_torchbench_tpu_config(
run_model_cmds=run_script_cmds,
time_out_in_min=time_out_in_min,
task_owner=test_owner.PEI_Z,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/torchbench",
)

job_metric_config = metric_config.MetricConfig(
Expand Down Expand Up @@ -330,6 +334,7 @@ def get_torchbench_gpu_config(
run_model_cmds=run_script_cmds,
time_out_in_min=time_out_in_min,
task_owner=test_owner.PEI_Z,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/torchbench",
)

job_metric_config = metric_config.MetricConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value
RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value
RUN_DATE = datetime.now().strftime("%Y_%m_%d")
GCS_SUBFOLDER_PREFIX = test_owner.Team.SOLUTIONS_TEAM.value


def get_flax_resnet_config(
Expand Down Expand Up @@ -204,6 +205,7 @@ def get_flax_vit_conv_config(
run_model_cmds=run_model_cmds,
time_out_in_min=time_out_in_min,
task_owner=test_owner.SHIVA_S,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/flax",
)

job_metric_config = metric_config.MetricConfig(
Expand Down Expand Up @@ -456,6 +458,7 @@ def get_flax_bart_conv_config(
run_model_cmds=run_model_cmds,
time_out_in_min=time_out_in_min,
task_owner=test_owner.SHIVA_S,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/flax",
)

job_metric_config = metric_config.MetricConfig(
Expand Down Expand Up @@ -578,6 +581,7 @@ def get_flax_bert_conv_config(
run_model_cmds=run_model_cmds,
time_out_in_min=time_out_in_min,
task_owner=test_owner.SHIVA_S,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/flax",
)

job_metric_config = metric_config.MetricConfig(
Expand Down
10 changes: 10 additions & 0 deletions dags/test_owner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

"""The file of test owners."""

import enum


class Team(enum.Enum):
SOLUTIONS_TEAM = "solutions_team"
PYTORCH_XLA = "pytorch_xla"
MULTIPOD = "multipod"
MLCOMPASS = "mlcompass"


# XLML - JAX/FLAX
SHIVA_S = "Shiva S."

Expand Down
1 change: 0 additions & 1 deletion dags/vm_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ class ClusterName(enum.Enum):
V5E_16_MULTISLICE_CLUSTER = "v5e-16-bodaborg"
V5E_256_MULTISLICE_CLUSTER = "v5e-256-bodaborg"
V5E_256_US_WEST_4_MULTISLICE_CLUSTER = "v5e-256-bodaborg-us-west4"

A3_CLUSTER = "maxtext-a3-20n"


Expand Down
9 changes: 6 additions & 3 deletions xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def provision(
)
ssh_keys = ssh.generate_ssh_keys()
output_location = name_format.generate_gcs_folder_location(
self.task_test_config.benchmark_id
self.task_test_config.gcs_subfolder,
self.task_test_config.benchmark_id,
)

queued_resource_op, queued_resource_name = tpu.create_queued_resource(
Expand Down Expand Up @@ -395,7 +396,8 @@ def run_model(
gcs_path = gcs_location
else:
gcs_path = name_format.generate_gcs_folder_location(
self.task_test_config.benchmark_id
self.task_test_config.gcs_subfolder,
self.task_test_config.benchmark_id,
)
launch_workload = self.launch_workload(workload_id, gcs_path)
wait_for_workload_completion = xpk.wait_for_workload_completion.override(
Expand Down Expand Up @@ -545,7 +547,8 @@ def provision(
gpu_name = gpu.generate_gpu_name()
ssh_keys = ssh.generate_ssh_keys()
gcs_location = name_format.generate_gcs_folder_location(
self.task_test_config.benchmark_id
self.task_test_config.gcs_subfolder,
self.task_test_config.benchmark_id,
)

ip_address = gpu.create_resource(
Expand Down
1 change: 1 addition & 0 deletions xlml/apis/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class TestConfig(abc.ABC, Generic[A]):
# TODO(wcromar): make this a timedelta
time_out_in_min: Optional[int] = attrs.field(default=None, kw_only=True)
task_owner: str = attrs.field(default='unowned', kw_only=True)
gcs_subfolder: str = attrs.field(default='unowned', kw_only=True)

@property
@abc.abstractmethod
Expand Down
9 changes: 7 additions & 2 deletions xlml/utils/name_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,18 @@ def generate_tb_file_location(run_name: str, base_output_directory: str) -> str:


@task
def generate_gcs_folder_location(benchmark_id: str) -> str:
def generate_gcs_folder_location(subfolder: str, benchmark_id: str) -> str:
"""Generates folder location in GCS.
Args:
subfolder: Folder name/path for artifacts, such as 'solutions_team/flax'
benchmark_id: Benchmark id of the test
Returns: GCS folder name with location
"""
current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
return f"{gcs_bucket.BASE_OUTPUT_DIR}/{benchmark_id}-{current_datetime}/"
return os.path.join(
gcs_bucket.BASE_OUTPUT_DIR,
subfolder,
f"{benchmark_id}-{current_datetime}/",
)

0 comments on commit 22fd87a

Please sign in to comment.