diff --git a/dags/mantaray/run_mantaray_jobs.py b/dags/mantaray/run_mantaray_jobs.py index 88b80fa3..3e9b861c 100644 --- a/dags/mantaray/run_mantaray_jobs.py +++ b/dags/mantaray/run_mantaray_jobs.py @@ -20,6 +20,7 @@ from xlml.utils import mantaray import yaml from dags import composer_env +import re # Skip running this script in unit test because gcs loading will fail. if composer_env.is_prod_env() or composer_env.is_dev_env(): @@ -29,18 +30,41 @@ f"{mantaray.MANTARAY_G3_GS_BUCKET}/xlml_jobs/xlml_jobs.yaml" ) xlml_jobs = yaml.safe_load(xlml_jobs_yaml) - # Create a DAG for each job + + # Create a DAG for PyTorch/XLA tests + pattern = r"^(ptxla|pytorchxla).*" + workload_file_name_list = [] for job in xlml_jobs: - with models.DAG( - dag_id=job["task_name"], - schedule=job["schedule"], - tags=["mantaray"], - start_date=datetime.datetime(2024, 4, 22), - catchup=False, - ) as dag: + if re.match(pattern, job["task_name"]): + workload_file_name_list.append(job["file_name"]) + + # merge all PyTorch/XLA tests ino one Dag + with models.DAG( + dag_id="pytorch_xla_model_regression_test_on_trillium", + schedule="0 0 * * *", # everyday at midnight # job["schedule"], + tags=["mantaray", "pytorchxla", "xlml"], + start_date=datetime.datetime(2024, 4, 22), + catchup=False, + ) as dag: + for workload_file_name in workload_file_name_list: run_workload = mantaray.run_workload( - workload_file_name=job["file_name"], + workload_file_name=workload_file_name, ) + run_workload + + # Create a DAG for each job from maxtext + for job in xlml_jobs: + if not re.match(pattern, job["task_name"]): + with models.DAG( + dag_id=job["task_name"], + schedule=job["schedule"], + tags=["mantaray"], + start_date=datetime.datetime(2024, 4, 22), + catchup=False, + ) as dag: + run_workload = mantaray.run_workload( + workload_file_name=job["file_name"], + ) run_workload else: print(