From f88aca9327c635278b35ceca675eb8a38e7050ea Mon Sep 17 00:00:00 2001 From: Furkan-rgb <50831308+Furkan-rgb@users.noreply.github.com> Date: Fri, 27 Mar 2026 13:17:10 +0000 Subject: [PATCH] Fix free-threaded cleanup race in dunder patching --- src/lightning/fabric/utilities/data.py | 10 +++++--- tests/tests_fabric/utilities/test_data.py | 29 +++++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index ea35d8c3da4a9..ae1b28f29889f 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -382,9 +382,13 @@ def _replace_dunder_methods(base_cls: type, store_explicit_arg: Optional[str] = for patched_name in ("__setattr__", "__delattr__", "__init__"): # Check that __old__{init,setattr,delattr} belongs to the class # https://stackoverflow.com/a/5253424 - if f"__old{patched_name}" in cls.__dict__: - setattr(cls, patched_name, getattr(cls, f"__old{patched_name}")) - delattr(cls, f"__old{patched_name}") + old_name = f"__old{patched_name}" + if old_name in cls.__dict__: + try: + setattr(cls, patched_name, getattr(cls, old_name)) + delattr(cls, old_name) + except AttributeError: + pass def _replace_value_in_saved_args( diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 91b0a4e47b8b0..960fcb3b2b9cc 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -80,6 +80,35 @@ def test_replace_dunder_methods_multiple_loaders_without_init(): assert before[cls] == cls.__init__ +def test_replace_dunder_methods_cleanup_tolerates_concurrent_restore(): + class ConcurrentCleanupMeta(type): + def __getattribute__(cls, name): + if ( + name == "__old__delattr__" + and type.__getattribute__(cls, "_cleanup_started") + and not type.__getattribute__(cls, "_restore_complete") + ): + original_method = type.__getattribute__(cls, name) + type.__setattr__(cls, "__delattr__", original_method) + type.__delattr__(cls, name) + type.__setattr__(cls, "_restore_complete", True) + raise AttributeError + return type.__getattribute__(cls, name) + + class ConcurrentBatchSampler(BatchSampler, metaclass=ConcurrentCleanupMeta): + _cleanup_started = False + _restore_complete = False + + pass + + original_delattr = ConcurrentBatchSampler.__delattr__ + with _replace_dunder_methods(ConcurrentBatchSampler): + ConcurrentBatchSampler._cleanup_started = True + + assert ConcurrentBatchSampler.__delattr__ is original_delattr + assert "__old__delattr__" not in ConcurrentBatchSampler.__dict__ + + class MyBaseDataLoader(DataLoader): pass