Skip to content

Commit be7c2de

Browse files
Merge pull request #2431 from AI-Hypercomputer:chzheng/disruption_manager
PiperOrigin-RevId: 829461550
2 parents 69ed0c5 + b65dfb7 commit be7c2de

File tree

6 files changed

+180
-283
lines changed

6 files changed

+180
-283
lines changed

benchmarks/benchmark_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222
import dataclasses
2323
import typing
2424

25+
from enum import Enum
26+
27+
28+
class Framework(Enum):
29+
PATHWAYS = "pathways"
30+
MCJAX = "mcjax"
31+
2532

2633
def str2bool(v: str) -> bool:
2734
"""Convert a string of truth to True or False.

benchmarks/disruption_management/disruption_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
MCJAX_STANDARD_TARGET_POD_REGEX_SUFFIX = ".*slice-job-0-0.*"
3434
MCJAX_STANDARD_STEP_POD_REGEX_SUFFIX = ".*slice-job-0-0.*"
3535
PATHWAYS_STANDARD_TARGET_POD_REGEX_SUFFIX = ".*worker-0-0.*"
36-
PATHWAYS_STANDARD_STEP_POD_REGEX_SUFFIX = ".*main-0-0.*"
36+
PATHWAYS_STANDARD_STEP_POD_REGEX_SUFFIX = ".*head-0-0.*"
3737

3838
PATHWAYS_WORKER_CONTAINER_NAME = "pathways-worker"
3939
MCJAX_WORKER_CONTAINER_NAME = "jax-tpu"

benchmarks/disruption_management/disruption_manager.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,18 @@
2424
from collections import defaultdict
2525
import threading
2626

27+
from benchmarks.benchmark_utils import Framework
2728
from benchmarks.disruption_management.disruption_handler import create_disruption_handler
2829
from benchmarks.disruption_management.disruption_handler import DisruptionConfig
2930
from benchmarks.disruption_management.disruption_handler import DisruptionHandler
31+
from benchmarks.disruption_management.disruption_handler import DisruptionMethod
32+
from benchmarks.disruption_management.disruption_handler import MCJAX_STANDARD_TARGET_POD_REGEX_SUFFIX
33+
from benchmarks.disruption_management.disruption_handler import MCJAX_STANDARD_STEP_POD_REGEX_SUFFIX
34+
from benchmarks.disruption_management.disruption_handler import MCJAX_WORKER_CONTAINER_NAME
35+
from benchmarks.disruption_management.disruption_handler import PATHWAYS_STANDARD_TARGET_POD_REGEX_SUFFIX
36+
from benchmarks.disruption_management.disruption_handler import PATHWAYS_STANDARD_STEP_POD_REGEX_SUFFIX
37+
from benchmarks.disruption_management.disruption_handler import PATHWAYS_WORKER_CONTAINER_NAME
38+
from benchmarks.disruption_management.disruption_handler import TriggerType
3039
from benchmarks.disruption_management.monitor import create_monitor
3140
from benchmarks.disruption_management.monitor import Monitor
3241
from benchmarks.xpk_configs import XpkClusterConfig
@@ -131,3 +140,36 @@ def _monitor_and_disrupt_workload(
131140
def _monitor_recovery(self) -> None:
132141
"""Monitors for recovery trigger and initiates recovery."""
133142
raise NotImplementedError("Recovery not implemented yet.")
143+
144+
145+
def construct_disruption_configs(
146+
framework: str,
147+
disruption_method: DisruptionMethod,
148+
disruptions,
149+
) -> list[DisruptionConfig]:
150+
"""Constructs the disruption configs for the benchmark."""
151+
152+
if Framework(framework) == Framework.PATHWAYS:
153+
target_pod_regex = PATHWAYS_STANDARD_TARGET_POD_REGEX_SUFFIX
154+
step_pod_regex = PATHWAYS_STANDARD_STEP_POD_REGEX_SUFFIX
155+
worker_container_name = PATHWAYS_WORKER_CONTAINER_NAME
156+
else:
157+
target_pod_regex = MCJAX_STANDARD_TARGET_POD_REGEX_SUFFIX
158+
step_pod_regex = MCJAX_STANDARD_STEP_POD_REGEX_SUFFIX
159+
worker_container_name = MCJAX_WORKER_CONTAINER_NAME
160+
161+
disruption_config_list = []
162+
for trigger_type, trigger_values in disruptions.items():
163+
for trigger_value in trigger_values:
164+
disruption_config_list.append(
165+
DisruptionConfig(
166+
name="_".join([str(trigger_value), trigger_type]),
167+
trigger_type=TriggerType.TIME_SECONDS if trigger_type == "time_seconds" else TriggerType.STEP,
168+
trigger_value=trigger_value,
169+
disruption_method=disruption_method,
170+
target_pod_regex=target_pod_regex,
171+
step_pod_regex=step_pod_regex,
172+
worker_container_name=worker_container_name,
173+
)
174+
)
175+
return disruption_config_list

benchmarks/recipes/pw_elastic_training_recipe.py

Lines changed: 26 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -21,149 +21,49 @@
2121
"""
2222

2323
import os
24+
import sys
2425

25-
import args_helper as helper
26+
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
27+
sys.path.append(parent_dir)
28+
from . import args_helper as helper
29+
from . import user_configs
2630

27-
from benchmarks.disruption_management.disruption_handler import DisruptionConfig
2831
from benchmarks.disruption_management.disruption_handler import DisruptionMethod
29-
from benchmarks.disruption_management.disruption_handler import MCJAX_STANDARD_TARGET_POD_REGEX_SUFFIX
30-
from benchmarks.disruption_management.disruption_handler import MCJAX_WORKER_CONTAINER_NAME
31-
from benchmarks.disruption_management.disruption_handler import PATHWAYS_STANDARD_TARGET_POD_REGEX_SUFFIX
32-
from benchmarks.disruption_management.disruption_handler import PATHWAYS_WORKER_CONTAINER_NAME
33-
from benchmarks.disruption_management.disruption_handler import TriggerType
34-
from benchmarks.maxtext_trillium_model_configs import MaxTextModel
35-
from benchmarks import maxtext_v5e_model_configs as v5e_model_configs
36-
from benchmarks import maxtext_xpk_runner as mxr
37-
from benchmarks.xpk_configs import XpkClusterConfig
32+
from .runner_utils import generate_and_run_workloads
3833

39-
PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server"
40-
SERVER_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server"
41-
RUNNER = "us-docker.pkg.dev/path/to/maxtext_runner"
42-
43-
# Cluster Params
44-
CLUSTER = "v6e-256-cluster"
45-
PROJECT = "tpu-prod-env-cluster"
46-
ZONE = "us-east5-b"
47-
COUNTRY = "us"
48-
DEVICE_TYPE = "v6e-256"
49-
50-
# Other parameters (MUST BE SET BY USER)
51-
XPK_PATH = "../xpk" # We're running this script from the maxtext directory
52-
USER = os.environ["USER"]
53-
BASE_OUTPUT_DIRECTORY = f"gs://{USER}-{PROJECT}-{COUNTRY}/disruption_management/"
54-
MAX_RESTARTS = 10
55-
NUM_SLICES = 2
56-
BENCHMARK_STEPS = 101
34+
user_configs.USER_CONFIG.max_restarts = 10
5735
COMPARE_WITH_MCJAX = True
5836

59-
60-
# Do 2 total disruptions, once after 2 minutes and once after 6 minutes.
61-
def construct_disruption_configs(
62-
pathways_config: mxr.PathwaysConfig,
63-
) -> list[DisruptionConfig]:
64-
"""Constructs the disruption configs for the benchmark."""
65-
66-
if pathways_config:
67-
target_pod_regex = PATHWAYS_STANDARD_TARGET_POD_REGEX_SUFFIX
68-
worker_container_name = PATHWAYS_WORKER_CONTAINER_NAME
69-
else:
70-
target_pod_regex = MCJAX_STANDARD_TARGET_POD_REGEX_SUFFIX
71-
worker_container_name = MCJAX_WORKER_CONTAINER_NAME
72-
73-
# Do 2 total disruptions, once after 2 minutes and once after 6 minutes.
74-
return [
75-
DisruptionConfig(
76-
name="sigill_2min",
77-
trigger_type=TriggerType.TIME_SECONDS,
78-
trigger_value=2 * 60, # 2 minutes
79-
disruption_method=DisruptionMethod.SIGILL,
80-
target_pod_regex=target_pod_regex,
81-
worker_container_name=worker_container_name,
82-
),
83-
DisruptionConfig(
84-
name="sigill_6min",
85-
trigger_type=TriggerType.TIME_SECONDS,
86-
trigger_value=6 * 60, # 6 minutes
87-
disruption_method=DisruptionMethod.SIGILL,
88-
target_pod_regex=target_pod_regex,
89-
worker_container_name=worker_container_name,
90-
),
91-
]
92-
93-
94-
def construct_workload_config_with_disruptions(
95-
cluster_config: XpkClusterConfig,
96-
model: MaxTextModel,
97-
pathways_config: mxr.PathwaysConfig = None,
98-
) -> list[mxr.WorkloadConfig]:
99-
"""Constructs the workload configs for the benchmark."""
100-
return mxr.WorkloadConfig(
101-
model=model,
102-
num_slices=NUM_SLICES,
103-
device_type=cluster_config.device_type,
104-
base_output_directory=BASE_OUTPUT_DIRECTORY,
105-
max_restarts=MAX_RESTARTS,
106-
libtpu_type=None,
107-
libtpu_nightly_version="",
108-
base_docker_image=RUNNER,
109-
pathways_config=pathways_config,
110-
xpk_path=XPK_PATH,
111-
num_steps=BENCHMARK_STEPS,
112-
disruption_configs=construct_disruption_configs(pathways_config),
113-
)
37+
DISRUPTION_METHOD = DisruptionMethod.SIGILL
38+
DISRUPTIONS = {
39+
"time_seconds": [120, 600],
40+
# "step":[3]
41+
}
11442

11543

11644
def main() -> None:
11745
"""Main function to run the elastic training disruption test."""
118-
119-
# Cluster Configuration
120-
cluster_config = XpkClusterConfig(
121-
cluster_name=CLUSTER,
122-
project=PROJECT,
123-
zone=ZONE,
124-
device_type=DEVICE_TYPE,
46+
user_configs.USER_CONFIG.headless = False
47+
should_continue = helper.handle_cmd_args(
48+
user_configs.USER_CONFIG.cluster_config, helper.DELETE, xpk_path=user_configs.USER_CONFIG.xpk_path
12549
)
12650

127-
# Handle command line arguments using args_helper
128-
should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, xpk_path=XPK_PATH)
129-
13051
if not should_continue:
131-
return
132-
133-
# Model Configuration - Using a simple default model for testing
134-
model = v5e_model_configs.llama3_1_8b_8192
135-
136-
pathways_config = mxr.PathwaysConfig(
137-
server_image=SERVER_IMAGE,
138-
proxy_server_image=PROXY_IMAGE,
139-
runner_image=RUNNER,
140-
# User can add additional flags here.
141-
server_flags="--enable_metrics_collection=false",
142-
proxy_flags="--enable_metrics_collection=false",
143-
worker_flags="--enable_metrics_collection=false",
52+
return 0
53+
54+
return_code = generate_and_run_workloads(
55+
user_configs.USER_CONFIG,
56+
user_configs.USER_CONFIG.num_slices_list,
57+
user_configs.USER_CONFIG.benchmark_steps,
58+
user_configs.USER_CONFIG.priority,
59+
DISRUPTION_METHOD,
60+
DISRUPTIONS,
14461
)
14562

146-
# Pathways Workload Configuration with Disruption
147-
workload_configs = []
148-
pathways_workload_config = construct_workload_config_with_disruptions(cluster_config, model, pathways_config)
149-
workload_configs.append(pathways_workload_config)
150-
151-
if COMPARE_WITH_MCJAX:
152-
# Add a workload config for MCJAX
153-
mcjax_workload_config = construct_workload_config_with_disruptions(cluster_config, model, None)
154-
workload_configs.append(mcjax_workload_config)
155-
156-
# Run the benchmark and use the returned disruption manager.
157-
disruption_manager = mxr.xpk_benchmark_runner(
158-
cluster_config=cluster_config,
159-
workload_configs=workload_configs,
160-
)
161-
162-
# Wait for disruptions to complete
163-
disruption_manager.start_disruptions_and_wait_for_completion()
164-
16563
print("Elastic Training disruptions completed. Please check logs for results.")
16664

65+
return return_code
66+
16767

16868
if __name__ == "__main__":
16969
main()

0 commit comments

Comments
 (0)