-
Notifications
You must be signed in to change notification settings - Fork 699
Disable DDP averaging to avoid repeated gradient averaging #2323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| is handled manually in the training loop (e.g., dividing by global token count). | ||
| """ | ||
| return ( | ||
| dist.all_reduce(bucket.buffer(), group=process_group, async_op=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, have you tried using
| def _dist_reduce( |
Can you attach more test results, eg:
- Does HSDP + TP work?
- DDP only vs FSDP vs HSDP grad norm curve?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the test plan showing all the cases. Simultaneously, as DDP does not use DTensor within torch titan and cannot be used with TP or any other parallelism as restricted by design, I think using all_reduce directly is better here to keep things simple.
) Summary: From the change to averaging over microbatches to averaging over the global tokens, the averaging for FSDP was disabled in D91432940 but not for DDP. This adds an additional scaling to the gradients diving them twice by the number for DP ranks when using pure DDP. This error does not emerge in change in loss for short runs due to the scale invariance property of the AdamW optimizer but can be seen clearly in the grad norm measuerment and difference in the measurement from FSDP. To fix this, as DDP does not have a property set_gradient_divide_factor() as FSDP, the solution employed is to put in a comm hook that replaces the default all reduce average operation with an all reduce sum operation. This also requires controlling the syncing performed in DDP as the all reduce operation for DDP is launched every forward backward pass which causes additional addition of gradients during gradient accumulation as no averaging is now performed in DDP. Thus, a no sync context has been added to the train loop to perform an all reduce sum in DDP only in the final backward pass of a train step. Differential Revision: D92301896
2053a28 to
32b64af
Compare
) Summary: From the change to averaging over microbatches to averaging over the global tokens, the averaging for FSDP was disabled in D91432940 but not for DDP. This adds an additional scaling to the gradients diving them twice by the number for DP ranks when using pure DDP. This error does not emerge in change in loss for short runs due to the scale invariance property of the AdamW optimizer but can be seen clearly in the grad norm measuerment and difference in the measurement from FSDP. To fix this, as DDP does not have a property set_gradient_divide_factor() as FSDP, the solution employed is to put in a comm hook that replaces the default all reduce average operation with an all reduce sum operation. This also requires controlling the syncing performed in DDP as the all reduce operation for DDP is launched every forward backward pass which causes additional addition of gradients during gradient accumulation as no averaging is now performed in DDP. Thus, a no sync context has been added to the train loop to perform an all reduce sum in DDP only in the final backward pass of a train step. Differential Revision: D92301896
32b64af to
19ea7df
Compare
|
One major update in the subsequent version of this pull request is the addition of no_sync context management in the training loop for DDP. I am not sure what other way is there to manage this but as updated in the diff summary, this is necessary to ensure correct computations when performing gradient accumulation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been done in PP, https://github.com/pytorch/pytorch/blob/main/torch/distributed/pipelining/stage.py#L639. Instead of adding the logic to TorchTitan, we should investigate why PP's logic doesn't work. My best guess is we are using replicate so the module type is not DistributedDataParallel, but replicate (or some other module class) cc., @H-Huang
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been done in PP
Wait why does it have anything to do with PP?
The proper fix is to support set_gradient_divide_factor in replicate(), no?
cc @anshul-si @wwwjn
|
@tianyu-l My comment was not clear. The |
|
Hi everyone, I just want to summarize the diff to help clear confusion. The main source of change is the change from global averaging over tokens from individual local averaging over microbatches from a previous PR. While the implicit default averaging over the individual DP ranks was disabled for FSDP, the DDP averaging in the all reduce operation was not. As the DDP settings used here are from the old primitive, not using DTensor backend (according to what I see, please correct if wrong), no option for set_gradient_divide_factor is available and the all reduce operation has been explicitly changed from averaging to sum. This change brings up another difficulty as the old DDP primitive performed the all reduce comm operation every backward pass. This is was fine previously as an average was performed every backward pass but now is incorrect as this causes double summation over multiple gradients when using gradient accumulation. To counteract this, all reduce has been disabled for all backward passes other than the last one. This is not necessary under FSDP as is already implemented there. This is also completely independent of PP as it is necessary to obtain the correct gradients even when PP = 1. |
|
I misunderstood the code change. I had the impression that the code change was in forward_backward_step, but it is in train_step. So, my comment was incorrect, this was independent from PP. |
Yes, the ideal fix would be integrating |
Summary:
From the change to averaging over microbatches to averaging over the global tokens, the averaging for FSDP was disabled in D91432940 but not for DDP. This adds an additional scaling to the gradients diving them twice by the number for DP ranks when using pure DDP. This error does not emerge in change in loss for short runs due to the scale invariance property of the AdamW optimizer but can be seen clearly in the grad norm measuerment and difference in the measurement from FSDP.
To fix this, as DDP does not have a property set_gradient_divide_factor() as FSDP, the solution employed is to put in a comm hook that replaces the default all reduce average operation with an all reduce sum operation
Differential Revision: D92301896