From 475015023cf986e8f7d0746a83df3a20b1b33d9b Mon Sep 17 00:00:00 2001 From: raymondzouu <31597464+raymondzouu@users.noreply.github.com> Date: Mon, 11 Mar 2024 14:43:32 -0700 Subject: [PATCH] Add nightly AOT compilation GKE tests for MaxText configs (#162) --- .../{maxtext_gke_config.py => gke_config.py} | 2 +- dags/multipod/maxtext_configs_aot.py | 105 ++++++++++++++++++ dags/multipod/maxtext_convergence.py | 4 +- dags/vm_resource.py | 1 + 4 files changed, 109 insertions(+), 3 deletions(-) rename dags/multipod/configs/{maxtext_gke_config.py => gke_config.py} (98%) create mode 100644 dags/multipod/maxtext_configs_aot.py diff --git a/dags/multipod/configs/maxtext_gke_config.py b/dags/multipod/configs/gke_config.py similarity index 98% rename from dags/multipod/configs/maxtext_gke_config.py rename to dags/multipod/configs/gke_config.py index 2dd81bab5..cfff8fea9 100644 --- a/dags/multipod/configs/maxtext_gke_config.py +++ b/dags/multipod/configs/gke_config.py @@ -20,7 +20,7 @@ from typing import Iterable -def get_maxtext_gke_config( +def get_gke_config( tpu_version: TpuVersion, tpu_cores: int, tpu_zone: str, diff --git a/dags/multipod/maxtext_configs_aot.py b/dags/multipod/maxtext_configs_aot.py new file mode 100644 index 000000000..439ba9722 --- /dev/null +++ b/dags/multipod/maxtext_configs_aot.py @@ -0,0 +1,105 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A DAG to run AOT compilation tests for MaxText model configs. +""" +import datetime +from airflow import models +from dags import composer_env, test_owner +from dags.vm_resource import TpuVersion, Zone, DockerImage +from dags.multipod.configs import gke_config +from dags.multipod.configs.common import SetupMode + +# Run once a day at 6 am UTC (10 pm PST) +SCHEDULED_TIME = "0 5 * * *" if composer_env.is_prod_env() else None + +with models.DAG( + dag_id="maxtext_configs_aot", + schedule=SCHEDULED_TIME, + tags=["multipod_team", "maxtext", "stable", "nightly"], + start_date=datetime.datetime(2024, 2, 19), + catchup=False, + concurrency=2, +) as dag: + # Testing configurations + model_configs = { + # accelerator: [(model_size, num_cores), ...], + "v4": [("22b", 128), ("52b", 384)], + "v5e": [("16b", 256), ("32b", 256), ("64b", 256), ("128b", 256)], + "v5p": [ + ("32b", 128), + ("64b", 128), + ("128b", 256), + ("128b", 512), + ("256b", 1024), + ("512b", 1024), + ("1024b", 2048), + ("1024b", 4096), + ], + } + num_slices = [1, 2] + docker_images = [ + (SetupMode.STABLE, DockerImage.MAXTEXT_JAX_STABLE), + (SetupMode.NIGHTLY, DockerImage.MAXTEXT_JAX_NIGHTLY), + ] + + run_model_cmds_dict = {} + for tpu, models in model_configs.items(): + run_model_cmds = [] + for model_size, num_cores in models: + for n in num_slices: + cmd = f"bash MaxText/configs/{tpu}/{model_size}.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY={tpu}-{num_cores} M_COMPILE_TOPOLOGY_NUM_SLICES={n}" + run_model_cmds.append(cmd) + run_model_cmds_dict[tpu] = run_model_cmds + + for mode, image in docker_images: + maxtext_v4_configs_test = gke_config.get_gke_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + test_name=f"maxtext-aot-v4-{mode.value}", + run_model_cmds=run_model_cmds_dict["v4"], + docker_image=image.value, + test_owner=test_owner.RAYMOND_Z, + ).run() + + maxtext_v5e_configs_test = gke_config.get_gke_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + test_name=f"maxtext-aot-v5e-{mode.value}", + run_model_cmds=run_model_cmds_dict["v5e"], + docker_image=image.value, + test_owner=test_owner.RAYMOND_Z, + ).run() + + maxtext_v5p_configs_test = gke_config.get_gke_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + test_name=f"maxtext-aot-v5p-{mode.value}", + run_model_cmds=run_model_cmds_dict["v5p"], + docker_image=image.value, + test_owner=test_owner.RAYMOND_Z, + ).run() + + ( + maxtext_v4_configs_test + >> maxtext_v5e_configs_test + >> maxtext_v5p_configs_test + ) diff --git a/dags/multipod/maxtext_convergence.py b/dags/multipod/maxtext_convergence.py index 9262c14a3..9ab27f8c8 100644 --- a/dags/multipod/maxtext_convergence.py +++ b/dags/multipod/maxtext_convergence.py @@ -19,7 +19,7 @@ from airflow import models from dags import composer_env, test_owner, gcs_bucket from dags.vm_resource import TpuVersion, Zone, DockerImage, ClusterName -from dags.multipod.configs import maxtext_gke_config +from dags.multipod.configs import gke_config from dags.multipod.configs.common import SetupMode from xlml.apis import gcp_config, metric_config, task, test_config @@ -53,7 +53,7 @@ } for test_name, run_command in convergence_tests.items(): - maxtext_v4_configs_test = maxtext_gke_config.get_maxtext_gke_config( + maxtext_v4_configs_test = gke_config.get_gke_config( tpu_version=TpuVersion.V4, tpu_cores=128, tpu_zone=Zone.US_CENTRAL2_B.value, diff --git a/dags/vm_resource.py b/dags/vm_resource.py index 01814fe24..34b30339e 100644 --- a/dags/vm_resource.py +++ b/dags/vm_resource.py @@ -118,6 +118,7 @@ class ClusterName(enum.Enum): V5E_4_CLUSTER = "mas-v5e-4" V5E_16_CLUSTER = "mas-v5e-16" V4_8_MULTISLICE_CLUSTER = "v4-8-maxtext" + V4_16_MULTISLICE_CLUSTER = "v4-16-maxtext" V4_128_MULTISLICE_CLUSTER = "v4-bodaborg" V5E_16_MULTISLICE_CLUSTER = "v5e-16-bodaborg" V5E_256_MULTISLICE_CLUSTER = "v5e-256-bodaborg"