Skip to content

Commit

Permalink
Added TFDS config.
Browse files Browse the repository at this point in the history
  • Loading branch information
RoshaniN committed Dec 16, 2024
1 parent cd1999e commit 90107c2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
2 changes: 2 additions & 0 deletions benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
'llama2_70b_4096_real_data',
'llama2_70b_4096_pw_long_run',
'llama2_70b_4096_real_data_pw_long_run',
'llama2_70b_4096_pw_rd_tfds ',
'llama2_70b_4096_synthetic_pw_lr',
'llama2_70b_4096_synthetic',
'llama3_70b_8192',
Expand All @@ -108,6 +109,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
'llama2_70b_4096_real_data '
'llama2_70b_4096_pw_long_run '
'llama2_70b_4096_real_data_pw_long_run '
'llama2_70b_4096_pw_rd_tfds '
'llama2_70b_4096_synthetic_pw_lr '
'llama2_70b_4096_synthetic '
'llama3_1_405b_8192_fsdp_dcn '
Expand Down
40 changes: 40 additions & 0 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,45 @@ class MaxTextModel:
),
)

llama2_70b_4096_pw_rd_tfds = MaxTextModel(
model_name="llama2_70b_4096_pw_rd_tfds",
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 2,
"ici_fsdp_parallelism": 1,
"ici_fsdp_transpose_parallelism": -1,
"ici_tensor_parallelism": 1,
"remat_policy": "qkv_proj_offloaded",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://trillium-storage-datasets-sr",
"enable_checkpointing": False,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,

# Additional tuning params for pathways long running test.
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)


llama3_8b_8192 = MaxTextModel(
model_name="llama3-8b-8192",
model_type="llama3-8b",
Expand Down Expand Up @@ -760,6 +799,7 @@ class MaxTextModel:
llama2_70b_4096_pw_long_run,
llama2_70b_4096_real_data,
llama2_70b_4096_real_data_pw_long_run,
llama2_70b_4096_pw_rd_tfds,
llama3_8b_8192, # Not Optimizied yet
llama3_70b_8192, # Not Optimizied yet
llama2_70b_4096_synthetic_pw_lr,
Expand Down

0 comments on commit 90107c2

Please sign in to comment.