Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/lightning/fabric/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,13 @@ def _replace_dunder_methods(base_cls: type, store_explicit_arg: Optional[str] =
for patched_name in ("__setattr__", "__delattr__", "__init__"):
# Check that __old__{init,setattr,delattr} belongs to the class
# https://stackoverflow.com/a/5253424
if f"__old{patched_name}" in cls.__dict__:
setattr(cls, patched_name, getattr(cls, f"__old{patched_name}"))
delattr(cls, f"__old{patched_name}")
old_name = f"__old{patched_name}"
if old_name in cls.__dict__:
try:
setattr(cls, patched_name, getattr(cls, old_name))
delattr(cls, old_name)
except AttributeError:
pass


def _replace_value_in_saved_args(
Expand Down
29 changes: 29 additions & 0 deletions tests/tests_fabric/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,35 @@ def test_replace_dunder_methods_multiple_loaders_without_init():
assert before[cls] == cls.__init__


def test_replace_dunder_methods_cleanup_tolerates_concurrent_restore():
class ConcurrentCleanupMeta(type):
def __getattribute__(cls, name):
if (
name == "__old__delattr__"
and type.__getattribute__(cls, "_cleanup_started")
and not type.__getattribute__(cls, "_restore_complete")
):
original_method = type.__getattribute__(cls, name)
type.__setattr__(cls, "__delattr__", original_method)
type.__delattr__(cls, name)
type.__setattr__(cls, "_restore_complete", True)
raise AttributeError
return type.__getattribute__(cls, name)

class ConcurrentBatchSampler(BatchSampler, metaclass=ConcurrentCleanupMeta):
_cleanup_started = False
_restore_complete = False

pass

original_delattr = ConcurrentBatchSampler.__delattr__
with _replace_dunder_methods(ConcurrentBatchSampler):
ConcurrentBatchSampler._cleanup_started = True

assert ConcurrentBatchSampler.__delattr__ is original_delattr
assert "__old__delattr__" not in ConcurrentBatchSampler.__dict__


class MyBaseDataLoader(DataLoader):
pass

Expand Down
Loading