Skip to content

Commit

Permalink
Add maxtext nightly test on 1 slice of v4-8 (#72)
Browse files Browse the repository at this point in the history
* save maxtext progs

* save prog

* Run maxtext on 1 slice

* Finisht the maxtext test on single slice.

* Remove unused vars
  • Loading branch information
tonyjohnchen authored Jan 12, 2024
1 parent 24c2ff9 commit 14ba06c
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 0 deletions.
27 changes: 27 additions & 0 deletions configs/maxtext/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities to construct common configs."""

from typing import Tuple

UPGRADE_PIP = "pip install --upgrade pip"


def download_maxtext() -> Tuple[str]:
"""Common set up for flax repo."""
return (
UPGRADE_PIP,
"git clone https://github.com/google/maxtext.git /tmp/maxtext",
)
80 changes: 80 additions & 0 deletions configs/maxtext/maxtext_gce_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities to construct configs for maxtext DAG."""

from typing import Tuple
import uuid
from apis import gcp_config, metric_config, task, test_config
from configs import gcs_bucket, test_owner
from configs.maxtext import common
from configs.vm_resource import TpuVersion, Project, RuntimeVersion
import datetime

PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value
RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value


def get_maxtext_nightly_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
time_out_in_min: int,
project_name: str = PROJECT_NAME,
runtime_version: str = RUNTIME_IMAGE,
network: str = "default",
subnetwork: str = "default",
is_tpu_reserved: bool = True,
) -> task.TpuQueuedResourceTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=project_name,
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)
current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
run_name = f"1slice-{tpu_version}_{tpu_cores}-maxtext-nightly-{current_datetime}"

set_up_cmds = common.download_maxtext()
run_model_cmds = (
(
"cd /tmp/maxtext && bash setup.sh MODE=nightly;"
f' JAX_PLATFORM_NAME=TPU XLA_FLAGS="--xla_dump_to=/tmp/xla_dump/" RUN_NAME="{run_name}" &&'
" python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME"
" base_output_directory=gs://tonyjohnchen-maxtext-nightly/"
" dataset_path=gs://max-datasets-rogue dataset_type=synthetic"
" per_device_batch_size=6 reuse_example_batch=1 global_parameter_scale=1 metrics_file='metrics.txt'"
" steps=50 enable_checkpointing=false enable_profiler=true gcs_metrics=true;"
),
)

job_test_config = test_config.TpuVmTest(
test_config.Tpu(
version=tpu_version,
cores=tpu_cores,
runtime_version=runtime_version,
reserved=is_tpu_reserved,
network=network,
subnetwork=subnetwork,
),
test_name="maxtext_nightly",
set_up_cmds=set_up_cmds,
run_model_cmds=run_model_cmds,
time_out_in_min=time_out_in_min,
task_owner=test_owner.Tony_C,
)

return task.TpuQueuedResourceTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
)
3 changes: 3 additions & 0 deletions configs/test_owner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@

# PYTORCH
PEI_Z = "Pei Z."

# Maxtext
Tony_C = "Tony C."
44 changes: 44 additions & 0 deletions dags/maxtext/maxtext_nightly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A DAG to run all supported ML models with the latest JAX/FLAX version."""

import datetime
from airflow import models
from configs import composer_env
from configs.vm_resource import Project, TpuVersion, Zone, RuntimeVersion, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS
from configs.maxtext import maxtext_gce_config


# Run once a day at 2 am UTC (6 pm PST)
SCHEDULED_TIME = "0 2 * * *" if composer_env.is_prod_env() else None


with models.DAG(
dag_id="maxtext_test",
schedule=SCHEDULED_TIME,
tags=["multipod_team", "maxtext"],
start_date=datetime.datetime(2024, 1, 10),
catchup=False,
) as dag:
# Maxtext
maxtext_nightly_v4_8 = maxtext_gce_config.get_maxtext_nightly_config(
tpu_version=TpuVersion.V4,
tpu_cores=8,
tpu_zone=Zone.US_CENTRAL2_B.value,
time_out_in_min=60,
).run()

# Test dependencies
maxtext_nightly_v4_8

0 comments on commit 14ba06c

Please sign in to comment.