Skip to content

Commit

Permalink
Add Pathways Support to Benchmark Runner
Browse files Browse the repository at this point in the history
  • Loading branch information
SujeethJinesh committed Dec 10, 2024
1 parent 0fe43b7 commit 7e619b7
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 28 deletions.
67 changes: 66 additions & 1 deletion benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from maxtext_xpk_runner import BenchmarkRunner
from maxtext_xpk_runner import HWConfig
from maxtext_xpk_runner import SWconfig
from maxtext_xpk_runner import PathwaysConfig
from maxtext_xpk_runner import xpk_benchmark_runner
from maxtext_xpk_runner import XpkConfig

Expand Down Expand Up @@ -86,6 +87,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
'llama2_7b_4096',
'llama2_70b_4096',
'llama2_70b_4096_real_data',
'llama2_70b_4096_pw_long_run',
'llama2_70b_4096_real_data_pw_long_run',
'llama3_70b_8192',
'llama3_1_405b_8192_fsdp_dcn',
'mixtral_8x7b_dropped',
Expand All @@ -101,6 +104,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
'llama2_7b_4096 '
'llama2_70b_4096 '
'llama2_70b_4096_real_data '
'llama2_70b_4096_pw_long_run '
'llama2_70b_4096_real_data_pw_long_run '
'llama3_1_405b_8192_fsdp_dcn '
'mixtral_8x7b_dropped '
'mixtral_8x7b_dropped_int8 '
Expand All @@ -122,6 +127,57 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
default='maxtext_base_image',
help='version of base docker image to be benchmarked command.',
)
custom_parser.add_argument(
'--pathways_server_image',
type=str,
default=(
'us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest'
),
help='version of pathways server image to be benchmarked command.',
)
custom_parser.add_argument(
'--pathways_proxy_image',
type=str,
default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest',
help='version of pathways proxy image to be benchmarked command.',
)
custom_parser.add_argument(
'--pathways_runner_image',
type=str,
default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest',
help='version of pathways runner image to be benchmarked command.',
)
custom_parser.add_argument(
'--use_pathways',
type=bool,
default=False,
help='whether to use pathways or not.',
)
custom_parser.add_argument(
'--xpk_path',
type=str,
default='~/xpk',
help='path to xpk dir.',
)
custom_parser.add_argument(
'--num_steps',
type=int,
default='20',
help='Number of steps to run benchmark for.',
)
custom_parser.add_argument(
'--priority',
type=str,
default='medium',
help='Priority the XPK workload should run with.',
)
custom_parser.add_argument(
'--max_restarts',
type=int,
default=0,
help='Number of restarts to attempt.',
)


def main() -> None:
parser = argparse.ArgumentParser(
Expand All @@ -137,11 +193,19 @@ def main() -> None:
num_slices=options.num_slices,
device_type=options.device_type,
base_output_directory=options.base_output_directory,
priority=options.priority,
max_restarts=options.max_restarts,
)

v6e_env_configs = SWconfig(
base_docker_image=options.base_docker_image,
libtpu_version=options.libtpu_version,
pathways_config=PathwaysConfig(
use_pathways=options.use_pathways,
server_image=options.pathways_server_image,
proxy_image=options.pathways_proxy_image,
runner_image=options.pathways_runner_image,
),
)

v6e_256_configs = HWConfig(
Expand All @@ -155,9 +219,10 @@ def main() -> None:
model_name=benchmark_model,
software_config=v6e_env_configs,
hardware_config=v6e_256_configs,
num_steps=options.num_steps,
)

xpk_benchmark_runner(cluster_config, [model_runner])
xpk_benchmark_runner(cluster_config, [model_runner], options.xpk_path)


if __name__ == '__main__':
Expand Down
80 changes: 80 additions & 0 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,45 @@ class MaxTextModel:
),
)


llama2_70b_4096_real_data_pw_long_run = MaxTextModel(
model_name="llama2-70b-4096-rd-pw-lr",
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 7,
"ici_fsdp_parallelism": -1,
"remat_policy": "full",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"reuse_example_batch": 1,
"profiler": "xplane",
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "tfds",
"tokenizer_path": "assets/tokenizer.llama2",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,

# Additional tuning params for pathways long running test.
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

# ici_fsdp_transpose_parallelism gives one TFLOP better performance.
llama2_70b_4096 = MaxTextModel(
model_name="llama2-70b-4096",
Expand Down Expand Up @@ -320,6 +359,45 @@ class MaxTextModel:
),
)

llama2_70b_4096_pw_long_run = MaxTextModel(
model_name="llama2-70b-4096-pw-lr",
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": 1,
"ici_fsdp_transpose_parallelism": -1,
"ici_tensor_parallelism": 1,
"remat_policy": "full",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,

# Additional tuning params for pathways long running test.
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

llama3_8b_8192 = MaxTextModel(
model_name="llama3-8b-8192",
model_type="llama3-8b",
Expand Down Expand Up @@ -615,7 +693,9 @@ class MaxTextModel:
gpt_3_175b,
llama2_7b_4096,
llama2_70b_4096,
llama2_70b_4096_pw_long_run,
llama2_70b_4096_real_data,
llama2_70b_4096_real_data_pw_long_run,
llama3_8b_8192, # Not Optimizied yet
llama3_70b_8192, # Not Optimizied yet
llama3_1_405b_8192_fsdp_dcn,
Expand Down
Loading

0 comments on commit 7e619b7

Please sign in to comment.