Skip to content

Commit

Permalink
simplify the flow, and add default parameters to the xlml-state config (
Browse files Browse the repository at this point in the history
#506)

* simplify the flow, and add default parameters to the xlml-state configuration

* reformat

* reformat 2

* reformat 3
  • Loading branch information
ortibazar authored Dec 16, 2024
1 parent 1c6bbe4 commit 8e27039
Showing 1 changed file with 60 additions and 66 deletions.
126 changes: 60 additions & 66 deletions dags/mlcompass/maxtext_gke.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,59 +20,16 @@
--location=us-central1 dags trigger \
-- \
mlcompass_maxtext_gke \
--conf={\\\"uuid\\\":\\\"abc\\\"} 70
--conf={\\\"uuid\\\":\\\"abc\\\"}
"""

import datetime
import json
from airflow import models
from airflow.decorators import task
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from xlml.apis.xpk_cluster_config import XpkClusterConfig
from dags import test_owner
from dags.vm_resource import Project, XpkClusters
from xlml.apis import gcp_config, metric_config, task as xlml_task, test_config
import json


def get_config_gke(
docker_image: str,
model_name: str,
base_output_directory: str,
task_owner: str = test_owner.ORTI_B,
cluster: XpkClusterConfig = XpkClusters.TPU_V4_8_MAXTEXT_CLUSTER,
time_out_in_min: int = 60,
num_slices: int = 1,
dataset_name: metric_config.DatasetOption = metric_config.DatasetOption.XLML_DATASET,
dataset_project: str = Project.CLOUD_ML_AUTO_SOLUTIONS.value,
composer_project: str = Project.CLOUD_ML_AUTO_SOLUTIONS.value,
) -> xlml_task.XpkTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=cluster.project,
zone=cluster.zone,
dataset_name=dataset_name,
dataset_project=dataset_project,
composer_project=composer_project,
)
job_test_config = test_config.TpuGkeTest(
test_config.Tpu(
version=cluster.device_version,
cores=cluster.core_count,
),
test_name="maxtext",
run_model_cmds=[
f"source benchmark_run.sh;run {model_name} {base_output_directory}",
],
set_up_cmds=None,
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=task_owner,
num_slices=num_slices,
cluster_name=cluster.name,
docker_image=docker_image,
)
return xlml_task.XpkTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
)
from xlml.utils import xpk


with models.DAG(
Expand All @@ -89,7 +46,7 @@ def get_config_gke(
},
) as dag:

@task.python
@task.python(multiple_outputs=True)
def load_xlml_state(params: dict = None):
dag.log.info(params)
uuid = params["uuid"]
Expand All @@ -101,27 +58,64 @@ def load_xlml_state(params: dict = None):
)
return json.loads(file_content)

@task.python
def get_docker_image_path(state: dict) -> str:
return state["docker_image_path"]
xlml_state = load_xlml_state()

@task.python
def get_model_name(state: dict) -> str:
return state["model_name"]
cluster_name = xlml_state["cluster_name"]
cluster_project = xlml_state["cluster_project"]
cluster_region = xlml_state["cluster_region"]
cluster_zone = xlml_state["cluster_zone"]
benchmark_id = xlml_state["test_name"]

@task.python
def get_base_output_directory(state: dict) -> str:
bucket = state["workdir_bucket"]
path = state["workdir_path"]
return f"gs://{bucket}/{path}"
docker_image_path = xlml_state["docker_image_path"]
accelerator_type = xlml_state["accelerator_type"]
num_slices = xlml_state["num_slices"]

xlml_state = load_xlml_state()
docker_image_path = get_docker_image_path(xlml_state)
model_name_arg = get_model_name(xlml_state)
base_output_directory_arg = get_base_output_directory(xlml_state)
model_name = xlml_state["model_name"]
workdir_bucket = xlml_state["workdir_bucket"]
workdir_path = xlml_state["workdir_path"]
gcs_path = f"gs://{workdir_bucket}/{workdir_path}"
workload_id = f'mlc-{xlml_state["uuid"]}'

default_benchmark = get_config_gke(
workload_provision_timeout = datetime.timedelta(minutes=300).total_seconds()
workload_run_timeout = datetime.timedelta(minutes=60).total_seconds()

run_workload = xpk.run_workload.override(owner=test_owner.ORTI_B)(
task_id="run_workload",
cluster_project=cluster_project,
zone=cluster_zone,
cluster_name=cluster_name,
benchmark_id=benchmark_id,
workload_id=workload_id,
gcs_path=gcs_path,
docker_image=docker_image_path,
model_name=model_name_arg,
base_output_directory=base_output_directory_arg,
).run(skip_post_process=True)
accelerator_type=accelerator_type,
run_cmds=f"source benchmark_run.sh;run {model_name} {gcs_path}",
num_slices=num_slices,
use_vertex_tensorboard=False,
use_pathways=False,
)

wait_for_workload_start = xpk.wait_for_workload_start.override(
timeout=workload_provision_timeout
)(
workload_id=workload_id,
project_id=cluster_project,
region=cluster_region,
cluster_name=cluster_name,
)

wait_for_workload_completion = xpk.wait_for_workload_completion.override(
timeout=workload_run_timeout
)(
workload_id=workload_id,
project_id=cluster_project,
region=cluster_region,
cluster_name=cluster_name,
)

(
xlml_state
>> run_workload
>> wait_for_workload_start
>> wait_for_workload_completion
)

0 comments on commit 8e27039

Please sign in to comment.