diff --git a/configs/maxtext/common.py b/configs/maxtext/common.py new file mode 100644 index 00000000..c2f6971c --- /dev/null +++ b/configs/maxtext/common.py @@ -0,0 +1,27 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to construct common configs.""" + +from typing import Tuple + +UPGRADE_PIP = "pip install --upgrade pip" + + +def download_maxtext() -> Tuple[str]: + """Common set up for flax repo.""" + return ( + UPGRADE_PIP, + "git clone https://github.com/google/maxtext.git /tmp/maxtext", + ) diff --git a/configs/maxtext/maxtext_gce_config.py b/configs/maxtext/maxtext_gce_config.py new file mode 100644 index 00000000..8869707c --- /dev/null +++ b/configs/maxtext/maxtext_gce_config.py @@ -0,0 +1,80 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to construct configs for maxtext DAG.""" + +from typing import Tuple +import uuid +from apis import gcp_config, metric_config, task, test_config +from configs import gcs_bucket, test_owner +from configs.maxtext import common +from configs.vm_resource import TpuVersion, Project, RuntimeVersion +import datetime + +PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value +RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value + + +def get_maxtext_nightly_config( + tpu_version: TpuVersion, + tpu_cores: int, + tpu_zone: str, + time_out_in_min: int, + project_name: str = PROJECT_NAME, + runtime_version: str = RUNTIME_IMAGE, + network: str = "default", + subnetwork: str = "default", + is_tpu_reserved: bool = True, +) -> task.TpuQueuedResourceTask: + job_gcp_config = gcp_config.GCPConfig( + project_name=project_name, + zone=tpu_zone, + dataset_name=metric_config.DatasetOption.XLML_DATASET, + ) + current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + run_name = f"1slice-{tpu_version}_{tpu_cores}-maxtext-nightly-{current_datetime}" + + set_up_cmds = common.download_maxtext() + run_model_cmds = ( + ( + "cd /tmp/maxtext && bash setup.sh MODE=nightly;" + f' JAX_PLATFORM_NAME=TPU XLA_FLAGS="--xla_dump_to=/tmp/xla_dump/" RUN_NAME="{run_name}" &&' + " python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME" + " base_output_directory=gs://tonyjohnchen-maxtext-nightly/" + " dataset_path=gs://max-datasets-rogue dataset_type=synthetic" + " per_device_batch_size=6 reuse_example_batch=1 global_parameter_scale=1 metrics_file='metrics.txt'" + " steps=50 enable_checkpointing=false enable_profiler=true gcs_metrics=true;" + ), + ) + + job_test_config = test_config.TpuVmTest( + test_config.Tpu( + version=tpu_version, + cores=tpu_cores, + runtime_version=runtime_version, + reserved=is_tpu_reserved, + network=network, + subnetwork=subnetwork, + ), + test_name="maxtext_nightly", + set_up_cmds=set_up_cmds, + run_model_cmds=run_model_cmds, + time_out_in_min=time_out_in_min, + task_owner=test_owner.Tony_C, + ) + + return task.TpuQueuedResourceTask( + task_test_config=job_test_config, + task_gcp_config=job_gcp_config, + ) diff --git a/configs/test_owner.py b/configs/test_owner.py index 3074690e..03882abe 100644 --- a/configs/test_owner.py +++ b/configs/test_owner.py @@ -28,3 +28,6 @@ # PYTORCH PEI_Z = "Pei Z." + +# Maxtext +Tony_C = "Tony C." diff --git a/dags/maxtext/maxtext_nightly.py b/dags/maxtext/maxtext_nightly.py new file mode 100644 index 00000000..26a470af --- /dev/null +++ b/dags/maxtext/maxtext_nightly.py @@ -0,0 +1,44 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A DAG to run all supported ML models with the latest JAX/FLAX version.""" + +import datetime +from airflow import models +from configs import composer_env +from configs.vm_resource import Project, TpuVersion, Zone, RuntimeVersion, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS +from configs.maxtext import maxtext_gce_config + + +# Run once a day at 2 am UTC (6 pm PST) +SCHEDULED_TIME = "0 2 * * *" if composer_env.is_prod_env() else None + + +with models.DAG( + dag_id="maxtext_test", + schedule=SCHEDULED_TIME, + tags=["multipod_team", "maxtext"], + start_date=datetime.datetime(2024, 1, 10), + catchup=False, +) as dag: + # Maxtext + maxtext_nightly_v4_8 = maxtext_gce_config.get_maxtext_nightly_config( + tpu_version=TpuVersion.V4, + tpu_cores=8, + tpu_zone=Zone.US_CENTRAL2_B.value, + time_out_in_min=60, + ).run() + + # Test dependencies + maxtext_nightly_v4_8