Skip to content

Commit 3cc317d

Browse files
authored
Migrate PyTorch multislice tests (#220)
1 parent 0ee1a95 commit 3cc317d

File tree

3 files changed

+126
-0
lines changed

3 files changed

+126
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dags.vm_resource import TpuVersion, Zone, DockerImage, ClusterName
16+
from dags.multipod.configs import gke_config
17+
from xlml.apis import task
18+
from typing import List
19+
20+
21+
# TODO(jonbolin): Refactor this to cluster definition
22+
CLUSTER_CONFIG = {
23+
ClusterName.V4_8_MULTISLICE_CLUSTER: {
24+
'tpu_version': TpuVersion.V4,
25+
'tpu_cores': 8,
26+
'tpu_zone': Zone.US_CENTRAL2_B.value,
27+
},
28+
ClusterName.V4_16_MULTISLICE_CLUSTER: {
29+
'tpu_version': TpuVersion.V4,
30+
'tpu_cores': 16,
31+
'tpu_zone': Zone.US_CENTRAL2_B.value,
32+
},
33+
}
34+
35+
36+
def get_nightly_pytorch_config(
37+
test_name: str,
38+
test_owner: str,
39+
run_commands: List[str],
40+
cluster: ClusterName,
41+
num_slices: int,
42+
) -> task.XpkTask:
43+
cmds = (
44+
'git clone https://github.com/pytorch/xla /pytorch/xla',
45+
*run_commands,
46+
)
47+
return gke_config.get_gke_config(
48+
cluster_name=cluster.value,
49+
test_name=test_name,
50+
run_model_cmds=cmds,
51+
num_slices=num_slices,
52+
docker_image=DockerImage.PYTORCH_NIGHTLY.value,
53+
test_owner=test_owner,
54+
time_out_in_min=60,
55+
**CLUSTER_CONFIG[cluster],
56+
)

dags/multipod/pytorch.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
A DAG to run PyTorch multislice tests
17+
"""
18+
import datetime
19+
from airflow import models
20+
from dags import composer_env, test_owner
21+
from dags.vm_resource import TpuVersion, Zone, DockerImage, ClusterName
22+
from dags.multipod.configs import pytorch_config
23+
from xlml.apis import metric_config
24+
25+
# Run once a day at 10 am UTC (3 am PST)
26+
SCHEDULED_TIME = "0 10 * * *" if composer_env.is_prod_env() else None
27+
28+
with models.DAG(
29+
dag_id="pytorch_multislice",
30+
schedule=SCHEDULED_TIME,
31+
tags=["multipod_team", "pytorch", "nightly"],
32+
start_date=datetime.datetime(2024, 3, 1),
33+
catchup=False,
34+
concurrency=2,
35+
) as dag:
36+
v4_8 = ClusterName.V4_8_MULTISLICE_CLUSTER
37+
v4_16 = ClusterName.V4_16_MULTISLICE_CLUSTER
38+
39+
for num_slices, cluster in [(1, v4_8), (2, v4_8), (1, v4_16)]:
40+
ici_chips = 4 if cluster == v4_8 else 8
41+
run_cmds = (
42+
(
43+
"python /pytorch/xla/test/spmd/test_sharding_strategies.py "
44+
f"--ici_fsdp_parallelism {ici_chips} "
45+
f"--dcn_data_parallelism {num_slices}"
46+
),
47+
)
48+
pytorch_config.get_nightly_pytorch_config(
49+
test_name="shardings",
50+
test_owner=test_owner.JON_B,
51+
run_commands=run_cmds,
52+
cluster=cluster,
53+
num_slices=num_slices,
54+
).run()
55+
56+
pytorch_config.get_nightly_pytorch_config(
57+
test_name="checkpoint",
58+
test_owner=test_owner.JON_B,
59+
run_commands=(
60+
f"export CHKPT_PATH={metric_config.SshEnvVars.GCS_OUTPUT.value}",
61+
"pip install gcsfs",
62+
(
63+
"python /pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py "
64+
"EndToEndCheckpointTest.test_multihost_checkpoint"
65+
),
66+
),
67+
cluster=v4_16,
68+
num_slices=2,
69+
).run()

dags/vm_resource.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class DockerImage(enum.Enum):
137137
"""Common docker images."""
138138

139139
XPK_JAX_TEST = "gcr.io/cloud-ml-auto-solutions/xpk_jax_test:latest"
140+
PYTORCH_NIGHTLY = f"us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_{datetime.datetime.today().strftime('%Y%m%d')}"
140141
MAXTEXT_TPU_JAX_STABLE = f"gcr.io/tpu-prod-env-multipod/maxtext_jax_stable:{datetime.datetime.today().strftime('%Y-%m-%d')}"
141142
MAXTEXT_TPU_JAX_NIGHTLY = f"gcr.io/tpu-prod-env-multipod/maxtext_jax_nightly:{datetime.datetime.today().strftime('%Y-%m-%d')}"
142143
MAXTEXT_GPU_JAX_STABLE = f"gcr.io/tpu-prod-env-multipod/maxtext_gpu_jax_stable:{datetime.datetime.today().strftime('%Y-%m-%d')}"

0 commit comments

Comments
 (0)