Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions torchtitan/experiments/simple_fsdp/simple_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,38 @@ def data_parallel(
),
)
return model


def _maybe_register_placement_as_opaque():
# The `PlacementVariable` was removed from TorchDynamo in
# https://github.com/pytorch/pytorch/pull/171482, so now TorchDynamo
# requires an opaque type, registration for placements. That includes
# derived classes like `_ScaledPartial`. According to the notes in:
# https://github.com/pytorch/pytorch/blob/main/torch/_library/opaque_object.py
# Opaque objects are the way TorchDynamo allows custom operators to accept
# a user-defined "black box" object as an input. Users can register their
# custom classes as `register_opaque_type(MyClass, typ=..., members=...)`,
# where `typ` is either "reference" or "value", and `members` is a dictionary
# mapping member names (attributes, properties, or methods) to their MemberType,
# which controls how they are handled during torch.compile tracing.

from torch._dynamo import variables

if not hasattr("PlacementVariable", variables):
from torch._library.opaque_object import MemberType, register_opaque_type

allowed_members = {
"reduce_op": MemberType.USE_REAL,
"is_shard": MemberType.USE_REAL,
"is_partial": MemberType.USE_REAL,
"is_replicate": MemberType.USE_REAL,
"__eq__": MemberType.USE_REAL,
}
register_opaque_type(
_ScaledPartial,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wwwjn We introduced _ScaledPartial to mimic FSDP2's set_gradient_divide_factor. Now that we handles gradient scale by ourselves, I feel we can deprecate this field. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as simpleFSDP don't rescale gradients, we should be good. Does _ScaledPartial means the gradients on each rank are scaled (divided by fsdp degree) when the information in the tensor are still partial / unreduced?

Copy link
Contributor

@tianyu-l tianyu-l Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good point. It does scaling here https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/simple_fsdp/simple_fsdp.py#L204 by using Partial(avg) instead of P(sum).

@Aidyn-A could you help change it to P(sum) and deprecate _ScaledPartial? We also need to compare loss curve with FSDP2 to make sure the change is correct.

If it sounds too involved, we can do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, do you want me to add a warning that _ScaledPartial is deprecated or immediately remove _ScaledPartial. In any case, what should I do with reduction_divide_factor? Where will that go?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to worry about warning. For reduction_divide_factor, definition and usage should all go away. Thanks!

typ="value",
members=allowed_members,
)


_maybe_register_placement_as_opaque()