From 19ea7df0dea4e1a4abc69302e2519acf6d113503 Mon Sep 17 00:00:00 2001 From: Shagun Gupta Date: Wed, 4 Feb 2026 13:31:36 -0800 Subject: [PATCH] Disable DDP averaging to avoid repeated gradient averaging (#2323) Summary: From the change to averaging over microbatches to averaging over the global tokens, the averaging for FSDP was disabled in D91432940 but not for DDP. This adds an additional scaling to the gradients diving them twice by the number for DP ranks when using pure DDP. This error does not emerge in change in loss for short runs due to the scale invariance property of the AdamW optimizer but can be seen clearly in the grad norm measuerment and difference in the measurement from FSDP. To fix this, as DDP does not have a property set_gradient_divide_factor() as FSDP, the solution employed is to put in a comm hook that replaces the default all reduce average operation with an all reduce sum operation. This also requires controlling the syncing performed in DDP as the all reduce operation for DDP is launched every forward backward pass which causes additional addition of gradients during gradient accumulation as no averaging is now performed in DDP. Thus, a no sync context has been added to the train loop to perform an all reduce sum in DDP only in the final backward pass of a train step. Differential Revision: D92301896 --- torchtitan/models/llama3/infra/parallelize.py | 36 +++++++++++++++++++ torchtitan/train.py | 19 +++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 87f4f91ca9..d56c5b04ae 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -8,6 +8,7 @@ # training techniques (e.g. activation checkpointing and compile) to the Llama model. import torch +import torch.distributed as dist import torch.nn as nn from torch.distributed._composable.fsdp import FSDPModule from torch.distributed._composable.replicate import replicate @@ -284,6 +285,38 @@ def disable_fsdp_gradient_division(model: nn.Module) -> None: module.set_gradient_divide_factor(1.0) +def _ddp_sum_allreduce_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + DDP communication hook that performs all-reduce with SUM (no averaging). + + Unlike the default DDP hook which divides by world size before all-reduce, + this hook performs a pure sum reduction. This is used when gradient scaling + is handled manually in the training loop (e.g., dividing by global token count). + """ + return ( + dist.all_reduce(bucket.buffer(), group=process_group, async_op=True) + .get_future() + .then(lambda fut: fut.value()[0]) + ) + + +def disable_ddp_gradient_averaging(model: nn.Module, dp_mesh: DeviceMesh) -> None: + """ + Disable DDP's automatic gradient averaging by registering a custom comm hook. + + By default, DDP divides gradients by world size before all-reduce. + This function registers a custom hook that uses SUM reduction instead, + allowing manual gradient scaling in the training loop (e.g., by global token count). + + Args: + model: The model wrapped with DDP (via replicate()) + dp_mesh: The device mesh used for data parallelism + """ + model.register_comm_hook(dp_mesh.get_group(), _ddp_sum_allreduce_hook) + + def apply_fsdp( model: nn.Module, dp_mesh: DeviceMesh, @@ -371,4 +404,7 @@ def apply_ddp( # pyrefly: ignore [invalid-param-spec] replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + # Disable DDP's automatic gradient averaging for all DDP modules + disable_ddp_gradient_averaging(model, dp_mesh) + logger.info("Applied DDP to the model") diff --git a/torchtitan/train.py b/torchtitan/train.py index 9378d742e3..f61b13db73 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -578,7 +578,24 @@ def train_step( # Process each microbatch: move to GPU, forward/backward, then free accumulated_losses = [] - for input_dict, labels in microbatches: + num_microbatches = len(microbatches) + + # Check if we're using DDP (not FSDP) and need to manage gradient sync + # DDP syncs gradients on every backward, so we need to disable sync + # for intermediate microbatches to avoid redundant all-reduces + using_ddp = ( + parallel_dims.dp_replicate_enabled and not parallel_dims.fsdp_enabled + ) + + for microbatch_idx, (input_dict, labels) in enumerate(microbatches): + is_last_microbatch = microbatch_idx == num_microbatches - 1 + + # For DDP with gradient accumulation: disable gradient sync for + # all but the last microbatch to avoid redundant all-reduces + if using_ddp and num_microbatches > 1: + for model in self.model_parts: + model.set_requires_gradient_sync(is_last_microbatch) + # Move tensors to GPU for k, v in input_dict.items(): if isinstance(v, torch.Tensor):