Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pathways Support to Benchmark Runner #1094

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
67 changes: 66 additions & 1 deletion benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from maxtext_xpk_runner import BenchmarkRunner
from maxtext_xpk_runner import HWConfig
from maxtext_xpk_runner import SWconfig
from maxtext_xpk_runner import PathwaysConfig
from maxtext_xpk_runner import xpk_benchmark_runner
from maxtext_xpk_runner import XpkConfig

Expand Down Expand Up @@ -86,6 +87,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
'llama2_7b_4096',
'llama2_70b_4096',
'llama2_70b_4096_real_data',
'llama2_70b_4096_pw_long_run',
'llama2_70b_4096_real_data_pw_long_run',
'llama3_70b_8192',
'llama3_1_405b_8192_fsdp_dcn',
'mixtral_8x7b_dropped',
Expand All @@ -101,6 +104,8 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
'llama2_7b_4096 '
'llama2_70b_4096 '
'llama2_70b_4096_real_data '
'llama2_70b_4096_pw_long_run '
'llama2_70b_4096_real_data_pw_long_run '
'llama3_1_405b_8192_fsdp_dcn '
'mixtral_8x7b_dropped '
'mixtral_8x7b_dropped_int8 '
Expand All @@ -122,6 +127,57 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
default='maxtext_base_image',
help='version of base docker image to be benchmarked command.',
)
custom_parser.add_argument(
'--pathways_server_image',
type=str,
default=(
'us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest'
),
help='version of pathways server image to be benchmarked command.',
)
custom_parser.add_argument(
'--pathways_proxy_image',
type=str,
default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest',
help='version of pathways proxy image to be benchmarked command.',
)
custom_parser.add_argument(
'--pathways_runner_image',
type=str,
default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest',
help='version of pathways runner image to be benchmarked command.',
)
custom_parser.add_argument(
'--use_pathways',
type=bool,
default=False,
help='whether to use pathways or not.',
)
custom_parser.add_argument(
'--xpk_path',
type=str,
default='~/xpk',
help='path to xpk dir.',
)
custom_parser.add_argument(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_steps is defined in model configs. Adding extra parser for steps will caused non-intentional replacement.

Copy link
Collaborator Author

@SujeethJinesh SujeethJinesh Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't see it in the model configs, I'm looking in this file (benchmarks/maxtext_trillium_model_configs.py), and they don't appear to be defined? Are you referring to the base.yml?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be like steps = 100 in tuning_params

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the other comments, we can leave this here so the user can override it if they want, but if they don't specify it, then fallback to what's in tuning_params.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in the other configs, none seem to have num_steps defined, so people are just using the default at the moment. I've edited the code to make it the default still.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to make it num_steps mandatory as it is already defined in some of the workload. can we make it optional ? Only used if needed by user otherwise, it will use either tuning_params or default 20 num of steps.

'--num_steps',
type=int,
default='20',
help='Number of steps to run benchmark for.',
)
custom_parser.add_argument(
'--priority',
type=str,
default='medium',
help='Priority the XPK workload should run with.',
)
custom_parser.add_argument(
'--max_restarts',
type=int,
default=0,
help='Number of restarts to attempt.',
)


def main() -> None:
parser = argparse.ArgumentParser(
Expand All @@ -137,11 +193,19 @@ def main() -> None:
num_slices=options.num_slices,
device_type=options.device_type,
base_output_directory=options.base_output_directory,
priority=options.priority,
max_restarts=options.max_restarts,
)

v6e_env_configs = SWconfig(
base_docker_image=options.base_docker_image,
libtpu_version=options.libtpu_version,
pathways_config=PathwaysConfig(
use_pathways=options.use_pathways,
server_image=options.pathways_server_image,
proxy_image=options.pathways_proxy_image,
runner_image=options.pathways_runner_image,
),
)

v6e_256_configs = HWConfig(
Expand All @@ -155,9 +219,10 @@ def main() -> None:
model_name=benchmark_model,
software_config=v6e_env_configs,
hardware_config=v6e_256_configs,
num_steps=options.num_steps,
Copy link
Collaborator

@suexu1025 suexu1025 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_steps is benchmark specific, can we change it to xpk_benchmark_runner optional argument, like exp_name xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRunner], exp_name=None)
instead of dataclass definition, which is non-optional.

)

xpk_benchmark_runner(cluster_config, [model_runner])
xpk_benchmark_runner(cluster_config, [model_runner], options.xpk_path)
SujeethJinesh marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == '__main__':
Expand Down
80 changes: 80 additions & 0 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,45 @@ class MaxTextModel:
),
)


llama2_70b_4096_real_data_pw_long_run = MaxTextModel(
model_name="llama2-70b-4096-rd-pw-lr",
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 7,
"ici_fsdp_parallelism": -1,
"remat_policy": "full",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"reuse_example_batch": 1,
"profiler": "xplane",
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "tfds",
"tokenizer_path": "assets/tokenizer.llama2",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,

# Additional tuning params for pathways long running test.
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

# ici_fsdp_transpose_parallelism gives one TFLOP better performance.
llama2_70b_4096 = MaxTextModel(
model_name="llama2-70b-4096",
Expand Down Expand Up @@ -320,6 +359,45 @@ class MaxTextModel:
),
)

llama2_70b_4096_pw_long_run = MaxTextModel(
model_name="llama2-70b-4096-pw-lr",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

benchmark name can be llama2-70b-4096-checkpoint?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want the name to be too long because it may overrun the 40 character limit for XPK :(

model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": 1,
"ici_fsdp_transpose_parallelism": -1,
"ici_tensor_parallelism": 1,
"remat_policy": "full",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,

# Additional tuning params for pathways long running test.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move pathways long running test related configurations like "metrics_file": "metrics.txt", "goodput_upload_interval_seconds": 30, "enable_pathways_goodput": True, "enable_checkpoint_cloud_logger": True, "enable_single_controller": True, in maxtext_xpk_runner? like new function xpk-command-for-pathway, maxtext_trillium_model_configs.py is supposed to have different benchmarks. In that case, it will be easy for pathway to run different benchmarks. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These configs are actually for the specific test we're running (long running tests). I think keeping it as this config would make more sense because we may have different tests that are just for perf, or non goodput related, etc. I wouldn't want to bifurcate the code too much, it's much easier for us to just add a specialized config for Pathways and flip on the specific flags for that test I think.

"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)

llama3_8b_8192 = MaxTextModel(
model_name="llama3-8b-8192",
model_type="llama3-8b",
Expand Down Expand Up @@ -615,7 +693,9 @@ class MaxTextModel:
gpt_3_175b,
llama2_7b_4096,
llama2_70b_4096,
llama2_70b_4096_pw_long_run,
llama2_70b_4096_real_data,
llama2_70b_4096_real_data_pw_long_run,
llama3_8b_8192, # Not Optimizied yet
llama3_70b_8192, # Not Optimizied yet
llama3_1_405b_8192_fsdp_dcn,
Expand Down
Loading
Loading