From 4501c045aacbcec97e789da59a9d6da6973d5c2e Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 15 Sep 2025 13:56:46 +0100 Subject: [PATCH] experiment with "ref" form of extra --- src/serializers/type_serializers/function.rs | 139 +++++++++++++------ 1 file changed, 100 insertions(+), 39 deletions(-) diff --git a/src/serializers/type_serializers/function.rs b/src/serializers/type_serializers/function.rs index 6504aee77..72458cc5e 100644 --- a/src/serializers/type_serializers/function.rs +++ b/src/serializers/type_serializers/function.rs @@ -1,5 +1,8 @@ use std::borrow::Cow; -use std::sync::Arc; +use std::f32::consts::E; +use std::marker::PhantomData; +use std::ptr::{self, NonNull}; +use std::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard}; use pyo3::exceptions::{PyAttributeError, PyRecursionError, PyRuntimeError}; use pyo3::gc::PyVisit; @@ -11,6 +14,7 @@ use pyo3::PyTraverseError; use pyo3::types::PyString; use crate::definitions::DefinitionsBuilder; +use crate::serializers::extra; use crate::tools::SchemaDict; use crate::tools::{function_name, py_err, py_error_type}; use crate::{PydanticOmit, PydanticSerializationUnexpectedValue}; @@ -393,7 +397,9 @@ impl FunctionWrapSerializer { ) -> PyResult<(bool, PyObject)> { let py = value.py(); if self.when_used.should_use(value, extra) { - let serialize = SerializationCallable::new(&self.serializer, include, exclude, extra); + let extra_ref_guard = ExtraRef::new(extra); + let serialize = + SerializationCallable::new(&self.serializer, include, exclude, extra_ref_guard.inner().clone()); let v = if self.is_field_serializer { if let Some(model) = extra.model { if self.info_arg { @@ -434,11 +440,56 @@ impl_py_gc_traverse!(FunctionWrapSerializer { function_type_serializer!(FunctionWrapSerializer); +/// A wrapper around `&Extra` which drops the lifetime, in order to be stored inside a Python object. +#[derive(Clone)] +struct ExtraRef { + value: Arc>>>, +} + +// Safety: `&Extra` is `Send + Sync` +unsafe impl Send for ExtraRef {} +unsafe impl Sync for ExtraRef {} + +impl ExtraRef { + fn new<'a>(extra: &'a Extra<'a>) -> ExtraRefGuard<'a> { + ExtraRefGuard( + ExtraRef { + value: Arc::new(RwLock::new(Some(ptr::from_ref(extra).cast()))), + }, + PhantomData, + ) + } + + fn map(&self, f: impl FnOnce(&Extra<'_>) -> R) -> Option { + // FIXME: deal with lock poisoning?, use try_read + let guard = self.value.read().unwrap(); + guard.as_ref().map(|ptr| { + // Safety: we ensure that the pointer is valid while `ExtraRef` is alive + let extra: &Extra = unsafe { &**ptr }; + f(extra) + }) + } +} + +struct ExtraRefGuard<'a>(ExtraRef, PhantomData<&'a Extra<'a>>); + +impl ExtraRefGuard<'_> { + fn inner(&self) -> &ExtraRef { + &self.0 + } +} + +impl Drop for ExtraRefGuard<'_> { + fn drop(&mut self) { + let mut guard = self.0.value.write().unwrap(); + *guard = None; + } +} + #[pyclass(module = "pydantic_core._pydantic_core")] -#[cfg_attr(debug_assertions, derive(Debug))] pub(crate) struct SerializationCallable { serializer: Arc, - extra_owned: ExtraOwned, + extra: ExtraRef, filter: AnyFilter, include: Option, exclude: Option, @@ -449,11 +500,11 @@ impl SerializationCallable { serializer: &Arc, include: Option<&Bound<'_, PyAny>>, exclude: Option<&Bound<'_, PyAny>>, - extra: &Extra, + extra: ExtraRef, ) -> Self { Self { serializer: serializer.clone(), - extra_owned: ExtraOwned::new(extra), + extra: extra, filter: AnyFilter::new(), include: include.map(|v| v.clone().unbind()), exclude: exclude.map(|v| v.clone().unbind()), @@ -467,24 +518,22 @@ impl SerializationCallable { if let Some(exclude) = &self.exclude { visit.call(exclude)?; } - if let Some(model) = &self.extra_owned.model { - visit.call(model)?; - } - if let Some(fallback) = &self.extra_owned.fallback { - visit.call(fallback)?; - } - if let Some(context) = &self.extra_owned.context { - visit.call(context)?; - } + self.extra + .map(|extra| { + // FIXME: not sound to get .read() of extra inside GC, probably need to make `Extra` not + // have the `'py` lifetime + visit.call(extra.model.map(Bound::as_unbound))?; + visit.call(extra.fallback.map(Bound::as_unbound))?; + visit.call(extra.context.map(Bound::as_unbound))?; + Ok(()) + }) + .transpose()?; Ok(()) } fn __clear__(&mut self) { self.include = None; self.exclude = None; - self.extra_owned.model = None; - self.extra_owned.fallback = None; - self.extra_owned.context = None; } } @@ -503,28 +552,40 @@ impl SerializationCallable { let include = self.include.as_ref().map(|o| o.bind(py)); let exclude = self.exclude.as_ref().map(|o| o.bind(py)); - let extra = self.extra_owned.to_extra(py); - if let Some(index_key) = index_key { - let filter = if let Ok(index) = index_key.extract::() { - self.filter.index_filter(index, include, exclude, None)? - } else { - self.filter.key_filter(index_key, include, exclude)? - }; - if let Some((next_include, next_exclude)) = filter { - let v = - self.serializer - .to_python_no_infer(value, next_include.as_ref(), next_exclude.as_ref(), &extra)?; - extra.warnings.final_check(py)?; - Ok(Some(v)) - } else { - Err(PydanticOmit::new_err()) - } - } else { - let v = self.serializer.to_python_no_infer(value, include, exclude, &extra)?; - extra.warnings.final_check(py)?; - Ok(Some(v)) - } + // FIXME: the &T is not sound here, since the guard is dropped at the end of this statement. + // Probably need to have a .map() method to avoid scope leak? + self.extra + .map(|extra| { + if let Some(index_key) = index_key { + let filter = if let Ok(index) = index_key.extract::() { + self.filter.index_filter(index, include, exclude, None)? + } else { + self.filter.key_filter(index_key, include, exclude)? + }; + if let Some((next_include, next_exclude)) = filter { + let v = self.serializer.to_python_no_infer( + value, + next_include.as_ref(), + next_exclude.as_ref(), + &extra, + )?; + extra.warnings.final_check(py)?; + Ok(Some(v)) + } else { + Err(PydanticOmit::new_err()) + } + } else { + let v = self.serializer.to_python_no_infer(value, include, exclude, &extra)?; + extra.warnings.final_check(py)?; + Ok(Some(v)) + } + }) + .unwrap_or_else(|| { + Err(PyRuntimeError::new_err( + "Attempted to use SerializationCallable after its wrap validation context was exited", + )) + }) } fn __repr__(&self) -> PyResult {