diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 07ad9a96..e0e22d73 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,13 +1,15 @@ # Default owners for everything in the repo, unless a later match takes precedence. -* @mbzomowski @RissyRan @allenwang28 +* @richardsliu @yixinshi @RissyRan @allenwang28 + +dags/common @richardsliu @yixinshi @RissyRan @allenwang28 @zpcore @ManfeiBai @gobbleturk @shralex @jiangjy1982 @ortibazar @parambole @vipannalla @crankshaw-google @polydier1 dags/solutions_team/configs/tensorflow @chandrasekhard2 @ZhaoyueCheng @richardsliu dags/solutions_team/solutionsteam_tf* @chandrasekhard2 @ZhaoyueCheng @richardsliu -dags/pytorch_xla @JackCaoG @vanbasten23 @zpcore @ManfeiBai -dags/legacy_test/tests/pytorch @JackCaoG @vanbasten23 @zpcore @ManfeiBai +dags/pytorch_xla @vanbasten23 @zpcore @ManfeiBai +dags/legacy_test/tests/pytorch @vanbasten23 @zpcore @ManfeiBai -dags/multipod @jonb377 @tonyjohnchen @raymondzouu @gobbleturk @shralex @RissyRan @jiangjy1982 +dags/multipod @tonyjohnchen @raymondzouu @gobbleturk @shralex @RissyRan @jiangjy1982 dags/mlcompass @ortibazar @sganeshb @brajiang @wlzhg @@ -15,6 +17,6 @@ dags/sparsity_diffusion_devx @RissyRan @parambole @jiangjy1982 @aireenmei @miche dags/sparsity_diffusion_devx/project_bite* @RissyRan @parambole @jiangjy1982 @aireenmei @michelle-yooh @jiya-zhang dags/sparsity_diffusion_devx/configs/project_bite* @RissyRan @parambole @jiangjy1982 @aireenmei @michelle-yooh @jiya-zhang -dags/inference @yeandy @vipannalla @morgandu @mailvijayasingh @sixiang-google @joezijunzhou @singh-mitali +dags/inference @vipannalla @mailvijayasingh @sixiang-google @joezijunzhou @singh-mitali dags/map_reproducibility @crankshaw-google @polydier1 diff --git a/dags/common/__init__.py b/dags/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dags/quarantined_tests.py b/dags/common/quarantined_tests.py similarity index 99% rename from dags/quarantined_tests.py rename to dags/common/quarantined_tests.py index a5ba9f6b..e1f9f35c 100644 --- a/dags/quarantined_tests.py +++ b/dags/common/quarantined_tests.py @@ -15,7 +15,7 @@ """Lists all currently broken tests.""" import dataclasses -from dags.test_owner import Team as team +from dags.common.test_owner import Team as team @dataclasses.dataclass diff --git a/dags/test_owner.py b/dags/common/test_owner.py similarity index 100% rename from dags/test_owner.py rename to dags/common/test_owner.py diff --git a/dags/vm_resource.py b/dags/common/vm_resource.py similarity index 100% rename from dags/vm_resource.py rename to dags/common/vm_resource.py diff --git a/dags/examples/configs/xpk_example_config.py b/dags/examples/configs/xpk_example_config.py index 3f031f3c..5f3fad0c 100644 --- a/dags/examples/configs/xpk_example_config.py +++ b/dags/examples/configs/xpk_example_config.py @@ -17,8 +17,8 @@ import datetime from xlml.apis import gcp_config, metric_config, task, test_config from xlml.apis.xpk_cluster_config import XpkClusterConfig -from dags import test_owner -from dags.vm_resource import TpuVersion +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion def get_flax_resnet_xpk_config( diff --git a/dags/examples/maxtext_aqtp_version_sweep_gke_example_dag.py b/dags/examples/maxtext_aqtp_version_sweep_gke_example_dag.py index d2da7954..98045666 100644 --- a/dags/examples/maxtext_aqtp_version_sweep_gke_example_dag.py +++ b/dags/examples/maxtext_aqtp_version_sweep_gke_example_dag.py @@ -20,8 +20,8 @@ import datetime from airflow import models -from dags import test_owner -from dags.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage from dags.multipod.configs import maxtext_sweep_gke_config # Set concurrency to number of workers otherwise tasks may time out diff --git a/dags/examples/maxtext_sweep_gke_example_dag.py b/dags/examples/maxtext_sweep_gke_example_dag.py index 316a0fb6..fe7c86be 100644 --- a/dags/examples/maxtext_sweep_gke_example_dag.py +++ b/dags/examples/maxtext_sweep_gke_example_dag.py @@ -20,8 +20,8 @@ import datetime from airflow import models -from dags import test_owner -from dags.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage from dags.multipod.configs import maxtext_sweep_gke_config # Set concurrency to number of workers otherwise tasks may time out diff --git a/dags/examples/xpk_example_dag.py b/dags/examples/xpk_example_dag.py index 1ddf4a32..fc9506d2 100644 --- a/dags/examples/xpk_example_dag.py +++ b/dags/examples/xpk_example_dag.py @@ -16,9 +16,9 @@ import datetime from airflow import models -from dags.vm_resource import TpuVersion, Project, Zone, XpkClusters, DockerImage +from dags.common.vm_resource import TpuVersion, Project, Zone, XpkClusters, DockerImage from dags.examples.configs import xpk_example_config as config -from dags import test_owner +from dags.common import test_owner from xlml.utils import name_format from airflow.utils.task_group import TaskGroup diff --git a/dags/framework3p/configs/microbenchmarks_config.py b/dags/framework3p/configs/microbenchmarks_config.py index 8f9e57a6..f2fb215e 100644 --- a/dags/framework3p/configs/microbenchmarks_config.py +++ b/dags/framework3p/configs/microbenchmarks_config.py @@ -1,7 +1,8 @@ +from dags.common import test_owner from xlml.apis import gcp_config, metric_config, task, test_config -from dags import gcs_bucket, test_owner +from dags import gcs_bucket import datetime -import dags.vm_resource as resource +import dags.common.vm_resource as resource def get_microbenchmark_config( diff --git a/dags/framework3p/microbenchmarks_dag.py b/dags/framework3p/microbenchmarks_dag.py index 8620d2db..9fe1dbd7 100644 --- a/dags/framework3p/microbenchmarks_dag.py +++ b/dags/framework3p/microbenchmarks_dag.py @@ -1,6 +1,7 @@ import datetime from airflow import models -from dags import composer_env, vm_resource, test_owner +from dags import composer_env +from dags.common import test_owner, vm_resource from dags.framework3p.configs.microbenchmarks_config import get_microbenchmark_config, get_microbenchmark_xpk_config diff --git a/dags/inference/configs/jetstream_benchmark_serving_gce_config.py b/dags/inference/configs/jetstream_benchmark_serving_gce_config.py index f0314a93..d6f87b70 100644 --- a/dags/inference/configs/jetstream_benchmark_serving_gce_config.py +++ b/dags/inference/configs/jetstream_benchmark_serving_gce_config.py @@ -18,9 +18,9 @@ import json from typing import Dict from xlml.apis import gcp_config, metric_config, task, test_config -from dags import test_owner +from dags.common import test_owner from dags.multipod.configs import common -from dags.vm_resource import TpuVersion, Project, RuntimeVersion +from dags.common.vm_resource import TpuVersion, Project, RuntimeVersion PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value diff --git a/dags/inference/configs/jetstream_pytorch_gce_config.py b/dags/inference/configs/jetstream_pytorch_gce_config.py index 42201e8f..cc101aa0 100644 --- a/dags/inference/configs/jetstream_pytorch_gce_config.py +++ b/dags/inference/configs/jetstream_pytorch_gce_config.py @@ -18,9 +18,9 @@ import json from typing import Dict from xlml.apis import gcp_config, metric_config, task, test_config -from dags import test_owner +from dags.common import test_owner from dags.multipod.configs import common -from dags.vm_resource import TpuVersion, Project, RuntimeVersion +from dags.common.vm_resource import TpuVersion, Project, RuntimeVersion PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value diff --git a/dags/inference/configs/maxtext_inference_microbenchmark_gce_config.py b/dags/inference/configs/maxtext_inference_microbenchmark_gce_config.py index ffaac119..fd9bb75a 100644 --- a/dags/inference/configs/maxtext_inference_microbenchmark_gce_config.py +++ b/dags/inference/configs/maxtext_inference_microbenchmark_gce_config.py @@ -18,9 +18,9 @@ import json from typing import Dict from xlml.apis import gcp_config, metric_config, task, test_config -from dags import test_owner +from dags.common import test_owner from dags.multipod.configs import common -from dags.vm_resource import TpuVersion, Project, RuntimeVersion +from dags.common.vm_resource import TpuVersion, Project, RuntimeVersion PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value diff --git a/dags/inference/configs/trt_llm_inference_config.py b/dags/inference/configs/trt_llm_inference_config.py index 913f29b6..c08a9132 100644 --- a/dags/inference/configs/trt_llm_inference_config.py +++ b/dags/inference/configs/trt_llm_inference_config.py @@ -15,9 +15,10 @@ """Utilities to construct configs for TensorRT-LLM inference DAG.""" import datetime +from dags.common import test_owner from xlml.apis import gcp_config, metric_config, task, test_config -from dags import test_owner, vm_resource -from dags.vm_resource import Project, RuntimeVersion +from dags.common import vm_resource +from dags.common.vm_resource import Project, RuntimeVersion RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value GCS_SUBFOLDER_PREFIX = test_owner.Team.INFERENCE.value diff --git a/dags/inference/configs/trt_llm_mlperf_v40_config.py b/dags/inference/configs/trt_llm_mlperf_v40_config.py index 5eeee49a..d50b000a 100644 --- a/dags/inference/configs/trt_llm_mlperf_v40_config.py +++ b/dags/inference/configs/trt_llm_mlperf_v40_config.py @@ -16,9 +16,10 @@ import datetime from typing import Dict +from dags.common import test_owner from xlml.apis import gcp_config, metric_config, task, test_config -from dags import test_owner, vm_resource -from dags.vm_resource import Project, RuntimeVersion +from dags.common import vm_resource +from dags.common.vm_resource import Project, RuntimeVersion RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value GCS_SUBFOLDER_PREFIX = test_owner.Team.INFERENCE.value diff --git a/dags/inference/configs/trt_llm_mlperf_v41_config.py b/dags/inference/configs/trt_llm_mlperf_v41_config.py index 00295581..7371739a 100644 --- a/dags/inference/configs/trt_llm_mlperf_v41_config.py +++ b/dags/inference/configs/trt_llm_mlperf_v41_config.py @@ -16,9 +16,10 @@ import datetime from typing import Dict, List +from dags.common import test_owner from xlml.apis import gcp_config, metric_config, task, test_config -from dags import test_owner, vm_resource -from dags.vm_resource import GpuVersion, Project, RuntimeVersion +from dags.common import vm_resource +from dags.common.vm_resource import GpuVersion, Project, RuntimeVersion RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value GCS_SUBFOLDER_PREFIX = test_owner.Team.INFERENCE.value diff --git a/dags/inference/jetstream_inference_e2e.py b/dags/inference/jetstream_inference_e2e.py index 749a9316..68be011c 100644 --- a/dags/inference/jetstream_inference_e2e.py +++ b/dags/inference/jetstream_inference_e2e.py @@ -16,7 +16,7 @@ import datetime from airflow import models -from dags.vm_resource import TpuVersion +from dags.common.vm_resource import TpuVersion from dags.inference.maxtext_model_config_generator import generate_model_configs """A JetStream inference E2E test (JAX nightly, no schedule) DAG. diff --git a/dags/inference/jetstream_pytorch_inference.py b/dags/inference/jetstream_pytorch_inference.py index 52cc5d01..0b95db5c 100644 --- a/dags/inference/jetstream_pytorch_inference.py +++ b/dags/inference/jetstream_pytorch_inference.py @@ -3,8 +3,9 @@ import datetime from airflow import models from airflow.models.baseoperator import chain -from dags import composer_env, test_owner -from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK +from dags import composer_env +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK from dags.inference.configs import jetstream_pytorch_gce_config from dags.multipod.configs.common import SetupMode, Platform import numpy as np diff --git a/dags/inference/maxtext_inference.py b/dags/inference/maxtext_inference.py index c737682d..ae55e2ae 100644 --- a/dags/inference/maxtext_inference.py +++ b/dags/inference/maxtext_inference.py @@ -18,7 +18,7 @@ import numpy as np from airflow import models from dags import composer_env -from dags.vm_resource import TpuVersion +from dags.common.vm_resource import TpuVersion from dags.inference.maxtext_model_config_generator import generate_model_configs USER_PREFIX = "" diff --git a/dags/inference/maxtext_inference_microbenchmark.py b/dags/inference/maxtext_inference_microbenchmark.py index 519a93fa..e6588768 100644 --- a/dags/inference/maxtext_inference_microbenchmark.py +++ b/dags/inference/maxtext_inference_microbenchmark.py @@ -19,7 +19,7 @@ import itertools import numpy from airflow import models -from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK +from dags.common.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK from dags.inference.configs import maxtext_inference_microbenchmark_gce_config from dags.multipod.configs.common import SetupMode diff --git a/dags/inference/maxtext_inference_offline_benchmark.py b/dags/inference/maxtext_inference_offline_benchmark.py index a5e70b02..c306b19b 100644 --- a/dags/inference/maxtext_inference_offline_benchmark.py +++ b/dags/inference/maxtext_inference_offline_benchmark.py @@ -17,8 +17,9 @@ import datetime from airflow import models -from dags import test_owner, composer_env -from dags.vm_resource import TpuVersion, Zone, Project, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK +from dags import composer_env +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, Project, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK from dags.multipod.configs import common from dags.multipod.configs.common import SetupMode from xlml.apis import gcp_config, metric_config, task, test_config diff --git a/dags/inference/maxtext_model_config_generator.py b/dags/inference/maxtext_model_config_generator.py index 085f3d2a..c7d4927c 100644 --- a/dags/inference/maxtext_model_config_generator.py +++ b/dags/inference/maxtext_model_config_generator.py @@ -14,7 +14,7 @@ """A helper to generate maxtext model configs.""" -from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK +from dags.common.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK from dags.inference.configs import jetstream_benchmark_serving_gce_config from dags.multipod.configs.common import SetupMode diff --git a/dags/inference/trt_llm_inference.py b/dags/inference/trt_llm_inference.py index 7b0f5b97..342e10b6 100644 --- a/dags/inference/trt_llm_inference.py +++ b/dags/inference/trt_llm_inference.py @@ -17,7 +17,7 @@ import datetime from airflow import models from dags import composer_env -from dags.vm_resource import H100_INFERENCE_SUBNETWORKS, INFERENCE_NETWORKS, GpuVersion, Zone, ImageFamily, ImageProject, MachineVersion, Project +from dags.common.vm_resource import H100_INFERENCE_SUBNETWORKS, INFERENCE_NETWORKS, GpuVersion, Zone, ImageFamily, ImageProject, MachineVersion, Project from dags.inference.configs import trt_llm_inference_config # Run once a day at 4 am UTC (8 pm PST) diff --git a/dags/inference/trt_llm_mlperf_v40_inference.py b/dags/inference/trt_llm_mlperf_v40_inference.py index 4d716ea4..2c4d2152 100644 --- a/dags/inference/trt_llm_mlperf_v40_inference.py +++ b/dags/inference/trt_llm_mlperf_v40_inference.py @@ -17,7 +17,7 @@ import datetime from airflow import models from dags import composer_env -from dags.vm_resource import GpuVersion, Zone, ImageFamily, ImageProject, MachineVersion, Project, INFERENCE_NETWORKS, H100_INFERENCE_SUBNETWORKS +from dags.common.vm_resource import GpuVersion, Zone, ImageFamily, ImageProject, MachineVersion, Project, INFERENCE_NETWORKS, H100_INFERENCE_SUBNETWORKS from dags.inference.configs import trt_llm_mlperf_v40_config # Run once a day at 4 am UTC (8 pm PST) diff --git a/dags/inference/trt_llm_mlperf_v41_inference.py b/dags/inference/trt_llm_mlperf_v41_inference.py index 5546324d..15948e80 100644 --- a/dags/inference/trt_llm_mlperf_v41_inference.py +++ b/dags/inference/trt_llm_mlperf_v41_inference.py @@ -17,7 +17,7 @@ import datetime from airflow import models from dags import composer_env -from dags.vm_resource import A100_INFERENCE_SUBNETWORKS, H100_INFERENCE_SUBNETWORKS, GpuVersion, Zone, ImageFamily, ImageProject, MachineVersion, Project, INFERENCE_NETWORKS, L4_INFERENCE_SUBNETWORKS +from dags.common.vm_resource import A100_INFERENCE_SUBNETWORKS, H100_INFERENCE_SUBNETWORKS, GpuVersion, Zone, ImageFamily, ImageProject, MachineVersion, Project, INFERENCE_NETWORKS, L4_INFERENCE_SUBNETWORKS from dags.inference.configs import trt_llm_mlperf_v41_config # Run once a day at 1 pm UTC (5 am PST) diff --git a/dags/infra/clean_up.py b/dags/infra/clean_up.py index ee522ae4..1ef2cea6 100644 --- a/dags/infra/clean_up.py +++ b/dags/infra/clean_up.py @@ -17,7 +17,7 @@ import datetime from airflow import models from dags import composer_env -from dags.vm_resource import Project, Zone +from dags.common.vm_resource import Project, Zone from xlml.utils import tpu diff --git a/dags/mlcompass/configs/simple_config.py b/dags/mlcompass/configs/simple_config.py index ffcedb08..10e6f863 100644 --- a/dags/mlcompass/configs/simple_config.py +++ b/dags/mlcompass/configs/simple_config.py @@ -16,8 +16,8 @@ import datetime from xlml.apis import gcp_config, metric_config, task, test_config -from dags import test_owner -from dags.vm_resource import TpuVersion, Zone, Project, RuntimeVersion +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, Project, RuntimeVersion def get_simple_config(): diff --git a/dags/mlcompass/maxtext_gke.py b/dags/mlcompass/maxtext_gke.py index 879f40b9..31a0a55f 100644 --- a/dags/mlcompass/maxtext_gke.py +++ b/dags/mlcompass/maxtext_gke.py @@ -28,7 +28,7 @@ from airflow import models from airflow.decorators import task from airflow.providers.google.cloud.hooks.gcs import GCSHook -from dags import test_owner +from dags.common import test_owner from xlml.utils import xpk diff --git a/dags/multipod/configs/gke_config.py b/dags/multipod/configs/gke_config.py index 6227f3cc..cecbd5f5 100644 --- a/dags/multipod/configs/gke_config.py +++ b/dags/multipod/configs/gke_config.py @@ -14,10 +14,11 @@ """Utilities to construct configs for maxtext DAG on GKE.""" +from dags.common import test_owner from xlml.apis import gcp_config, metric_config, task, test_config from xlml.apis.xpk_cluster_config import XpkClusterConfig -from dags import test_owner, gcs_bucket -from dags.vm_resource import TpuVersion, Project, XpkClusters, GpuVersion, CpuVersion +from dags import gcs_bucket +from dags.common.vm_resource import TpuVersion, Project, XpkClusters, GpuVersion, CpuVersion from typing import Iterable import datetime diff --git a/dags/multipod/configs/jax_tests_gce_config.py b/dags/multipod/configs/jax_tests_gce_config.py index 9d785462..6fadda20 100644 --- a/dags/multipod/configs/jax_tests_gce_config.py +++ b/dags/multipod/configs/jax_tests_gce_config.py @@ -15,9 +15,9 @@ """Utilities to construct configs for JAX tests for GCE.""" from xlml.apis import gcp_config, metric_config, task, test_config -from dags import test_owner +from dags.common import test_owner from dags.multipod.configs import common -from dags.vm_resource import TpuVersion, Project, RuntimeVersion +from dags.common.vm_resource import TpuVersion, Project, RuntimeVersion import datetime PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value diff --git a/dags/multipod/configs/jax_tests_gke_config.py b/dags/multipod/configs/jax_tests_gke_config.py index 183c4d6e..4bb77f6f 100644 --- a/dags/multipod/configs/jax_tests_gke_config.py +++ b/dags/multipod/configs/jax_tests_gke_config.py @@ -14,9 +14,9 @@ """Utilities to construct configs for JAX tests for GCE.""" -from dags import test_owner +from dags.common import test_owner from dags.multipod.configs import gke_config -from dags.vm_resource import XpkClusterConfig +from dags.common.vm_resource import XpkClusterConfig def get_jax_distributed_initialize_config( diff --git a/dags/multipod/configs/legacy_unit_test.py b/dags/multipod/configs/legacy_unit_test.py index 949658a6..0e81bf31 100644 --- a/dags/multipod/configs/legacy_unit_test.py +++ b/dags/multipod/configs/legacy_unit_test.py @@ -20,9 +20,9 @@ from xlml.apis.xpk_cluster_config import XpkClusterConfig from base64 import b64encode from collections.abc import Iterable -from dags import test_owner +from dags.common import test_owner from dags.multipod.configs import common -from dags.vm_resource import TpuVersion, Project, RuntimeVersion, XpkClusters +from dags.common.vm_resource import TpuVersion, Project, RuntimeVersion, XpkClusters def get_legacy_unit_test_config( diff --git a/dags/multipod/configs/maxtext_gce_config.py b/dags/multipod/configs/maxtext_gce_config.py index a7a5bd52..8b425146 100644 --- a/dags/multipod/configs/maxtext_gce_config.py +++ b/dags/multipod/configs/maxtext_gce_config.py @@ -14,10 +14,11 @@ """Utilities to construct configs for maxtext DAG.""" +from dags.common import test_owner from xlml.apis import gcp_config, metric_config, task, test_config -from dags import test_owner, gcs_bucket +from dags import gcs_bucket from dags.multipod.configs import common -from dags.vm_resource import TpuVersion, Project, RuntimeVersion +from dags.common.vm_resource import TpuVersion, Project, RuntimeVersion import datetime PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value diff --git a/dags/multipod/configs/maxtext_sweep_gke_config.py b/dags/multipod/configs/maxtext_sweep_gke_config.py index 60a09032..d712ab7a 100644 --- a/dags/multipod/configs/maxtext_sweep_gke_config.py +++ b/dags/multipod/configs/maxtext_sweep_gke_config.py @@ -16,7 +16,7 @@ import datetime from xlml.apis import gcp_config, metric_config, task, test_config -from dags.vm_resource import TpuVersion, XpkClusterConfig +from dags.common.vm_resource import TpuVersion, XpkClusterConfig import itertools from typing import List, Iterable, Dict, Any diff --git a/dags/multipod/configs/mxla_collective_config.py b/dags/multipod/configs/mxla_collective_config.py index 85dea755..594c9673 100644 --- a/dags/multipod/configs/mxla_collective_config.py +++ b/dags/multipod/configs/mxla_collective_config.py @@ -14,10 +14,11 @@ """Utilities to construct configs for maxtext DAG.""" +from dags.common import test_owner from xlml.apis import gcp_config, metric_config, task, test_config -from dags import test_owner, gcs_bucket +from dags import gcs_bucket from dags.multipod.configs import common -from dags.vm_resource import TpuVersion, Project, RuntimeVersion +from dags.common.vm_resource import TpuVersion, Project, RuntimeVersion import datetime PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value diff --git a/dags/multipod/configs/pytorch_config.py b/dags/multipod/configs/pytorch_config.py index 702e4321..a5992ebe 100644 --- a/dags/multipod/configs/pytorch_config.py +++ b/dags/multipod/configs/pytorch_config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dags.vm_resource import TpuVersion, Zone, DockerImage +from dags.common.vm_resource import TpuVersion, Zone, DockerImage from dags.multipod.configs import gke_config from xlml.apis.xpk_cluster_config import XpkClusterConfig from xlml.apis import task diff --git a/dags/multipod/jax_functional_tests.py b/dags/multipod/jax_functional_tests.py index 1b88a0fc..2dc2af3c 100644 --- a/dags/multipod/jax_functional_tests.py +++ b/dags/multipod/jax_functional_tests.py @@ -17,7 +17,7 @@ import datetime from airflow import models from dags import composer_env -from dags.vm_resource import DockerImage, TpuVersion, Zone, Project, V5_NETWORKS, V5P_SUBNETWORKS, RuntimeVersion, XpkClusters +from dags.common.vm_resource import DockerImage, TpuVersion, Zone, Project, V5_NETWORKS, V5P_SUBNETWORKS, RuntimeVersion, XpkClusters from dags.multipod.configs import jax_tests_gce_config, jax_tests_gke_config from dags.multipod.configs.common import SetupMode diff --git a/dags/multipod/legacy.py b/dags/multipod/legacy.py index 996da8d9..65bcb870 100644 --- a/dags/multipod/legacy.py +++ b/dags/multipod/legacy.py @@ -16,8 +16,9 @@ import datetime from airflow import models -from dags import composer_env, gcs_bucket, test_owner -from dags.vm_resource import TpuVersion, Zone, Project, DockerImage, XpkClusters +from dags import composer_env, gcs_bucket +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, Project, DockerImage, XpkClusters from dags.multipod.configs import legacy_unit_test, gke_config from dags.multipod.configs.common import SetupMode, Platform diff --git a/dags/multipod/maxtext_checkpointing.py b/dags/multipod/maxtext_checkpointing.py index 64d54282..8a9c60f1 100644 --- a/dags/multipod/maxtext_checkpointing.py +++ b/dags/multipod/maxtext_checkpointing.py @@ -17,8 +17,9 @@ """ import datetime from airflow import models -from dags import composer_env, test_owner, gcs_bucket -from dags.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters +from dags import composer_env, gcs_bucket +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters from dags.multipod.configs import gke_config from dags.multipod.configs.common import SetupMode diff --git a/dags/multipod/maxtext_configs_aot.py b/dags/multipod/maxtext_configs_aot.py index 224c7bfd..d9d463a5 100644 --- a/dags/multipod/maxtext_configs_aot.py +++ b/dags/multipod/maxtext_configs_aot.py @@ -18,8 +18,9 @@ import datetime from airflow import models from airflow.utils.task_group import TaskGroup -from dags import composer_env, test_owner -from dags.vm_resource import GpuVersion, TpuVersion, Zone, DockerImage, XpkClusters +from dags import composer_env +from dags.common import test_owner +from dags.common.vm_resource import GpuVersion, TpuVersion, Zone, DockerImage, XpkClusters from dags.multipod.configs import gke_config from dags.multipod.configs.common import SetupMode diff --git a/dags/multipod/maxtext_configs_aot_hybridsim.py b/dags/multipod/maxtext_configs_aot_hybridsim.py index 0afa4608..2ae44ed9 100644 --- a/dags/multipod/maxtext_configs_aot_hybridsim.py +++ b/dags/multipod/maxtext_configs_aot_hybridsim.py @@ -18,9 +18,10 @@ import datetime from airflow import models from airflow.utils.task_group import TaskGroup -from dags import composer_env, test_owner -from dags.quarantined_tests import QuarantineTests -from dags.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters, Project +from dags import composer_env +from dags.common.quarantined_tests import QuarantineTests +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters, Project from dags.multipod.configs import gke_config from xlml.utils import name_format from dags.multipod.configs import gke_config diff --git a/dags/multipod/maxtext_convergence.py b/dags/multipod/maxtext_convergence.py index 813123c6..b68f2c20 100644 --- a/dags/multipod/maxtext_convergence.py +++ b/dags/multipod/maxtext_convergence.py @@ -17,8 +17,9 @@ """ import datetime from airflow import models -from dags import composer_env, test_owner, gcs_bucket -from dags.vm_resource import XpkClusters, DockerImage, Project, TpuVersion, Zone +from dags import composer_env, gcs_bucket +from dags.common import test_owner +from dags.common.vm_resource import XpkClusters, DockerImage, Project, TpuVersion, Zone from dags.multipod.configs import gke_config from dags.multipod.configs.common import SetupMode from xlml.apis import gcp_config, metric_config, task, test_config diff --git a/dags/multipod/maxtext_end_to_end.py b/dags/multipod/maxtext_end_to_end.py index 7670a28b..2bea0ed6 100644 --- a/dags/multipod/maxtext_end_to_end.py +++ b/dags/multipod/maxtext_end_to_end.py @@ -18,9 +18,10 @@ import datetime from airflow import models from airflow.utils.task_group import TaskGroup -from dags import composer_env, test_owner -from dags.quarantined_tests import QuarantineTests -from dags.vm_resource import XpkClusters, DockerImage +from dags import composer_env +from dags.common.quarantined_tests import QuarantineTests +from dags.common import test_owner +from dags.common.vm_resource import XpkClusters, DockerImage from dags.multipod.configs import gke_config from xlml.utils import name_format diff --git a/dags/multipod/maxtext_gpu_end_to_end.py b/dags/multipod/maxtext_gpu_end_to_end.py index c99256a5..cde04693 100644 --- a/dags/multipod/maxtext_gpu_end_to_end.py +++ b/dags/multipod/maxtext_gpu_end_to_end.py @@ -18,8 +18,9 @@ import datetime from airflow import models from airflow.utils.task_group import TaskGroup -from dags import composer_env, test_owner -from dags.vm_resource import XpkClusters, CpuVersion, DockerImage, GpuVersion, Project, TpuVersion, Zone +from dags import composer_env +from dags.common import test_owner +from dags.common.vm_resource import XpkClusters, CpuVersion, DockerImage, GpuVersion, Project, TpuVersion, Zone from dags.multipod.configs import gke_config from xlml.utils import name_format diff --git a/dags/multipod/maxtext_profiling.py b/dags/multipod/maxtext_profiling.py index ed427779..bb64ec4c 100644 --- a/dags/multipod/maxtext_profiling.py +++ b/dags/multipod/maxtext_profiling.py @@ -17,8 +17,9 @@ """ import datetime from airflow import models -from dags import composer_env, test_owner, gcs_bucket -from dags.vm_resource import TpuVersion, Zone, DockerImage +from dags import composer_env, gcs_bucket +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, DockerImage from dags.multipod.configs import gke_config from dags.multipod.configs.common import SetupMode diff --git a/dags/multipod/maxtext_profiling_vertex_ai_tensorboard.py b/dags/multipod/maxtext_profiling_vertex_ai_tensorboard.py index a8b2996d..eee8ba91 100644 --- a/dags/multipod/maxtext_profiling_vertex_ai_tensorboard.py +++ b/dags/multipod/maxtext_profiling_vertex_ai_tensorboard.py @@ -17,8 +17,9 @@ """ import datetime from airflow import models -from dags import composer_env, test_owner, gcs_bucket -from dags.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters +from dags import composer_env, gcs_bucket +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters from dags.multipod.configs import gke_config from dags.multipod.configs.common import SetupMode diff --git a/dags/multipod/maxtext_trillium_configs_perf.py b/dags/multipod/maxtext_trillium_configs_perf.py index 7fc8a107..d12e3bab 100644 --- a/dags/multipod/maxtext_trillium_configs_perf.py +++ b/dags/multipod/maxtext_trillium_configs_perf.py @@ -18,8 +18,9 @@ import datetime from airflow import models from airflow.utils.task_group import TaskGroup -from dags import composer_env, test_owner -from dags.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage +from dags import composer_env +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage from dags.multipod.configs import maxtext_sweep_gke_config from dags.multipod.configs.common import SetupMode from xlml.apis import metric_config diff --git a/dags/multipod/maxtext_v5e_configs_perf.py b/dags/multipod/maxtext_v5e_configs_perf.py index adf7b4c0..10a9da51 100644 --- a/dags/multipod/maxtext_v5e_configs_perf.py +++ b/dags/multipod/maxtext_v5e_configs_perf.py @@ -18,8 +18,9 @@ import datetime from airflow import models from airflow.utils.task_group import TaskGroup -from dags import composer_env, test_owner -from dags.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage +from dags import composer_env +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage from dags.multipod.configs import maxtext_sweep_gke_config from dags.multipod.configs.common import SetupMode from xlml.apis import metric_config diff --git a/dags/multipod/mxla_collective_nightly.py b/dags/multipod/mxla_collective_nightly.py index c5a44ac8..e9360759 100644 --- a/dags/multipod/mxla_collective_nightly.py +++ b/dags/multipod/mxla_collective_nightly.py @@ -17,7 +17,7 @@ import datetime from airflow import models from dags import composer_env -from dags.vm_resource import TpuVersion, Zone +from dags.common.vm_resource import TpuVersion, Zone from dags.multipod.configs import mxla_collective_config from dags.multipod.configs.common import SetupMode, Platform diff --git a/dags/multipod/mxla_gpt3_6b_nightly_gke.py b/dags/multipod/mxla_gpt3_6b_nightly_gke.py index 64edc164..0b142285 100644 --- a/dags/multipod/mxla_gpt3_6b_nightly_gke.py +++ b/dags/multipod/mxla_gpt3_6b_nightly_gke.py @@ -18,8 +18,9 @@ import datetime from airflow import models from airflow.utils.task_group import TaskGroup -from dags import composer_env, test_owner -from dags.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters, Project +from dags import composer_env +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters, Project from dags.multipod.configs import gke_config # Run once a day at 9 am UTC (1 am PST) diff --git a/dags/multipod/mxla_maxtext_nightly.py b/dags/multipod/mxla_maxtext_nightly.py index 395bf5fa..4018ceb3 100644 --- a/dags/multipod/mxla_maxtext_nightly.py +++ b/dags/multipod/mxla_maxtext_nightly.py @@ -17,7 +17,7 @@ import datetime from airflow import models from dags import composer_env -from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5P_SUBNETWORKS, RuntimeVersion +from dags.common.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5P_SUBNETWORKS, RuntimeVersion from dags.multipod.configs import maxtext_gce_config from dags.multipod.configs.common import SetupMode, Platform diff --git a/dags/multipod/mxla_maxtext_nightly_gke.py b/dags/multipod/mxla_maxtext_nightly_gke.py index f0d4ab60..760a25d2 100644 --- a/dags/multipod/mxla_maxtext_nightly_gke.py +++ b/dags/multipod/mxla_maxtext_nightly_gke.py @@ -18,8 +18,9 @@ import datetime from airflow import models from airflow.utils.task_group import TaskGroup -from dags import composer_env, test_owner -from dags.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters, Project +from dags import composer_env +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters, Project from dags.multipod.configs import gke_config # Run once a day at 9 am UTC (1 am PST) diff --git a/dags/multipod/pytorch.py b/dags/multipod/pytorch.py index 07346a9b..079243dc 100644 --- a/dags/multipod/pytorch.py +++ b/dags/multipod/pytorch.py @@ -17,8 +17,9 @@ """ import datetime from airflow import models -from dags import composer_env, test_owner -from dags.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters +from dags import composer_env +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters from dags.multipod.configs import pytorch_config from xlml.apis import metric_config diff --git a/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py b/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py index 475122c0..f0f9407c 100644 --- a/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py +++ b/dags/pytorch_xla/configs/pytorchxla_torchbench_config.py @@ -18,8 +18,8 @@ import enum from typing import Tuple from xlml.apis import gcp_config, metric_config, task, test_config -import dags.vm_resource as resource -from dags import test_owner +import dags.common.vm_resource as resource +from dags.common import test_owner GCS_SUBFOLDER_PREFIX = test_owner.Team.PYTORCH_XLA.value diff --git a/dags/pytorch_xla/nightly.py b/dags/pytorch_xla/nightly.py index 11c40822..7df8b36c 100644 --- a/dags/pytorch_xla/nightly.py +++ b/dags/pytorch_xla/nightly.py @@ -17,7 +17,7 @@ from airflow import models from xlml.apis import gcp_config, metric_config, task, test_config from dags import composer_env -from dags.vm_resource import Project, Zone, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, V6E_SUBNETWORKS +from dags.common.vm_resource import Project, Zone, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, V6E_SUBNETWORKS # Run once a day at 2 pm UTC (6 am PST) diff --git a/dags/pytorch_xla/pytorchxla-torchbench-release.py b/dags/pytorch_xla/pytorchxla-torchbench-release.py index 722a4d56..41f313c3 100644 --- a/dags/pytorch_xla/pytorchxla-torchbench-release.py +++ b/dags/pytorch_xla/pytorchxla-torchbench-release.py @@ -18,7 +18,7 @@ import datetime from dags import composer_env from dags.pytorch_xla.configs import pytorchxla_torchbench_config as config -import dags.vm_resource as resource +import dags.common.vm_resource as resource SCHEDULED_TIME = None diff --git a/dags/pytorch_xla/pytorchxla2_torchbench.py b/dags/pytorch_xla/pytorchxla2_torchbench.py index 0f432ff5..9ab19fd3 100644 --- a/dags/pytorch_xla/pytorchxla2_torchbench.py +++ b/dags/pytorch_xla/pytorchxla2_torchbench.py @@ -18,7 +18,7 @@ import datetime from dags import composer_env from dags.pytorch_xla.configs import pytorchxla_torchbench_config as config -import dags.vm_resource as resource +import dags.common.vm_resource as resource # Schudule the job to run everyday at 3:00AM PST (11:00AM UTC). SCHEDULED_TIME = "0 11 * * *" if composer_env.is_prod_env() else None diff --git a/dags/pytorch_xla/pytorchxla_torchbench.py b/dags/pytorch_xla/pytorchxla_torchbench.py index 82046c7b..e580c396 100644 --- a/dags/pytorch_xla/pytorchxla_torchbench.py +++ b/dags/pytorch_xla/pytorchxla_torchbench.py @@ -18,7 +18,7 @@ import datetime from dags import composer_env from dags.pytorch_xla.configs import pytorchxla_torchbench_config as config -import dags.vm_resource as resource +import dags.common.vm_resource as resource # Schudule the job to run everyday at 3:00AM PST (11:00AM UTC). SCHEDULED_TIME = "0 11 * * *" if composer_env.is_prod_env() else None diff --git a/dags/pytorch_xla/r2_6.py b/dags/pytorch_xla/r2_6.py index 3dcb54ed..14f68192 100644 --- a/dags/pytorch_xla/r2_6.py +++ b/dags/pytorch_xla/r2_6.py @@ -17,7 +17,7 @@ from airflow import models from xlml.apis import gcp_config, metric_config, task, test_config from dags import composer_env -from dags.vm_resource import Project, Zone, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, V6E_SUBNETWORKS +from dags.common.vm_resource import Project, Zone, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, V6E_SUBNETWORKS # Run once a day at 2 pm UTC (6 am PST) diff --git a/dags/solutions_team/configs/tensorflow/solutionsteam_tf_nightly_supported_config.py b/dags/solutions_team/configs/tensorflow/solutionsteam_tf_nightly_supported_config.py index 3a4d78de..b9e249b0 100644 --- a/dags/solutions_team/configs/tensorflow/solutionsteam_tf_nightly_supported_config.py +++ b/dags/solutions_team/configs/tensorflow/solutionsteam_tf_nightly_supported_config.py @@ -16,11 +16,12 @@ import time import datetime +from dags.common import test_owner from xlml.apis import gcp_config, metric_config, task, test_config -from dags import gcs_bucket, test_owner +from dags import gcs_bucket from dags.solutions_team.configs.tensorflow import common from airflow.models import Variable -from dags.vm_resource import TpuVersion, Project, RuntimeVersion +from dags.common.vm_resource import TpuVersion, Project, RuntimeVersion from typing import List diff --git a/dags/solutions_team/configs/tensorflow/solutionsteam_tf_release_supported_config.py b/dags/solutions_team/configs/tensorflow/solutionsteam_tf_release_supported_config.py index e647327d..87979bdc 100644 --- a/dags/solutions_team/configs/tensorflow/solutionsteam_tf_release_supported_config.py +++ b/dags/solutions_team/configs/tensorflow/solutionsteam_tf_release_supported_config.py @@ -18,11 +18,12 @@ import datetime import time from datetime import date +from dags.common import test_owner from xlml.apis import gcp_config, metric_config, task, test_config -from dags import gcs_bucket, test_owner +from dags import gcs_bucket from dags.solutions_team.configs.tensorflow import common from airflow.models import Variable -from dags.vm_resource import TpuVersion, Project, RuntimeVersion +from dags.common.vm_resource import TpuVersion, Project, RuntimeVersion MAJOR_VERSION = "2" diff --git a/dags/solutions_team/configs/vllm/vllm_benchmark_config.py b/dags/solutions_team/configs/vllm/vllm_benchmark_config.py index b822e1b3..5623428f 100644 --- a/dags/solutions_team/configs/vllm/vllm_benchmark_config.py +++ b/dags/solutions_team/configs/vllm/vllm_benchmark_config.py @@ -21,9 +21,9 @@ from typing import Dict from xlml.apis import gcp_config, metric_config, task, test_config from airflow.models import Variable -from dags import test_owner +from dags.common import test_owner from dags.multipod.configs import common -from dags.vm_resource import MachineVersion, ImageFamily, ImageProject, GpuVersion, TpuVersion, Project, RuntimeVersion, Zone +from dags.common.vm_resource import MachineVersion, ImageFamily, ImageProject, GpuVersion, TpuVersion, Project, RuntimeVersion, Zone PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value diff --git a/dags/solutions_team/solutionsteam_tf_dlrm_benchmarks.py b/dags/solutions_team/solutionsteam_tf_dlrm_benchmarks.py index 3cb04331..b211fabd 100644 --- a/dags/solutions_team/solutionsteam_tf_dlrm_benchmarks.py +++ b/dags/solutions_team/solutionsteam_tf_dlrm_benchmarks.py @@ -18,7 +18,7 @@ import datetime from airflow import models from dags import composer_env -from dags.vm_resource import TpuVersion, Project, Zone, RuntimeVersion, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS +from dags.common.vm_resource import TpuVersion, Project, Zone, RuntimeVersion, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS from dags.solutions_team.configs.tensorflow import solutionsteam_tf_release_supported_config as tf_config from dags.solutions_team.configs.tensorflow import common diff --git a/dags/solutions_team/solutionsteam_tf_nightly_supported.py b/dags/solutions_team/solutionsteam_tf_nightly_supported.py index 5e8c72df..2e7c64e0 100644 --- a/dags/solutions_team/solutionsteam_tf_nightly_supported.py +++ b/dags/solutions_team/solutionsteam_tf_nightly_supported.py @@ -18,7 +18,7 @@ import time from airflow import models from dags import composer_env -from dags.vm_resource import TpuVersion, Project, Zone, RuntimeVersion, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS +from dags.common.vm_resource import TpuVersion, Project, Zone, RuntimeVersion, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS from dags.solutions_team.configs.tensorflow import solutionsteam_tf_nightly_supported_config as tf_config from dags.solutions_team.configs.tensorflow import common diff --git a/dags/solutions_team/solutionsteam_tf_release_se_supported.py b/dags/solutions_team/solutionsteam_tf_release_se_supported.py index dc93eb38..5cecae3a 100644 --- a/dags/solutions_team/solutionsteam_tf_release_se_supported.py +++ b/dags/solutions_team/solutionsteam_tf_release_se_supported.py @@ -18,7 +18,7 @@ import time from airflow import models from dags import composer_env -from dags.vm_resource import TpuVersion, Project, Zone, RuntimeVersion, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS +from dags.common.vm_resource import TpuVersion, Project, Zone, RuntimeVersion, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS from dags.solutions_team.configs.tensorflow import solutionsteam_tf_release_supported_config as tf_config from dags.solutions_team.configs.tensorflow import common diff --git a/dags/solutions_team/solutionsteam_tf_release_supported.py b/dags/solutions_team/solutionsteam_tf_release_supported.py index 10a33bf0..3880a9c7 100644 --- a/dags/solutions_team/solutionsteam_tf_release_supported.py +++ b/dags/solutions_team/solutionsteam_tf_release_supported.py @@ -18,7 +18,7 @@ import time from airflow import models from dags import composer_env -from dags.vm_resource import TpuVersion, Project, Zone, RuntimeVersion, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS +from dags.common.vm_resource import TpuVersion, Project, Zone, RuntimeVersion, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS from dags.solutions_team.configs.tensorflow import solutionsteam_tf_release_supported_config as tf_config from dags.solutions_team.configs.tensorflow import common diff --git a/dags/solutions_team/solutionsteam_vllm_benchmarks.py b/dags/solutions_team/solutionsteam_vllm_benchmarks.py index 6c6d0605..c300a8f4 100644 --- a/dags/solutions_team/solutionsteam_vllm_benchmarks.py +++ b/dags/solutions_team/solutionsteam_vllm_benchmarks.py @@ -4,8 +4,9 @@ import enum from airflow import models from airflow.models.baseoperator import chain -from dags import composer_env, test_owner -from dags.vm_resource import AcceleratorType, GpuVersion, TpuVersion, Region, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, BM_NETWORKS, A100_BM_SUBNETWORKS, ImageProject, ImageFamily, MachineVersion, RuntimeVersion +from dags import composer_env +from dags.common import test_owner +from dags.common.vm_resource import AcceleratorType, GpuVersion, TpuVersion, Region, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, BM_NETWORKS, A100_BM_SUBNETWORKS, ImageProject, ImageFamily, MachineVersion, RuntimeVersion from dags.multipod.configs.common import SetupMode, Platform from dags.solutions_team.configs.vllm import vllm_benchmark_config diff --git a/dags/sparsity_diffusion_devx/configs/gke_config.py b/dags/sparsity_diffusion_devx/configs/gke_config.py index 3c8429bd..8eda3f99 100644 --- a/dags/sparsity_diffusion_devx/configs/gke_config.py +++ b/dags/sparsity_diffusion_devx/configs/gke_config.py @@ -14,10 +14,11 @@ """Utilities to construct configs for solutionsteam_jax_bite DAG.""" +from dags.common import test_owner from xlml.apis import gcp_config, metric_config, task, test_config from xlml.apis.xpk_cluster_config import XpkClusterConfig -from dags import test_owner, gcs_bucket -from dags.vm_resource import TpuVersion, Project, XpkClusters, GpuVersion, CpuVersion, Zone +from dags import gcs_bucket +from dags.common.vm_resource import TpuVersion, Project, XpkClusters, GpuVersion, CpuVersion, Zone from typing import Iterable import datetime diff --git a/dags/sparsity_diffusion_devx/configs/project_bite_config.py b/dags/sparsity_diffusion_devx/configs/project_bite_config.py index 65aff62c..542cb908 100644 --- a/dags/sparsity_diffusion_devx/configs/project_bite_config.py +++ b/dags/sparsity_diffusion_devx/configs/project_bite_config.py @@ -17,10 +17,11 @@ import datetime from typing import Tuple, Optional +from dags.common import test_owner from xlml.apis import gcp_config, metric_config, task, test_config -from dags import gcs_bucket, test_owner +from dags import gcs_bucket from dags.sparsity_diffusion_devx.configs import common -from dags.vm_resource import TpuVersion, Project +from dags.common.vm_resource import TpuVersion, Project from airflow.models.taskmixin import DAGNode diff --git a/dags/sparsity_diffusion_devx/jax_stable_stack_gpu_e2e.py b/dags/sparsity_diffusion_devx/jax_stable_stack_gpu_e2e.py index 6719b364..d353e5cc 100644 --- a/dags/sparsity_diffusion_devx/jax_stable_stack_gpu_e2e.py +++ b/dags/sparsity_diffusion_devx/jax_stable_stack_gpu_e2e.py @@ -17,8 +17,9 @@ import datetime from airflow import models -from dags import composer_env, test_owner, gcs_bucket -from dags.vm_resource import Project, TpuVersion, CpuVersion, Zone, DockerImage, GpuVersion, XpkClusters +from dags import composer_env, gcs_bucket +from dags.common import test_owner +from dags.common.vm_resource import Project, TpuVersion, CpuVersion, Zone, DockerImage, GpuVersion, XpkClusters from airflow.utils.task_group import TaskGroup from dags.sparsity_diffusion_devx.configs import gke_config as config from xlml.utils import name_format diff --git a/dags/sparsity_diffusion_devx/jax_stable_stack_tpu_e2e.py b/dags/sparsity_diffusion_devx/jax_stable_stack_tpu_e2e.py index 572a6c32..761e4261 100644 --- a/dags/sparsity_diffusion_devx/jax_stable_stack_tpu_e2e.py +++ b/dags/sparsity_diffusion_devx/jax_stable_stack_tpu_e2e.py @@ -18,8 +18,9 @@ import datetime from airflow import models from airflow.utils.task_group import TaskGroup -from dags import composer_env, test_owner, gcs_bucket -from dags.vm_resource import Project, TpuVersion, CpuVersion, Zone, DockerImage, GpuVersion, XpkClusters +from dags import composer_env, gcs_bucket +from dags.common import test_owner +from dags.common.vm_resource import Project, TpuVersion, CpuVersion, Zone, DockerImage, GpuVersion, XpkClusters from dags.sparsity_diffusion_devx.configs import gke_config as config from dags.multipod.configs.common import SetupMode from xlml.utils import name_format diff --git a/dags/sparsity_diffusion_devx/maxdiffusion_e2e.py b/dags/sparsity_diffusion_devx/maxdiffusion_e2e.py index 3c9c7311..072ef970 100644 --- a/dags/sparsity_diffusion_devx/maxdiffusion_e2e.py +++ b/dags/sparsity_diffusion_devx/maxdiffusion_e2e.py @@ -18,8 +18,9 @@ import datetime from airflow import models from airflow.utils.task_group import TaskGroup -from dags import composer_env, test_owner, gcs_bucket -from dags.vm_resource import Project, TpuVersion, CpuVersion, Zone, DockerImage, GpuVersion, XpkClusters +from dags import composer_env, gcs_bucket +from dags.common import test_owner +from dags.common.vm_resource import Project, TpuVersion, CpuVersion, Zone, DockerImage, GpuVersion, XpkClusters from dags.sparsity_diffusion_devx.configs import gke_config as config from xlml.utils import name_format diff --git a/dags/sparsity_diffusion_devx/maxtext_moe_tpu_e2e.py b/dags/sparsity_diffusion_devx/maxtext_moe_tpu_e2e.py index 3860eb54..a926abc5 100644 --- a/dags/sparsity_diffusion_devx/maxtext_moe_tpu_e2e.py +++ b/dags/sparsity_diffusion_devx/maxtext_moe_tpu_e2e.py @@ -18,9 +18,10 @@ import datetime from airflow import models from airflow.utils.task_group import TaskGroup -from dags import composer_env, test_owner -from dags.quarantined_tests import QuarantineTests -from dags.vm_resource import XpkClusters, DockerImage +from dags import composer_env +from dags.common.quarantined_tests import QuarantineTests +from dags.common import test_owner +from dags.common.vm_resource import XpkClusters, DockerImage from dags.multipod.configs import gke_config from xlml.utils import name_format diff --git a/dags/sparsity_diffusion_devx/project_bite_gpu_e2e.py b/dags/sparsity_diffusion_devx/project_bite_gpu_e2e.py index 8a4c74ba..7d7809a7 100644 --- a/dags/sparsity_diffusion_devx/project_bite_gpu_e2e.py +++ b/dags/sparsity_diffusion_devx/project_bite_gpu_e2e.py @@ -17,8 +17,9 @@ import datetime from airflow import models -from dags import composer_env, test_owner, gcs_bucket -from dags.vm_resource import DockerImage, XpkClusters +from dags import composer_env, gcs_bucket +from dags.common import test_owner +from dags.common.vm_resource import DockerImage, XpkClusters from dags.sparsity_diffusion_devx.configs import gke_config as config from xlml.utils import name_format diff --git a/dags/sparsity_diffusion_devx/project_bite_tpu_e2e.py b/dags/sparsity_diffusion_devx/project_bite_tpu_e2e.py index 804b2f32..7103cc23 100644 --- a/dags/sparsity_diffusion_devx/project_bite_tpu_e2e.py +++ b/dags/sparsity_diffusion_devx/project_bite_tpu_e2e.py @@ -16,8 +16,9 @@ import datetime from airflow import models -from dags import composer_env, test_owner -from dags.vm_resource import TpuVersion, Zone, RuntimeVersion +from dags import composer_env +from dags.common import test_owner +from dags.common.vm_resource import TpuVersion, Zone, RuntimeVersion from dags.sparsity_diffusion_devx.configs import project_bite_config as config diff --git a/xlml/apis/gcp_config.py b/xlml/apis/gcp_config.py index 501ffcf6..78a96d3a 100644 --- a/xlml/apis/gcp_config.py +++ b/xlml/apis/gcp_config.py @@ -16,7 +16,7 @@ import dataclasses -from dags.vm_resource import Project +from dags.common.vm_resource import Project from xlml.apis import metric_config diff --git a/xlml/apis/task.py b/xlml/apis/task.py index 117bc3c8..5e721cee 100644 --- a/xlml/apis/task.py +++ b/xlml/apis/task.py @@ -18,7 +18,7 @@ import dataclasses import datetime import shlex -from dags.quarantined_tests import QuarantineTests +from dags.common.quarantined_tests import QuarantineTests from typing import Optional, Tuple, Union import airflow from airflow.models.taskmixin import DAGNode diff --git a/xlml/apis/test_config.py b/xlml/apis/test_config.py index ea734a72..e6d724cc 100644 --- a/xlml/apis/test_config.py +++ b/xlml/apis/test_config.py @@ -51,7 +51,7 @@ def __init__(self, accelerator, task_owner=None, test_name): import attrs import datetime -from dags.vm_resource import TpuVersion, CpuVersion +from dags.common.vm_resource import TpuVersion, CpuVersion class Accelerator(abc.ABC): diff --git a/xlml/utils/metric_test.py b/xlml/utils/metric_test.py index 02a01ecd..5cea794d 100644 --- a/xlml/utils/metric_test.py +++ b/xlml/utils/metric_test.py @@ -27,7 +27,7 @@ from xlml.utils import bigquery, composer, metric import jsonlines import tensorflow as tf -from dags.vm_resource import TpuVersion, RuntimeVersion +from dags.common.vm_resource import TpuVersion, RuntimeVersion class BenchmarkMetricTest(parameterized.TestCase, absltest.TestCase): diff --git a/xlml/utils/xpk.py b/xlml/utils/xpk.py index 19f5c6a1..a5317bc4 100644 --- a/xlml/utils/xpk.py +++ b/xlml/utils/xpk.py @@ -24,7 +24,7 @@ from kubernetes import client as k8s_client from xlml.apis import metric_config from xlml.utils import gke -from dags.vm_resource import GpuVersion +from dags.common.vm_resource import GpuVersion WORKLOAD_URL_FORMAT = "https://console.cloud.google.com/kubernetes/service/{region}/{cluster}/default/{workload_id}/details?project={project}"