From 7e619b7705004e098cccec9501fcf162c74a470b Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Tue, 10 Dec 2024 04:04:36 +0000 Subject: [PATCH 01/16] Add Pathways Support to Benchmark Runner --- benchmarks/benchmark_runner.py | 67 ++++++++++++- benchmarks/maxtext_trillium_model_configs.py | 80 ++++++++++++++++ benchmarks/maxtext_xpk_runner.py | 98 ++++++++++++++------ 3 files changed, 217 insertions(+), 28 deletions(-) diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index 909c36276..50653b8e6 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -28,6 +28,7 @@ from maxtext_xpk_runner import BenchmarkRunner from maxtext_xpk_runner import HWConfig from maxtext_xpk_runner import SWconfig +from maxtext_xpk_runner import PathwaysConfig from maxtext_xpk_runner import xpk_benchmark_runner from maxtext_xpk_runner import XpkConfig @@ -86,6 +87,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_7b_4096', 'llama2_70b_4096', 'llama2_70b_4096_real_data', + 'llama2_70b_4096_pw_long_run', + 'llama2_70b_4096_real_data_pw_long_run', 'llama3_70b_8192', 'llama3_1_405b_8192_fsdp_dcn', 'mixtral_8x7b_dropped', @@ -101,6 +104,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_7b_4096 ' 'llama2_70b_4096 ' 'llama2_70b_4096_real_data ' + 'llama2_70b_4096_pw_long_run ' + 'llama2_70b_4096_real_data_pw_long_run ' 'llama3_1_405b_8192_fsdp_dcn ' 'mixtral_8x7b_dropped ' 'mixtral_8x7b_dropped_int8 ' @@ -122,6 +127,57 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): default='maxtext_base_image', help='version of base docker image to be benchmarked command.', ) + custom_parser.add_argument( + '--pathways_server_image', + type=str, + default=( + 'us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest' + ), + help='version of pathways server image to be benchmarked command.', + ) + custom_parser.add_argument( + '--pathways_proxy_image', + type=str, + default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest', + help='version of pathways proxy image to be benchmarked command.', + ) + custom_parser.add_argument( + '--pathways_runner_image', + type=str, + default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest', + help='version of pathways runner image to be benchmarked command.', + ) + custom_parser.add_argument( + '--use_pathways', + type=bool, + default=False, + help='whether to use pathways or not.', + ) + custom_parser.add_argument( + '--xpk_path', + type=str, + default='~/xpk', + help='path to xpk dir.', + ) + custom_parser.add_argument( + '--num_steps', + type=int, + default='20', + help='Number of steps to run benchmark for.', + ) + custom_parser.add_argument( + '--priority', + type=str, + default='medium', + help='Priority the XPK workload should run with.', + ) + custom_parser.add_argument( + '--max_restarts', + type=int, + default=0, + help='Number of restarts to attempt.', + ) + def main() -> None: parser = argparse.ArgumentParser( @@ -137,11 +193,19 @@ def main() -> None: num_slices=options.num_slices, device_type=options.device_type, base_output_directory=options.base_output_directory, + priority=options.priority, + max_restarts=options.max_restarts, ) v6e_env_configs = SWconfig( base_docker_image=options.base_docker_image, libtpu_version=options.libtpu_version, + pathways_config=PathwaysConfig( + use_pathways=options.use_pathways, + server_image=options.pathways_server_image, + proxy_image=options.pathways_proxy_image, + runner_image=options.pathways_runner_image, + ), ) v6e_256_configs = HWConfig( @@ -155,9 +219,10 @@ def main() -> None: model_name=benchmark_model, software_config=v6e_env_configs, hardware_config=v6e_256_configs, + num_steps=options.num_steps, ) - xpk_benchmark_runner(cluster_config, [model_runner]) + xpk_benchmark_runner(cluster_config, [model_runner], options.xpk_path) if __name__ == '__main__': diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index b3606ec8f..a3df6aabf 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -291,6 +291,45 @@ class MaxTextModel: ), ) + +llama2_70b_4096_real_data_pw_long_run = MaxTextModel( + model_name="llama2-70b-4096-rd-pw-lr", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 7, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "reuse_example_batch": 1, + "profiler": "xplane", + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "tfds", + "tokenizer_path": "assets/tokenizer.llama2", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), +) + # ici_fsdp_transpose_parallelism gives one TFLOP better performance. llama2_70b_4096 = MaxTextModel( model_name="llama2-70b-4096", @@ -320,6 +359,45 @@ class MaxTextModel: ), ) +llama2_70b_4096_pw_long_run = MaxTextModel( + model_name="llama2-70b-4096-pw-lr", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), +) + llama3_8b_8192 = MaxTextModel( model_name="llama3-8b-8192", model_type="llama3-8b", @@ -615,7 +693,9 @@ class MaxTextModel: gpt_3_175b, llama2_7b_4096, llama2_70b_4096, + llama2_70b_4096_pw_long_run, llama2_70b_4096_real_data, + llama2_70b_4096_real_data_pw_long_run, llama3_8b_8192, # Not Optimizied yet llama3_70b_8192, # Not Optimizied yet llama3_1_405b_8192_fsdp_dcn, diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index 8c728da3e..ce5361341 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -49,6 +49,16 @@ class XpkConfig: num_slices: str device_type: str base_output_directory: str + priority: str + max_restarts: int + + +@dataclasses.dataclass +class PathwaysConfig: + use_pathways: bool + server_image: str + proxy_image: str + runner_image: str @dataclasses.dataclass @@ -61,6 +71,7 @@ class HWConfig: class SWconfig: libtpu_version: str base_docker_image: str + pathways_config: PathwaysConfig @dataclasses.dataclass @@ -68,6 +79,7 @@ class BenchmarkRunner: model_name: str hardware_config: HWConfig software_config: SWconfig + num_steps: int def chunks(lst: list, n: int): @@ -257,21 +269,25 @@ def run_command_with_updates(command, task, verbose=True) -> int: def build_user_command( + name: str, model: model_configs.MaxTextModel, num_slices: int, num_steps: int, libtpu_type: LibTpuType, libtpu_date: str, cluster_config: XpkConfig, - base_output_directory: str, + base_output_directory: str, buffer_size: int, + pathways_config: PathwaysConfig = None, ): config_tuning_params = '' for key, value in model.tuning_params.items(): config_tuning_params += f'{key}={value} ' install_libtpu_cmd = '' - if libtpu_type == LibTpuType.NIGHTLY: + if pathways_config.use_pathways: + pass + elif libtpu_type == LibTpuType.NIGHTLY: install_libtpu_cmd += ( f' pip install libtpu-nightly==0.1.dev{libtpu_date} -f' ' https://storage.googleapis.com/libtpu-releases/index.html &&' @@ -288,35 +304,30 @@ def build_user_command( # model.xla_flags += ' --grpc_enable_rpc_receive_coalescing=true' # model.xla_flags += ' --grpc_experiments=tcp_rcv_lowat' + # Use single quotes for LIBTPU_INIT_ARGS and escape inner single quotes libtpu_flags = f"LIBTPU_INIT_ARGS='{model.xla_flags}'" + jax_platforms = 'proxy' if pathways_config.use_pathways else 'tpu,cpu' + vertex_tensorboard = ' vertex_tensorboard_project="" vertex_tensorboard_region=""' if pathways_config.use_pathways else '' - return ( - # f'python3 -m pip install google-cloud-aiplatform==v1.61.0 &&' - # f'pip install -U "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html &&' - # f' pip install https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-0.4.27.dev20240501-cp310-cp310-manylinux2014_x86_64.whl &&' - # f' pip install git+https://github.com/jax-ml/jax.git@57bfe81260545556ec22509347f7ced112496200 &&' - f' {install_libtpu_cmd}' - # f' mv libtpu.so /lib/ &&' - # f' export TPU_LIBRARY_PATH=$PWD/libtpu.so &&' + # Construct the command string with proper formatting and line continuations + command = ( + f'{install_libtpu_cmd}' f' echo {libtpu_flags} &&' - # f' echo {model.tuning_params["sa_block_q"]}-q-dq-{model.tuning_params["sa_block_q_dq"]}-q-dkv-{model.tuning_params["sa_block_q_dkv"]} &&' - # f' echo {model.tuning_params["ici_fsdp_parallelism"]} {model.tuning_params["ici_tensor_parallelism"]} &&' - f' export JAX_PLATFORMS=tpu,cpu &&' - # f' export JAX_DEBUG_NANS=True &&' - # f' export TPU_MEGACORE=megachip_tccontrol &&' - # f' echo TPU MEGACORE: $TPU_MEGACORE &&' + f' export ENABLE_PATHWAYS_PERSISTENCE=1 &&' + f' export JAX_PLATFORMS={jax_platforms} &&' f' export TPU_PREMAPPED_BUFFER_SIZE={buffer_size} &&' f' echo {buffer_size} &&' f' export ENABLE_PJRT_COMPATIBILITY=true &&' - f' export {libtpu_flags} && ' + f' export {libtpu_flags} &&' ' python3 MaxText/train.py MaxText/configs/base.yml' - f' {config_tuning_params} steps={num_steps} enable_checkpointing=false' + f' {config_tuning_params} steps={num_steps}' f' model_name={model.model_type}' f' base_output_directory={base_output_directory}' f' use_vertex_tensorboard=false' - ' vertex_tensorboard_project="" vertex_tensorboard_region=""' - f' run_name="{model.model_name}-{num_slices}-{libtpu_date}"' + f' {vertex_tensorboard}' + f' run_name={name}' ) + return command def generate_xpk_workload_cmd( @@ -327,11 +338,12 @@ def generate_xpk_workload_cmd( libtpu_version: str, base_output_directory: str, buffer_size: int, + num_steps: int = 100, + xpk_path: str = '~/xpk', + pathways_config: PathwaysConfig = None, ): """Generates a command to run a maxstar model on XPK.""" - num_steps = 20 time.localtime() - test_purpose_name = f'maxstar-benchmarks-{model.model_name}-{libtpu_version}' N = 3 temp_post_fix = ''.join( random.choice(string.ascii_lowercase + string.digits) for _ in range(N) @@ -340,7 +352,14 @@ def generate_xpk_workload_cmd( name = ( f"{model.model_name.replace('_', '-')}-{cluster_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}" ) + if pathways_config.use_pathways: + # Pathways run names are long and need to be shortened. + name = ( + f"pw-{model.model_name.replace('_', '-')}-{cluster_config.num_slices}-{temp_post_fix}" + ) + user_command = build_user_command( + name, model, num_slices, num_steps, @@ -349,6 +368,7 @@ def generate_xpk_workload_cmd( cluster_config, base_output_directory, buffer_size, + pathways_config, ) additional_flags = '' @@ -361,23 +381,40 @@ def generate_xpk_workload_cmd( ' https://raw.githubusercontent.com/GoogleCloudPlatform/ai-on-gke/9ff340f07f70be0130454f9e7238551587242b75/scripts/network-setup/v6e-network-optimization.yaml' ) + # pathways-related flags + pathways_specific_flags = '' + docker_image_flag = f'--base-docker-image="{BASE_DOCKER_IMAGE}"' + if pathways_config.use_pathways: + pathways_specific_flags = ( + '--use-pathways' + f' --server-image={pathways_config.server_image}' + f' --proxy-server-image={pathways_config.proxy_image}' + ' --termination-grace-period-seconds=300' + f' --pathways-gcs-location={base_output_directory}' + f' --restart-on-user-code-failure' + ) + docker_image_flag = ( + f'--docker-image={pathways_config.runner_image}' + ) + print(f'User command: {user_command}') return ( ( # f'{perf_optimzation_dcn} &&' - 'python3 ~/xpk/xpk.py workload create' + f'python3 {xpk_path}/xpk.py workload create' + f' {pathways_specific_flags}' f' --cluster={cluster_config.cluster_name}' f' --project={cluster_config.project}' f' --zone={cluster_config.zone}' f' --device-type={cluster_config.device_type}' f' --num-slices={cluster_config.num_slices}' f' --command="{user_command}"' - f' --base-docker-image="{BASE_DOCKER_IMAGE}"' + f' {docker_image_flag}' ' --enable-debug-logs' f' --workload={name}' - ' --priority=medium' + f' --priority={cluster_config.priority}' + f' --max-restarts={cluster_config.max_restarts}' # ' --use-vertex-tensorboard' - # f' --experiment-name={test_purpose_name}' f' {additional_flags}' ), name, @@ -406,7 +443,7 @@ def run_xpk_workload( return run_command_with_updates(command, 'Run XPK workload', cluster_config) -def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRunner]): +def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRunner], xpk_path: str): xpk_workload_names = [] xpk_workload_cmds = [] for benchmark in benchmarks: @@ -418,8 +455,15 @@ def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRu libtpu_version=benchmark.software_config.libtpu_version, base_output_directory=cluster_config.base_output_directory, buffer_size=4294967296, + num_steps=benchmark.num_steps, + xpk_path=xpk_path, + pathways_config=benchmark.software_config.pathways_config, ) + + print(f"name of the workload is: {name}") xpk_workload_names.append(name) + + print(f"XPK command to be used is: {command}") xpk_workload_cmds.append(command) returncodes = run_commands( From 988b0a18531873b49dcbe028f63d73f06b744f6c Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Fri, 13 Dec 2024 00:33:22 +0000 Subject: [PATCH 02/16] Minor fixes --- benchmarks/maxtext_xpk_runner.py | 37 ++++++++++++++++---------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index ce5361341..e326b0522 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -307,26 +307,26 @@ def build_user_command( # Use single quotes for LIBTPU_INIT_ARGS and escape inner single quotes libtpu_flags = f"LIBTPU_INIT_ARGS='{model.xla_flags}'" jax_platforms = 'proxy' if pathways_config.use_pathways else 'tpu,cpu' - vertex_tensorboard = ' vertex_tensorboard_project="" vertex_tensorboard_region=""' if pathways_config.use_pathways else '' + vertex_tensorboard = 'vertex_tensorboard_project="" vertex_tensorboard_region=""' if pathways_config.use_pathways else '' # Construct the command string with proper formatting and line continuations - command = ( - f'{install_libtpu_cmd}' - f' echo {libtpu_flags} &&' - f' export ENABLE_PATHWAYS_PERSISTENCE=1 &&' - f' export JAX_PLATFORMS={jax_platforms} &&' - f' export TPU_PREMAPPED_BUFFER_SIZE={buffer_size} &&' - f' echo {buffer_size} &&' - f' export ENABLE_PJRT_COMPATIBILITY=true &&' - f' export {libtpu_flags} &&' - ' python3 MaxText/train.py MaxText/configs/base.yml' - f' {config_tuning_params} steps={num_steps}' - f' model_name={model.model_type}' - f' base_output_directory={base_output_directory}' - f' use_vertex_tensorboard=false' - f' {vertex_tensorboard}' - f' run_name={name}' - ) + command = ' '.join([ + f'{install_libtpu_cmd}', + f'echo {libtpu_flags} &&' if not pathways_config.use_pathways else '', + f'export {libtpu_flags} &&' if not pathways_config.use_pathways else '', + 'export ENABLE_PATHWAYS_PERSISTENCE=1 &&', + f'export JAX_PLATFORMS={jax_platforms} &&', + f'export TPU_PREMAPPED_BUFFER_SIZE={buffer_size} &&', + f'echo {buffer_size} &&', + 'export ENABLE_PJRT_COMPATIBILITY=true &&', + 'python3 MaxText/train.py MaxText/configs/base.yml', + f'{config_tuning_params}steps={num_steps}', + f'model_name={model.model_type}', + f'base_output_directory={base_output_directory}', + 'use_vertex_tensorboard=false', + f'{vertex_tensorboard}', + f'run_name={name}' + ]) return command @@ -392,6 +392,7 @@ def generate_xpk_workload_cmd( ' --termination-grace-period-seconds=300' f' --pathways-gcs-location={base_output_directory}' f' --restart-on-user-code-failure' + f' --debug-dump-gcs={base_output_directory}' ) docker_image_flag = ( f'--docker-image={pathways_config.runner_image}' From 9a086ea0323d12bbfe1c64e76c2a376b796b4f21 Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Fri, 13 Dec 2024 01:01:24 +0000 Subject: [PATCH 03/16] Update PW Long Running Config --- benchmarks/maxtext_trillium_model_configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index a3df6aabf..8a2c0c415 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -296,14 +296,14 @@ class MaxTextModel: model_name="llama2-70b-4096-rd-pw-lr", model_type="llama2-70b", tuning_params={ - "per_device_batch_size": 7, + "per_device_batch_size": 4, "ici_fsdp_parallelism": -1, "remat_policy": "full", "max_target_length": 4096, "attention": "flash", "gcs_metrics": True, "use_iota_embed": True, - "reuse_example_batch": 1, + "reuse_example_batch": 0, "profiler": "xplane", "dataset_path": "gs://max-datasets-rogue", "dataset_type": "tfds", From ac4e74076a98a38e01abcbcc8ffbf91b45a1f00c Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Fri, 13 Dec 2024 01:21:16 +0000 Subject: [PATCH 04/16] Resolve Comments --- benchmarks/maxtext_xpk_runner.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index e326b0522..36577d45c 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -308,6 +308,7 @@ def build_user_command( libtpu_flags = f"LIBTPU_INIT_ARGS='{model.xla_flags}'" jax_platforms = 'proxy' if pathways_config.use_pathways else 'tpu,cpu' vertex_tensorboard = 'vertex_tensorboard_project="" vertex_tensorboard_region=""' if pathways_config.use_pathways else '' + steps_config = f'steps={num_steps}' if num_steps else '' # default to whatever the config has. # Construct the command string with proper formatting and line continuations command = ' '.join([ @@ -320,7 +321,7 @@ def build_user_command( f'echo {buffer_size} &&', 'export ENABLE_PJRT_COMPATIBILITY=true &&', 'python3 MaxText/train.py MaxText/configs/base.yml', - f'{config_tuning_params}steps={num_steps}', + f'{config_tuning_params}{steps_config}', f'model_name={model.model_type}', f'base_output_directory={base_output_directory}', 'use_vertex_tensorboard=false', @@ -338,11 +339,16 @@ def generate_xpk_workload_cmd( libtpu_version: str, base_output_directory: str, buffer_size: int, - num_steps: int = 100, + num_steps: int, xpk_path: str = '~/xpk', pathways_config: PathwaysConfig = None, ): """Generates a command to run a maxstar model on XPK.""" + if not num_steps and 'steps' in model.tuning_params: + num_steps = model.tuning_params['steps'] + elif not num_steps: + num_steps = 20 + time.localtime() N = 3 temp_post_fix = ''.join( @@ -444,7 +450,11 @@ def run_xpk_workload( return run_command_with_updates(command, 'Run XPK workload', cluster_config) -def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRunner], xpk_path: str): +def xpk_benchmark_runner( + cluster_config: XpkConfig, + benchmarks: list[BenchmarkRunner], + xpk_path: str = '~/xpk', +): xpk_workload_names = [] xpk_workload_cmds = [] for benchmark in benchmarks: From eac72e37fc13d52448743464b24d293fc9a3d498 Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Fri, 13 Dec 2024 01:23:52 +0000 Subject: [PATCH 05/16] Remove default for num_steps --- benchmarks/benchmark_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index 50653b8e6..dd28ed582 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -162,7 +162,6 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): custom_parser.add_argument( '--num_steps', type=int, - default='20', help='Number of steps to run benchmark for.', ) custom_parser.add_argument( From 6a9efdbfe9b1016ea246120cfe0f071ad93cf8ab Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Fri, 13 Dec 2024 01:30:23 +0000 Subject: [PATCH 06/16] Minor change --- benchmarks/maxtext_xpk_runner.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index 36577d45c..20b60e1af 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -282,6 +282,10 @@ def build_user_command( ): config_tuning_params = '' for key, value in model.tuning_params.items(): + # If the user provides a number of steps use that, otherwise, + # use the tuning params value. + if key == 'steps' and num_steps: + value = num_steps config_tuning_params += f'{key}={value} ' install_libtpu_cmd = '' @@ -308,7 +312,6 @@ def build_user_command( libtpu_flags = f"LIBTPU_INIT_ARGS='{model.xla_flags}'" jax_platforms = 'proxy' if pathways_config.use_pathways else 'tpu,cpu' vertex_tensorboard = 'vertex_tensorboard_project="" vertex_tensorboard_region=""' if pathways_config.use_pathways else '' - steps_config = f'steps={num_steps}' if num_steps else '' # default to whatever the config has. # Construct the command string with proper formatting and line continuations command = ' '.join([ @@ -321,7 +324,7 @@ def build_user_command( f'echo {buffer_size} &&', 'export ENABLE_PJRT_COMPATIBILITY=true &&', 'python3 MaxText/train.py MaxText/configs/base.yml', - f'{config_tuning_params}{steps_config}', + f'{config_tuning_params}', f'model_name={model.model_type}', f'base_output_directory={base_output_directory}', 'use_vertex_tensorboard=false', @@ -344,11 +347,6 @@ def generate_xpk_workload_cmd( pathways_config: PathwaysConfig = None, ): """Generates a command to run a maxstar model on XPK.""" - if not num_steps and 'steps' in model.tuning_params: - num_steps = model.tuning_params['steps'] - elif not num_steps: - num_steps = 20 - time.localtime() N = 3 temp_post_fix = ''.join( From ad820315ada1f3c0d697412ee6b3af96d850a000 Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Fri, 13 Dec 2024 01:32:59 +0000 Subject: [PATCH 07/16] More code --- benchmarks/maxtext_xpk_runner.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index 20b60e1af..799dbdac6 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -281,13 +281,18 @@ def build_user_command( pathways_config: PathwaysConfig = None, ): config_tuning_params = '' + steps_set = False for key, value in model.tuning_params.items(): # If the user provides a number of steps use that, otherwise, # use the tuning params value. if key == 'steps' and num_steps: value = num_steps + steps_set = True config_tuning_params += f'{key}={value} ' + if not steps_set: + config_tuning_params += f'steps=20' + install_libtpu_cmd = '' if pathways_config.use_pathways: pass From 4f7df81095670dc20f758163a4ca10a4972fe719 Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Fri, 13 Dec 2024 01:38:56 +0000 Subject: [PATCH 08/16] fix logic --- benchmarks/maxtext_xpk_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index 799dbdac6..efbbf5b55 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -285,8 +285,9 @@ def build_user_command( for key, value in model.tuning_params.items(): # If the user provides a number of steps use that, otherwise, # use the tuning params value. - if key == 'steps' and num_steps: - value = num_steps + if key == 'steps': + if num_steps: + value = num_steps steps_set = True config_tuning_params += f'{key}={value} ' From 28c2daab0353c630b49426d51e460c8175b68e46 Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Fri, 13 Dec 2024 01:40:23 +0000 Subject: [PATCH 09/16] fix logic --- benchmarks/maxtext_xpk_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index efbbf5b55..3c9769755 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -292,7 +292,7 @@ def build_user_command( config_tuning_params += f'{key}={value} ' if not steps_set: - config_tuning_params += f'steps=20' + config_tuning_params += f'steps=20 ' install_libtpu_cmd = '' if pathways_config.use_pathways: From 386139f24f398119c76aceb04032fda9e1ce3508 Mon Sep 17 00:00:00 2001 From: RoshaniN Date: Fri, 13 Dec 2024 18:32:05 +0000 Subject: [PATCH 10/16] Added a new config for synthetic run. --- benchmarks/maxtext_trillium_model_configs.py | 64 ++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 8a2c0c415..e2983be8c 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -358,6 +358,70 @@ class MaxTextModel: + xla_flags_library.CF_FOR_ALL_GATHER ), ) +llama2_70b_4096_sc_synthetic = MaxTextModel( +model_name="llama2-70b-4096", +model_type="llama2-70b", +tuning_params={ +"per_device_batch_size": 2, +"ici_fsdp_parallelism": 1, +"ici_fsdp_transpose_parallelism": -1, +"ici_tensor_parallelism": 1, +"remat_policy": "qkv_proj_offloaded", +"max_target_length": 4096, +"attention": "flash", +"gcs_metrics": True, +"use_iota_embed": True, +"dataset_path": "gs://max-datasets-rogue", +"dataset_type": "synthetic", +"enable_checkpointing": False, +"profiler": "xplane", +"sa_block_q": 1024, +"sa_block_q_dkv": 2048, +"sa_block_q_dq": 2048, +}, +xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER +), +) + +llama2_70b_4096_sc_synthetic_pw_lr = MaxTextModel( +model_name="llama2-70b-4096-pw-lr", +model_type="llama2-70b", +tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + # "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, +}, +xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER +), +) llama2_70b_4096_pw_long_run = MaxTextModel( model_name="llama2-70b-4096-pw-lr", From 0aba7a8a4d125f507f17648c1bb7ea6cc5e335d2 Mon Sep 17 00:00:00 2001 From: RoshaniN Date: Fri, 13 Dec 2024 19:41:43 +0000 Subject: [PATCH 11/16] Allow nightly JAX VERSION also. --- setup.sh | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/setup.sh b/setup.sh index 4b4ba090e..724597be4 100644 --- a/setup.sh +++ b/setup.sh @@ -60,10 +60,10 @@ if [[ $LIBTPU_GCS_PATH == NONE ]]; then unset LIBTPU_GCS_PATH fi -if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE || ($MODE == "nightly" && $DEVICE == "gpu")) ]]; then - echo -e "\n\nError: You can only specify a JAX_VERSION with stable mode (plus nightly mode on GPU).\n\n" - exit 1 -fi +# if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE || ($MODE == "nightly" && $DEVICE == "gpu")) ]]; then +# echo -e "\n\nError: You can only specify a JAX_VERSION with stable mode (plus nightly mode on GPU).\n\n" +# exit 1 +# fi if [[ $DEVICE == "tpu" ]]; then libtpu_path="$HOME/custom_libtpu/libtpu.so" @@ -170,12 +170,17 @@ elif [[ $MODE == "nightly" ]]; then export NVTE_FRAMEWORK=jax pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable elif [[ $DEVICE == "tpu" ]]; then - echo "Installing jax-nightly, jaxlib-nightly" # Install jax-nightly - pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - # Install jaxlib-nightly - pip3 install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - + if [[ -n "$JAX_VERSION" ]]; then + echo "Installing jax-nightly, jaxlib-nightly ${JAX_VERSION}" + pip install -U --pre jax==${JAX_VERSION} jaxlib==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + else + echo "Installing jax-nightly, jaxlib-nightly" + # Install jax-nightly + pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + # Install jaxlib-nightly + pip3 install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + fi if [[ -n "$LIBTPU_GCS_PATH" ]]; then # Install custom libtpu echo "Installing libtpu.so from $LIBTPU_GCS_PATH to $libtpu_path" From a216153512b8d0901f01bd5e2bfd67debbfe660c Mon Sep 17 00:00:00 2001 From: RoshaniN Date: Fri, 13 Dec 2024 19:57:49 +0000 Subject: [PATCH 12/16] Update mode name. --- benchmarks/maxtext_trillium_model_configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index e2983be8c..be19cd494 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -359,7 +359,7 @@ class MaxTextModel: ), ) llama2_70b_4096_sc_synthetic = MaxTextModel( -model_name="llama2-70b-4096", +model_name="llama2_70b_4096_sc_synthetic", model_type="llama2-70b", tuning_params={ "per_device_batch_size": 2, @@ -386,7 +386,7 @@ class MaxTextModel: ) llama2_70b_4096_sc_synthetic_pw_lr = MaxTextModel( -model_name="llama2-70b-4096-pw-lr", +model_name="llama2_70b_4096_sc_synthetic_pw_lr", model_type="llama2-70b", tuning_params={ "per_device_batch_size": 2, From 77c8c044999e8f4179bd777de23e330b0a8d1530 Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Fri, 13 Dec 2024 20:13:16 +0000 Subject: [PATCH 13/16] Fix model name config --- benchmarks/benchmark_runner.py | 4 ++ benchmarks/maxtext_trillium_model_configs.py | 58 ++++++++++---------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index dd28ed582..afaf5c0bc 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -89,6 +89,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_70b_4096_real_data', 'llama2_70b_4096_pw_long_run', 'llama2_70b_4096_real_data_pw_long_run', + 'llama2_70b_4096_sc_synthetic_pw_lr', + 'llama2_70b_4096_sc_synthetic', 'llama3_70b_8192', 'llama3_1_405b_8192_fsdp_dcn', 'mixtral_8x7b_dropped', @@ -106,6 +108,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_70b_4096_real_data ' 'llama2_70b_4096_pw_long_run ' 'llama2_70b_4096_real_data_pw_long_run ' + 'llama2_70b_4096_sc_synthetic_pw_lr ' + 'llama2_70b_4096_sc_synthetic ' 'llama3_1_405b_8192_fsdp_dcn ' 'mixtral_8x7b_dropped ' 'mixtral_8x7b_dropped_int8 ' diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index be19cd494..f70938764 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -359,36 +359,36 @@ class MaxTextModel: ), ) llama2_70b_4096_sc_synthetic = MaxTextModel( -model_name="llama2_70b_4096_sc_synthetic", -model_type="llama2-70b", -tuning_params={ -"per_device_batch_size": 2, -"ici_fsdp_parallelism": 1, -"ici_fsdp_transpose_parallelism": -1, -"ici_tensor_parallelism": 1, -"remat_policy": "qkv_proj_offloaded", -"max_target_length": 4096, -"attention": "flash", -"gcs_metrics": True, -"use_iota_embed": True, -"dataset_path": "gs://max-datasets-rogue", -"dataset_type": "synthetic", -"enable_checkpointing": False, -"profiler": "xplane", -"sa_block_q": 1024, -"sa_block_q_dkv": 2048, -"sa_block_q_dq": 2048, -}, -xla_flags=( + model_name="llama2_70b_4096_sc_synthetic", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER -), + ), ) llama2_70b_4096_sc_synthetic_pw_lr = MaxTextModel( -model_name="llama2_70b_4096_sc_synthetic_pw_lr", -model_type="llama2-70b", -tuning_params={ + model_name="llama2_70b_4096_sc_synthetic_pw_lr", + model_type="llama2-70b", + tuning_params={ "per_device_batch_size": 2, "ici_fsdp_parallelism": 1, "ici_fsdp_transpose_parallelism": -1, @@ -416,11 +416,11 @@ class MaxTextModel: "enable_pathways_goodput": True, "enable_checkpoint_cloud_logger": True, "enable_single_controller": True, -}, -xla_flags=( + }, + xla_flags=( xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER -), + ), ) llama2_70b_4096_pw_long_run = MaxTextModel( @@ -762,6 +762,8 @@ class MaxTextModel: llama2_70b_4096_real_data_pw_long_run, llama3_8b_8192, # Not Optimizied yet llama3_70b_8192, # Not Optimizied yet + llama2_70b_4096_sc_synthetic_pw_lr, + llama2_70b_4096_sc_synthetic, llama3_1_405b_8192_fsdp_dcn, llama3_1_70b_129024, mixtral_8x7b_dropped, From cd1999e118bf6dcbc9558d055f8ccaf7fd37b820 Mon Sep 17 00:00:00 2001 From: RoshaniN Date: Mon, 16 Dec 2024 15:26:32 +0000 Subject: [PATCH 14/16] Shortening model names to be under 40 chars. --- benchmarks/benchmark_runner.py | 8 ++++---- benchmarks/maxtext_trillium_model_configs.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index afaf5c0bc..291672ddb 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -89,8 +89,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_70b_4096_real_data', 'llama2_70b_4096_pw_long_run', 'llama2_70b_4096_real_data_pw_long_run', - 'llama2_70b_4096_sc_synthetic_pw_lr', - 'llama2_70b_4096_sc_synthetic', + 'llama2_70b_4096_synthetic_pw_lr', + 'llama2_70b_4096_synthetic', 'llama3_70b_8192', 'llama3_1_405b_8192_fsdp_dcn', 'mixtral_8x7b_dropped', @@ -108,8 +108,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_70b_4096_real_data ' 'llama2_70b_4096_pw_long_run ' 'llama2_70b_4096_real_data_pw_long_run ' - 'llama2_70b_4096_sc_synthetic_pw_lr ' - 'llama2_70b_4096_sc_synthetic ' + 'llama2_70b_4096_synthetic_pw_lr ' + 'llama2_70b_4096_synthetic ' 'llama3_1_405b_8192_fsdp_dcn ' 'mixtral_8x7b_dropped ' 'mixtral_8x7b_dropped_int8 ' diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index f70938764..61c8b1883 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -358,8 +358,8 @@ class MaxTextModel: + xla_flags_library.CF_FOR_ALL_GATHER ), ) -llama2_70b_4096_sc_synthetic = MaxTextModel( - model_name="llama2_70b_4096_sc_synthetic", +llama2_70b_4096_synthetic = MaxTextModel( + model_name="llama2_70b_4096_synthetic", model_type="llama2-70b", tuning_params={ "per_device_batch_size": 2, @@ -385,8 +385,8 @@ class MaxTextModel: ), ) -llama2_70b_4096_sc_synthetic_pw_lr = MaxTextModel( - model_name="llama2_70b_4096_sc_synthetic_pw_lr", +llama2_70b_4096_synthetic_pw_lr = MaxTextModel( + model_name="llama2_70b_4096_synthetic_pw_lr", model_type="llama2-70b", tuning_params={ "per_device_batch_size": 2, @@ -762,8 +762,8 @@ class MaxTextModel: llama2_70b_4096_real_data_pw_long_run, llama3_8b_8192, # Not Optimizied yet llama3_70b_8192, # Not Optimizied yet - llama2_70b_4096_sc_synthetic_pw_lr, - llama2_70b_4096_sc_synthetic, + llama2_70b_4096_synthetic_pw_lr, + llama2_70b_4096_synthetic, llama3_1_405b_8192_fsdp_dcn, llama3_1_70b_129024, mixtral_8x7b_dropped, From 90107c26e95607b75f5ad7e45c86349ad256b956 Mon Sep 17 00:00:00 2001 From: RoshaniN Date: Mon, 16 Dec 2024 19:48:56 +0000 Subject: [PATCH 15/16] Added TFDS config. --- benchmarks/benchmark_runner.py | 2 + benchmarks/maxtext_trillium_model_configs.py | 40 ++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index 291672ddb..bf942bda3 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -89,6 +89,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_70b_4096_real_data', 'llama2_70b_4096_pw_long_run', 'llama2_70b_4096_real_data_pw_long_run', + 'llama2_70b_4096_pw_rd_tfds ', 'llama2_70b_4096_synthetic_pw_lr', 'llama2_70b_4096_synthetic', 'llama3_70b_8192', @@ -108,6 +109,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_70b_4096_real_data ' 'llama2_70b_4096_pw_long_run ' 'llama2_70b_4096_real_data_pw_long_run ' + 'llama2_70b_4096_pw_rd_tfds ' 'llama2_70b_4096_synthetic_pw_lr ' 'llama2_70b_4096_synthetic ' 'llama3_1_405b_8192_fsdp_dcn ' diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 61c8b1883..74f16203c 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -462,6 +462,45 @@ class MaxTextModel: ), ) +llama2_70b_4096_pw_rd_tfds = MaxTextModel( + model_name="llama2_70b_4096_pw_rd_tfds", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://trillium-storage-datasets-sr", + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), +) + + llama3_8b_8192 = MaxTextModel( model_name="llama3-8b-8192", model_type="llama3-8b", @@ -760,6 +799,7 @@ class MaxTextModel: llama2_70b_4096_pw_long_run, llama2_70b_4096_real_data, llama2_70b_4096_real_data_pw_long_run, + llama2_70b_4096_pw_rd_tfds, llama3_8b_8192, # Not Optimizied yet llama3_70b_8192, # Not Optimizied yet llama2_70b_4096_synthetic_pw_lr, From f7dd4ccdb215e9822066d0bb0c1bc18e392bc2ed Mon Sep 17 00:00:00 2001 From: Sadi Kneipp Date: Mon, 16 Dec 2024 21:38:11 +0000 Subject: [PATCH 16/16] remove trailing space --- benchmarks/benchmark_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index bf942bda3..328b9d579 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -89,7 +89,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser): 'llama2_70b_4096_real_data', 'llama2_70b_4096_pw_long_run', 'llama2_70b_4096_real_data_pw_long_run', - 'llama2_70b_4096_pw_rd_tfds ', + 'llama2_70b_4096_pw_rd_tfds', 'llama2_70b_4096_synthetic_pw_lr', 'llama2_70b_4096_synthetic', 'llama3_70b_8192',