From 740beac52f1d80e63491515feda51c301d89260c Mon Sep 17 00:00:00 2001 From: Anshul Sinha Date: Mon, 15 Sep 2025 16:08:04 -0700 Subject: [PATCH 1/4] [torchtitan][replicate] experimenting new replicate integration with torchtitan [ghstack-poisoned] --- torchtitan/distributed/parallel_dims.py | 12 ++++++++++-- torchtitan/models/llama3/infra/parallelize.py | 12 +++++++----- .../models/llama3/train_configs/debug_model.toml | 2 +- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 44822039a6..8779435ff4 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -100,7 +100,13 @@ def _build_mesh_with_ep(self) -> DeviceMesh: ): # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping # helps the MoE layers do mixed precision training - if d > 1 or name == "dp_shard_mod_ep": + # dp_shard_in_ep is included even if it equals 1 when replicate > 1 + # to make device_mesh compatible with replicate function + if ( + d > 1 + or name == "dp_shard_mod_ep" + or (name == "dp_shard_in_ep" and self.dp_replicate > 1) + ): dims.append(d) names.append(name) @@ -151,7 +157,9 @@ def _build_mesh_without_ep(self) -> DeviceMesh: [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], ["pp", "dp_replicate", "dp_shard", "cp", "tp"], ): - if d > 1: + # Include dp_shard dimension even if it equals 1 when replicate > 1 + # to make device_mesh compatible with replicate function + if d > 1 or (name == "dp_shard" and self.dp_replicate > 1): dims.append(d) names.append(name) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 7d8aa76f0d..21df427444 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn -from torch.distributed._composable.replicate import replicate +from torch.distributed._composable.replicate_with_fsdp import replicate from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy @@ -135,11 +135,13 @@ def parallelize_llama( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: - raise RuntimeError("DDP has not supported > 1D parallelism") + # if world_mesh.ndim > 1: + # raise RuntimeError("DDP has not supported > 1D parallelism") + + dp_mesh_dim_names = ("dp_replicate", "dp_shard") apply_ddp( model, - world_mesh, + world_mesh[tuple(dp_mesh_dim_names)], enable_compile=model_compile_enabled, enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) @@ -328,6 +330,6 @@ def apply_ddp( else: torch._dynamo.config.optimize_ddp = "ddp_optimizer" - replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + replicate(model, device_mesh=dp_mesh) logger.info("Applied DDP to the model") diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 79789629a6..cb7b05632e 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -43,7 +43,7 @@ steps = 10 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] -data_parallel_replicate_degree = 1 +data_parallel_replicate_degree = 8 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 From 82ccb85184302c8df04584c55ebcba750141d76c Mon Sep 17 00:00:00 2001 From: Anshul Sinha Date: Tue, 16 Sep 2025 15:06:28 -0700 Subject: [PATCH 2/4] Update on "[torchtitan][replicate] experimenting new replicate integration with torchtitan" **Summary:** During this experiment to integrate the new replicate function into torchtitan, I used https://github.com/pytorch/pytorch/pull/162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, https://github.com/pytorch/pytorch/pull/160135, which has landed, should be fine. https://github.com/pytorch/pytorch/pull/160133 is the last time replicate_with_fsdp.py and its replicate api were touched. In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. **Test Case** 1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh Expected output of this experiment should be something like: [rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training [rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config **[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]** [rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters [rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model **[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model** [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%) [rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2) [ghstack-poisoned] --- torchtitan/distributed/parallel_dims.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 8779435ff4..6843b23e3b 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -102,11 +102,7 @@ def _build_mesh_with_ep(self) -> DeviceMesh: # helps the MoE layers do mixed precision training # dp_shard_in_ep is included even if it equals 1 when replicate > 1 # to make device_mesh compatible with replicate function - if ( - d > 1 - or name == "dp_shard_mod_ep" - or (name == "dp_shard_in_ep" and self.dp_replicate > 1) - ): + if d > 1 or name == "dp_shard_mod_ep": dims.append(d) names.append(name) From b9467848d414245d841800b4c14fd96c4a826f47 Mon Sep 17 00:00:00 2001 From: Anshul Sinha Date: Tue, 16 Sep 2025 15:07:50 -0700 Subject: [PATCH 3/4] Update on "[torchtitan][replicate] experimenting new replicate integration with torchtitan" **Summary:** During this experiment to integrate the new replicate function into torchtitan, I used https://github.com/pytorch/pytorch/pull/162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, https://github.com/pytorch/pytorch/pull/160135, which has landed, should be fine. https://github.com/pytorch/pytorch/pull/160133 is the last time replicate_with_fsdp.py and its replicate api were touched. In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. **Test Case** 1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh Expected output of this experiment should be something like: [rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training [rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config **[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]** [rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters [rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model **[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model** [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%) [rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2) [ghstack-poisoned] --- torchtitan/distributed/parallel_dims.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 6843b23e3b..5e3f597385 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -100,8 +100,6 @@ def _build_mesh_with_ep(self) -> DeviceMesh: ): # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping # helps the MoE layers do mixed precision training - # dp_shard_in_ep is included even if it equals 1 when replicate > 1 - # to make device_mesh compatible with replicate function if d > 1 or name == "dp_shard_mod_ep": dims.append(d) names.append(name) From 23421c2b6e2421fc6252b3d7c08d2b39b81142d6 Mon Sep 17 00:00:00 2001 From: Anshul Sinha Date: Tue, 23 Sep 2025 11:22:24 -0700 Subject: [PATCH 4/4] Update on "[torchtitan][replicate] experimenting new replicate integration with torchtitan" **Summary:** During this experiment to integrate the new replicate function into torchtitan, I used https://github.com/pytorch/pytorch/pull/162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, https://github.com/pytorch/pytorch/pull/160135, which has landed, should be fine. https://github.com/pytorch/pytorch/pull/160133 is the last time replicate_with_fsdp.py and its replicate api were touched. In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. **Test Case** 1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh Expected output of this experiment should be something like: [rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training [rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config **[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]** [rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters [rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model **[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model** [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%) [rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2) [ghstack-poisoned] --- torchtitan/models/llama3/infra/parallelize.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 21df427444..851f7d542f 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -142,8 +142,11 @@ def parallelize_llama( apply_ddp( model, world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], enable_compile=model_compile_enabled, enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + cpu_offload=job_config.training.enable_cpu_offload, ) return model @@ -319,17 +322,33 @@ def apply_fsdp( def apply_ddp( model: nn.Module, dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, enable_compile: bool, enable_compiled_autograd: bool, + cpu_offload: bool = False, ): - if enable_compile: - if enable_compiled_autograd: - torch._dynamo.config.optimize_ddp = ( - "python_reducer_without_compiled_forward" - ) - else: - torch._dynamo.config.optimize_ddp = "ddp_optimizer" + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + replicate_config = {"device_mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + replicate_config["offload_policy"] = CPUOffloadPolicy() - replicate(model, device_mesh=dp_mesh) + if model.tok_embeddings is not None: + replicate( + model.tok_embeddings, + **replicate_config, + ) + for layer_id, transformer_block in model.layers.items(): + replicate( + transformer_block, + **replicate_config, + ) + + if model.norm is not None and model.output is not None: + replicate( + [model.norm, model.output], + **replicate_config, + ) + replicate(model, **replicate_config) logger.info("Applied DDP to the model")