diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 87f4f91ca9..d56c5b04ae 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -8,6 +8,7 @@ # training techniques (e.g. activation checkpointing and compile) to the Llama model. import torch +import torch.distributed as dist import torch.nn as nn from torch.distributed._composable.fsdp import FSDPModule from torch.distributed._composable.replicate import replicate @@ -284,6 +285,38 @@ def disable_fsdp_gradient_division(model: nn.Module) -> None: module.set_gradient_divide_factor(1.0) +def _ddp_sum_allreduce_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + DDP communication hook that performs all-reduce with SUM (no averaging). + + Unlike the default DDP hook which divides by world size before all-reduce, + this hook performs a pure sum reduction. This is used when gradient scaling + is handled manually in the training loop (e.g., dividing by global token count). + """ + return ( + dist.all_reduce(bucket.buffer(), group=process_group, async_op=True) + .get_future() + .then(lambda fut: fut.value()[0]) + ) + + +def disable_ddp_gradient_averaging(model: nn.Module, dp_mesh: DeviceMesh) -> None: + """ + Disable DDP's automatic gradient averaging by registering a custom comm hook. + + By default, DDP divides gradients by world size before all-reduce. + This function registers a custom hook that uses SUM reduction instead, + allowing manual gradient scaling in the training loop (e.g., by global token count). + + Args: + model: The model wrapped with DDP (via replicate()) + dp_mesh: The device mesh used for data parallelism + """ + model.register_comm_hook(dp_mesh.get_group(), _ddp_sum_allreduce_hook) + + def apply_fsdp( model: nn.Module, dp_mesh: DeviceMesh, @@ -371,4 +404,7 @@ def apply_ddp( # pyrefly: ignore [invalid-param-spec] replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + # Disable DDP's automatic gradient averaging for all DDP modules + disable_ddp_gradient_averaging(model, dp_mesh) + logger.info("Applied DDP to the model") diff --git a/torchtitan/train.py b/torchtitan/train.py index 9378d742e3..f61b13db73 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -578,7 +578,24 @@ def train_step( # Process each microbatch: move to GPU, forward/backward, then free accumulated_losses = [] - for input_dict, labels in microbatches: + num_microbatches = len(microbatches) + + # Check if we're using DDP (not FSDP) and need to manage gradient sync + # DDP syncs gradients on every backward, so we need to disable sync + # for intermediate microbatches to avoid redundant all-reduces + using_ddp = ( + parallel_dims.dp_replicate_enabled and not parallel_dims.fsdp_enabled + ) + + for microbatch_idx, (input_dict, labels) in enumerate(microbatches): + is_last_microbatch = microbatch_idx == num_microbatches - 1 + + # For DDP with gradient accumulation: disable gradient sync for + # all but the last microbatch to avoid redundant all-reduces + if using_ddp and num_microbatches > 1: + for model in self.model_parts: + model.set_requires_gradient_sync(is_last_microbatch) + # Move tensors to GPU for k, v in input_dict.items(): if isinstance(v, torch.Tensor):