Skip to content

Commit

Permalink
Refactor Python interpreter state handling
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro committed Feb 1, 2025
1 parent 100f641 commit 987e40f
Show file tree
Hide file tree
Showing 15 changed files with 252 additions and 417 deletions.
1 change: 1 addition & 0 deletions granian/server/mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def _spawn_asgi_lifespan_worker(
loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio
)
serve(scheduler, loop, shutdown_event)
loop.run_until_complete(lifespan_handler.shutdown())

@staticmethod
def _spawn_rsgi_worker(
Expand Down
128 changes: 35 additions & 93 deletions src/asgi/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use super::{
use crate::{
callbacks::ArcCBScheduler,
http::{response_500, HTTPResponse},
runtime::RuntimeRef,
runtime::{Runtime, RuntimeRef},
utils::log_application_callable_exception,
ws::{HyperWebsocket, UpgradeData},
};
Expand All @@ -35,9 +35,9 @@ macro_rules! callback_impl_done_ws {
}

macro_rules! callback_impl_done_err {
($self:expr, $err:expr) => {
($self:expr, $py:expr, $err:expr) => {
$self.done();
log_application_callable_exception($err);
log_application_callable_exception($py, $err);
};
}

Expand Down Expand Up @@ -72,8 +72,8 @@ impl CallbackWatcherHTTP {
callback_impl_done_http!(self);
}

fn err(&self, err: Bound<PyAny>) {
callback_impl_done_err!(self, &PyErr::from_value(err));
fn err(&self, py: Python, err: Bound<PyAny>) {
callback_impl_done_err!(self, py, &PyErr::from_value(err));
}

fn taskref(&self, py: Python, task: PyObject) {
Expand Down Expand Up @@ -106,8 +106,8 @@ impl CallbackWatcherWebsocket {
callback_impl_done_ws!(self);
}

fn err(&self, err: Bound<PyAny>) {
callback_impl_done_err!(self, &PyErr::from_value(err));
fn err(&self, py: Python, err: Bound<PyAny>) {
callback_impl_done_err!(self, py, &PyErr::from_value(err));
}

fn taskref(&self, py: Python, task: PyObject) {
Expand Down Expand Up @@ -138,7 +138,6 @@ impl CallbackWatcherWebsocket {
// }
// }

#[cfg(not(Py_GIL_DISABLED))]
#[inline]
pub(crate) fn call_http(
cb: ArcCBScheduler,
Expand All @@ -149,12 +148,11 @@ pub(crate) fn call_http(
req: hyper::http::request::Parts,
body: hyper::body::Incoming,
) -> oneshot::Receiver<HTTPResponse> {
let brt = rt.innerb.clone();
let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, body, tx);
let protocol = HTTPProtocol::new(rt.clone(), body, tx);
let scheme: Arc<str> = scheme.into();

let _ = brt.run(move || {
rt.spawn_blocking(move |py| {
scope_native_parts!(
req,
server_addr,
Expand All @@ -165,45 +163,18 @@ pub(crate) fn call_http(
server,
client
);
Python::with_gil(|py| {
let scope = build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap();
let watcher = Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap();
cb.get().schedule(py, watcher.as_any());
});
});

rx
}

#[cfg(Py_GIL_DISABLED)]
#[inline]
pub(crate) fn call_http(
cb: ArcCBScheduler,
rt: RuntimeRef,
server_addr: SocketAddr,
client_addr: SocketAddr,
scheme: &str,
req: hyper::http::request::Parts,
body: hyper::body::Incoming,
) -> oneshot::Receiver<HTTPResponse> {
let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, body, tx);
let scheme: Arc<str> = scheme.into();

scope_native_parts!(
req,
server_addr,
client_addr,
path,
query_string,
version,
server,
client
);
Python::with_gil(|py| {
let scope = build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap();
let watcher = Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap();
cb.get().schedule(py, watcher.as_any());
cb.get().schedule(
py,
Py::new(
py,
CallbackWatcherHTTP::new(
py,
protocol,
build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap(),
),
)
.unwrap(),
);
});

rx
Expand All @@ -221,12 +192,11 @@ pub(crate) fn call_ws(
req: hyper::http::request::Parts,
upgrade: UpgradeData,
) -> oneshot::Receiver<WebsocketDetachedTransport> {
let brt = rt.innerb.clone();
let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);
let protocol = WebsocketProtocol::new(rt.clone(), tx, ws, upgrade);
let scheme: Arc<str> = scheme.into();

let _ = brt.run(move || {
rt.spawn_blocking(move |py| {
scope_native_parts!(
req,
server_addr,
Expand All @@ -237,46 +207,18 @@ pub(crate) fn call_ws(
server,
client
);
Python::with_gil(|py| {
let scope = build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap();
let watcher = Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap();
cb.get().schedule(py, watcher.as_any());
});
});

rx
}

#[cfg(Py_GIL_DISABLED)]
#[inline]
pub(crate) fn call_ws(
cb: ArcCBScheduler,
rt: RuntimeRef,
server_addr: SocketAddr,
client_addr: SocketAddr,
scheme: &str,
ws: HyperWebsocket,
req: hyper::http::request::Parts,
upgrade: UpgradeData,
) -> oneshot::Receiver<WebsocketDetachedTransport> {
let (tx, rx) = oneshot::channel();
let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade);
let scheme: Arc<str> = scheme.into();

scope_native_parts!(
req,
server_addr,
client_addr,
path,
query_string,
version,
server,
client
);
Python::with_gil(|py| {
let scope = build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap();
let watcher = Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap();
cb.get().schedule(py, watcher.as_any());
cb.get().schedule(
py,
Py::new(
py,
CallbackWatcherWebsocket::new(
py,
protocol,
build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap(),
),
)
.unwrap(),
);
});

rx
Expand Down
71 changes: 62 additions & 9 deletions src/blocking.rs
Original file line number Diff line number Diff line change
@@ -1,50 +1,103 @@
use crossbeam_channel as channel;
use pyo3::prelude::*;
use std::thread;

pub(crate) struct BlockingTask {
inner: Box<dyn FnOnce() + Send + 'static>,
inner: Box<dyn FnOnce(Python) + Send + 'static>,
}

impl BlockingTask {
pub fn new<T>(inner: T) -> BlockingTask
where
T: FnOnce() + Send + 'static,
T: FnOnce(Python) + Send + 'static,
{
Self { inner: Box::new(inner) }
}

pub fn run(self) {
(self.inner)();
pub fn run(self, py: Python) {
(self.inner)(py);
}
}

#[derive(Clone)]
pub(crate) struct BlockingRunner {
queue: channel::Sender<BlockingTask>,
#[cfg(Py_GIL_DISABLED)]
sig: channel::Sender<()>,
}

impl BlockingRunner {
#[cfg(not(Py_GIL_DISABLED))]
pub fn new() -> Self {
let queue = blocking_thread();
Self { queue }
}

#[cfg(Py_GIL_DISABLED)]
pub fn new() -> Self {
let (sigtx, sigrx) = channel::bounded(1);
let queue = blocking_thread(sigrx);
Self { queue, sig: sigtx }
}

pub fn run<T>(&self, task: T) -> Result<(), channel::SendError<BlockingTask>>
where
T: FnOnce() + Send + 'static,
T: FnOnce(Python) + Send + 'static,
{
self.queue.send(BlockingTask::new(task))
}

#[cfg(Py_GIL_DISABLED)]
pub fn shutdown(&self) {
_ = self.sig.send(());
}
}

fn bloking_loop(queue: channel::Receiver<BlockingTask>) {
#[cfg(not(Py_GIL_DISABLED))]
fn blocking_loop(queue: channel::Receiver<BlockingTask>) {
while let Ok(task) = queue.recv() {
task.run();
Python::with_gil(|py| task.run(py));
}
}

// NOTE: for some reason, on no-gil callback watchers are not GCd until following req.
// It's not clear atm wether this is an issue with pyo3, CPython itself, or smth
// different in terms of pointers due to multi-threaded environment.
// Thus, we need a signal to manually stop the loop and let the server shutdown.
// The following function would be the intended one if we hadn't the issue just described.
//
// #[cfg(Py_GIL_DISABLED)]
// fn blocking_loop(queue: channel::Receiver<BlockingTask>) {
// Python::with_gil(|py| {
// while let Ok(task) = queue.recv() {
// task.run(py);
// }
// });
// }
#[cfg(Py_GIL_DISABLED)]
fn blocking_loop(queue: channel::Receiver<BlockingTask>, sig: channel::Receiver<()>) {
Python::with_gil(|py| loop {
crossbeam_channel::select! {
recv(queue) -> task => match task {
Ok(task) => task.run(py),
_ => break,
},
recv(sig) -> _ => break
}
});
}

#[cfg(not(Py_GIL_DISABLED))]
fn blocking_thread() -> channel::Sender<BlockingTask> {
let (qtx, qrx) = channel::unbounded();
thread::spawn(|| bloking_loop(qrx));
thread::spawn(|| blocking_loop(qrx));

qtx
}

#[cfg(Py_GIL_DISABLED)]
fn blocking_thread(sig: channel::Receiver<()>) -> channel::Sender<BlockingTask> {
let (qtx, qrx) = channel::unbounded();
thread::spawn(|| blocking_loop(qrx, sig));

qtx
}
33 changes: 21 additions & 12 deletions src/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ pub(crate) struct CallbackScheduler {
#[cfg(not(PyPy))]
impl CallbackScheduler {
#[inline]
pub(crate) fn schedule(&self, _py: Python, watcher: &PyObject) {
pub(crate) fn schedule<T>(&self, py: Python, watcher: Py<T>) {
let cbarg = watcher.as_ptr();
let sched = self.schedule_fn.get().unwrap().as_ptr();

unsafe {
pyo3::ffi::PyObject_CallOneArg(sched, cbarg);
}

watcher.drop_ref(py);
}

#[inline]
Expand Down Expand Up @@ -130,13 +132,15 @@ impl CallbackScheduler {
#[cfg(PyPy)]
impl CallbackScheduler {
#[inline]
pub(crate) fn schedule(&self, py: Python, watcher: &PyObject) {
pub(crate) fn schedule(&self, py: Python, watcher: Py<T>) {
let cbarg = (watcher,).into_pyobject(py).unwrap().into_ptr();
let sched = self.schedule_fn.get().unwrap().as_ptr();

unsafe {
pyo3::ffi::PyObject_CallObject(sched, cbarg);
}

watcher.drop_ref(py);
}

#[inline]
Expand Down Expand Up @@ -508,8 +512,9 @@ impl PyIterAwaitable {
}

#[inline]
pub(crate) fn set_result(&self, py: Python, result: FutureResultToPy) {
let _ = self.result.set(result.into_pyobject(py).map(Bound::unbind));
pub(crate) fn set_result(pyself: Py<Self>, py: Python, result: FutureResultToPy) {
_ = pyself.get().result.set(result.into_pyobject(py).map(Bound::unbind));
pyself.drop_ref(py);
}
}

Expand Down Expand Up @@ -583,18 +588,22 @@ impl PyFutureAwaitable {
)
.is_err()
{
pyself.drop_ref(py);
return;
}

let ack = rself.ack.read().unwrap();
if let Some((cb, ctx)) = &*ack {
let _ = rself.event_loop.clone_ref(py).call_method(
py,
pyo3::intern!(py, "call_soon_threadsafe"),
(cb, pyself.clone_ref(py)),
Some(ctx.bind(py)),
);
{
let ack = rself.ack.read().unwrap();
if let Some((cb, ctx)) = &*ack {
_ = rself.event_loop.clone_ref(py).call_method(
py,
pyo3::intern!(py, "call_soon_threadsafe"),
(cb, pyself.clone_ref(py)),
Some(ctx.bind(py)),
);
}
}
pyself.drop_ref(py);
}
}

Expand Down
Loading

0 comments on commit 987e40f

Please sign in to comment.