Skip to content

Conversation

@Shagun-G
Copy link
Contributor

@Shagun-G Shagun-G commented Feb 4, 2026

Summary:
From the change to averaging over microbatches to averaging over the global tokens, the averaging for FSDP was disabled in D91432940 but not for DDP. This adds an additional scaling to the gradients diving them twice by the number for DP ranks when using pure DDP. This error does not emerge in change in loss for short runs due to the scale invariance property of the AdamW optimizer but can be seen clearly in the grad norm measuerment and difference in the measurement from FSDP.

To fix this, as DDP does not have a property set_gradient_divide_factor() as FSDP, the solution employed is to put in a comm hook that replaces the default all reduce average operation with an all reduce sum operation

Differential Revision: D92301896

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 4, 2026
@meta-codesync
Copy link

meta-codesync bot commented Feb 4, 2026

@Shagun-G has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92301896.

is handled manually in the training loop (e.g., dividing by global token count).
"""
return (
dist.all_reduce(bucket.buffer(), group=process_group, async_op=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, have you tried using

instead of all_reduce directly? The difference is this function handles when the tensor is DTensor (eg, when TP is applied and the gradient bucket thensor might be sharded)

Can you attach more test results, eg:

  1. Does HSDP + TP work?
  2. DDP only vs FSDP vs HSDP grad norm curve?

Copy link
Contributor Author

@Shagun-G Shagun-G Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the test plan showing all the cases. Simultaneously, as DDP does not use DTensor within torch titan and cannot be used with TP or any other parallelism as restricted by design, I think using all_reduce directly is better here to keep things simple.

Shagun-G pushed a commit to Shagun-G/torchtitan that referenced this pull request Feb 4, 2026
)

Summary:

From the change to averaging over  microbatches to averaging over the global tokens, the averaging for FSDP was disabled in D91432940 but not for DDP. This adds an additional scaling to the gradients diving them twice by the number for DP ranks when using pure DDP. This error does not emerge in change in loss for short runs due to the scale invariance property of the AdamW optimizer but can be seen clearly in the grad norm measuerment and difference in the measurement from FSDP.

To fix this, as DDP does not have a property set_gradient_divide_factor() as FSDP, the solution employed is to put in a comm hook that replaces the default all reduce average operation with an all reduce sum operation.

This also requires controlling the syncing performed in DDP as the all reduce operation for DDP is launched every forward backward pass which causes additional addition of gradients during gradient accumulation as no averaging is now performed in DDP. Thus, a no sync context has been added to the train loop to perform an all reduce sum in DDP only in the final backward pass of a train step.

Differential Revision: D92301896
)

Summary:

From the change to averaging over  microbatches to averaging over the global tokens, the averaging for FSDP was disabled in D91432940 but not for DDP. This adds an additional scaling to the gradients diving them twice by the number for DP ranks when using pure DDP. This error does not emerge in change in loss for short runs due to the scale invariance property of the AdamW optimizer but can be seen clearly in the grad norm measuerment and difference in the measurement from FSDP.

To fix this, as DDP does not have a property set_gradient_divide_factor() as FSDP, the solution employed is to put in a comm hook that replaces the default all reduce average operation with an all reduce sum operation.

This also requires controlling the syncing performed in DDP as the all reduce operation for DDP is launched every forward backward pass which causes additional addition of gradients during gradient accumulation as no averaging is now performed in DDP. Thus, a no sync context has been added to the train loop to perform an all reduce sum in DDP only in the final backward pass of a train step.

Differential Revision: D92301896
@Shagun-G
Copy link
Contributor Author

Shagun-G commented Feb 4, 2026

One major update in the subsequent version of this pull request is the addition of no_sync context management in the training loop for DDP. I am not sure what other way is there to manage this but as updated in the diff summary, this is necessary to ensure correct computations when performing gradient accumulation.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been done in PP, https://github.com/pytorch/pytorch/blob/main/torch/distributed/pipelining/stage.py#L639. Instead of adding the logic to TorchTitan, we should investigate why PP's logic doesn't work. My best guess is we are using replicate so the module type is not DistributedDataParallel, but replicate (or some other module class) cc., @H-Huang

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin

This has been done in PP

Wait why does it have anything to do with PP?

The proper fix is to support set_gradient_divide_factor in replicate(), no?

cc @anshul-si @wwwjn

@anshul-si
Copy link

@fegin @tianyu-l to be clear, we still haven't merged changes integrating replicate_fsdp version with torchtitan. I think the old replicate is based on DDP right?

@fegin
Copy link
Contributor

fegin commented Feb 5, 2026

@tianyu-l My comment was not clear. The set_gradient_divide_factor is just one change, which this PR uses register_comm_hook to achieve. But the modification to PP determining whether set_requires_gradient_sync to True or False should already be handled by PP, https://github.com/pytorch/pytorch/blob/main/torch/distributed/pipelining/stage.py#L639. That's what I meant.

@Shagun-G
Copy link
Contributor Author

Shagun-G commented Feb 5, 2026

Hi everyone, I just want to summarize the diff to help clear confusion. The main source of change is the change from global averaging over tokens from individual local averaging over microbatches from a previous PR. While the implicit default averaging over the individual DP ranks was disabled for FSDP, the DDP averaging in the all reduce operation was not. As the DDP settings used here are from the old primitive, not using DTensor backend (according to what I see, please correct if wrong), no option for set_gradient_divide_factor is available and the all reduce operation has been explicitly changed from averaging to sum.

This change brings up another difficulty as the old DDP primitive performed the all reduce comm operation every backward pass. This is was fine previously as an average was performed every backward pass but now is incorrect as this causes double summation over multiple gradients when using gradient accumulation. To counteract this, all reduce has been disabled for all backward passes other than the last one. This is not necessary under FSDP as is already implemented there. This is also completely independent of PP as it is necessary to obtain the correct gradients even when PP = 1.

@tianyu-l tianyu-l added the bug Something isn't working label Feb 5, 2026
@fegin
Copy link
Contributor

fegin commented Feb 5, 2026

I misunderstood the code change. I had the impression that the code change was in forward_backward_step, but it is in train_step. So, my comment was incorrect, this was independent from PP.

@wwwjn
Copy link
Contributor

wwwjn commented Feb 5, 2026

@fegin @tianyu-l to be clear, we still haven't merged changes integrating replicate_fsdp version with torchtitan. I think the old replicate is based on DDP right?

Yes, the ideal fix would be integrating replicate_fsdp , and then call set_gradient_divide_factor() as well when pure data parallel is applied. Can you remind me how much work are need to be done before we can merge the replicate_fsdp PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working CLA Signed This label is managed by the Meta Open Source bot. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants