Skip to content

Commit

Permalink
fix: Prevent race conditions during multithreading during deepcopying.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewwardrop committed Nov 27, 2023
1 parent e5b3b20 commit 437545f
Showing 1 changed file with 32 additions and 9 deletions.
41 changes: 32 additions & 9 deletions spec_classes/utils/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copyreg
import functools
import inspect
from contextlib import contextmanager
from threading import RLock
from types import ModuleType
from typing import Any, Callable, Dict, Optional, Set, Type, Union

Expand Down Expand Up @@ -31,21 +31,44 @@ def protect_via_deepcopy(obj: Any, memo: Any = None) -> Any:
- For base immutable types copying is not required to ensure object
protection, and so such objects are returned as is.
- Modules are not copyable, and so are also returned as is.
- During copying, we hack the `copyreg.dispatch_table` to allow for the
passthrough of modules. We revert this change afterwards. To prevent
race conditions with threads, we only revert after all all threads have
completed their copying.
"""
if isinstance(obj, (bool, int, float, str, bytes, type, ModuleType)):
return obj
with _modules_copyable():
return copy.deepcopy(obj, memo)


@contextmanager
def _modules_copyable():
module_reductor = copyreg.dispatch_table.get(ModuleType, MISSING)
if module_reductor is MISSING:
copyreg.dispatch_table[ModuleType] = lambda module: "passthrough"
yield
if module_reductor is MISSING:
del copyreg.dispatch_table[ModuleType]
class _modules_copyable:
"""
During copying we hack the `copyreg.dispatch_table` to allow modules to be
'copied'. In multi-threading contexts this can lead to a race condition
where the hack is rolled back during a copying process. To prevent this, we
use a re-entrant lock that ensures that copying is not interrupted by thread
context switches.
"""

def __init__(self):
self.lock = RLock()
self.refcount = 0
self.patched_table = False

def __enter__(self):
with self.lock:
self.refcount += 1
module_reductor = copyreg.dispatch_table.get(ModuleType, MISSING)
if module_reductor is MISSING:
copyreg.dispatch_table[ModuleType] = lambda module: "passthrough"
self.patched_table = True

def __exit__(self, *args):
with self.lock:
self.refcount -= 1
if self.patched_table and self.refcount == 0:
del copyreg.dispatch_table[ModuleType]


def mutate_attr(
Expand Down

0 comments on commit 437545f

Please sign in to comment.