|
21 | 21 | """ |
22 | 22 |
|
23 | 23 | import os |
| 24 | +import sys |
24 | 25 |
|
25 | | -import args_helper as helper |
| 26 | +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| 27 | +sys.path.append(parent_dir) |
| 28 | +from . import args_helper as helper |
| 29 | +from . import user_configs |
26 | 30 |
|
27 | | -from benchmarks.disruption_management.disruption_handler import DisruptionConfig |
28 | 31 | from benchmarks.disruption_management.disruption_handler import DisruptionMethod |
29 | | -from benchmarks.disruption_management.disruption_handler import MCJAX_STANDARD_TARGET_POD_REGEX_SUFFIX |
30 | | -from benchmarks.disruption_management.disruption_handler import MCJAX_WORKER_CONTAINER_NAME |
31 | | -from benchmarks.disruption_management.disruption_handler import PATHWAYS_STANDARD_TARGET_POD_REGEX_SUFFIX |
32 | | -from benchmarks.disruption_management.disruption_handler import PATHWAYS_WORKER_CONTAINER_NAME |
33 | | -from benchmarks.disruption_management.disruption_handler import TriggerType |
34 | | -from benchmarks.maxtext_trillium_model_configs import MaxTextModel |
35 | | -from benchmarks import maxtext_v5e_model_configs as v5e_model_configs |
36 | | -from benchmarks import maxtext_xpk_runner as mxr |
37 | | -from benchmarks.xpk_configs import XpkClusterConfig |
| 32 | +from .runner_utils import generate_and_run_workloads |
38 | 33 |
|
39 | | -PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server" |
40 | | -SERVER_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server" |
41 | | -RUNNER = "us-docker.pkg.dev/path/to/maxtext_runner" |
42 | | - |
43 | | -# Cluster Params |
44 | | -CLUSTER = "v6e-256-cluster" |
45 | | -PROJECT = "tpu-prod-env-cluster" |
46 | | -ZONE = "us-east5-b" |
47 | | -COUNTRY = "us" |
48 | | -DEVICE_TYPE = "v6e-256" |
49 | | - |
50 | | -# Other parameters (MUST BE SET BY USER) |
51 | | -XPK_PATH = "../xpk" # We're running this script from the maxtext directory |
52 | | -USER = os.environ["USER"] |
53 | | -BASE_OUTPUT_DIRECTORY = f"gs://{USER}-{PROJECT}-{COUNTRY}/disruption_management/" |
54 | | -MAX_RESTARTS = 10 |
55 | | -NUM_SLICES = 2 |
56 | | -BENCHMARK_STEPS = 101 |
| 34 | +user_configs.USER_CONFIG.max_restarts = 10 |
57 | 35 | COMPARE_WITH_MCJAX = True |
58 | 36 |
|
59 | | - |
60 | | -# Do 2 total disruptions, once after 2 minutes and once after 6 minutes. |
61 | | -def construct_disruption_configs( |
62 | | - pathways_config: mxr.PathwaysConfig, |
63 | | -) -> list[DisruptionConfig]: |
64 | | - """Constructs the disruption configs for the benchmark.""" |
65 | | - |
66 | | - if pathways_config: |
67 | | - target_pod_regex = PATHWAYS_STANDARD_TARGET_POD_REGEX_SUFFIX |
68 | | - worker_container_name = PATHWAYS_WORKER_CONTAINER_NAME |
69 | | - else: |
70 | | - target_pod_regex = MCJAX_STANDARD_TARGET_POD_REGEX_SUFFIX |
71 | | - worker_container_name = MCJAX_WORKER_CONTAINER_NAME |
72 | | - |
73 | | - # Do 2 total disruptions, once after 2 minutes and once after 6 minutes. |
74 | | - return [ |
75 | | - DisruptionConfig( |
76 | | - name="sigill_2min", |
77 | | - trigger_type=TriggerType.TIME_SECONDS, |
78 | | - trigger_value=2 * 60, # 2 minutes |
79 | | - disruption_method=DisruptionMethod.SIGILL, |
80 | | - target_pod_regex=target_pod_regex, |
81 | | - worker_container_name=worker_container_name, |
82 | | - ), |
83 | | - DisruptionConfig( |
84 | | - name="sigill_6min", |
85 | | - trigger_type=TriggerType.TIME_SECONDS, |
86 | | - trigger_value=6 * 60, # 6 minutes |
87 | | - disruption_method=DisruptionMethod.SIGILL, |
88 | | - target_pod_regex=target_pod_regex, |
89 | | - worker_container_name=worker_container_name, |
90 | | - ), |
91 | | - ] |
92 | | - |
93 | | - |
94 | | -def construct_workload_config_with_disruptions( |
95 | | - cluster_config: XpkClusterConfig, |
96 | | - model: MaxTextModel, |
97 | | - pathways_config: mxr.PathwaysConfig = None, |
98 | | -) -> list[mxr.WorkloadConfig]: |
99 | | - """Constructs the workload configs for the benchmark.""" |
100 | | - return mxr.WorkloadConfig( |
101 | | - model=model, |
102 | | - num_slices=NUM_SLICES, |
103 | | - device_type=cluster_config.device_type, |
104 | | - base_output_directory=BASE_OUTPUT_DIRECTORY, |
105 | | - max_restarts=MAX_RESTARTS, |
106 | | - libtpu_type=None, |
107 | | - libtpu_nightly_version="", |
108 | | - base_docker_image=RUNNER, |
109 | | - pathways_config=pathways_config, |
110 | | - xpk_path=XPK_PATH, |
111 | | - num_steps=BENCHMARK_STEPS, |
112 | | - disruption_configs=construct_disruption_configs(pathways_config), |
113 | | - ) |
| 37 | +DISRUPTION_METHOD = DisruptionMethod.SIGILL |
| 38 | +DISRUPTIONS = { |
| 39 | + "time_seconds": [120, 600], |
| 40 | + # "step":[3] |
| 41 | +} |
114 | 42 |
|
115 | 43 |
|
116 | 44 | def main() -> None: |
117 | 45 | """Main function to run the elastic training disruption test.""" |
118 | | - |
119 | | - # Cluster Configuration |
120 | | - cluster_config = XpkClusterConfig( |
121 | | - cluster_name=CLUSTER, |
122 | | - project=PROJECT, |
123 | | - zone=ZONE, |
124 | | - device_type=DEVICE_TYPE, |
| 46 | + user_configs.USER_CONFIG.headless = False |
| 47 | + should_continue = helper.handle_cmd_args( |
| 48 | + user_configs.USER_CONFIG.cluster_config, helper.DELETE, xpk_path=user_configs.USER_CONFIG.xpk_path |
125 | 49 | ) |
126 | 50 |
|
127 | | - # Handle command line arguments using args_helper |
128 | | - should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, xpk_path=XPK_PATH) |
129 | | - |
130 | 51 | if not should_continue: |
131 | | - return |
132 | | - |
133 | | - # Model Configuration - Using a simple default model for testing |
134 | | - model = v5e_model_configs.llama3_1_8b_8192 |
135 | | - |
136 | | - pathways_config = mxr.PathwaysConfig( |
137 | | - server_image=SERVER_IMAGE, |
138 | | - proxy_server_image=PROXY_IMAGE, |
139 | | - runner_image=RUNNER, |
140 | | - # User can add additional flags here. |
141 | | - server_flags="--enable_metrics_collection=false", |
142 | | - proxy_flags="--enable_metrics_collection=false", |
143 | | - worker_flags="--enable_metrics_collection=false", |
| 52 | + return 0 |
| 53 | + |
| 54 | + return_code = generate_and_run_workloads( |
| 55 | + user_configs.USER_CONFIG, |
| 56 | + user_configs.USER_CONFIG.num_slices_list, |
| 57 | + user_configs.USER_CONFIG.benchmark_steps, |
| 58 | + user_configs.USER_CONFIG.priority, |
| 59 | + DISRUPTION_METHOD, |
| 60 | + DISRUPTIONS, |
144 | 61 | ) |
145 | 62 |
|
146 | | - # Pathways Workload Configuration with Disruption |
147 | | - workload_configs = [] |
148 | | - pathways_workload_config = construct_workload_config_with_disruptions(cluster_config, model, pathways_config) |
149 | | - workload_configs.append(pathways_workload_config) |
150 | | - |
151 | | - if COMPARE_WITH_MCJAX: |
152 | | - # Add a workload config for MCJAX |
153 | | - mcjax_workload_config = construct_workload_config_with_disruptions(cluster_config, model, None) |
154 | | - workload_configs.append(mcjax_workload_config) |
155 | | - |
156 | | - # Run the benchmark and use the returned disruption manager. |
157 | | - disruption_manager = mxr.xpk_benchmark_runner( |
158 | | - cluster_config=cluster_config, |
159 | | - workload_configs=workload_configs, |
160 | | - ) |
161 | | - |
162 | | - # Wait for disruptions to complete |
163 | | - disruption_manager.start_disruptions_and_wait_for_completion() |
164 | | - |
165 | 63 | print("Elastic Training disruptions completed. Please check logs for results.") |
166 | 64 |
|
| 65 | + return return_code |
| 66 | + |
167 | 67 |
|
168 | 68 | if __name__ == "__main__": |
169 | 69 | main() |
0 commit comments