From 7e619b7705004e098cccec9501fcf162c74a470b Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Tue, 10 Dec 2024 04:04:36 +0000 Subject: [PATCH] Add Pathways Support to Benchmark Runner --- benchmarks/benchmark_runner.py | 67 ++++++++++++- benchmarks/maxtext_trillium_model_configs.py | 80 ++++++++++++++++ benchmarks/maxtext_xpk_runner.py | 98 ++++++++++++++------ 3 files changed, 217 insertions(+), 28 deletions(-) diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index 909c36276..50653b8e6 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -28,6 +28,7 @@ from maxtext_xpk_runner import BenchmarkRunner from maxtext_xpk_runner import HWConfig from maxtext_xpk_runner import SWconfig +from maxtext_xpk_runner import PathwaysConfig from maxtext_xpk_runner import xpk_benchmark_runner from maxtext_xpk_runner import XpkConfig @@ -86,6 +87,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_7b_4096', 'llama2_70b_4096', 'llama2_70b_4096_real_data', + 'llama2_70b_4096_pw_long_run', + 'llama2_70b_4096_real_data_pw_long_run', 'llama3_70b_8192', 'llama3_1_405b_8192_fsdp_dcn', 'mixtral_8x7b_dropped', @@ -101,6 +104,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_7b_4096 ' 'llama2_70b_4096 ' 'llama2_70b_4096_real_data ' + 'llama2_70b_4096_pw_long_run ' + 'llama2_70b_4096_real_data_pw_long_run ' 'llama3_1_405b_8192_fsdp_dcn ' 'mixtral_8x7b_dropped ' 'mixtral_8x7b_dropped_int8 ' @@ -122,6 +127,57 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): default='maxtext_base_image', help='version of base docker image to be benchmarked command.', ) + custom_parser.add_argument( + '--pathways_server_image', + type=str, + default=( + 'us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest' + ), + help='version of pathways server image to be benchmarked command.', + ) + custom_parser.add_argument( + '--pathways_proxy_image', + type=str, + default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest', + help='version of pathways proxy image to be benchmarked command.', + ) + custom_parser.add_argument( + '--pathways_runner_image', + type=str, + default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest', + help='version of pathways runner image to be benchmarked command.', + ) + custom_parser.add_argument( + '--use_pathways', + type=bool, + default=False, + help='whether to use pathways or not.', + ) + custom_parser.add_argument( + '--xpk_path', + type=str, + default='~/xpk', + help='path to xpk dir.', + ) + custom_parser.add_argument( + '--num_steps', + type=int, + default='20', + help='Number of steps to run benchmark for.', + ) + custom_parser.add_argument( + '--priority', + type=str, + default='medium', + help='Priority the XPK workload should run with.', + ) + custom_parser.add_argument( + '--max_restarts', + type=int, + default=0, + help='Number of restarts to attempt.', + ) + def main() -> None: parser = argparse.ArgumentParser( @@ -137,11 +193,19 @@ def main() -> None: num_slices=options.num_slices, device_type=options.device_type, base_output_directory=options.base_output_directory, + priority=options.priority, + max_restarts=options.max_restarts, ) v6e_env_configs = SWconfig( base_docker_image=options.base_docker_image, libtpu_version=options.libtpu_version, + pathways_config=PathwaysConfig( + use_pathways=options.use_pathways, + server_image=options.pathways_server_image, + proxy_image=options.pathways_proxy_image, + runner_image=options.pathways_runner_image, + ), ) v6e_256_configs = HWConfig( @@ -155,9 +219,10 @@ def main() -> None: model_name=benchmark_model, software_config=v6e_env_configs, hardware_config=v6e_256_configs, + num_steps=options.num_steps, ) - xpk_benchmark_runner(cluster_config, [model_runner]) + xpk_benchmark_runner(cluster_config, [model_runner], options.xpk_path) if __name__ == '__main__': diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index b3606ec8f..a3df6aabf 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -291,6 +291,45 @@ class MaxTextModel: ), ) + +llama2_70b_4096_real_data_pw_long_run = MaxTextModel( + model_name="llama2-70b-4096-rd-pw-lr", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 7, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "reuse_example_batch": 1, + "profiler": "xplane", + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "tfds", + "tokenizer_path": "assets/tokenizer.llama2", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), +) + # ici_fsdp_transpose_parallelism gives one TFLOP better performance. llama2_70b_4096 = MaxTextModel( model_name="llama2-70b-4096", @@ -320,6 +359,45 @@ class MaxTextModel: ), ) +llama2_70b_4096_pw_long_run = MaxTextModel( + model_name="llama2-70b-4096-pw-lr", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), +) + llama3_8b_8192 = MaxTextModel( model_name="llama3-8b-8192", model_type="llama3-8b", @@ -615,7 +693,9 @@ class MaxTextModel: gpt_3_175b, llama2_7b_4096, llama2_70b_4096, + llama2_70b_4096_pw_long_run, llama2_70b_4096_real_data, + llama2_70b_4096_real_data_pw_long_run, llama3_8b_8192, # Not Optimizied yet llama3_70b_8192, # Not Optimizied yet llama3_1_405b_8192_fsdp_dcn, diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index 8c728da3e..ce5361341 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -49,6 +49,16 @@ class XpkConfig: num_slices: str device_type: str base_output_directory: str + priority: str + max_restarts: int + + +@dataclasses.dataclass +class PathwaysConfig: + use_pathways: bool + server_image: str + proxy_image: str + runner_image: str @dataclasses.dataclass @@ -61,6 +71,7 @@ class HWConfig: class SWconfig: libtpu_version: str base_docker_image: str + pathways_config: PathwaysConfig @dataclasses.dataclass @@ -68,6 +79,7 @@ class BenchmarkRunner: model_name: str hardware_config: HWConfig software_config: SWconfig + num_steps: int def chunks(lst: list, n: int): @@ -257,21 +269,25 @@ def run_command_with_updates(command, task, verbose=True) -> int: def build_user_command( + name: str, model: model_configs.MaxTextModel, num_slices: int, num_steps: int, libtpu_type: LibTpuType, libtpu_date: str, cluster_config: XpkConfig, - base_output_directory: str, + base_output_directory: str, buffer_size: int, + pathways_config: PathwaysConfig = None, ): config_tuning_params = '' for key, value in model.tuning_params.items(): config_tuning_params += f'{key}={value} ' install_libtpu_cmd = '' - if libtpu_type == LibTpuType.NIGHTLY: + if pathways_config.use_pathways: + pass + elif libtpu_type == LibTpuType.NIGHTLY: install_libtpu_cmd += ( f' pip install libtpu-nightly==0.1.dev{libtpu_date} -f' ' https://storage.googleapis.com/libtpu-releases/index.html &&' @@ -288,35 +304,30 @@ def build_user_command( # model.xla_flags += ' --grpc_enable_rpc_receive_coalescing=true' # model.xla_flags += ' --grpc_experiments=tcp_rcv_lowat' + # Use single quotes for LIBTPU_INIT_ARGS and escape inner single quotes libtpu_flags = f"LIBTPU_INIT_ARGS='{model.xla_flags}'" + jax_platforms = 'proxy' if pathways_config.use_pathways else 'tpu,cpu' + vertex_tensorboard = ' vertex_tensorboard_project="" vertex_tensorboard_region=""' if pathways_config.use_pathways else '' - return ( - # f'python3 -m pip install google-cloud-aiplatform==v1.61.0 &&' - # f'pip install -U "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html &&' - # f' pip install https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-0.4.27.dev20240501-cp310-cp310-manylinux2014_x86_64.whl &&' - # f' pip install git+https://github.com/jax-ml/jax.git@57bfe81260545556ec22509347f7ced112496200 &&' - f' {install_libtpu_cmd}' - # f' mv libtpu.so /lib/ &&' - # f' export TPU_LIBRARY_PATH=$PWD/libtpu.so &&' + # Construct the command string with proper formatting and line continuations + command = ( + f'{install_libtpu_cmd}' f' echo {libtpu_flags} &&' - # f' echo {model.tuning_params["sa_block_q"]}-q-dq-{model.tuning_params["sa_block_q_dq"]}-q-dkv-{model.tuning_params["sa_block_q_dkv"]} &&' - # f' echo {model.tuning_params["ici_fsdp_parallelism"]} {model.tuning_params["ici_tensor_parallelism"]} &&' - f' export JAX_PLATFORMS=tpu,cpu &&' - # f' export JAX_DEBUG_NANS=True &&' - # f' export TPU_MEGACORE=megachip_tccontrol &&' - # f' echo TPU MEGACORE: $TPU_MEGACORE &&' + f' export ENABLE_PATHWAYS_PERSISTENCE=1 &&' + f' export JAX_PLATFORMS={jax_platforms} &&' f' export TPU_PREMAPPED_BUFFER_SIZE={buffer_size} &&' f' echo {buffer_size} &&' f' export ENABLE_PJRT_COMPATIBILITY=true &&' - f' export {libtpu_flags} && ' + f' export {libtpu_flags} &&' ' python3 MaxText/train.py MaxText/configs/base.yml' - f' {config_tuning_params} steps={num_steps} enable_checkpointing=false' + f' {config_tuning_params} steps={num_steps}' f' model_name={model.model_type}' f' base_output_directory={base_output_directory}' f' use_vertex_tensorboard=false' - ' vertex_tensorboard_project="" vertex_tensorboard_region=""' - f' run_name="{model.model_name}-{num_slices}-{libtpu_date}"' + f' {vertex_tensorboard}' + f' run_name={name}' ) + return command def generate_xpk_workload_cmd( @@ -327,11 +338,12 @@ def generate_xpk_workload_cmd( libtpu_version: str, base_output_directory: str, buffer_size: int, + num_steps: int = 100, + xpk_path: str = '~/xpk', + pathways_config: PathwaysConfig = None, ): """Generates a command to run a maxstar model on XPK.""" - num_steps = 20 time.localtime() - test_purpose_name = f'maxstar-benchmarks-{model.model_name}-{libtpu_version}' N = 3 temp_post_fix = ''.join( random.choice(string.ascii_lowercase + string.digits) for _ in range(N) @@ -340,7 +352,14 @@ def generate_xpk_workload_cmd( name = ( f"{model.model_name.replace('_', '-')}-{cluster_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}" ) + if pathways_config.use_pathways: + # Pathways run names are long and need to be shortened. + name = ( + f"pw-{model.model_name.replace('_', '-')}-{cluster_config.num_slices}-{temp_post_fix}" + ) + user_command = build_user_command( + name, model, num_slices, num_steps, @@ -349,6 +368,7 @@ def generate_xpk_workload_cmd( cluster_config, base_output_directory, buffer_size, + pathways_config, ) additional_flags = '' @@ -361,23 +381,40 @@ def generate_xpk_workload_cmd( ' https://raw.githubusercontent.com/GoogleCloudPlatform/ai-on-gke/9ff340f07f70be0130454f9e7238551587242b75/scripts/network-setup/v6e-network-optimization.yaml' ) + # pathways-related flags + pathways_specific_flags = '' + docker_image_flag = f'--base-docker-image="{BASE_DOCKER_IMAGE}"' + if pathways_config.use_pathways: + pathways_specific_flags = ( + '--use-pathways' + f' --server-image={pathways_config.server_image}' + f' --proxy-server-image={pathways_config.proxy_image}' + ' --termination-grace-period-seconds=300' + f' --pathways-gcs-location={base_output_directory}' + f' --restart-on-user-code-failure' + ) + docker_image_flag = ( + f'--docker-image={pathways_config.runner_image}' + ) + print(f'User command: {user_command}') return ( ( # f'{perf_optimzation_dcn} &&' - 'python3 ~/xpk/xpk.py workload create' + f'python3 {xpk_path}/xpk.py workload create' + f' {pathways_specific_flags}' f' --cluster={cluster_config.cluster_name}' f' --project={cluster_config.project}' f' --zone={cluster_config.zone}' f' --device-type={cluster_config.device_type}' f' --num-slices={cluster_config.num_slices}' f' --command="{user_command}"' - f' --base-docker-image="{BASE_DOCKER_IMAGE}"' + f' {docker_image_flag}' ' --enable-debug-logs' f' --workload={name}' - ' --priority=medium' + f' --priority={cluster_config.priority}' + f' --max-restarts={cluster_config.max_restarts}' # ' --use-vertex-tensorboard' - # f' --experiment-name={test_purpose_name}' f' {additional_flags}' ), name, @@ -406,7 +443,7 @@ def run_xpk_workload( return run_command_with_updates(command, 'Run XPK workload', cluster_config) -def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRunner]): +def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRunner], xpk_path: str): xpk_workload_names = [] xpk_workload_cmds = [] for benchmark in benchmarks: @@ -418,8 +455,15 @@ def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRu libtpu_version=benchmark.software_config.libtpu_version, base_output_directory=cluster_config.base_output_directory, buffer_size=4294967296, + num_steps=benchmark.num_steps, + xpk_path=xpk_path, + pathways_config=benchmark.software_config.pathways_config, ) + + print(f"name of the workload is: {name}") xpk_workload_names.append(name) + + print(f"XPK command to be used is: {command}") xpk_workload_cmds.append(command) returncodes = run_commands(