diff --git a/Cargo.lock b/Cargo.lock index c708c516ab36..fb8198117c7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2237,7 +2237,9 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-functions-aggregate-common", "datafusion-proto", + "datafusion-proto-common", "doc-comment", "futures", "log", diff --git a/Cargo.toml b/Cargo.toml index d26446c11167..fc4e4cbdff8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -211,5 +211,5 @@ large_futures = "warn" used_underscore_binding = "warn" [workspace.lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] } +unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)", "cfg(tarpaulin_include)"] } unused_qualifications = "deny" diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 27a7e7ae3c06..d75a8ea7ba93 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -104,10 +104,8 @@ pub(crate) mod test_util { mod tests { use std::fmt::{self, Display, Formatter}; - use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use std::task::{Context, Poll}; use std::time::Duration; use crate::datasource::file_format::parquet::test_util::store_parquet; @@ -117,7 +115,7 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use arrow::array::RecordBatch; - use arrow_schema::{Schema, SchemaRef}; + use arrow_schema::Schema; use datafusion_catalog::Session; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array, @@ -137,7 +135,7 @@ mod tests { }; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; - use datafusion_execution::{RecordBatchStream, TaskContext}; + use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{collect, ExecutionPlan}; @@ -150,7 +148,7 @@ mod tests { use async_trait::async_trait; use datafusion_datasource::file_groups::FileGroup; use futures::stream::BoxStream; - use futures::{Stream, StreamExt}; + use futures::StreamExt; use insta::assert_snapshot; use log::error; use object_store::local::LocalFileSystem; @@ -166,6 +164,8 @@ mod tests { use parquet::format::FileMetaData; use tokio::fs::File; + use crate::test_util::bounded_stream; + enum ForceViews { Yes, No, @@ -1646,42 +1646,110 @@ mod tests { Ok(()) } - /// Creates an bounded stream for testing purposes. - fn bounded_stream( - batch: RecordBatch, - limit: usize, - ) -> datafusion_execution::SendableRecordBatchStream { - Box::pin(BoundedStream { - count: 0, - limit, - batch, - }) - } + #[tokio::test] + async fn test_memory_reservation_column_parallel() -> Result<()> { + async fn test_memory_reservation(global: ParquetOptions) -> Result<()> { + let field_a = Field::new("a", DataType::Utf8, false); + let field_b = Field::new("b", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let object_store_url = ObjectStoreUrl::local_filesystem(); - struct BoundedStream { - limit: usize, - count: usize, - batch: RecordBatch, - } + let file_sink_config = FileSinkConfig { + original_url: String::default(), + object_store_url: object_store_url.clone(), + file_group: FileGroup::new(vec![PartitionedFile::new( + "/tmp".to_string(), + 1, + )]), + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![], + insert_op: InsertOp::Overwrite, + keep_partition_by_columns: false, + file_extension: "parquet".into(), + }; + let parquet_sink = Arc::new(ParquetSink::new( + file_sink_config, + TableParquetOptions { + key_value_metadata: std::collections::HashMap::from([ + ("my-data".to_string(), Some("stuff".to_string())), + ("my-data-bool-key".to_string(), None), + ]), + global, + ..Default::default() + }, + )); + + // create data + let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); + let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); + let batch = + RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); - impl Stream for BoundedStream { - type Item = Result; + // create task context + let task_context = build_ctx(object_store_url.as_ref()); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no bytes are reserved yet" + ); - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - if self.count >= self.limit { - return Poll::Ready(None); + let mut write_task = FileSink::write_all( + parquet_sink.as_ref(), + Box::pin(RecordBatchStreamAdapter::new( + schema, + bounded_stream(batch, 1000), + )), + &task_context, + ); + + // incrementally poll and check for memory reservation + let mut reserved_bytes = 0; + while futures::poll!(&mut write_task).is_pending() { + reserved_bytes += task_context.memory_pool().reserved(); + tokio::time::sleep(Duration::from_micros(1)).await; } - self.count += 1; - Poll::Ready(Some(Ok(self.batch.clone()))) - } - } + assert!( + reserved_bytes > 0, + "should have bytes reserved during write" + ); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no leaking byte reservation" + ); - impl RecordBatchStream for BoundedStream { - fn schema(&self) -> SchemaRef { - self.batch.schema() + Ok(()) } + + let write_opts = ParquetOptions { + allow_single_file_parallelism: false, + ..Default::default() + }; + test_memory_reservation(write_opts) + .await + .expect("should track for non-parallel writes"); + + let row_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 10, + maximum_buffered_record_batches_per_stream: 1, + ..Default::default() + }; + test_memory_reservation(row_parallel_write_opts) + .await + .expect("should track for row-parallel writes"); + + let col_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 1, + maximum_buffered_record_batches_per_stream: 2, + ..Default::default() + }; + test_memory_reservation(col_parallel_write_opts) + .await + .expect("should track for column-parallel writes"); + + Ok(()) } } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index cc510bc81f1a..a1afec3d90fa 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -829,6 +829,7 @@ pub mod test; mod schema_equivalence; pub mod test_util; +// pub use test_util::bounded_stream; #[cfg(doctest)] doc_comment::doctest!("../../../README.md", readme_example_test); diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index d6865ca3d532..2f8e66a2bbfb 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -22,12 +22,14 @@ pub mod parquet; pub mod csv; +use futures::Stream; use std::any::Any; use std::collections::HashMap; use std::fs::File; use std::io::Write; use std::path::Path; use std::sync::Arc; +use std::task::{Context, Poll}; use crate::catalog::{TableProvider, TableProviderFactory}; use crate::dataframe::DataFrame; @@ -38,11 +40,13 @@ use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; +use crate::execution::SendableRecordBatchStream; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_catalog::Session; use datafusion_common::TableReference; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use std::pin::Pin; use async_trait::async_trait; @@ -52,6 +56,8 @@ use tempfile::TempDir; pub use datafusion_common::test_util::parquet_test_data; pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; +use crate::execution::RecordBatchStream; + /// Scan an empty data source, mainly used in tests pub fn scan_empty( name: Option<&str>, @@ -234,3 +240,44 @@ pub fn register_unbounded_file_with_ordering( ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } + +/// Creates a bounded stream that emits the same record batch a specified number of times. +/// This is useful for testing purposes. +pub fn bounded_stream( + record_batch: RecordBatch, + limit: usize, +) -> SendableRecordBatchStream { + Box::pin(BoundedStream { + record_batch, + count: 0, + limit, + }) +} + +struct BoundedStream { + record_batch: RecordBatch, + count: usize, + limit: usize, +} + +impl Stream for BoundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.count >= self.limit { + Poll::Ready(None) + } else { + self.count += 1; + Poll::Ready(Some(Ok(self.record_batch.clone()))) + } + } +} + +impl RecordBatchStream for BoundedStream { + fn schema(&self) -> SchemaRef { + self.record_batch.schema() + } +} diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 5ff1c1d07216..9bcc1edff882 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -21,7 +21,7 @@ use arrow::array::{ArrayRef, BooleanArray}; use datafusion_common::{not_impl_err, Result}; /// Describes how many rows should be emitted during grouping. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum EmitTo { /// Emit all groups All, diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 5c80c1b04225..fce926fcc10f 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -43,7 +43,9 @@ arrow = { workspace = true, features = ["ffi"] } async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } +datafusion-functions-aggregate-common = { workspace = true } datafusion-proto = { workspace = true } +datafusion-proto-common = { workspace = true } futures = { workspace = true } log = { workspace = true } prost = { workspace = true } @@ -55,3 +57,4 @@ doc-comment = { workspace = true } [features] integration-tests = [] +tarpaulin_include = [] # Exists only to prevent warnings on stable and still have accurate coverage diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index a18e6df59bf1..45049f696a4f 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -21,7 +21,8 @@ use abi_stable::StableAbi; use arrow::{ array::{make_array, ArrayRef}, datatypes::{Schema, SchemaRef}, - ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + error::ArrowError, + ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; use log::error; @@ -31,30 +32,37 @@ use log::error; #[derive(Debug, StableAbi)] pub struct WrappedSchema(#[sabi(unsafe_opaque_field)] pub FFI_ArrowSchema); +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_ffi_schema_error(e: ArrowError) -> FFI_ArrowSchema { + error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); + FFI_ArrowSchema::empty() +} + impl From for WrappedSchema { fn from(value: SchemaRef) -> Self { - let ffi_schema = match FFI_ArrowSchema::try_from(value.as_ref()) { - Ok(s) => s, - Err(e) => { - error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); - FFI_ArrowSchema::empty() - } - }; - - WrappedSchema(ffi_schema) + WrappedSchema( + FFI_ArrowSchema::try_from(value.as_ref()) + .unwrap_or_else(catch_ffi_schema_error), + ) } } +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_df_schema_error(e: ArrowError) -> Schema { + error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); + Schema::empty() +} impl From for SchemaRef { fn from(value: WrappedSchema) -> Self { - let schema = match Schema::try_from(&value.0) { - Ok(s) => s, - Err(e) => { - error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); - Schema::empty() - } - }; - Arc::new(schema) + Arc::new(Schema::try_from(&value.0).unwrap_or_else(catch_df_schema_error)) } } @@ -71,7 +79,7 @@ pub struct WrappedArray { } impl TryFrom for ArrayRef { - type Error = arrow::error::ArrowError; + type Error = ArrowError; fn try_from(value: WrappedArray) -> Result { let data = unsafe { from_ffi(value.array, &value.schema.0)? }; @@ -79,3 +87,14 @@ impl TryFrom for ArrayRef { Ok(make_array(data)) } } + +impl TryFrom<&ArrayRef> for WrappedArray { + type Error = ArrowError; + + fn try_from(array: &ArrayRef) -> Result { + let (array, schema) = to_ffi(&array.to_data())?; + let schema = WrappedSchema(schema); + + Ok(WrappedArray { array, schema }) + } +} diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index 877129fc5bb1..4fb227abecee 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -34,6 +34,7 @@ pub mod schema_provider; pub mod session_config; pub mod table_provider; pub mod table_source; +pub mod udaf; pub mod udf; pub mod util; pub mod volatility; diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 3592c16b8fab..1b2b37708b54 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -300,7 +300,10 @@ impl From for EmissionType { #[cfg(test)] mod tests { - use datafusion::physical_plan::Partitioning; + use datafusion::{ + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::Partitioning, + }; use super::*; @@ -311,8 +314,13 @@ mod tests { Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); let original_props = PlanProperties::new( - EquivalenceProperties::new(schema), - Partitioning::UnknownPartitioning(3), + EquivalenceProperties::new(Arc::clone(&schema)).with_reorder( + LexOrdering::new(vec![PhysicalSortExpr { + expr: datafusion::physical_plan::expressions::col("a", &schema)?, + options: Default::default(), + }]), + ), + Partitioning::RoundRobinBatch(3), EmissionType::Incremental, Boundedness::Bounded, ); diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 939c4050028c..78d65a816fcc 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -196,3 +196,49 @@ impl Stream for FFI_RecordBatchStream { } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::{ + common::record_batch, error::Result, execution::SendableRecordBatchStream, + test_util::bounded_stream, + }; + + use super::FFI_RecordBatchStream; + use futures::StreamExt; + + #[tokio::test] + async fn test_round_trip_record_batch_stream() -> Result<()> { + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 3]), + ("b", Float64, vec![Some(4.0), None, Some(5.0)]) + )?; + let original_rbs = bounded_stream(record_batch.clone(), 1); + + let ffi_rbs: FFI_RecordBatchStream = original_rbs.into(); + let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs); + + let schema = ffi_rbs.schema(); + assert_eq!( + schema, + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true) + ])) + ); + + let batch = ffi_rbs.next().await; + assert!(batch.is_some()); + assert!(batch.as_ref().unwrap().is_ok()); + assert_eq!(batch.unwrap().unwrap(), record_batch); + + // There should only be one batch + let no_batch = ffi_rbs.next().await; + assert!(no_batch.is_none()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index c7a9816431e1..0590ff02b9a2 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -29,6 +29,8 @@ use catalog::create_catalog_provider; use crate::catalog_provider::FFI_CatalogProvider; +use crate::udaf::FFI_AggregateUDF; + use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; use arrow::array::RecordBatch; use async_provider::create_async_table_provider; @@ -37,7 +39,10 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_random_func}; +use udf_udaf_udwf::{ + create_ffi_abs_func, create_ffi_random_func, create_ffi_stddev_func, + create_ffi_sum_func, +}; mod async_provider; pub mod catalog; @@ -61,6 +66,11 @@ pub struct ForeignLibraryModule { /// Create a scalar UDF pub create_scalar_udf: extern "C" fn() -> FFI_ScalarUDF, + /// Create an aggregate UDAF using sum + pub create_sum_udaf: extern "C" fn() -> FFI_AggregateUDF, + + /// Createa grouping UDAF using stddev + pub create_stddev_udaf: extern "C" fn() -> FFI_AggregateUDF, pub create_nullary_udf: extern "C" fn() -> FFI_ScalarUDF, pub version: extern "C" fn() -> u64, @@ -108,6 +118,8 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { create_catalog: create_catalog_provider, create_table: construct_table_provider, create_scalar_udf: create_ffi_abs_func, + create_sum_udaf: create_ffi_sum_func, + create_stddev_udaf: create_ffi_stddev_func, create_nullary_udf: create_ffi_random_func, version: super::version, } diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index b40bec762bd7..1df846220b25 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -15,22 +15,31 @@ // specific language governing permissions and limitations // under the License. -use crate::udf::FFI_ScalarUDF; +use crate::{udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF}; use datafusion::{ functions::math::{abs::AbsFunc, random::RandomFunc}, - logical_expr::ScalarUDF, + functions_aggregate::{stddev::Stddev, sum::Sum}, + logical_expr::{AggregateUDF, ScalarUDF}, }; use std::sync::Arc; pub(crate) extern "C" fn create_ffi_abs_func() -> FFI_ScalarUDF { let udf: Arc = Arc::new(AbsFunc::new().into()); - udf.into() } +pub(crate) extern "C" fn create_ffi_sum_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Sum::new().into()); + udaf.into() +} + +pub(crate) extern "C" fn create_ffi_stddev_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Stddev::new().into()); + udaf.into() +} + pub(crate) extern "C" fn create_ffi_random_func() -> FFI_ScalarUDF { let udf: Arc = Arc::new(RandomFunc::new().into()); - udf.into() } diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs new file mode 100644 index 000000000000..897cd9f49cc3 --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -0,0 +1,361 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::{array::ArrayRef, error::ArrowError}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::Accumulator, + scalar::ScalarValue, +}; +use prost::Message; + +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; + +/// A stable struct for sharing [`Accumulator`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_Accumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn(accumulator: &Self) -> RResult, RString>, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: + unsafe extern "C" fn(accumulator: &Self) -> RResult>, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + states: RVec, + ) -> RResult<(), RString>, + + pub retract_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + pub supports_retract_batch: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_Accumulator {} +unsafe impl Sync for FFI_Accumulator {} + +pub struct AccumulatorPrivateData { + pub accumulator: Box, +} + +impl FFI_Accumulator { + #[inline] + unsafe fn inner(&self) -> &mut AccumulatorPrivateData { + let private_data = self.private_data as *mut AccumulatorPrivateData; + &mut (*private_data) + } +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let accum_data = accumulator.inner(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + rresult!(accum_data.accumulator.update_batch(&values_arrays)) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &FFI_Accumulator, +) -> RResult, RString> { + let accum_data = accumulator.inner(); + + let scalar_result = rresult_return!(accum_data.accumulator.evaluate()); + let proto_result: datafusion_proto::protobuf::ScalarValue = + rresult_return!((&scalar_result).try_into()); + + RResult::ROk(proto_result.encode_to_vec().into()) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { + // let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + // let accum_data = &mut (*private_data); + accumulator.inner().accumulator.size() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &FFI_Accumulator, +) -> RResult>, RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let state = rresult_return!(accum_data.accumulator.state()); + let state = state + .into_iter() + .map(|state_val| { + datafusion_proto::protobuf::ScalarValue::try_from(&state_val) + .map_err(DataFusionError::from) + .map(|v| RVec::from(v.encode_to_vec())) + }) + .collect::>>() + .map(|state_vec| state_vec.into()); + + rresult!(state) +} + +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + states: RVec, +) -> RResult<(), RString> { + let accum_data = accumulator.inner(); + + let states = rresult_return!(states + .into_iter() + .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) + .collect::>>()); + + rresult!(accum_data.accumulator.merge_batch(&states)) +} + +unsafe extern "C" fn retract_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let accum_data = accumulator.inner(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + rresult!(accum_data.accumulator.retract_batch(&values_arrays)) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); + drop(private_data); +} + +impl From> for FFI_Accumulator { + fn from(accumulator: Box) -> Self { + let supports_retract_batch = accumulator.supports_retract_batch(); + let private_data = AccumulatorPrivateData { accumulator }; + + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + retract_batch: retract_batch_fn_wrapper, + supports_retract_batch, + + // clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_Accumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_Accumulator. +#[derive(Debug)] +pub struct ForeignAccumulator { + accumulator: FFI_Accumulator, +} + +unsafe impl Send for ForeignAccumulator {} +unsafe impl Sync for ForeignAccumulator {} + +impl From for ForeignAccumulator { + fn from(accumulator: FFI_Accumulator) -> Self { + Self { accumulator } + } +} + +impl Accumulator for ForeignAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn evaluate(&mut self) -> Result { + unsafe { + let scalar_bytes = + df_result!((self.accumulator.evaluate)(&self.accumulator))?; + + let proto_scalar = + datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn state(&mut self) -> Result> { + unsafe { + let state_protos = df_result!((self.accumulator.state)(&self.accumulator))?; + + state_protos + .into_iter() + .map(|proto_bytes| { + datafusion_proto::protobuf::ScalarValue::decode(proto_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e))) + .and_then(|proto_value| { + ScalarValue::try_from(&proto_value) + .map_err(DataFusionError::from) + }) + }) + .collect::>>() + } + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + unsafe { + let states = states + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + states.into() + )) + } + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.retract_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn supports_retract_batch(&self) -> bool { + self.accumulator.supports_retract_batch + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array}; + use datafusion::{ + common::create_array, error::Result, + functions_aggregate::average::AvgAccumulator, logical_expr::Accumulator, + scalar::ScalarValue, + }; + + use super::{FFI_Accumulator, ForeignAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let original_accum = AvgAccumulator::default(); + let original_size = original_accum.size(); + let original_supports_retract = original_accum.supports_retract_batch(); + + let boxed_accum: Box = Box::new(original_accum); + let ffi_accum: FFI_Accumulator = boxed_accum.into(); + let mut foreign_accum: ForeignAccumulator = ffi_accum.into(); + + // Send in an array to average. There are 5 values and it should average to 30.0 + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + foreign_accum.update_batch(&[values])?; + + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + let state = foreign_accum.state()?; + assert_eq!(state.len(), 2); + assert_eq!(state[0], ScalarValue::UInt64(Some(5))); + assert_eq!(state[1], ScalarValue::Float64(Some(150.0))); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = vec![ + make_array(create_array!(UInt64, vec![1]).to_data()), + make_array(create_array!(Float64, vec![0.0]).to_data()), + ]; + + foreign_accum.merge_batch(&second_states)?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(25.0))); + + // If we remove a batch that is equivalent to the state we added + // we should go back to our original value of 30.0 + let values = create_array!(Float64, vec![0.0]); + foreign_accum.retract_batch(&[values])?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + assert_eq!(original_size, foreign_accum.size()); + assert_eq!( + original_supports_retract, + foreign_accum.supports_retract_batch() + ); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs new file mode 100644 index 000000000000..3a25d09c4a55 --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use abi_stable::{ + std_types::{RString, RVec}, + StableAbi, +}; +use arrow::{ + datatypes::{DataType, Schema}, + ffi::FFI_ArrowSchema, +}; +use datafusion::{ + error::DataFusionError, logical_expr::function::AccumulatorArgs, + physical_expr::LexOrdering, physical_plan::PhysicalExpr, prelude::SessionContext, +}; +use datafusion_proto::{ + physical_plan::{ + from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, + to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs}, + DefaultPhysicalExtensionCodec, + }, + protobuf::PhysicalAggregateExprNode, +}; +use prost::Message; + +use crate::arrow_wrappers::WrappedSchema; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AccumulatorArgs { + return_type: WrappedSchema, + schema: WrappedSchema, + is_reversed: bool, + name: RString, + physical_expr_def: RVec, +} + +impl TryFrom> for FFI_AccumulatorArgs { + type Error = DataFusionError; + + fn try_from(args: AccumulatorArgs) -> Result { + let return_type = WrappedSchema(FFI_ArrowSchema::try_from(args.return_type)?); + let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); + + let codec = DefaultPhysicalExtensionCodec {}; + let ordering_req = + serialize_physical_sort_exprs(args.ordering_req.to_owned(), &codec)?; + + let expr = serialize_physical_exprs(args.exprs, &codec)?; + + let physical_expr_def = PhysicalAggregateExprNode { + expr, + ordering_req, + distinct: args.is_distinct, + ignore_nulls: args.ignore_nulls, + fun_definition: None, + aggregate_function: None, + }; + let physical_expr_def = physical_expr_def.encode_to_vec().into(); + + Ok(Self { + return_type, + schema, + is_reversed: args.is_reversed, + name: args.name.into(), + physical_expr_def, + }) + } +} + +/// This struct mirrors AccumulatorArgs except that it contains owned data. +/// It is necessary to create this struct so that we can parse the protobuf +/// data across the FFI boundary and turn it into owned data that +/// AccumulatorArgs can then reference. +pub struct ForeignAccumulatorArgs { + pub return_type: DataType, + pub schema: Schema, + pub ignore_nulls: bool, + pub ordering_req: LexOrdering, + pub is_reversed: bool, + pub name: String, + pub is_distinct: bool, + pub exprs: Vec>, +} + +impl TryFrom for ForeignAccumulatorArgs { + type Error = DataFusionError; + + fn try_from(value: FFI_AccumulatorArgs) -> Result { + let proto_def = + PhysicalAggregateExprNode::decode(value.physical_expr_def.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + let return_type = (&value.return_type.0).try_into()?; + let schema = Schema::try_from(&value.schema.0)?; + + let default_ctx = SessionContext::new(); + let codex = DefaultPhysicalExtensionCodec {}; + + // let proto_ordering_req = + // rresult_return!(PhysicalSortExprNodeCollection::decode(ordering_req.as_ref())); + let ordering_req = parse_physical_sort_exprs( + &proto_def.ordering_req, + &default_ctx, + &schema, + &codex, + )?; + + let exprs = parse_physical_exprs(&proto_def.expr, &default_ctx, &schema, &codex)?; + + Ok(Self { + return_type, + schema, + ignore_nulls: proto_def.ignore_nulls, + ordering_req, + is_reversed: value.is_reversed, + name: value.name.to_string(), + is_distinct: proto_def.distinct, + exprs, + }) + } +} + +impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { + fn from(value: &'a ForeignAccumulatorArgs) -> Self { + Self { + return_type: &value.return_type, + schema: &value.schema, + ignore_nulls: value.ignore_nulls, + ordering_req: &value.ordering_req, + is_reversed: value.is_reversed, + name: value.name.as_str(), + is_distinct: value.is_distinct, + exprs: &value.exprs, + } + } +} + +#[cfg(test)] +mod tests { + use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::{ + error::Result, + logical_expr::function::AccumulatorArgs, + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::expressions::col, + }; + + #[test] + fn test_round_trip_accumulator_args() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let orig_args = AccumulatorArgs { + return_type: &DataType::Float64, + schema: &schema, + ignore_nulls: false, + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + let orig_str = format!("{:?}", orig_args); + + let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?; + let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?; + let round_trip_args: AccumulatorArgs = (&foreign_args).into(); + + let round_trip_str = format!("{:?}", round_trip_args); + + // Since AccumulatorArgs doesn't implement Eq, simply compare + // the debug strings. + assert_eq!(orig_str, round_trip_str); + println!("{}", round_trip_str); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs new file mode 100644 index 000000000000..3f6b5def4f9b --- /dev/null +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -0,0 +1,542 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; + +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::{ + array::{Array, ArrayRef, BooleanArray}, + error::ArrowError, + ffi::to_ffi, +}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::{EmitTo, GroupsAccumulator}, +}; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, +}; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_GroupsAccumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn( + accumulator: &Self, + emit_to: FFI_EmitTo, + ) -> RResult, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: unsafe extern "C" fn( + accumulator: &Self, + emit_to: FFI_EmitTo, + ) -> RResult, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + pub convert_to_state: unsafe extern "C" fn( + accumulator: &Self, + values: RVec, + opt_filter: ROption, + ) + -> RResult, RString>, + + pub supports_convert_to_state: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignGroupsAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_GroupsAccumulator {} +unsafe impl Sync for FFI_GroupsAccumulator {} + +pub struct GroupsAccumulatorPrivateData { + pub accumulator: Box, +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + let group_indices: Vec = group_indices.into_iter().collect(); + + let maybe_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()); + let opt_filter = maybe_filter.map(BooleanArray::from); + + rresult!(accum_data.accumulator.update_batch( + &values_arrays, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let result = rresult_return!(accum_data.accumulator.evaluate(emit_to.into())); + + rresult!(WrappedArray::try_from(&result)) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + accum_data.accumulator.size() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult, RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let state = rresult_return!(accum_data.accumulator.state(emit_to.into())); + rresult!(state + .into_iter() + .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + let group_indices: Vec = group_indices.into_iter().collect(); + + let maybe_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()); + let opt_filter = maybe_filter.map(BooleanArray::from); + + rresult!(accum_data.accumulator.merge_batch( + &values_arrays, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn convert_to_state_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + values: RVec, + opt_filter: ROption, +) -> RResult, RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let values = rresult_return!(values + .into_iter() + .map(|v| ArrayRef::try_from(v).map_err(DataFusionError::from)) + .collect::>>()); + + let opt_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()).map(BooleanArray::from); + + let state = rresult_return!(accum_data + .accumulator + .convert_to_state(&values, opt_filter.as_ref())); + + rresult!(state + .iter() + .map(|arr| WrappedArray::try_from(arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData); + drop(private_data); +} + +// unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> FFI_GroupsAccumulator { +// let private_data = accumulator.private_data as *const GroupsAccumulatorPrivateData; +// let accum_data = &(*private_data); + +// Box::new(accum_data.accumulator).into() +// } + +// impl Clone for FFI_GroupsAccumulator { +// fn clone(&self) -> Self { +// unsafe { (self.clone)(self) } +// } +// } + +impl From> for FFI_GroupsAccumulator { + fn from(accumulator: Box) -> Self { + let supports_convert_to_state = accumulator.supports_convert_to_state(); + let private_data = GroupsAccumulatorPrivateData { accumulator }; + + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + convert_to_state: convert_to_state_fn_wrapper, + supports_convert_to_state, + + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_GroupsAccumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignGroupsAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_GroupsAccumulator. +#[derive(Debug)] +pub struct ForeignGroupsAccumulator { + accumulator: FFI_GroupsAccumulator, +} + +unsafe impl Send for ForeignGroupsAccumulator {} +unsafe impl Sync for ForeignGroupsAccumulator {} + +impl From for ForeignGroupsAccumulator { + fn from(accumulator: FFI_GroupsAccumulator) -> Self { + Self { accumulator } + } +} + +impl GroupsAccumulator for ForeignGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + unsafe { + let return_array = df_result!((self.accumulator.evaluate)( + &self.accumulator, + emit_to.into() + ))?; + + return_array.try_into().map_err(DataFusionError::from) + } + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + unsafe { + let returned_arrays = + df_result!((self.accumulator.state)(&self.accumulator, emit_to.into()))?; + + returned_arrays + .into_iter() + .map(|wrapped_array| { + wrapped_array.try_into().map_err(DataFusionError::from) + }) + .collect::>>() + } + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + let returned_array = df_result!((self.accumulator.convert_to_state)( + &self.accumulator, + values, + opt_filter + ))?; + + returned_array + .into_iter() + .map(|arr| arr.try_into().map_err(DataFusionError::from)) + .collect() + } + } + + fn supports_convert_to_state(&self) -> bool { + self.accumulator.supports_convert_to_state + } +} + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_EmitTo { + All, + First(usize), +} + +impl From for FFI_EmitTo { + fn from(value: EmitTo) -> Self { + match value { + EmitTo::All => Self::All, + EmitTo::First(v) => Self::First(v), + } + } +} + +impl From for EmitTo { + fn from(value: FFI_EmitTo) -> Self { + match value { + FFI_EmitTo::All => Self::All, + FFI_EmitTo::First(v) => Self::First(v), + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array, BooleanArray}; + use datafusion::{ + common::create_array, + error::Result, + logical_expr::{EmitTo, GroupsAccumulator}, + }; + use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; + + use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let boxed_accum: Box = + Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true)); + let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); + let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); + + // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. + let values = create_array!(Boolean, vec![true, true, true, false, true, true]); + let opt_filter = + create_array!(Boolean, vec![true, true, true, true, false, false]); + foreign_accum.update_batch( + &[values], + &[0, 0, 1, 1, 2, 2], + Some(opt_filter.as_ref()), + 3, + )?; + + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + let groups_bool = groups_bool.as_any().downcast_ref::().unwrap(); + + assert_eq!( + groups_bool, + create_array!(Boolean, vec![Some(true), Some(false), None]).as_ref() + ); + + let state = foreign_accum.state(EmitTo::All)?; + assert_eq!(state.len(), 1); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = + vec![make_array(create_array!(Boolean, vec![false]).to_data())]; + + let opt_filter = create_array!(Boolean, vec![true]); + foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?; + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + assert_eq!(groups_bool.len(), 1); + assert_eq!( + groups_bool.as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() + ); + + let values = create_array!(Boolean, vec![false]); + let opt_filter = create_array!(Boolean, vec![true]); + let groups_bool = + foreign_accum.convert_to_state(&[values], Some(opt_filter.as_ref()))?; + + assert_eq!( + groups_bool[0].as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() + ); + + Ok(()) + } + + fn test_emit_to_round_trip(value: EmitTo) -> Result<()> { + let ffi_value: FFI_EmitTo = value.into(); + let round_trip_value: EmitTo = ffi_value.into(); + + assert_eq!(value, round_trip_value); + Ok(()) + } + + /// This test ensures all enum values are properly translated + #[test] + fn test_all_emit_to_round_trip() -> Result<()> { + test_emit_to_round_trip(EmitTo::All)?; + test_emit_to_round_trip(EmitTo::First(10))?; + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs new file mode 100644 index 000000000000..ac59cef8fdf4 --- /dev/null +++ b/datafusion/ffi/src/udaf/mod.rs @@ -0,0 +1,707 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ROption, RResult, RStr, RString, RVec}, + StableAbi, +}; +use accumulator::{FFI_Accumulator, ForeignAccumulator}; +use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; +use arrow::datatypes::{DataType, Field}; +use arrow::ffi::FFI_ArrowSchema; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + type_coercion::functions::data_types_with_aggregate_udf, + utils::AggregateOrderSensitivity, + Accumulator, GroupsAccumulator, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{AggregateUDF, AggregateUDFImpl, Signature}, +}; +use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; +use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + +use crate::{ + arrow_wrappers::WrappedSchema, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; +use prost::{DecodeError, Message}; + +mod accumulator; +mod accumulator_args; +mod groups_accumulator; + +/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AggregateUDF { + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub name: RString, + + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub aliases: RVec, + + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub volatility: FFI_Volatility, + + pub return_type: unsafe extern "C" fn( + udaf: &Self, + arg_types: RVec, + ) -> RResult, + + pub is_nullable: bool, + + pub groups_accumulator_supported: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool, + + pub accumulator: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + pub create_sliding_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + #[allow(clippy::type_complexity)] + pub state_fields: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_types: RVec, + return_type: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, + ) -> RResult>, RString>, + + pub create_groups_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + pub with_beneficial_ordering: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, + ) -> RResult, RString>, + + pub order_sensitivity: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity, + + /// Performs type coersion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`AggregateUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, RString>, + + /// Used to create a clone on the provider of the udaf. This should + /// only need to be called by the receiver of the udaf. + pub clone: unsafe extern "C" fn(udaf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udaf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udaf. + /// A [`ForeignAggregateUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_AggregateUDF {} +unsafe impl Sync for FFI_AggregateUDF {} + +pub struct AggregateUDFPrivateData { + pub udaf: Arc, +} + +impl FFI_AggregateUDF { + unsafe fn inner(&self) -> &Arc { + let private_data = self.private_data as *const AggregateUDFPrivateData; + &(*private_data).udaf + } +} + +unsafe extern "C" fn return_type_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec, +) -> RResult { + let udaf = udaf.inner(); + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_type = udaf + .return_type(&arg_types) + .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from)) + .map(WrappedSchema); + + rresult!(return_type) +} + +unsafe extern "C" fn accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_sliding_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_sliding_accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_groups_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_groups_accumulator(accumulator_args.into()) + .map(FFI_GroupsAccumulator::from)) +} + +unsafe extern "C" fn groups_accumulator_supported_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> bool { + let udaf = udaf.inner(); + + ForeignAccumulatorArgs::try_from(args) + .map(|a| udaf.groups_accumulator_supported((&a).into())) + .unwrap_or_else(|e| { + log::warn!("Unable to parse accumulator args. {}", e); + false + }) +} + +unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, +) -> RResult, RString> { + let udaf = udaf.inner().as_ref().clone(); + + let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering)); + let result = rresult_return!(result + .map(|func| func.with_beneficial_ordering(beneficial_ordering)) + .transpose()) + .flatten() + .map(|func| FFI_AggregateUDF::from(Arc::new(func))); + + RResult::ROk(result.into()) +} + +unsafe extern "C" fn state_fields_fn_wrapper( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_types: RVec, + return_type: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, +) -> RResult>, RString> { + let udaf = udaf.inner(); + + let input_types = &rresult_return!(rvec_wrapped_to_vec_datatype(&input_types)); + let return_type = &rresult_return!(DataType::try_from(&return_type.0)); + + let ordering_fields = &rresult_return!(ordering_fields + .into_iter() + .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref())) + .collect::, DecodeError>>()); + + let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields)); + + let args = StateFieldsArgs { + name: name.as_str(), + input_types, + return_type, + ordering_fields, + is_distinct, + }; + + let state_fields = rresult_return!(udaf.state_fields(args)); + let state_fields = rresult_return!(state_fields + .iter() + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()) + .into_iter() + .map(|field| field.encode_to_vec().into()) + .collect(); + + RResult::ROk(state_fields) +} + +unsafe extern "C" fn order_sensitivity_fn_wrapper( + udaf: &FFI_AggregateUDF, +) -> FFI_AggregateOrderSensitivity { + udaf.inner().order_sensitivity().into() +} + +unsafe extern "C" fn coerce_types_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec, +) -> RResult, RString> { + let udaf = udaf.inner(); + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_types = rresult_return!(data_types_with_aggregate_udf(&arg_types, udaf)); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + +unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { + let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF { + Arc::clone(udaf.inner()).into() +} + +impl Clone for FFI_AggregateUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_AggregateUDF { + fn from(udaf: Arc) -> Self { + let name = udaf.name().into(); + let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect(); + let is_nullable = udaf.is_nullable(); + let volatility = udaf.signature().volatility.into(); + + let private_data = Box::new(AggregateUDFPrivateData { udaf }); + + Self { + name, + is_nullable, + volatility, + aliases, + return_type: return_type_fn_wrapper, + accumulator: accumulator_fn_wrapper, + create_sliding_accumulator: create_sliding_accumulator_fn_wrapper, + create_groups_accumulator: create_groups_accumulator_fn_wrapper, + groups_accumulator_supported: groups_accumulator_supported_fn_wrapper, + with_beneficial_ordering: with_beneficial_ordering_fn_wrapper, + state_fields: state_fields_fn_wrapper, + order_sensitivity: order_sensitivity_fn_wrapper, + coerce_types: coerce_types_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_AggregateUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAggregateUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_AggregateUDF. +#[derive(Debug)] +pub struct ForeignAggregateUDF { + signature: Signature, + aliases: Vec, + udaf: FFI_AggregateUDF, +} + +unsafe impl Send for ForeignAggregateUDF {} +unsafe impl Sync for ForeignAggregateUDF {} + +impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { + type Error = DataFusionError; + + fn try_from(udaf: &FFI_AggregateUDF) -> Result { + let signature = Signature::user_defined((&udaf.volatility).into()); + let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Self { + udaf: udaf.clone(), + signature, + aliases, + }) + } +} + +impl AggregateUDFImpl for ForeignAggregateUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + self.udaf.name.as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + + let result = unsafe { (self.udaf.return_type)(&self.udaf, arg_types) }; + + let result = df_result!(result); + + result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) + } + + fn is_nullable(&self) -> bool { + self.udaf.is_nullable + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let args = acc_args.try_into()?; + unsafe { + df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| { + Box::new(ForeignAccumulator::from(accum)) as Box + }) + } + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + unsafe { + let name = RStr::from_str(args.name); + let input_types = vec_datatype_to_rvec_wrapped(args.input_types)?; + let return_type = WrappedSchema(FFI_ArrowSchema::try_from(args.return_type)?); + let ordering_fields = args + .ordering_fields + .iter() + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()? + .into_iter() + .map(|proto_field| proto_field.encode_to_vec().into()) + .collect(); + + let fields = df_result!((self.udaf.state_fields)( + &self.udaf, + &name, + input_types, + return_type, + ordering_fields, + args.is_distinct + ))?; + let fields = fields + .into_iter() + .map(|field_bytes| { + datafusion_proto_common::Field::decode(field_bytes.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string())) + }) + .collect::>>()?; + + parse_proto_fields_to_fields(fields.iter()) + .map_err(|e| DataFusionError::Execution(e.to_string())) + } + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + let args = match FFI_AccumulatorArgs::try_from(args) { + Ok(v) => v, + Err(e) => { + log::warn!("Attempting to convert accumulator arguments: {}", e); + return false; + } + }; + + unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) } + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let args = FFI_AccumulatorArgs::try_from(args)?; + + unsafe { + df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map( + |accum| { + Box::new(ForeignGroupsAccumulator::from(accum)) + as Box + }, + ) + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let args = args.try_into()?; + unsafe { + df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map( + |accum| Box::new(ForeignAccumulator::from(accum)) as Box, + ) + } + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + unsafe { + let result = df_result!((self.udaf.with_beneficial_ordering)( + &self.udaf, + beneficial_ordering + ))? + .into_option(); + + let result = result + .map(|func| ForeignAggregateUDF::try_from(&func)) + .transpose()?; + + Ok(result.map(|func| Arc::new(func) as Arc)) + } + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() } + } + + fn simplify(&self) -> Option { + None + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = + df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } +} + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_AggregateOrderSensitivity { + Insensitive, + HardRequirement, + Beneficial, +} + +impl From for AggregateOrderSensitivity { + fn from(value: FFI_AggregateOrderSensitivity) -> Self { + match value { + FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive, + FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +impl From for FFI_AggregateOrderSensitivity { + fn from(value: AggregateOrderSensitivity) -> Self { + match value { + AggregateOrderSensitivity::Insensitive => Self::Insensitive, + AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Schema; + use datafusion::{ + common::create_array, + functions_aggregate::sum::Sum, + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::expressions::col, + scalar::ScalarValue, + }; + + use super::*; + + fn create_test_foreign_udaf( + original_udaf: impl AggregateUDFImpl + 'static, + ) -> Result { + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + Ok(foreign_udaf.into()) + } + + #[test] + fn test_round_trip_udaf() -> Result<()> { + let original_udaf = Sum::new(); + let original_name = original_udaf.name().to_owned(); + + let foreign_udaf = create_test_foreign_udaf(original_udaf)?; + // let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + // let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + // let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + // let foreign_udaf: AggregateUDF = foreign_udaf.into(); + + assert_eq!(original_name, foreign_udaf.name()); + Ok(()) + } + + #[test] + fn test_foreign_udaf_aliases() -> Result<()> { + let foreign_udaf = + create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]); + + let return_type = foreign_udaf.return_type(&[DataType::Float64])?; + assert_eq!(return_type, DataType::Float64); + Ok(()) + } + + #[test] + fn test_foreign_udaf_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let acc_args = AccumulatorArgs { + return_type: &DataType::Float64, + schema: &schema, + ignore_nulls: true, + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + let mut accumulator = foreign_udaf.accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + #[test] + fn test_beneficial_ordering() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf( + datafusion::functions_aggregate::first_last::FirstValue::new(), + )?; + + let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap(); + + assert_eq!( + foreign_udaf.order_sensitivity(), + AggregateOrderSensitivity::Beneficial + ); + + let a_field = Field::new("a", DataType::Float64, true); + let state_fields = foreign_udaf.state_fields(StateFieldsArgs { + name: "a", + input_types: &[DataType::Float64], + return_type: &DataType::Float64, + ordering_fields: &[a_field.clone()], + is_distinct: false, + })?; + + println!("{:#?}", state_fields); + assert_eq!(state_fields.len(), 3); + assert_eq!(state_fields[1], a_field); + Ok(()) + } + + #[test] + fn test_sliding_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let acc_args = AccumulatorArgs { + return_type: &DataType::Float64, + schema: &schema, + ignore_nulls: true, + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + + let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) { + let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into(); + let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into(); + + assert_eq!(sensitivity, round_trip_sensitivity); + } + + #[test] + fn test_round_trip_all_order_sensitivities() { + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial); + } +} diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index 9d5f2aefe324..97f6490509ae 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -20,7 +20,7 @@ use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; use crate::arrow_wrappers::WrappedSchema; -/// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a +/// This macro is a helpful conversion utility to convert from an abi_stable::RResult to a /// DataFusion result. #[macro_export] macro_rules! df_result { diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index c6df324e9a17..eb7e54b42d87 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -20,12 +20,19 @@ #[cfg(feature = "integration-tests")] mod tests { + use abi_stable::library::RootModule; + use arrow::array::Float64Array; + use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; - use datafusion::prelude::SessionContext; + use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; + use datafusion::prelude::{col, SessionContext}; use datafusion_ffi::catalog_provider::ForeignCatalogProvider; use datafusion_ffi::table_provider::ForeignTableProvider; - use datafusion_ffi::tests::create_record_batch; use datafusion_ffi::tests::utils::get_module; + use datafusion_ffi::tests::{create_record_batch, ForeignLibraryModuleRef}; + use datafusion_ffi::udaf::ForeignAggregateUDF; + use datafusion_ffi::udf::ForeignScalarUDF; + use std::path::Path; use std::sync::Arc; /// It is important that this test is in the `tests` directory and not in the @@ -96,4 +103,103 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_ffi_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_avg_func = + module + .create_sum_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_avg_func: ForeignAggregateUDF = (&ffi_avg_func).try_into()?; + + let udaf: AggregateUDF = foreign_avg_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ("b", Float64, vec![1.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0]) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("sum_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + + let expected = record_batch!( + ("a", Int32, vec![1, 2, 4]), + ("sum_b", Float64, vec![1.0, 4.0, 16.0]) + )?; + + assert_eq!(result[0], expected); + + Ok(()) + } + + #[tokio::test] + async fn test_ffi_grouping_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_stddev_func = + module + .create_stddev_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_stddev_func: ForeignAggregateUDF = (&ffi_stddev_func).try_into()?; + + let udaf: AggregateUDF = foreign_stddev_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ( + "b", + Float64, + vec![ + 1.0, + 2.0, + 2.0 + 2.0_f64.sqrt(), + 4.0, + 4.0, + 4.0 + 3.0_f64.sqrt(), + 4.0 + 3.0_f64.sqrt() + ] + ) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("stddev_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + let result = result[0].column_by_name("stddev_b").unwrap(); + let result = result + .as_any() + .downcast_ref::() + .unwrap() + .values(); + + assert!(result.first().unwrap().is_nan()); + assert!(result.get(1).unwrap() - 1.0 < 0.00001); + assert!(result.get(2).unwrap() - 1.0 < 0.00001); + + Ok(()) + } }