Skip to content

Commit

Permalink
Make PyClassBorrowChecker thread safe (#4544)
Browse files Browse the repository at this point in the history
* use a mutex in PyClassBorrowChecker

* add a test that triggers a reference race

* make BorrowFlag wrap an AtomicUsize

* fix errors seen on CI

* add changelog entry

* use a compare-exchange loop in try_borrow

* move atomic increment implementation into increment method

* fix bug pointed out in code review

* make test use an atomic

* make test runnable on GIL-enabled build

* use less restrictive ordering, add comments

* fix ruff error

* relax ordering in mutable borrows as well

* Update impl_.rs

Co-authored-by: David Hewitt <mail@davidhewitt.dev>

* fix path

* use AcqRel for mutable borrow compare_exchange loop

* add test from david

* one more test

* disable thread safety tests on WASM

* skip python test on WASM as well

* fix format

* fixup skipif reason

---------

Co-authored-by: David Hewitt <mail@davidhewitt.dev>
  • Loading branch information
ngoldbaum and davidhewitt authored Oct 5, 2024
1 parent 71012db commit 8288fb9
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 27 deletions.
2 changes: 2 additions & 0 deletions newsfragments/4544.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
* Refactored runtime borrow checking for mutable pyclass instances
to be thread-safe when the GIL is disabled.
26 changes: 26 additions & 0 deletions pytests/src/pyclasses.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::{thread, time};

use pyo3::exceptions::{PyStopIteration, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyType;
Expand Down Expand Up @@ -43,6 +45,29 @@ impl PyClassIter {
}
}

#[pyclass]
#[derive(Default)]
struct PyClassThreadIter {
count: usize,
}

#[pymethods]
impl PyClassThreadIter {
#[new]
pub fn new() -> Self {
Default::default()
}

fn __next__(&mut self, py: Python<'_>) -> usize {
let current_count = self.count;
self.count += 1;
if current_count == 0 {
py.allow_threads(|| thread::sleep(time::Duration::from_millis(100)));
}
self.count
}
}

/// Demonstrates a base class which can operate on the relevant subclass in its constructor.
#[pyclass(subclass)]
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -83,6 +108,7 @@ impl ClassWithDict {
pub fn pyclasses(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<EmptyClass>()?;
m.add_class::<PyClassIter>()?;
m.add_class::<PyClassThreadIter>()?;
m.add_class::<AssertingBaseClass>()?;
m.add_class::<ClassWithoutConstructor>()?;
#[cfg(any(Py_3_10, not(Py_LIMITED_API)))]
Expand Down
22 changes: 22 additions & 0 deletions pytests/tests/test_pyclasses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import platform
from typing import Type

import pytest
Expand Down Expand Up @@ -53,6 +54,27 @@ def test_iter():
assert excinfo.value.value == "Ended"


@pytest.mark.skipif(
platform.machine() in ["wasm32", "wasm64"],
reason="not supporting threads in CI for WASM yet",
)
def test_parallel_iter():
import concurrent.futures

i = pyclasses.PyClassThreadIter()

def func():
next(i)

# the second thread attempts to borrow a reference to the instance's
# state while the first thread is still sleeping, so we trigger a
# runtime borrow-check error
with pytest.raises(RuntimeError, match="Already borrowed"):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as tpe:
futures = [tpe.submit(func), tpe.submit(func)]
[f.result() for f in futures]


class AssertingSubClass(pyclasses.AssertingBaseClass):
pass

Expand Down
169 changes: 142 additions & 27 deletions src/pycell/impl_.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#![allow(missing_docs)]
//! Crate-private implementation of PyClassObject

use std::cell::{Cell, UnsafeCell};
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::sync::atomic::{AtomicUsize, Ordering};

use crate::impl_::pyclass::{
PyClassBaseType, PyClassDict, PyClassImpl, PyClassThreadChecker, PyClassWeakRef,
Expand Down Expand Up @@ -50,22 +51,49 @@ impl<M: PyClassMutability> PyClassMutability for ExtendsMutableAncestor<M> {
type MutableChild = ExtendsMutableAncestor<MutableClass>;
}

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
struct BorrowFlag(usize);
#[derive(Debug)]
struct BorrowFlag(AtomicUsize);

impl BorrowFlag {
pub(crate) const UNUSED: BorrowFlag = BorrowFlag(0);
const HAS_MUTABLE_BORROW: BorrowFlag = BorrowFlag(usize::MAX);
const fn increment(self) -> Self {
Self(self.0 + 1)
pub(crate) const UNUSED: usize = 0;
const HAS_MUTABLE_BORROW: usize = usize::MAX;
fn increment(&self) -> Result<(), PyBorrowError> {
let mut value = self.0.load(Ordering::Relaxed);
loop {
if value == BorrowFlag::HAS_MUTABLE_BORROW {
return Err(PyBorrowError { _private: () });
}
match self.0.compare_exchange(
// only increment if the value hasn't changed since the
// last atomic load
value,
value + 1,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(..) => {
// value has been successfully incremented, we need an acquire fence
// so that data this borrow flag protects can be read safely in this thread
std::sync::atomic::fence(Ordering::Acquire);
break Ok(());
}
Err(changed_value) => {
// value changed under us, need to try again
value = changed_value;
}
}
}
}
const fn decrement(self) -> Self {
Self(self.0 - 1)
fn decrement(&self) {
// impossible to get into a bad state from here so relaxed
// ordering is fine, the decrement only needs to eventually
// be visible
self.0.fetch_sub(1, Ordering::Relaxed);
}
}

pub struct EmptySlot(());
pub struct BorrowChecker(Cell<BorrowFlag>);
pub struct BorrowChecker(BorrowFlag);

pub trait PyClassBorrowChecker {
/// Initial value for self
Expand Down Expand Up @@ -110,36 +138,38 @@ impl PyClassBorrowChecker for EmptySlot {
impl PyClassBorrowChecker for BorrowChecker {
#[inline]
fn new() -> Self {
Self(Cell::new(BorrowFlag::UNUSED))
Self(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)))
}

fn try_borrow(&self) -> Result<(), PyBorrowError> {
let flag = self.0.get();
if flag != BorrowFlag::HAS_MUTABLE_BORROW {
self.0.set(flag.increment());
Ok(())
} else {
Err(PyBorrowError { _private: () })
}
self.0.increment()
}

fn release_borrow(&self) {
let flag = self.0.get();
self.0.set(flag.decrement())
self.0.decrement();
}

fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
let flag = self.0.get();
if flag == BorrowFlag::UNUSED {
self.0.set(BorrowFlag::HAS_MUTABLE_BORROW);
Ok(())
} else {
Err(PyBorrowMutError { _private: () })
let flag = &self.0;
match flag.0.compare_exchange(
// only allowed to transition to mutable borrow if the reference is
// currently unused
BorrowFlag::UNUSED,
BorrowFlag::HAS_MUTABLE_BORROW,
// On success, reading the flag and updating its state are an atomic
// operation
Ordering::AcqRel,
// It doesn't matter precisely when the failure gets turned
// into an error
Ordering::Relaxed,
) {
Ok(..) => Ok(()),
Err(..) => Err(PyBorrowMutError { _private: () }),
}
}

fn release_borrow_mut(&self) {
self.0.set(BorrowFlag::UNUSED)
self.0 .0.store(BorrowFlag::UNUSED, Ordering::Release)
}
}

Expand Down Expand Up @@ -497,4 +527,89 @@ mod tests {
assert!(mmm_bound.extract::<PyRefMut<'_, MutableBase>>().is_ok());
})
}

#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_thread_safety() {
#[crate::pyclass(crate = "crate")]
struct MyClass {
x: u64,
}

Python::with_gil(|py| {
let inst = Py::new(py, MyClass { x: 0 }).unwrap();

let total_modifications = py.allow_threads(|| {
std::thread::scope(|s| {
// Spawn a bunch of threads all racing to write to
// the same instance of `MyClass`.
let threads = (0..10)
.map(|_| {
s.spawn(|| {
Python::with_gil(|py| {
// Each thread records its own view of how many writes it made
let mut local_modifications = 0;
for _ in 0..100 {
if let Ok(mut i) = inst.try_borrow_mut(py) {
i.x += 1;
local_modifications += 1;
}
}
local_modifications
})
})
})
.collect::<Vec<_>>();

// Sum up the total number of writes made by all threads
threads.into_iter().map(|t| t.join().unwrap()).sum::<u64>()
})
});

// If the implementation is free of data races, the total number of writes
// should match the final value of `x`.
assert_eq!(total_modifications, inst.borrow(py).x);
});
}

#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_thread_safety_2() {
struct SyncUnsafeCell<T>(UnsafeCell<T>);
unsafe impl<T> Sync for SyncUnsafeCell<T> {}

impl<T> SyncUnsafeCell<T> {
fn get(&self) -> *mut T {
self.0.get()
}
}

let data = SyncUnsafeCell(UnsafeCell::new(0));
let data2 = SyncUnsafeCell(UnsafeCell::new(0));
let borrow_checker = BorrowChecker(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)));

std::thread::scope(|s| {
s.spawn(|| {
for _ in 0..1_000_000 {
if borrow_checker.try_borrow_mut().is_ok() {
// thread 1 writes to both values during the mutable borrow
unsafe { *data.get() += 1 };
unsafe { *data2.get() += 1 };
borrow_checker.release_borrow_mut();
}
}
});

s.spawn(|| {
for _ in 0..1_000_000 {
if borrow_checker.try_borrow().is_ok() {
// if the borrow checker is working correctly, it should be impossible
// for thread 2 to observe a difference in the two values
assert_eq!(unsafe { *data.get() }, unsafe { *data2.get() });
borrow_checker.release_borrow();
}
}
});
});
}
}

0 comments on commit 8288fb9

Please sign in to comment.