-
Notifications
You must be signed in to change notification settings - Fork 699
Register _ScaledPartial placement
#2313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| "is_replicate": MemberType.USE_REAL, | ||
| "__eq__": MemberType.USE_REAL, | ||
| } | ||
| register_opaque_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please comment
- what this function is doing
- why we need it
- what is allowed_members
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review. I have added a comment. However, the PR pytorch/pytorch#171482 got reverted. If it lands, this PR should land as well.
| "__eq__": MemberType.USE_REAL, | ||
| } | ||
| register_opaque_type( | ||
| _ScaledPartial, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wwwjn We introduced _ScaledPartial to mimic FSDP2's set_gradient_divide_factor. Now that we handles gradient scale by ourselves, I feel we can deprecate this field. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as simpleFSDP don't rescale gradients, we should be good. Does _ScaledPartial means the gradients on each rank are scaled (divided by fsdp degree) when the information in the tensor are still partial / unreduced?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah good point. It does scaling here https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/simple_fsdp/simple_fsdp.py#L204 by using Partial(avg) instead of P(sum).
@Aidyn-A could you help change it to P(sum) and deprecate _ScaledPartial? We also need to compare loss curve with FSDP2 to make sure the change is correct.
If it sounds too involved, we can do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, do you want me to add a warning that _ScaledPartial is deprecated or immediately remove _ScaledPartial. In any case, what should I do with reduction_divide_factor? Where will that go?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need to worry about warning. For reduction_divide_factor, definition and usage should all go away. Thanks!
A follow up on #2313 (comment). This PR removes `_ScaledPartial` placement in favor of `Partial(reduce_op="sum")` placement. cc @tianyu-l, @wwwjn
In the PR pytorch/pytorch#171482, the
PlacementClassVariableandPlacementVariablewere removed from TorchDynamo, so the explicit registration for_ScaledPartialis now required, otherwise it will fail with:This PR adds the registration for
_ScaledPartialplacement to the simple FSDP.