diff --git a/newsfragments/4544.changed.md b/newsfragments/4544.changed.md new file mode 100644 index 00000000000..c94758a770d --- /dev/null +++ b/newsfragments/4544.changed.md @@ -0,0 +1,2 @@ +* Refactored runtime borrow checking for mutable pyclass instances + to be thread-safe when the GIL is disabled. diff --git a/pytests/src/pyclasses.rs b/pytests/src/pyclasses.rs index e499b436395..4e52dbc8712 100644 --- a/pytests/src/pyclasses.rs +++ b/pytests/src/pyclasses.rs @@ -1,3 +1,5 @@ +use std::{thread, time}; + use pyo3::exceptions::{PyStopIteration, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyType; @@ -43,6 +45,29 @@ impl PyClassIter { } } +#[pyclass] +#[derive(Default)] +struct PyClassThreadIter { + count: usize, +} + +#[pymethods] +impl PyClassThreadIter { + #[new] + pub fn new() -> Self { + Default::default() + } + + fn __next__(&mut self, py: Python<'_>) -> usize { + let current_count = self.count; + self.count += 1; + if current_count == 0 { + py.allow_threads(|| thread::sleep(time::Duration::from_millis(100))); + } + self.count + } +} + /// Demonstrates a base class which can operate on the relevant subclass in its constructor. #[pyclass(subclass)] #[derive(Clone, Debug)] @@ -83,6 +108,7 @@ impl ClassWithDict { pub fn pyclasses(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; #[cfg(any(Py_3_10, not(Py_LIMITED_API)))] diff --git a/pytests/tests/test_pyclasses.py b/pytests/tests/test_pyclasses.py index a1424fc75aa..9f611b634b6 100644 --- a/pytests/tests/test_pyclasses.py +++ b/pytests/tests/test_pyclasses.py @@ -1,3 +1,4 @@ +import platform from typing import Type import pytest @@ -53,6 +54,27 @@ def test_iter(): assert excinfo.value.value == "Ended" +@pytest.mark.skipif( + platform.machine() in ["wasm32", "wasm64"], + reason="not supporting threads in CI for WASM yet", +) +def test_parallel_iter(): + import concurrent.futures + + i = pyclasses.PyClassThreadIter() + + def func(): + next(i) + + # the second thread attempts to borrow a reference to the instance's + # state while the first thread is still sleeping, so we trigger a + # runtime borrow-check error + with pytest.raises(RuntimeError, match="Already borrowed"): + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as tpe: + futures = [tpe.submit(func), tpe.submit(func)] + [f.result() for f in futures] + + class AssertingSubClass(pyclasses.AssertingBaseClass): pass diff --git a/src/pycell/impl_.rs b/src/pycell/impl_.rs index 1b5ce774379..1bd225de830 100644 --- a/src/pycell/impl_.rs +++ b/src/pycell/impl_.rs @@ -1,9 +1,10 @@ #![allow(missing_docs)] //! Crate-private implementation of PyClassObject -use std::cell::{Cell, UnsafeCell}; +use std::cell::UnsafeCell; use std::marker::PhantomData; use std::mem::ManuallyDrop; +use std::sync::atomic::{AtomicUsize, Ordering}; use crate::impl_::pyclass::{ PyClassBaseType, PyClassDict, PyClassImpl, PyClassThreadChecker, PyClassWeakRef, @@ -50,22 +51,49 @@ impl PyClassMutability for ExtendsMutableAncestor { type MutableChild = ExtendsMutableAncestor; } -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -struct BorrowFlag(usize); +#[derive(Debug)] +struct BorrowFlag(AtomicUsize); impl BorrowFlag { - pub(crate) const UNUSED: BorrowFlag = BorrowFlag(0); - const HAS_MUTABLE_BORROW: BorrowFlag = BorrowFlag(usize::MAX); - const fn increment(self) -> Self { - Self(self.0 + 1) + pub(crate) const UNUSED: usize = 0; + const HAS_MUTABLE_BORROW: usize = usize::MAX; + fn increment(&self) -> Result<(), PyBorrowError> { + let mut value = self.0.load(Ordering::Relaxed); + loop { + if value == BorrowFlag::HAS_MUTABLE_BORROW { + return Err(PyBorrowError { _private: () }); + } + match self.0.compare_exchange( + // only increment if the value hasn't changed since the + // last atomic load + value, + value + 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(..) => { + // value has been successfully incremented, we need an acquire fence + // so that data this borrow flag protects can be read safely in this thread + std::sync::atomic::fence(Ordering::Acquire); + break Ok(()); + } + Err(changed_value) => { + // value changed under us, need to try again + value = changed_value; + } + } + } } - const fn decrement(self) -> Self { - Self(self.0 - 1) + fn decrement(&self) { + // impossible to get into a bad state from here so relaxed + // ordering is fine, the decrement only needs to eventually + // be visible + self.0.fetch_sub(1, Ordering::Relaxed); } } pub struct EmptySlot(()); -pub struct BorrowChecker(Cell); +pub struct BorrowChecker(BorrowFlag); pub trait PyClassBorrowChecker { /// Initial value for self @@ -110,36 +138,38 @@ impl PyClassBorrowChecker for EmptySlot { impl PyClassBorrowChecker for BorrowChecker { #[inline] fn new() -> Self { - Self(Cell::new(BorrowFlag::UNUSED)) + Self(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED))) } fn try_borrow(&self) -> Result<(), PyBorrowError> { - let flag = self.0.get(); - if flag != BorrowFlag::HAS_MUTABLE_BORROW { - self.0.set(flag.increment()); - Ok(()) - } else { - Err(PyBorrowError { _private: () }) - } + self.0.increment() } fn release_borrow(&self) { - let flag = self.0.get(); - self.0.set(flag.decrement()) + self.0.decrement(); } fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> { - let flag = self.0.get(); - if flag == BorrowFlag::UNUSED { - self.0.set(BorrowFlag::HAS_MUTABLE_BORROW); - Ok(()) - } else { - Err(PyBorrowMutError { _private: () }) + let flag = &self.0; + match flag.0.compare_exchange( + // only allowed to transition to mutable borrow if the reference is + // currently unused + BorrowFlag::UNUSED, + BorrowFlag::HAS_MUTABLE_BORROW, + // On success, reading the flag and updating its state are an atomic + // operation + Ordering::AcqRel, + // It doesn't matter precisely when the failure gets turned + // into an error + Ordering::Relaxed, + ) { + Ok(..) => Ok(()), + Err(..) => Err(PyBorrowMutError { _private: () }), } } fn release_borrow_mut(&self) { - self.0.set(BorrowFlag::UNUSED) + self.0 .0.store(BorrowFlag::UNUSED, Ordering::Release) } } @@ -497,4 +527,89 @@ mod tests { assert!(mmm_bound.extract::>().is_ok()); }) } + + #[test] + #[cfg(not(target_arch = "wasm32"))] + fn test_thread_safety() { + #[crate::pyclass(crate = "crate")] + struct MyClass { + x: u64, + } + + Python::with_gil(|py| { + let inst = Py::new(py, MyClass { x: 0 }).unwrap(); + + let total_modifications = py.allow_threads(|| { + std::thread::scope(|s| { + // Spawn a bunch of threads all racing to write to + // the same instance of `MyClass`. + let threads = (0..10) + .map(|_| { + s.spawn(|| { + Python::with_gil(|py| { + // Each thread records its own view of how many writes it made + let mut local_modifications = 0; + for _ in 0..100 { + if let Ok(mut i) = inst.try_borrow_mut(py) { + i.x += 1; + local_modifications += 1; + } + } + local_modifications + }) + }) + }) + .collect::>(); + + // Sum up the total number of writes made by all threads + threads.into_iter().map(|t| t.join().unwrap()).sum::() + }) + }); + + // If the implementation is free of data races, the total number of writes + // should match the final value of `x`. + assert_eq!(total_modifications, inst.borrow(py).x); + }); + } + + #[test] + #[cfg(not(target_arch = "wasm32"))] + fn test_thread_safety_2() { + struct SyncUnsafeCell(UnsafeCell); + unsafe impl Sync for SyncUnsafeCell {} + + impl SyncUnsafeCell { + fn get(&self) -> *mut T { + self.0.get() + } + } + + let data = SyncUnsafeCell(UnsafeCell::new(0)); + let data2 = SyncUnsafeCell(UnsafeCell::new(0)); + let borrow_checker = BorrowChecker(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED))); + + std::thread::scope(|s| { + s.spawn(|| { + for _ in 0..1_000_000 { + if borrow_checker.try_borrow_mut().is_ok() { + // thread 1 writes to both values during the mutable borrow + unsafe { *data.get() += 1 }; + unsafe { *data2.get() += 1 }; + borrow_checker.release_borrow_mut(); + } + } + }); + + s.spawn(|| { + for _ in 0..1_000_000 { + if borrow_checker.try_borrow().is_ok() { + // if the borrow checker is working correctly, it should be impossible + // for thread 2 to observe a difference in the two values + assert_eq!(unsafe { *data.get() }, unsafe { *data2.get() }); + borrow_checker.release_borrow(); + } + } + }); + }); + } }