Skip to content

Commit

Permalink
Add shared gcs location example for xpk (#224)
Browse files Browse the repository at this point in the history
* Add shared gcs location example for xpk

* Address comments
  • Loading branch information
RissyRan authored Mar 28, 2024
1 parent 3cc317d commit 2fbfe80
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
42 changes: 41 additions & 1 deletion dags/examples/xpk_example_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from airflow import models
from dags.vm_resource import TpuVersion, Project, Zone, ClusterName, DockerImage
from dags.examples.configs import xpk_example_config as config

from xlml.utils import name_format
from airflow.utils.task_group import TaskGroup

# TODO(ranran): add following examples:
# 1) jax_resnet_tpu_qr (diff dag)
Expand Down Expand Up @@ -53,3 +54,42 @@
time_out_in_min=60,
num_slices=2,
).run()

# Example to run multiple tests that share one GCS location for artifacts
# The value of 'test_group_id':
# 1) a task group name for those chained tests
# 2) an ID to generate gcs folder path in format:
# "{gcs_bucket.BASE_OUTPUT_DIR}/{group_id}-{current_datetime}/"
test_group_id = "chained_tests"
with TaskGroup(group_id=test_group_id) as group:
shared_gcs_location = name_format.generate_gcs_folder_location(
test_group_id
)
chained_resnet_tpu_singleslice_v4_8 = config.get_flax_resnet_xpk_config(
tpu_version=TpuVersion.V4,
tpu_cores=8,
tpu_zone=Zone.US_CENTRAL2_B.value,
test_name="chained-resnet-single-slice",
project_name=Project.CLOUD_ML_AUTO_SOLUTIONS.value,
cluster_name=ClusterName.V4_8_CLUSTER.value,
docker_image=DockerImage.XPK_JAX_TEST.value,
time_out_in_min=60,
).run(gcs_location=shared_gcs_location)

chained_resnet_tpu_multislice_v4_128 = config.get_flax_resnet_xpk_config(
tpu_version=TpuVersion.V4,
tpu_cores=128,
tpu_zone=Zone.US_CENTRAL2_B.value,
test_name="chained-resnet-multi-slice",
project_name=Project.TPU_PROD_ENV_MULTIPOD.value,
cluster_name=ClusterName.V4_128_MULTISLICE_CLUSTER.value,
docker_image=DockerImage.XPK_JAX_TEST.value,
time_out_in_min=60,
num_slices=2,
).run(gcs_location=shared_gcs_location)

(
shared_gcs_location
>> chained_resnet_tpu_singleslice_v4_8
>> chained_resnet_tpu_multislice_v4_128
)
29 changes: 23 additions & 6 deletions xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ class XpkTask(BaseTask):
task_test_config: Test configs to run on this TPU/GPU.
task_gcp_config: Runtime TPU/GPU creation parameters.
task_metric_config: Metric configs to process metrics.
workload_provision_timeout: Time allowed for provisioning a workload.
"""

task_test_config: Union[test_config.TpuGkeTest, test_config.GpuXpkTest]
Expand All @@ -324,15 +325,22 @@ class XpkTask(BaseTask):
minutes=300
)

def run(self) -> DAGNode:
def run(
self,
*,
gcs_location: Optional[airflow.XComArg] = None,
) -> DAGNode:
"""Run a test job within a docker image.
Attributes:
gcs_location: GCS path for all artifacts of the test.
Returns:
A task group with the following tasks chained: run_model and
post_process.
"""
with TaskGroup(group_id=self.task_test_config.benchmark_id) as group:
self.run_model() >> self.post_process()
self.run_model(gcs_location) >> self.post_process()

return group

Expand Down Expand Up @@ -369,17 +377,26 @@ def run_with_run_name_generation(self) -> DAGNode:

return group

def run_model(self) -> DAGNode:
def run_model(
self,
gcs_location: Optional[airflow.XComArg] = None,
) -> DAGNode:
"""Run the TPU/GPU test in `task_test_config` using xpk.
Attributes:
gcs_location: GCS path for all artifacts of the test.
Returns:
A DAG node that executes the model test.
"""
with TaskGroup(group_id="run_model") as group:
workload_id = xpk.generate_workload_id(self.task_test_config.benchmark_id)
gcs_path = name_format.generate_gcs_folder_location(
self.task_test_config.benchmark_id
)
if gcs_location:
gcs_path = gcs_location
else:
gcs_path = name_format.generate_gcs_folder_location(
self.task_test_config.benchmark_id
)
launch_workload = self.launch_workload(workload_id, gcs_path)
wait_for_workload_completion = xpk.wait_for_workload_completion.override(
timeout=self.task_test_config.time_out_in_min * 60,
Expand Down

0 comments on commit 2fbfe80

Please sign in to comment.