Skip to content

Commit

Permalink
pyink
Browse files Browse the repository at this point in the history
  • Loading branch information
guptaaka committed Nov 15, 2024
1 parent cf6e37d commit 5847467
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 30 deletions.
2 changes: 1 addition & 1 deletion dags/multipod/configs/jax_tests_gke_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_jax_distributed_initialize_config(
num_slices: int = 1,
):
run_model_cmds = [
"bash end_to_end/test_jdi.sh",
"bash end_to_end/test_jdi.sh",
]

return gke_config.get_gke_config(
Expand Down
60 changes: 31 additions & 29 deletions dags/multipod/jax_functional_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,25 @@
v5p_subnetwork = V5P_SUBNETWORKS
v5p_runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value
test_modes_with_docker_images = [
(SetupMode.STABLE, None),
(SetupMode.JAX_STABLE_STACK, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK),
(SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_JAX_NIGHTLY),
]
(SetupMode.STABLE, None),
(SetupMode.JAX_STABLE_STACK, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK),
(SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_JAX_NIGHTLY),
]

v4_task_arr, v5p_task_arr = [], []

for test_mode, gke_docker_image in test_modes_with_docker_images:
for num_slices in (1, 2):
# v4 GCE
jax_gce_v4_8 = (
jax_tests_gce_config.get_jax_distributed_initialize_config(
tpu_version=TpuVersion.V4,
tpu_cores=8,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
is_tpu_reserved=False,
num_slices=num_slices,
test_name=f"{default_test_name}-gce-{test_mode.value}",
test_mode=test_mode,
)
jax_gce_v4_8 = jax_tests_gce_config.get_jax_distributed_initialize_config(
tpu_version=TpuVersion.V4,
tpu_cores=8,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
is_tpu_reserved=False,
num_slices=num_slices,
test_name=f"{default_test_name}-gce-{test_mode.value}",
test_mode=test_mode,
)
if len(v4_task_arr) > 1:
# pylint: disable-next=pointless-statement
Expand All @@ -66,13 +64,15 @@

# v4 GKE
if gke_docker_image is not None:
jax_gke_v4_8 = jax_tests_gke_config.get_jax_distributed_initialize_config(
cluster=XpkClusters.TPU_V4_8_MAXTEXT_CLUSTER,
time_out_in_min=60,
num_slices=num_slices,
test_name=f"{default_test_name}-gke-{test_mode.value}",
docker_image=gke_docker_image.value,
).run()
jax_gke_v4_8 = (
jax_tests_gke_config.get_jax_distributed_initialize_config(
cluster=XpkClusters.TPU_V4_8_MAXTEXT_CLUSTER,
time_out_in_min=60,
num_slices=num_slices,
test_name=f"{default_test_name}-gke-{test_mode.value}",
docker_image=gke_docker_image.value,
).run()
)
# pylint: disable-next=pointless-statement
v4_task_arr[-1] >> jax_gke_v4_8
v4_task_arr.append(jax_gke_v4_8)
Expand Down Expand Up @@ -101,13 +101,15 @@

# v5p GKE
if gke_docker_image is not None:
jax_gke_v5p_8 = jax_tests_gke_config.get_jax_distributed_initialize_config(
cluster=XpkClusters.TPU_V5P_8_CLUSTER,
time_out_in_min=60,
num_slices=num_slices,
test_name=f"{default_test_name}-gke-{test_mode.value}",
docker_image=gke_docker_image.value,
).run()
jax_gke_v5p_8 = (
jax_tests_gke_config.get_jax_distributed_initialize_config(
cluster=XpkClusters.TPU_V5P_8_CLUSTER,
time_out_in_min=60,
num_slices=num_slices,
test_name=f"{default_test_name}-gke-{test_mode.value}",
docker_image=gke_docker_image.value,
).run()
)
# pylint: disable-next=pointless-statement
v5p_task_arr[-1] >> jax_gke_v5p_8
v5p_task_arr.append(jax_gke_v5p_8)

0 comments on commit 5847467

Please sign in to comment.