Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use pyo3-asyncio to get a fresh tokio runtime #2

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ reqwest = { version = "*", features = ["native-tls-vendored"] }
version = "0.20"
features = ["extension-module", "abi3", "abi3-py38"]

[dependencies.pyo3-asyncio]
version = "0.20"
features = ["tokio-runtime"]

[dependencies.deltalake]
path = "../crates/deltalake"
version = "0"
Expand Down
79 changes: 39 additions & 40 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod error;
mod filesystem;
mod schema;
mod utils;
extern crate pyo3_asyncio;

use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
Expand Down Expand Up @@ -52,10 +53,15 @@ use crate::filesystem::FsConfig;
use crate::schema::schema_to_pyobject;

#[inline]
fn rt() -> PyResult<tokio::runtime::Runtime> {
fn rt_pyo3() -> PyResult<tokio::runtime::Runtime> {
tokio::runtime::Runtime::new().map_err(|err| PyRuntimeError::new_err(err.to_string()))
}

#[inline]
fn rt() -> &'static tokio::runtime::Runtime {
pyo3_asyncio::tokio::get_runtime()
}

#[derive(FromPyObject)]
enum PartitionFilterValue<'a> {
Single(&'a str),
Expand Down Expand Up @@ -113,7 +119,7 @@ impl RawDeltaTable {
.map_err(PythonError::from)?;
}

let table = rt()?.block_on(builder.load()).map_err(PythonError::from)?;
let table = rt().block_on(builder.load()).map_err(PythonError::from)?;
Ok(RawDeltaTable {
_table: table,
_config: FsConfig {
Expand All @@ -135,7 +141,7 @@ impl RawDeltaTable {
) -> PyResult<String> {
let data_catalog = deltalake::data_catalog::get_data_catalog(data_catalog, catalog_options)
.map_err(|e| PyValueError::new_err(format!("{}", e)))?;
let table_uri = rt()?
let table_uri = rt()
.block_on(data_catalog.get_table_storage_location(
data_catalog_id,
database_name,
Expand Down Expand Up @@ -174,13 +180,13 @@ impl RawDeltaTable {
}

pub fn load_version(&mut self, version: i64) -> PyResult<()> {
Ok(rt()?
Ok(rt()
.block_on(self._table.load_version(version))
.map_err(PythonError::from)?)
}

pub fn get_latest_version(&mut self) -> PyResult<i64> {
Ok(rt()?
Ok(rt()
.block_on(self._table.get_latest_version())
.map_err(PythonError::from)?)
}
Expand All @@ -190,7 +196,7 @@ impl RawDeltaTable {
DateTime::<Utc>::from(DateTime::<FixedOffset>::parse_from_rfc3339(ds).map_err(
|err| PyValueError::new_err(format!("Failed to parse datetime string: {err}")),
)?);
Ok(rt()?
Ok(rt()
.block_on(self._table.load_with_datetime(datetime))
.map_err(PythonError::from)?)
}
Expand Down Expand Up @@ -280,7 +286,7 @@ impl RawDeltaTable {
if let Some(retention_period) = retention_hours {
cmd = cmd.with_retention_period(Duration::hours(retention_period as i64));
}
let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -333,7 +339,7 @@ impl RawDeltaTable {
cmd = cmd.with_predicate(update_predicate);
}

let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -361,7 +367,7 @@ impl RawDeltaTable {
.map_err(PythonError::from)?;
cmd = cmd.with_filters(&converted_filters);

let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -394,7 +400,7 @@ impl RawDeltaTable {
.map_err(PythonError::from)?;
cmd = cmd.with_filters(&converted_filters);

let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -593,7 +599,7 @@ impl RawDeltaTable {
}
}

let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -624,7 +630,7 @@ impl RawDeltaTable {
}
cmd = cmd.with_ignore_missing_files(ignore_missing_files);
cmd = cmd.with_protocol_downgrade_allowed(protocol_downgrade_allowed);
let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand All @@ -633,7 +639,7 @@ impl RawDeltaTable {

/// Run the History command on the Delta Table: Returns provenance information, including the operation, user, and so on, for each write to a table.
pub fn history(&mut self, limit: Option<usize>) -> PyResult<Vec<String>> {
let history = rt()?
let history = rt()
.block_on(self._table.history(limit))
.map_err(PythonError::from)?;
Ok(history
Expand All @@ -643,7 +649,7 @@ impl RawDeltaTable {
}

pub fn update_incremental(&mut self) -> PyResult<()> {
Ok(rt()?
Ok(rt()
.block_on(self._table.update_incremental(None))
.map_err(PythonError::from)?)
}
Expand Down Expand Up @@ -821,39 +827,36 @@ impl RawDeltaTable {
};
let store = self._table.log_store();

rt()?
.block_on(commit(
&*store,
&actions,
operation,
self._table.get_state(),
None,
))
.map_err(PythonError::from)?;
rt().block_on(commit(
&*store,
&actions,
operation,
self._table.get_state(),
None,
))
.map_err(PythonError::from)?;

Ok(())
}

pub fn get_py_storage_backend(&self) -> PyResult<filesystem::DeltaFileSystemHandler> {
Ok(filesystem::DeltaFileSystemHandler {
inner: self._table.object_store(),
rt: Arc::new(rt()?),
rt: Arc::new(rt_pyo3()?),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeltaFileSystemHandler (https://github.com/delta-io/delta-rs/blob/main/python/src/filesystem.rs#L58) uses the runtime in utils.rs (https://github.com/delta-io/delta-rs/blob/main/python/src/utils.rs#L10-L13), which is not pyo3-asyncio. So I keep the original rt but change name to rt_pyo3 to avoid errors.
Is there a better way to handle this?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, so utils.rs and lib.rs both have a runtime that they create? I don't actually know what is the right approach here without more knowledge of runtimes and tokio. Should we get help or learn ourselves?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://sourcecode.vectra.io/projects/DP/repos/delta-rs/pull-requests/26/overview I think this provides the answer. This change seems worthy of pushing upstream, and would require a separate PR.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That answers the question of why we're changing lib.rs. However, my question is related specifically to our having two runtime functions: should we also be using py03-asyncio's tokio-runtime in utils.rs?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get this out and have the community decide. Frankly @tsh56 and I have higher-priority issues to resolve.

config: self._config.clone(),
known_sizes: None,
})
}

pub fn create_checkpoint(&self) -> PyResult<()> {
rt()?
.block_on(create_checkpoint(&self._table))
rt().block_on(create_checkpoint(&self._table))
.map_err(PythonError::from)?;

Ok(())
}

pub fn cleanup_metadata(&self) -> PyResult<()> {
rt()?
.block_on(cleanup_metadata(&self._table))
rt().block_on(cleanup_metadata(&self._table))
.map_err(PythonError::from)?;

Ok(())
Expand All @@ -875,7 +878,7 @@ impl RawDeltaTable {
if let Some(predicate) = predicate {
cmd = cmd.with_predicate(predicate);
}
let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand All @@ -889,7 +892,7 @@ impl RawDeltaTable {
let cmd = FileSystemCheckBuilder::new(self._table.log_store(), self._table.state.clone())
.with_dry_run(dry_run);

let (table, metrics) = rt()?
let (table, metrics) = rt()
.block_on(cmd.into_future())
.map_err(PythonError::from)?;
self._table.state = table.state;
Expand Down Expand Up @@ -1076,7 +1079,7 @@ fn batch_distinct(batch: PyArrowType<RecordBatch>) -> PyResult<PyArrowType<Recor
let schema = batch.0.schema();
ctx.register_batch("batch", batch.0)
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
let batches = rt()?
let batches = rt()
.block_on(async { ctx.table("batch").await?.distinct()?.collect().await })
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;

Expand Down Expand Up @@ -1142,7 +1145,7 @@ fn write_to_deltalake(
let save_mode = mode.parse().map_err(PythonError::from)?;

let options = storage_options.clone().unwrap_or_default();
let table = rt()?
let table = rt()
.block_on(DeltaOps::try_from_uri_with_storage_options(
&table_uri, options,
))
Expand Down Expand Up @@ -1174,8 +1177,7 @@ fn write_to_deltalake(
builder = builder.with_configuration(config);
};

rt()?
.block_on(builder.into_future())
rt().block_on(builder.into_future())
.map_err(PythonError::from)?;

Ok(())
Expand Down Expand Up @@ -1219,8 +1221,7 @@ fn create_deltalake(
builder = builder.with_configuration(config);
};

rt()?
.block_on(builder.into_future())
rt().block_on(builder.into_future())
.map_err(PythonError::from)?;

Ok(())
Expand Down Expand Up @@ -1264,8 +1265,7 @@ fn write_new_deltalake(
builder = builder.with_configuration(config);
};

rt()?
.block_on(builder.into_future())
rt().block_on(builder.into_future())
.map_err(PythonError::from)?;

Ok(())
Expand Down Expand Up @@ -1317,8 +1317,7 @@ fn convert_to_deltalake(
builder = builder.with_metadata(json_metadata);
};

rt()?
.block_on(builder.into_future())
rt().block_on(builder.into_future())
.map_err(PythonError::from)?;
Ok(())
}
Expand Down