Skip to content

Commit 9c7b33c

Browse files
authored
Keep asyncio task ref when running callbacks (#441)
1 parent 9ce9551 commit 9c7b33c

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

src/asgi/callbacks.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use pyo3::prelude::*;
22
use pyo3::types::PyDict;
3-
use std::{net::SocketAddr, sync::Arc};
3+
use std::{
4+
net::SocketAddr,
5+
sync::{Arc, Mutex},
6+
};
47
use tokio::sync::oneshot;
58

69
use super::{
@@ -117,6 +120,7 @@ pub(crate) struct CallbackWrappedRunnerHTTP {
117120
cb: PyObject,
118121
#[pyo3(get)]
119122
scope: PyObject,
123+
pytaskref: Arc<Mutex<Option<PyObject>>>,
120124
}
121125

122126
impl CallbackWrappedRunnerHTTP {
@@ -126,6 +130,7 @@ impl CallbackWrappedRunnerHTTP {
126130
context: cb.context,
127131
cb: cb.callback.clone_ref(py),
128132
scope: scope.into_py(py),
133+
pytaskref: Arc::new(Mutex::new(None)),
129134
}
130135
}
131136

@@ -140,10 +145,12 @@ impl CallbackWrappedRunnerHTTP {
140145

141146
fn done(&self) {
142147
callback_impl_done_http!(self);
148+
self.pytaskref.lock().unwrap().take();
143149
}
144150

145151
fn err(&self, err: Bound<PyAny>) {
146152
callback_impl_done_err!(self, &PyErr::from_value_bound(err));
153+
self.pytaskref.lock().unwrap().take();
147154
}
148155
}
149156

@@ -238,6 +245,7 @@ pub(crate) struct CallbackWrappedRunnerWebsocket {
238245
cb: PyObject,
239246
#[pyo3(get)]
240247
scope: PyObject,
248+
pytaskref: Arc<Mutex<Option<PyObject>>>,
241249
}
242250

243251
impl CallbackWrappedRunnerWebsocket {
@@ -247,6 +255,7 @@ impl CallbackWrappedRunnerWebsocket {
247255
context: cb.context,
248256
cb: cb.callback.clone_ref(py),
249257
scope: scope.into_py(py),
258+
pytaskref: Arc::new(Mutex::new(None)),
250259
}
251260
}
252261

@@ -261,10 +270,12 @@ impl CallbackWrappedRunnerWebsocket {
261270

262271
fn done(&self) {
263272
callback_impl_done_ws!(self);
273+
self.pytaskref.lock().unwrap().take();
264274
}
265275

266276
fn err(&self, err: Bound<PyAny>) {
267277
callback_impl_done_err!(self, &PyErr::from_value_bound(err));
278+
self.pytaskref.lock().unwrap().take();
268279
}
269280
}
270281

src/callbacks.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,15 @@ macro_rules! callback_impl_run {
300300
macro_rules! callback_impl_run_pytask {
301301
() => {
302302
pub fn run(self, py: Python<'_>) -> PyResult<Bound<PyAny>> {
303+
let taskref = self.pytaskref.clone();
303304
let event_loop = self.context.event_loop(py);
304305
let context = self.context.context(py);
305306
let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?;
306307
let kwctx = pyo3::types::PyDict::new_bound(py);
308+
{
309+
let mut taskref_guard = taskref.lock().unwrap();
310+
*taskref_guard = Some(target.clone_ref(py));
311+
}
307312
kwctx.set_item(pyo3::intern!(py, "context"), context)?;
308313
event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(&kwctx))
309314
}

src/rsgi/callbacks.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use pyo3::prelude::*;
2+
use std::sync::{Arc, Mutex};
23
use tokio::sync::oneshot;
34

45
use super::{
@@ -114,6 +115,7 @@ pub(crate) struct CallbackWrappedRunnerHTTP {
114115
cb: PyObject,
115116
#[pyo3(get)]
116117
scope: PyObject,
118+
pytaskref: Arc<Mutex<Option<PyObject>>>,
117119
}
118120

119121
impl CallbackWrappedRunnerHTTP {
@@ -123,6 +125,7 @@ impl CallbackWrappedRunnerHTTP {
123125
context: cb.context,
124126
cb: cb.callback.clone_ref(py),
125127
scope: scope.into_py(py),
128+
pytaskref: Arc::new(Mutex::new(None)),
126129
}
127130
}
128131

@@ -137,10 +140,12 @@ impl CallbackWrappedRunnerHTTP {
137140

138141
fn done(&self) {
139142
callback_impl_done_http!(self);
143+
self.pytaskref.lock().unwrap().take();
140144
}
141145

142146
fn err(&self, err: Bound<PyAny>) {
143147
callback_impl_done_err!(self, &PyErr::from_value_bound(err));
148+
self.pytaskref.lock().unwrap().take();
144149
}
145150
}
146151

@@ -233,6 +238,7 @@ pub(crate) struct CallbackWrappedRunnerWebsocket {
233238
cb: PyObject,
234239
#[pyo3(get)]
235240
scope: PyObject,
241+
pytaskref: Arc<Mutex<Option<PyObject>>>,
236242
}
237243

238244
impl CallbackWrappedRunnerWebsocket {
@@ -242,6 +248,7 @@ impl CallbackWrappedRunnerWebsocket {
242248
context: cb.context,
243249
cb: cb.callback.clone_ref(py),
244250
scope: scope.into_py(py),
251+
pytaskref: Arc::new(Mutex::new(None)),
245252
}
246253
}
247254

@@ -256,10 +263,12 @@ impl CallbackWrappedRunnerWebsocket {
256263

257264
fn done(&self) {
258265
callback_impl_done_ws!(self);
266+
self.pytaskref.lock().unwrap().take();
259267
}
260268

261269
fn err(&self, err: Bound<PyAny>) {
262270
callback_impl_done_err!(self, &PyErr::from_value_bound(err));
271+
self.pytaskref.lock().unwrap().take();
263272
}
264273
}
265274

0 commit comments

Comments
 (0)