@@ -462,6 +462,45 @@ class MaxTextModel:
462
462
),
463
463
)
464
464
465
+ llama2_70b_4096_pw_rd_tfds = MaxTextModel (
466
+ model_name = "llama2_70b_4096_pw_rd_tfds" ,
467
+ model_type = "llama2-70b" ,
468
+ tuning_params = {
469
+ "per_device_batch_size" : 2 ,
470
+ "ici_fsdp_parallelism" : 1 ,
471
+ "ici_fsdp_transpose_parallelism" : - 1 ,
472
+ "ici_tensor_parallelism" : 1 ,
473
+ "remat_policy" : "qkv_proj_offloaded" ,
474
+ "max_target_length" : 4096 ,
475
+ "attention" : "flash" ,
476
+ "gcs_metrics" : True ,
477
+ "use_iota_embed" : True ,
478
+ "dataset_path" : "gs://trillium-storage-datasets-sr" ,
479
+ "enable_checkpointing" : False ,
480
+ "profiler" : "xplane" ,
481
+ "sa_block_q" : 1024 ,
482
+ "sa_block_q_dkv" : 2048 ,
483
+ "sa_block_q_dq" : 2048 ,
484
+
485
+ # Additional tuning params for pathways long running test.
486
+ "enable_checkpointing" : True ,
487
+ "async_checkpointing" : True ,
488
+ "checkpoint_period" : 100 ,
489
+ "checkpoint_storage_use_ocdbt" : False ,
490
+ "checkpoint_storage_use_zarr3" : False ,
491
+ "metrics_file" : "metrics.txt" ,
492
+ "goodput_upload_interval_seconds" : 30 ,
493
+ "enable_pathways_goodput" : True ,
494
+ "enable_checkpoint_cloud_logger" : True ,
495
+ "enable_single_controller" : True ,
496
+ },
497
+ xla_flags = (
498
+ xla_flags_library .DENSE_VMEM_LIMIT_FLAG
499
+ + xla_flags_library .CF_FOR_ALL_GATHER
500
+ ),
501
+ )
502
+
503
+
465
504
llama3_8b_8192 = MaxTextModel (
466
505
model_name = "llama3-8b-8192" ,
467
506
model_type = "llama3-8b" ,
@@ -760,6 +799,7 @@ class MaxTextModel:
760
799
llama2_70b_4096_pw_long_run ,
761
800
llama2_70b_4096_real_data ,
762
801
llama2_70b_4096_real_data_pw_long_run ,
802
+ llama2_70b_4096_pw_rd_tfds ,
763
803
llama3_8b_8192 , # Not Optimizied yet
764
804
llama3_70b_8192 , # Not Optimizied yet
765
805
llama2_70b_4096_synthetic_pw_lr ,
0 commit comments