diff --git a/spec_classes/utils/mutation.py b/spec_classes/utils/mutation.py index 3f315d0..904074c 100644 --- a/spec_classes/utils/mutation.py +++ b/spec_classes/utils/mutation.py @@ -3,7 +3,7 @@ import copyreg import functools import inspect -from contextlib import contextmanager +from threading import RLock from types import ModuleType from typing import Any, Callable, Dict, Optional, Set, Type, Union @@ -31,6 +31,10 @@ def protect_via_deepcopy(obj: Any, memo: Any = None) -> Any: - For base immutable types copying is not required to ensure object protection, and so such objects are returned as is. - Modules are not copyable, and so are also returned as is. + - During copying, we hack the `copyreg.dispatch_table` to allow for the + passthrough of modules. We revert this change afterwards. To prevent + race conditions with threads, we only revert after all all threads have + completed their copying. """ if isinstance(obj, (bool, int, float, str, bytes, type, ModuleType)): return obj @@ -38,14 +42,33 @@ def protect_via_deepcopy(obj: Any, memo: Any = None) -> Any: return copy.deepcopy(obj, memo) -@contextmanager -def _modules_copyable(): - module_reductor = copyreg.dispatch_table.get(ModuleType, MISSING) - if module_reductor is MISSING: - copyreg.dispatch_table[ModuleType] = lambda module: "passthrough" - yield - if module_reductor is MISSING: - del copyreg.dispatch_table[ModuleType] +class _modules_copyable: + """ + During copying we hack the `copyreg.dispatch_table` to allow modules to be + 'copied'. In multi-threading contexts this can lead to a race condition + where the hack is rolled back during a copying process. To prevent this, we + use a re-entrant lock that ensures that copying is not interrupted by thread + context switches. + """ + + def __init__(self): + self.lock = RLock() + self.refcount = 0 + self.patched_table = False + + def __enter__(self): + with self.lock: + self.refcount += 1 + module_reductor = copyreg.dispatch_table.get(ModuleType, MISSING) + if module_reductor is MISSING: + copyreg.dispatch_table[ModuleType] = lambda module: "passthrough" + self.patched_table = True + + def __exit__(self, *args): + with self.lock: + self.refcount -= 1 + if self.patched_table and self.refcount == 0: + del copyreg.dispatch_table[ModuleType] def mutate_attr(