Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# training techniques (e.g. activation checkpointing and compile) to the Llama model.

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.fsdp import FSDPModule
from torch.distributed._composable.replicate import replicate
Expand Down Expand Up @@ -284,6 +285,38 @@ def disable_fsdp_gradient_division(model: nn.Module) -> None:
module.set_gradient_divide_factor(1.0)


def _ddp_sum_allreduce_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
DDP communication hook that performs all-reduce with SUM (no averaging).

Unlike the default DDP hook which divides by world size before all-reduce,
this hook performs a pure sum reduction. This is used when gradient scaling
is handled manually in the training loop (e.g., dividing by global token count).
"""
return (
dist.all_reduce(bucket.buffer(), group=process_group, async_op=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, have you tried using

instead of all_reduce directly? The difference is this function handles when the tensor is DTensor (eg, when TP is applied and the gradient bucket thensor might be sharded)

Can you attach more test results, eg:

  1. Does HSDP + TP work?
  2. DDP only vs FSDP vs HSDP grad norm curve?

Copy link
Contributor Author

@Shagun-G Shagun-G Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the test plan showing all the cases. Simultaneously, as DDP does not use DTensor within torch titan and cannot be used with TP or any other parallelism as restricted by design, I think using all_reduce directly is better here to keep things simple.

.get_future()
.then(lambda fut: fut.value()[0])
)


def disable_ddp_gradient_averaging(model: nn.Module, dp_mesh: DeviceMesh) -> None:
"""
Disable DDP's automatic gradient averaging by registering a custom comm hook.

By default, DDP divides gradients by world size before all-reduce.
This function registers a custom hook that uses SUM reduction instead,
allowing manual gradient scaling in the training loop (e.g., by global token count).

Args:
model: The model wrapped with DDP (via replicate())
dp_mesh: The device mesh used for data parallelism
"""
model.register_comm_hook(dp_mesh.get_group(), _ddp_sum_allreduce_hook)


def apply_fsdp(
model: nn.Module,
dp_mesh: DeviceMesh,
Expand Down Expand Up @@ -371,4 +404,7 @@ def apply_ddp(
# pyrefly: ignore [invalid-param-spec]
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)

# Disable DDP's automatic gradient averaging for all DDP modules
disable_ddp_gradient_averaging(model, dp_mesh)

logger.info("Applied DDP to the model")
19 changes: 18 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,24 @@ def train_step(

# Process each microbatch: move to GPU, forward/backward, then free
accumulated_losses = []
for input_dict, labels in microbatches:
num_microbatches = len(microbatches)

# Check if we're using DDP (not FSDP) and need to manage gradient sync
# DDP syncs gradients on every backward, so we need to disable sync
# for intermediate microbatches to avoid redundant all-reduces
using_ddp = (
parallel_dims.dp_replicate_enabled and not parallel_dims.fsdp_enabled
)

for microbatch_idx, (input_dict, labels) in enumerate(microbatches):
is_last_microbatch = microbatch_idx == num_microbatches - 1

# For DDP with gradient accumulation: disable gradient sync for
# all but the last microbatch to avoid redundant all-reduces
if using_ddp and num_microbatches > 1:
for model in self.model_parts:
model.set_requires_gradient_sync(is_last_microbatch)

# Move tensors to GPU
for k, v in input_dict.items():
if isinstance(v, torch.Tensor):
Expand Down
Loading