From df6c6bf2cb5dcef8d9b9438b2578d41ddcf2bdf8 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 16 Dec 2024 16:01:57 -0500 Subject: [PATCH 1/5] Support async iteration of RecordBatchStream --- Cargo.lock | 14 +++++++++ Cargo.toml | 3 +- python/datafusion/record_batch.py | 14 ++++++--- src/record_batch.rs | 51 +++++++++++++++++++++++++------ 4 files changed, 66 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d1f291be9..352771cdb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1303,6 +1303,7 @@ dependencies = [ "prost", "prost-types", "pyo3", + "pyo3-async-runtimes", "pyo3-build-config", "tokio", "url", @@ -2672,6 +2673,19 @@ dependencies = [ "unindent", ] +[[package]] +name = "pyo3-async-runtimes" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2529f0be73ffd2be0cc43c013a640796558aa12d7ca0aab5cc14f375b4733031" +dependencies = [ + "futures", + "once_cell", + "pin-project-lite", + "pyo3", + "tokio", +] + [[package]] name = "pyo3-build-config" version = "0.22.6" diff --git a/Cargo.toml b/Cargo.toml index 703fc5a26..d28844685 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ substrait = ["dep:datafusion-substrait"] [dependencies] tokio = { version = "1.41", features = ["macros", "rt", "rt-multi-thread", "sync"] } pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] } +pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"]} arrow = { version = "53", features = ["pyarrow"] } datafusion = { version = "43.0.0", features = ["pyarrow", "avro", "unicode_expressions"] } datafusion-substrait = { version = "43.0.0", optional = true } @@ -60,4 +61,4 @@ crate-type = ["cdylib", "rlib"] [profile.release] lto = true -codegen-units = 1 \ No newline at end of file +codegen-units = 1 diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index 44936f7d8..005a423f6 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -59,18 +59,22 @@ def __init__(self, record_batch_stream: df_internal.RecordBatchStream) -> None: def next(self) -> RecordBatch | None: """See :py:func:`__next__` for the iterator function.""" - try: - next_batch = next(self) - except StopIteration: - return None + return next(self) - return next_batch + async def __anext__(self) -> RecordBatch: + """Async iterator function.""" + next_batch = anext(self.rbs) + return RecordBatch(next_batch) def __next__(self) -> RecordBatch: """Iterator function.""" next_batch = next(self.rbs) return RecordBatch(next_batch) + def __aiter__(self) -> typing_extensions.Self: + """Async iterator function.""" + return self + def __iter__(self) -> typing_extensions.Self: """Iterator function.""" return self diff --git a/src/record_batch.rs b/src/record_batch.rs index 427807f22..eacdb5867 100644 --- a/src/record_batch.rs +++ b/src/record_batch.rs @@ -15,13 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::utils::wait_for_future; use datafusion::arrow::pyarrow::ToPyArrow; use datafusion::arrow::record_batch::RecordBatch; use datafusion::physical_plan::SendableRecordBatchStream; use futures::StreamExt; +use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration}; use pyo3::prelude::*; use pyo3::{pyclass, pymethods, PyObject, PyResult, Python}; +use tokio::sync::Mutex; #[pyclass(name = "RecordBatch", module = "datafusion", subclass)] pub struct PyRecordBatch { @@ -43,31 +47,58 @@ impl From for PyRecordBatch { #[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)] pub struct PyRecordBatchStream { - stream: SendableRecordBatchStream, + stream: Arc>, } impl PyRecordBatchStream { pub fn new(stream: SendableRecordBatchStream) -> Self { - Self { stream } + Self { + stream: Arc::new(Mutex::new(stream)), + } } } #[pymethods] impl PyRecordBatchStream { - fn next(&mut self, py: Python) -> PyResult> { - let result = self.stream.next(); - match wait_for_future(py, result) { - None => Ok(None), - Some(Ok(b)) => Ok(Some(b.into())), - Some(Err(e)) => Err(e.into()), - } + fn next(&mut self, py: Python) -> PyResult { + let stream = self.stream.clone(); + wait_for_future(py, next_stream(stream, true)) } - fn __next__(&mut self, py: Python) -> PyResult> { + fn __next__(&mut self, py: Python) -> PyResult { self.next(py) } + fn __anext__<'py>(&'py self, py: Python<'py>) -> PyResult> { + let stream = self.stream.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream, false)) + } + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } + + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } +} + +async fn next_stream( + stream: Arc>, + sync: bool, +) -> PyResult { + let mut stream = stream.lock().await; + match stream.next().await { + Some(Ok(batch)) => Ok(batch.into()), + Some(Err(e)) => Err(e.into()), + None => { + // Depending on whether the iteration is sync or not, we raise either a + // StopIteration or a StopAsyncIteration + if sync { + Err(PyStopIteration::new_err("stream exhausted")) + } else { + Err(PyStopAsyncIteration::new_err("stream exhausted")) + } + } + } } From 4862d9c18f0b9db180436637d5cefa5629264b94 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 16 Dec 2024 16:52:04 -0500 Subject: [PATCH 2/5] use __anext__ --- python/datafusion/record_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index 005a423f6..a2ad850c0 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -63,7 +63,7 @@ def next(self) -> RecordBatch | None: async def __anext__(self) -> RecordBatch: """Async iterator function.""" - next_batch = anext(self.rbs) + next_batch = self.rbs.__anext__() return RecordBatch(next_batch) def __next__(self) -> RecordBatch: From 0c159ce842fcdf38fa2977f75c19a2ef55ff259c Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 16 Dec 2024 16:52:16 -0500 Subject: [PATCH 3/5] use await --- python/datafusion/record_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index a2ad850c0..633e341ff 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -63,7 +63,7 @@ def next(self) -> RecordBatch | None: async def __anext__(self) -> RecordBatch: """Async iterator function.""" - next_batch = self.rbs.__anext__() + next_batch = await self.rbs.__anext__() return RecordBatch(next_batch) def __next__(self) -> RecordBatch: From e620b820570925b4a13ece7c035a09ea0f154cd2 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Tue, 17 Dec 2024 13:11:29 -0500 Subject: [PATCH 4/5] fix failing test --- python/tests/test_dataframe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index b82f95e35..e3bd1b2a5 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -761,8 +761,8 @@ def test_execution_plan(aggregate_df): batch = stream.next() assert batch is not None # there should be no more batches - batch = stream.next() - assert batch is None + with pytest.raises(StopIteration): + stream.next() def test_repartition(df): From 65864ce01af7193c55feeb23ce1c4fc16504327d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 8 Jan 2025 17:09:12 -0500 Subject: [PATCH 5/5] Since we are raising an error instead of returning a None, we can update the type hint. --- python/datafusion/record_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index 633e341ff..75e58998f 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -57,7 +57,7 @@ def __init__(self, record_batch_stream: df_internal.RecordBatchStream) -> None: """This constructor is typically not called by the end user.""" self.rbs = record_batch_stream - def next(self) -> RecordBatch | None: + def next(self) -> RecordBatch: """See :py:func:`__next__` for the iterator function.""" return next(self)