Skip to content

Commit

Permalink
Add nightly AOT compilation GKE tests for MaxText configs (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
raymondzouu authored Mar 11, 2024
1 parent 5893de7 commit 4750150
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Iterable


def get_maxtext_gke_config(
def get_gke_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
Expand Down
105 changes: 105 additions & 0 deletions dags/multipod/maxtext_configs_aot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
A DAG to run AOT compilation tests for MaxText model configs.
"""
import datetime
from airflow import models
from dags import composer_env, test_owner
from dags.vm_resource import TpuVersion, Zone, DockerImage
from dags.multipod.configs import gke_config
from dags.multipod.configs.common import SetupMode

# Run once a day at 6 am UTC (10 pm PST)
SCHEDULED_TIME = "0 5 * * *" if composer_env.is_prod_env() else None

with models.DAG(
dag_id="maxtext_configs_aot",
schedule=SCHEDULED_TIME,
tags=["multipod_team", "maxtext", "stable", "nightly"],
start_date=datetime.datetime(2024, 2, 19),
catchup=False,
concurrency=2,
) as dag:
# Testing configurations
model_configs = {
# accelerator: [(model_size, num_cores), ...],
"v4": [("22b", 128), ("52b", 384)],
"v5e": [("16b", 256), ("32b", 256), ("64b", 256), ("128b", 256)],
"v5p": [
("32b", 128),
("64b", 128),
("128b", 256),
("128b", 512),
("256b", 1024),
("512b", 1024),
("1024b", 2048),
("1024b", 4096),
],
}
num_slices = [1, 2]
docker_images = [
(SetupMode.STABLE, DockerImage.MAXTEXT_JAX_STABLE),
(SetupMode.NIGHTLY, DockerImage.MAXTEXT_JAX_NIGHTLY),
]

run_model_cmds_dict = {}
for tpu, models in model_configs.items():
run_model_cmds = []
for model_size, num_cores in models:
for n in num_slices:
cmd = f"bash MaxText/configs/{tpu}/{model_size}.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY={tpu}-{num_cores} M_COMPILE_TOPOLOGY_NUM_SLICES={n}"
run_model_cmds.append(cmd)
run_model_cmds_dict[tpu] = run_model_cmds

for mode, image in docker_images:
maxtext_v4_configs_test = gke_config.get_gke_config(
tpu_version=TpuVersion.V4,
tpu_cores=8,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
test_name=f"maxtext-aot-v4-{mode.value}",
run_model_cmds=run_model_cmds_dict["v4"],
docker_image=image.value,
test_owner=test_owner.RAYMOND_Z,
).run()

maxtext_v5e_configs_test = gke_config.get_gke_config(
tpu_version=TpuVersion.V4,
tpu_cores=8,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
test_name=f"maxtext-aot-v5e-{mode.value}",
run_model_cmds=run_model_cmds_dict["v5e"],
docker_image=image.value,
test_owner=test_owner.RAYMOND_Z,
).run()

maxtext_v5p_configs_test = gke_config.get_gke_config(
tpu_version=TpuVersion.V4,
tpu_cores=8,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
test_name=f"maxtext-aot-v5p-{mode.value}",
run_model_cmds=run_model_cmds_dict["v5p"],
docker_image=image.value,
test_owner=test_owner.RAYMOND_Z,
).run()

(
maxtext_v4_configs_test
>> maxtext_v5e_configs_test
>> maxtext_v5p_configs_test
)
4 changes: 2 additions & 2 deletions dags/multipod/maxtext_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from airflow import models
from dags import composer_env, test_owner, gcs_bucket
from dags.vm_resource import TpuVersion, Zone, DockerImage, ClusterName
from dags.multipod.configs import maxtext_gke_config
from dags.multipod.configs import gke_config
from dags.multipod.configs.common import SetupMode
from xlml.apis import gcp_config, metric_config, task, test_config

Expand Down Expand Up @@ -53,7 +53,7 @@
}

for test_name, run_command in convergence_tests.items():
maxtext_v4_configs_test = maxtext_gke_config.get_maxtext_gke_config(
maxtext_v4_configs_test = gke_config.get_gke_config(
tpu_version=TpuVersion.V4,
tpu_cores=128,
tpu_zone=Zone.US_CENTRAL2_B.value,
Expand Down
1 change: 1 addition & 0 deletions dags/vm_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ClusterName(enum.Enum):
V5E_4_CLUSTER = "mas-v5e-4"
V5E_16_CLUSTER = "mas-v5e-16"
V4_8_MULTISLICE_CLUSTER = "v4-8-maxtext"
V4_16_MULTISLICE_CLUSTER = "v4-16-maxtext"
V4_128_MULTISLICE_CLUSTER = "v4-bodaborg"
V5E_16_MULTISLICE_CLUSTER = "v5e-16-bodaborg"
V5E_256_MULTISLICE_CLUSTER = "v5e-256-bodaborg"
Expand Down

0 comments on commit 4750150

Please sign in to comment.