Skip to content

Commit 90107c2

Browse files
committed
Added TFDS config.
1 parent cd1999e commit 90107c2

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

benchmarks/benchmark_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
8989
'llama2_70b_4096_real_data',
9090
'llama2_70b_4096_pw_long_run',
9191
'llama2_70b_4096_real_data_pw_long_run',
92+
'llama2_70b_4096_pw_rd_tfds ',
9293
'llama2_70b_4096_synthetic_pw_lr',
9394
'llama2_70b_4096_synthetic',
9495
'llama3_70b_8192',
@@ -108,6 +109,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
108109
'llama2_70b_4096_real_data '
109110
'llama2_70b_4096_pw_long_run '
110111
'llama2_70b_4096_real_data_pw_long_run '
112+
'llama2_70b_4096_pw_rd_tfds '
111113
'llama2_70b_4096_synthetic_pw_lr '
112114
'llama2_70b_4096_synthetic '
113115
'llama3_1_405b_8192_fsdp_dcn '

benchmarks/maxtext_trillium_model_configs.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,45 @@ class MaxTextModel:
462462
),
463463
)
464464

465+
llama2_70b_4096_pw_rd_tfds = MaxTextModel(
466+
model_name="llama2_70b_4096_pw_rd_tfds",
467+
model_type="llama2-70b",
468+
tuning_params={
469+
"per_device_batch_size": 2,
470+
"ici_fsdp_parallelism": 1,
471+
"ici_fsdp_transpose_parallelism": -1,
472+
"ici_tensor_parallelism": 1,
473+
"remat_policy": "qkv_proj_offloaded",
474+
"max_target_length": 4096,
475+
"attention": "flash",
476+
"gcs_metrics": True,
477+
"use_iota_embed": True,
478+
"dataset_path": "gs://trillium-storage-datasets-sr",
479+
"enable_checkpointing": False,
480+
"profiler": "xplane",
481+
"sa_block_q": 1024,
482+
"sa_block_q_dkv": 2048,
483+
"sa_block_q_dq": 2048,
484+
485+
# Additional tuning params for pathways long running test.
486+
"enable_checkpointing": True,
487+
"async_checkpointing": True,
488+
"checkpoint_period": 100,
489+
"checkpoint_storage_use_ocdbt": False,
490+
"checkpoint_storage_use_zarr3": False,
491+
"metrics_file": "metrics.txt",
492+
"goodput_upload_interval_seconds": 30,
493+
"enable_pathways_goodput": True,
494+
"enable_checkpoint_cloud_logger": True,
495+
"enable_single_controller": True,
496+
},
497+
xla_flags=(
498+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
499+
+ xla_flags_library.CF_FOR_ALL_GATHER
500+
),
501+
)
502+
503+
465504
llama3_8b_8192 = MaxTextModel(
466505
model_name="llama3-8b-8192",
467506
model_type="llama3-8b",
@@ -760,6 +799,7 @@ class MaxTextModel:
760799
llama2_70b_4096_pw_long_run,
761800
llama2_70b_4096_real_data,
762801
llama2_70b_4096_real_data_pw_long_run,
802+
llama2_70b_4096_pw_rd_tfds,
763803
llama3_8b_8192, # Not Optimizied yet
764804
llama3_70b_8192, # Not Optimizied yet
765805
llama2_70b_4096_synthetic_pw_lr,

0 commit comments

Comments
 (0)