From cf4360df87ad9446ca85e5903141583f8b12edfa Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Mon, 2 Feb 2026 21:48:02 +0400 Subject: [PATCH 1/3] Register _ScaledPartial placement as opaque --- .../experiments/simple_fsdp/simple_fsdp.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 6597c45f9d..20b50632fc 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -347,3 +347,22 @@ def data_parallel( ), ) return model + + +def _register_placement_as_opaque(): + from torch._library.opaque_object import MemberType, register_opaque_type + + allowed_members = { + "reduce_op": MemberType.USE_REAL, + "is_shard": MemberType.USE_REAL, + "is_partial": MemberType.USE_REAL, + "is_replicate": MemberType.USE_REAL, + "__eq__": MemberType.USE_REAL, + } + register_opaque_type( + _ScaledPartial, + typ="value", + members=allowed_members, + ) + +_register_placement_as_opaque() From 5928f10c59f2ea7456d5089657ef2d1aafd8a4fe Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Tue, 3 Feb 2026 12:51:21 +0400 Subject: [PATCH 2/3] add comment --- .../experiments/simple_fsdp/simple_fsdp.py | 47 ++++++++++++------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 20b50632fc..d0d1ff2b3b 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -349,20 +349,35 @@ def data_parallel( return model -def _register_placement_as_opaque(): - from torch._library.opaque_object import MemberType, register_opaque_type - - allowed_members = { - "reduce_op": MemberType.USE_REAL, - "is_shard": MemberType.USE_REAL, - "is_partial": MemberType.USE_REAL, - "is_replicate": MemberType.USE_REAL, - "__eq__": MemberType.USE_REAL, - } - register_opaque_type( - _ScaledPartial, - typ="value", - members=allowed_members, - ) +def _maybe_register_placement_as_opaque(): + # The PyTorch PR https://github.com/pytorch/pytorch/pull/171482 + # has removed PlacementVariable, so now TorchDynamo requires an opaque type + # registration for placements and derived objects like `_ScaledPartial`. + # According to https://github.com/pytorch/pytorch/blob/main/torch/_library/opaque_object.py + # Opaque objects are the way TorchDynamo allows custom operators to accept + # a user-defined "black box" object as an input. Users can register a custom class as: + # register_opaque_type(MyClass, typ=..., members=...) + # where `typ` is either "reference" or "value", and `members` is a dictionary mapping + # member names (attributes, properties, or methods) to their MemberType, which controls + # how they are handled during torch.compile tracing + + from torch._dynamo import variables + + if not hasattr("PlacementVariable", variables): + from torch._library.opaque_object import MemberType, register_opaque_type + + allowed_members = { + "reduce_op": MemberType.USE_REAL, + "is_shard": MemberType.USE_REAL, + "is_partial": MemberType.USE_REAL, + "is_replicate": MemberType.USE_REAL, + "__eq__": MemberType.USE_REAL, + } + register_opaque_type( + _ScaledPartial, + typ="value", + members=allowed_members, + ) + -_register_placement_as_opaque() +_maybe_register_placement_as_opaque() From 9d6cca9dbdf060f0ab56bc7ee6728ef4f9983e3e Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Wed, 4 Feb 2026 20:53:03 +0400 Subject: [PATCH 3/3] Some touch-ups in the comment --- .../experiments/simple_fsdp/simple_fsdp.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index d0d1ff2b3b..de2196298e 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -350,16 +350,17 @@ def data_parallel( def _maybe_register_placement_as_opaque(): - # The PyTorch PR https://github.com/pytorch/pytorch/pull/171482 - # has removed PlacementVariable, so now TorchDynamo requires an opaque type - # registration for placements and derived objects like `_ScaledPartial`. - # According to https://github.com/pytorch/pytorch/blob/main/torch/_library/opaque_object.py + # The `PlacementVariable` was removed from TorchDynamo in + # https://github.com/pytorch/pytorch/pull/171482, so now TorchDynamo + # requires an opaque type, registration for placements. That includes + # derived classes like `_ScaledPartial`. According to the notes in: + # https://github.com/pytorch/pytorch/blob/main/torch/_library/opaque_object.py # Opaque objects are the way TorchDynamo allows custom operators to accept - # a user-defined "black box" object as an input. Users can register a custom class as: - # register_opaque_type(MyClass, typ=..., members=...) - # where `typ` is either "reference" or "value", and `members` is a dictionary mapping - # member names (attributes, properties, or methods) to their MemberType, which controls - # how they are handled during torch.compile tracing + # a user-defined "black box" object as an input. Users can register their + # custom classes as `register_opaque_type(MyClass, typ=..., members=...)`, + # where `typ` is either "reference" or "value", and `members` is a dictionary + # mapping member names (attributes, properties, or methods) to their MemberType, + # which controls how they are handled during torch.compile tracing. from torch._dynamo import variables