diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 6597c45f9d..de2196298e 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -347,3 +347,38 @@ def data_parallel( ), ) return model + + +def _maybe_register_placement_as_opaque(): + # The `PlacementVariable` was removed from TorchDynamo in + # https://github.com/pytorch/pytorch/pull/171482, so now TorchDynamo + # requires an opaque type, registration for placements. That includes + # derived classes like `_ScaledPartial`. According to the notes in: + # https://github.com/pytorch/pytorch/blob/main/torch/_library/opaque_object.py + # Opaque objects are the way TorchDynamo allows custom operators to accept + # a user-defined "black box" object as an input. Users can register their + # custom classes as `register_opaque_type(MyClass, typ=..., members=...)`, + # where `typ` is either "reference" or "value", and `members` is a dictionary + # mapping member names (attributes, properties, or methods) to their MemberType, + # which controls how they are handled during torch.compile tracing. + + from torch._dynamo import variables + + if not hasattr("PlacementVariable", variables): + from torch._library.opaque_object import MemberType, register_opaque_type + + allowed_members = { + "reduce_op": MemberType.USE_REAL, + "is_shard": MemberType.USE_REAL, + "is_partial": MemberType.USE_REAL, + "is_replicate": MemberType.USE_REAL, + "__eq__": MemberType.USE_REAL, + } + register_opaque_type( + _ScaledPartial, + typ="value", + members=allowed_members, + ) + + +_maybe_register_placement_as_opaque()