Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pathways Support to Benchmark Runner #1094

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

SujeethJinesh
Copy link
Collaborator

@SujeethJinesh SujeethJinesh commented Dec 10, 2024

Description

Add Pathways support to Benchmark Runner. Makes running benchmarks easier for Pathways scale and long running testing. Also added pathways specific configs for long running tests for now.

Tests

Ran at scale and in long running tests

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • [ X ] I have performed a self-review of my code.
  • [ X ] I have necessary comments in my code, particularly in hard-to-understand areas.
  • [ X ] I have run end-to-end tests tests and provided workload links above if applicable.
  • [ X ] I have made or will make corresponding changes to the doc if needed.

@SujeethJinesh SujeethJinesh force-pushed the sujinesh/llama2_v6e_pw_long_running_test branch 9 times, most recently from c51c8c1 to 6ed15bc Compare December 10, 2024 06:41
@SujeethJinesh SujeethJinesh force-pushed the sujinesh/llama2_v6e_pw_long_running_test branch from 6ed15bc to 7e619b7 Compare December 10, 2024 17:22
@SujeethJinesh SujeethJinesh marked this pull request as ready for review December 11, 2024 17:17
@suexu1025 suexu1025 self-requested a review December 11, 2024 18:31
default='~/xpk',
help='path to xpk dir.',
)
custom_parser.add_argument(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_steps is defined in model configs. Adding extra parser for steps will caused non-intentional replacement.

Copy link
Collaborator Author

@SujeethJinesh SujeethJinesh Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't see it in the model configs, I'm looking in this file (benchmarks/maxtext_trillium_model_configs.py), and they don't appear to be defined? Are you referring to the base.yml?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be like steps = 100 in tuning_params

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the other comments, we can leave this here so the user can override it if they want, but if they don't specify it, then fallback to what's in tuning_params.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in the other configs, none seem to have num_steps defined, so people are just using the default at the moment. I've edited the code to make it the default still.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to make it num_steps mandatory as it is already defined in some of the workload. can we make it optional ? Only used if needed by user otherwise, it will use either tuning_params or default 20 num of steps.



@dataclasses.dataclass
class BenchmarkRunner:
model_name: str
hardware_config: HWConfig
software_config: SWconfig
num_steps: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for num_steps

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, it's better in xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRunner], num_steps =0)

@@ -257,21 +269,25 @@ def run_command_with_updates(command, task, verbose=True) -> int:


def build_user_command(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently it's adding extra pathway configs by default, we can either add maxtext_xpk_pw_runner.py or another build_user_commands function. what do you think? @SujeethJinesh

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PathwaysConfig is only used if --use_pathways is passed in. This is intentional because it would make it easier for us to have similar configs and code changes as McJAX but just make drop in replacements to support pathways. I think having a new function or file would bifurcate the code too much and make it harder to share configs if needed over time.

"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,

# Additional tuning params for pathways long running test.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move pathways long running test related configurations like "metrics_file": "metrics.txt", "goodput_upload_interval_seconds": 30, "enable_pathways_goodput": True, "enable_checkpoint_cloud_logger": True, "enable_single_controller": True, in maxtext_xpk_runner? like new function xpk-command-for-pathway, maxtext_trillium_model_configs.py is supposed to have different benchmarks. In that case, it will be easy for pathway to run different benchmarks. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These configs are actually for the specific test we're running (long running tests). I think keeping it as this config would make more sense because we may have different tests that are just for perf, or non goodput related, etc. I wouldn't want to bifurcate the code too much, it's much easier for us to just add a specialized config for Pathways and flip on the specific flags for that test I think.

@@ -320,6 +359,45 @@ class MaxTextModel:
),
)

llama2_70b_4096_pw_long_run = MaxTextModel(
model_name="llama2-70b-4096-pw-lr",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

benchmark name can be llama2-70b-4096-checkpoint?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want the name to be too long because it may overrun the 40 character limit for XPK :(

time.localtime()
test_purpose_name = f'maxstar-benchmarks-{model.model_name}-{libtpu_version}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add back

 if "steps" not in model.tuning_params:
    num_steps = model.tuning_params["steps"]

so that workload specified steps will not be overwritten.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure this does what's intended. I believe this code is better. I've added it in.

  if not num_steps and 'steps' in model.tuning_params:
    num_steps = model.tuning_params['steps']
  elif not num_steps:
    num_steps = 20

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Minor Update, ignore the comment above). Made the change when we're doing the tuning in particular.

Copy link
Collaborator

@suexu1025 suexu1025 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few comments

@SujeethJinesh
Copy link
Collaborator Author

Thank you @suexu1025, resolved your comments

@@ -155,9 +222,10 @@ def main() -> None:
model_name=benchmark_model,
software_config=v6e_env_configs,
hardware_config=v6e_256_configs,
num_steps=options.num_steps,
Copy link
Collaborator

@suexu1025 suexu1025 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_steps is benchmark specific, can we change it to xpk_benchmark_runner optional argument, like exp_name xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRunner], exp_name=None)
instead of dataclass definition, which is non-optional.

@sadikneipp sadikneipp force-pushed the sujinesh/llama2_v6e_pw_long_running_test branch 2 times, most recently from f202161 to 90107c2 Compare December 16, 2024 21:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants