From aed3d5498d7bad32e32ac54bf8ca64bacf7bb170 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 15 Feb 2025 09:20:48 -0500 Subject: [PATCH 01/18] Work in progress adding user defined aggregate function FFI support --- datafusion/ffi/src/arrow_wrappers.rs | 13 +- datafusion/ffi/src/lib.rs | 1 + datafusion/ffi/src/udaf/accumulator.rs | 342 +++++++++++++++++++++++++ datafusion/ffi/src/udaf/mod.rs | 302 ++++++++++++++++++++++ 4 files changed, 657 insertions(+), 1 deletion(-) create mode 100644 datafusion/ffi/src/udaf/accumulator.rs create mode 100644 datafusion/ffi/src/udaf/mod.rs diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index a18e6df59bf12..d87d47a9b1848 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -21,7 +21,7 @@ use abi_stable::StableAbi; use arrow::{ array::{make_array, ArrayRef}, datatypes::{Schema, SchemaRef}, - ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; use log::error; @@ -79,3 +79,14 @@ impl TryFrom for ArrayRef { Ok(make_array(data)) } } + +impl TryFrom<&ArrayRef> for WrappedArray { + type Error = arrow::error::ArrowError; + + fn try_from(array: &ArrayRef) -> Result { + let (array, schema) = to_ffi(&array.to_data())?; + let schema = WrappedSchema(schema); + + Ok(WrappedArray { array, schema }) + } +} diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index 4eabf91d892a9..93185473318fd 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -31,6 +31,7 @@ pub mod record_batch_stream; pub mod session_config; pub mod table_provider; pub mod table_source; +pub mod udaf; pub mod udf; pub mod util; pub mod volatility; diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs new file mode 100644 index 0000000000000..b46c9a479f35a --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -0,0 +1,342 @@ +use std::{ + ffi::c_void, + sync::{Arc, Mutex}, +}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::{array::ArrayRef, error::ArrowError}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::Accumulator, + scalar::ScalarValue, +}; +use prost::Message; + +use crate::{ + arrow_wrappers::WrappedArray, + df_result, rresult, rresult_return, +}; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_Accumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn(accumulator: &Self) -> RResult, RString>, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: + unsafe extern "C" fn(accumulator: &Self) -> RResult>, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + states: RVec, + ) -> RResult<(), RString>, + + pub retract_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + pub supports_retract_batch: unsafe extern "C" fn(accumulator: &Self) -> bool, + + /// Used to create a clone on the provider of the accumulator. This should + /// only need to be called by the receiver of the accumulator. + pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_Accumulator {} +unsafe impl Sync for FFI_Accumulator {} + +pub struct AccumulatorPrivateData { + pub accumulator: Arc>, +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + let mut accumulator = rresult_return!(accum_data + .accumulator + .lock() + .map_err(|e| DataFusionError::Execution(e.to_string()))); + + rresult!(accumulator.update_batch(&values_arrays)) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &FFI_Accumulator, +) -> RResult, RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + let mut accumulator_internal = rresult_return!(accum_data + .accumulator + .lock() + .map_err(|_| DataFusionError::Execution( + "Unable to aquire lock on FFI Accumulator".to_string() + ))); + + let scalar_result = rresult_return!(accumulator_internal.evaluate()); + let proto_result: datafusion_proto::protobuf::ScalarValue = + rresult_return!((&scalar_result).try_into()); + + RResult::ROk(proto_result.encode_to_vec().into()) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + + accum_data + .accumulator + .lock() + .map(|accum| accum.size()) + .unwrap_or_default() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &FFI_Accumulator, +) -> RResult>, RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + let mut accumulator_internal = rresult_return!(accum_data + .accumulator + .lock() + .map_err(|_| DataFusionError::Execution( + "Unable to aquire lock on FFI Accumulator".to_string() + ))); + + let state = rresult_return!(accumulator_internal.state()); + let state = state + .into_iter() + .map(|state_val| { + datafusion_proto::protobuf::ScalarValue::try_from(&state_val) + .map_err(DataFusionError::from) + .map(|v| RVec::from(v.encode_to_vec())) + }) + .collect::>>() + .map(|state_vec| state_vec.into()); + + rresult!(state) +} + +unsafe extern "C" fn merge_batch_fn_wrapper(accumulator: &mut FFI_Accumulator, + states: RVec, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + let mut accumulator_internal = rresult_return!(accum_data + .accumulator + .lock() + .map_err(|_| DataFusionError::Execution( + "Unable to aquire lock on FFI Accumulator".to_string() + ))); + + let states = rresult_return!(states.into_iter() + .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) + .collect::>>()); + + rresult!(accumulator_internal.merge_batch(&states)) +} + +unsafe extern "C" fn retract_batch_fn_wrapper(accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + let mut accumulator_internal = rresult_return!(accum_data + .accumulator + .lock() + .map_err(|_| DataFusionError::Execution( + "Unable to aquire lock on FFI Accumulator".to_string() + ))); + + let values = rresult_return!(values.into_iter() + .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) + .collect::>>()); + + rresult!(accumulator_internal.retract_batch(&values)) +} + +unsafe extern "C" fn supports_retract_batch_fn_wrapper(accumulator: &FFI_Accumulator, +) -> bool { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + accum_data + .accumulator + .lock() + .map(|accum| accum.supports_retract_batch()) + .unwrap_or(false) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_Accumulator) -> FFI_Accumulator { + let private_data = accumulator.private_data as *const AccumulatorPrivateData; + let accum_data = &(*private_data); + + Arc::clone(&accum_data.accumulator).into() +} + +impl Clone for FFI_Accumulator { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From>> for FFI_Accumulator { + fn from(accumulator: Arc>) -> Self { + let private_data = Box::new(AccumulatorPrivateData { accumulator }); + + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + retract_batch: retract_batch_fn_wrapper, + supports_retract_batch: supports_retract_batch_fn_wrapper, + + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_Accumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_Accumulator. +#[derive(Debug)] +pub struct ForeignAccumulator { + accumulator: FFI_Accumulator, +} + +unsafe impl Send for ForeignAccumulator {} +unsafe impl Sync for ForeignAccumulator {} + +impl From<&FFI_Accumulator> for ForeignAccumulator { + fn from(accumulator: &FFI_Accumulator) -> Self { + Self { + accumulator: accumulator.clone(), + } + } +} + +impl Accumulator for ForeignAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn evaluate(&mut self) -> Result { + unsafe { + let scalar_bytes = + df_result!((self.accumulator.evaluate)(&self.accumulator))?; + + let proto_scalar = + datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn state(&mut self) -> Result> { + unsafe { + let state_protos = df_result!((self.accumulator.state)(&self.accumulator))?; + + state_protos + .into_iter() + .map(|proto_bytes| { + datafusion_proto::protobuf::ScalarValue::decode(proto_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e))) + .and_then(|proto_value| { + ScalarValue::try_from(&proto_value) + .map_err(DataFusionError::from) + }) + }) + .collect::>>() + } + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + unsafe { + let states = states + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + states.into() + )) + } + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.retract_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn supports_retract_batch(&self) -> bool { + unsafe { (self.accumulator.supports_retract_batch)(&self.accumulator) } + } +} diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs new file mode 100644 index 0000000000000..1c8794f96f0e7 --- /dev/null +++ b/datafusion/ffi/src/udaf/mod.rs @@ -0,0 +1,302 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::datatypes::DataType; +use arrow::ffi::{from_ffi, to_ffi, FFI_ArrowSchema}; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + utils::AggregateOrderSensitivity, + Accumulator, GroupsAccumulator, ReversedUDAF, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{ + AggregateUDF, AggregateUDFImpl, ColumnarValue, ScalarFunctionArgs, Signature, + }, +}; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, + signature::{self, rvec_wrapped_to_vec_datatype, FFI_Signature}, +}; + +mod accumulator; + +/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AggregateUDF { + /// Return the udaf name. + pub name: RString, + + pub signature: unsafe extern "C" fn(udaf: &Self) -> RResult, + + pub aliases: unsafe extern "C" fn(udaf: &Self) -> RVec, + + pub return_type: unsafe extern "C" fn( + udaf: &Self, + arg_types: RVec, + ) -> RResult, + + pub is_nullable: bool, + + + /// Used to create a clone on the provider of the udaf. This should + /// only need to be called by the receiver of the udaf. + pub clone: unsafe extern "C" fn(udaf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udaf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udaf. + /// A [`ForeignAggregateUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_AggregateUDF {} +unsafe impl Sync for FFI_AggregateUDF {} + +pub struct AggregateUDFPrivateData { + pub udaf: Arc, +} + +unsafe extern "C" fn name_fn_wrapper(udaf: &FFI_AggregateUDF) -> RString { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + udaf.name().into() +} + +unsafe extern "C" fn signature_fn_wrapper( + udaf: &FFI_AggregateUDF, +) -> RResult { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + rresult!(udaf.signature().try_into()) +} + +unsafe extern "C" fn aliases_fn_wrapper(udaf: &FFI_AggregateUDF) -> RVec { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + udaf.aliases().iter().map(|s| s.to_owned().into()).collect() +} + +unsafe extern "C" fn return_type_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec, +) -> RResult { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_type = udaf + .return_type(&arg_types) + .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from)) + .map(WrappedSchema); + + rresult!(return_type) +} + +unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { + let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf_data = &(*private_data); + + Arc::clone(&udaf_data.udaf).into() +} + +impl Clone for FFI_AggregateUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_AggregateUDF { + fn from(udaf: Arc) -> Self { + let name = udaf.name().into(); + let is_nullable = udaf.is_nullable(); + + let private_data = Box::new(AggregateUDFPrivateData { udaf }); + + Self { + name, + is_nullable, + signature: signature_fn_wrapper, + aliases: aliases_fn_wrapper, + return_type: return_type_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_AggregateUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAggregateUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_AggregateUDF. +#[derive(Debug)] +pub struct ForeignAggregateUDF { + signature: Signature, + aliases: Vec, + udaf: FFI_AggregateUDF, +} + +unsafe impl Send for ForeignAggregateUDF {} +unsafe impl Sync for ForeignAggregateUDF {} + +impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { + type Error = DataFusionError; + + fn try_from(udaf: &FFI_AggregateUDF) -> Result { + unsafe { + let ffi_signature = df_result!((udaf.signature)(udaf))?; + let signature = (&ffi_signature).try_into()?; + + let aliases = (udaf.aliases)(udaf) + .into_iter() + .map(|s| s.to_string()) + .collect(); + + Ok(Self { + udaf: udaf.clone(), + signature, + aliases, + }) + } + } +} + +impl AggregateUDFImpl for ForeignAggregateUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + self.udaf.name.as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_types = signature::vec_datatype_to_rvec_wrapped(arg_types)?; + + let result = unsafe { (self.udaf.return_type)(&self.udaf, arg_types) }; + + let result = df_result!(result); + + result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) + } + + fn is_nullable(&self) -> bool { + self.udaf.is_nullable + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> {} + + fn state_fields(&self, args: StateFieldsArgs) -> Result> {} + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {} + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + } + + fn with_beneficial_ordering( + self: Arc, + _beneficial_ordering: bool, + ) -> Result>> { + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity {} + + fn simplify(&self) -> Option {} + + fn reverse_expr(&self) -> ReversedUDAF {} + + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> {} + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {} + + fn is_descending(&self) -> Option {} + + fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option { + } + + fn default_value(&self, data_type: &DataType) -> Result {} + + fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {} +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_round_trip_udaf() -> Result<()> { + let original_udaf = datafusion::functions::math::abs::AbsFunc::new(); + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + + assert!(original_udaf.name() == foreign_udaf.name()); + + Ok(()) + } +} From 2443b3b464b7ffc599b8ed473e3ed2e625dc2b1a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 18 Feb 2025 16:25:57 -0500 Subject: [PATCH 02/18] Intermediate work. Going through groups accumulator --- Cargo.lock | 1 + datafusion/ffi/Cargo.toml | 1 + datafusion/ffi/src/udaf/accumulator.rs | 121 ++--- datafusion/ffi/src/udaf/accumulator_args.rs | 108 +++++ datafusion/ffi/src/udaf/groups_accumulator.rs | 445 ++++++++++++++++++ datafusion/ffi/src/udaf/mod.rs | 163 ++++++- 6 files changed, 750 insertions(+), 89 deletions(-) create mode 100644 datafusion/ffi/src/udaf/accumulator_args.rs create mode 100644 datafusion/ffi/src/udaf/groups_accumulator.rs diff --git a/Cargo.lock b/Cargo.lock index 716e0cf10386a..6477c003ea037 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2027,6 +2027,7 @@ dependencies = [ "async-trait", "datafusion", "datafusion-proto", + "datafusion-proto-common", "doc-comment", "futures", "log", diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 97914666688fd..532ccd9ff83d7 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -44,6 +44,7 @@ async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } datafusion-proto = { workspace = true } +datafusion-proto-common = { workspace = true } futures = { workspace = true } log = { workspace = true } prost = { workspace = true } diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index b46c9a479f35a..4b709ff1a3c3f 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -15,10 +15,7 @@ use datafusion::{ }; use prost::Message; -use crate::{ - arrow_wrappers::WrappedArray, - df_result, rresult, rresult_return, -}; +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; #[repr(C)] #[derive(Debug, StableAbi)] @@ -51,7 +48,7 @@ pub struct FFI_Accumulator { /// Used to create a clone on the provider of the accumulator. This should /// only need to be called by the receiver of the accumulator. - pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, + // pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(accumulator: &mut Self), @@ -65,7 +62,7 @@ unsafe impl Send for FFI_Accumulator {} unsafe impl Sync for FFI_Accumulator {} pub struct AccumulatorPrivateData { - pub accumulator: Arc>, + pub accumulator: Box, } unsafe extern "C" fn update_batch_fn_wrapper( @@ -81,12 +78,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( .collect::>>(); let values_arrays = rresult_return!(values_arrays); - let mut accumulator = rresult_return!(accum_data - .accumulator - .lock() - .map_err(|e| DataFusionError::Execution(e.to_string()))); - - rresult!(accumulator.update_batch(&values_arrays)) + rresult!(accum_data.accumulator.update_batch(&values_arrays)) } unsafe extern "C" fn evaluate_fn_wrapper( @@ -94,14 +86,8 @@ unsafe extern "C" fn evaluate_fn_wrapper( ) -> RResult, RString> { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - let mut accumulator_internal = rresult_return!(accum_data - .accumulator - .lock() - .map_err(|_| DataFusionError::Execution( - "Unable to aquire lock on FFI Accumulator".to_string() - ))); - - let scalar_result = rresult_return!(accumulator_internal.evaluate()); + + let scalar_result = rresult_return!(accum_data.accumulator.evaluate()); let proto_result: datafusion_proto::protobuf::ScalarValue = rresult_return!((&scalar_result).try_into()); @@ -112,11 +98,7 @@ unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - accum_data - .accumulator - .lock() - .map(|accum| accum.size()) - .unwrap_or_default() + accum_data.accumulator.size() } unsafe extern "C" fn state_fn_wrapper( @@ -124,14 +106,8 @@ unsafe extern "C" fn state_fn_wrapper( ) -> RResult>, RString> { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - let mut accumulator_internal = rresult_return!(accum_data - .accumulator - .lock() - .map_err(|_| DataFusionError::Execution( - "Unable to aquire lock on FFI Accumulator".to_string() - ))); - - let state = rresult_return!(accumulator_internal.state()); + + let state = rresult_return!(accum_data.accumulator.state()); let state = state .into_iter() .map(|state_val| { @@ -145,53 +121,42 @@ unsafe extern "C" fn state_fn_wrapper( rresult!(state) } -unsafe extern "C" fn merge_batch_fn_wrapper(accumulator: &mut FFI_Accumulator, +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, states: RVec, ) -> RResult<(), RString> { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - let mut accumulator_internal = rresult_return!(accum_data - .accumulator - .lock() - .map_err(|_| DataFusionError::Execution( - "Unable to aquire lock on FFI Accumulator".to_string() - ))); - - let states = rresult_return!(states.into_iter() + + let states = rresult_return!(states + .into_iter() .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) .collect::>>()); - rresult!(accumulator_internal.merge_batch(&states)) + rresult!(accum_data.accumulator.merge_batch(&states)) } -unsafe extern "C" fn retract_batch_fn_wrapper(accumulator: &mut FFI_Accumulator, +unsafe extern "C" fn retract_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, values: RVec, ) -> RResult<(), RString> { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - let mut accumulator_internal = rresult_return!(accum_data - .accumulator - .lock() - .map_err(|_| DataFusionError::Execution( - "Unable to aquire lock on FFI Accumulator".to_string() - ))); - - let values = rresult_return!(values.into_iter() + + let values = rresult_return!(values + .into_iter() .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) .collect::>>()); - rresult!(accumulator_internal.retract_batch(&values)) + rresult!(accum_data.accumulator.retract_batch(&values)) } -unsafe extern "C" fn supports_retract_batch_fn_wrapper(accumulator: &FFI_Accumulator, +unsafe extern "C" fn supports_retract_batch_fn_wrapper( + accumulator: &FFI_Accumulator, ) -> bool { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - accum_data - .accumulator - .lock() - .map(|accum| accum.supports_retract_batch()) - .unwrap_or(false) + accum_data.accumulator.supports_retract_batch() } unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { @@ -200,23 +165,21 @@ unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { drop(private_data); } -unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_Accumulator) -> FFI_Accumulator { - let private_data = accumulator.private_data as *const AccumulatorPrivateData; - let accum_data = &(*private_data); - - Arc::clone(&accum_data.accumulator).into() -} +// unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_Accumulator) -> FFI_Accumulator { +// let private_data = accumulator.private_data as *const AccumulatorPrivateData; +// let accum_data = &(*private_data); -impl Clone for FFI_Accumulator { - fn clone(&self) -> Self { - unsafe { (self.clone)(self) } - } -} +// Box::new(accum_data.accumulator).into() +// } -impl From>> for FFI_Accumulator { - fn from(accumulator: Arc>) -> Self { - let private_data = Box::new(AccumulatorPrivateData { accumulator }); +// impl Clone for FFI_Accumulator { +// fn clone(&self) -> Self { +// unsafe { (self.clone)(self) } +// } +// } +impl From> for FFI_Accumulator { + fn from(accumulator: Box) -> Self { Self { update_batch: update_batch_fn_wrapper, evaluate: evaluate_fn_wrapper, @@ -226,9 +189,9 @@ impl From>> for FFI_Accumulator { retract_batch: retract_batch_fn_wrapper, supports_retract_batch: supports_retract_batch_fn_wrapper, - clone: clone_fn_wrapper, + // clone: clone_fn_wrapper, release: release_fn_wrapper, - private_data: Box::into_raw(private_data) as *mut c_void, + private_data: Box::into_raw(accumulator) as *mut c_void, } } } @@ -253,11 +216,9 @@ pub struct ForeignAccumulator { unsafe impl Send for ForeignAccumulator {} unsafe impl Sync for ForeignAccumulator {} -impl From<&FFI_Accumulator> for ForeignAccumulator { - fn from(accumulator: &FFI_Accumulator) -> Self { - Self { - accumulator: accumulator.clone(), - } +impl From for ForeignAccumulator { + fn from(accumulator: FFI_Accumulator) -> Self { + Self { accumulator } } } diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs new file mode 100644 index 0000000000000..2eeccca56c961 --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -0,0 +1,108 @@ +use std::sync::Arc; + +use abi_stable::{ + std_types::{RString, RVec}, + StableAbi, +}; +use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::function::AccumulatorArgs, + prelude::SessionContext, +}; +use datafusion_proto::{ + physical_plan::{ + from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, + to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs}, + DefaultPhysicalExtensionCodec, + }, + protobuf::PhysicalAggregateExprNode, +}; +use prost::Message; + +use crate::{arrow_wrappers::WrappedSchema, rresult_return}; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AccumulatorArgs { + return_type: WrappedSchema, + schema: WrappedSchema, + is_reversed: bool, + name: RString, + physical_expr_def: RVec, +} + +impl FFI_AccumulatorArgs { + pub fn to_accumulator_args(&self) -> Result { + let proto_def = + PhysicalAggregateExprNode::decode(self.physical_expr_def.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + let return_type = &(&self.return_type.0).try_into()?; + let schema = &Arc::new(Schema::try_from(&self.schema.0)?); + + let default_ctx = SessionContext::new(); + let codex = DefaultPhysicalExtensionCodec {}; + + // let proto_ordering_req = + // rresult_return!(PhysicalSortExprNodeCollection::decode(ordering_req.as_ref())); + let ordering_req = &parse_physical_sort_exprs( + &proto_def.ordering_req, + &default_ctx, + &schema, + &codex, + )?; + + let exprs = &rresult_return!(parse_physical_exprs( + &proto_def.expr, + &default_ctx, + &schema, + &codex + )); + + Ok(AccumulatorArgs { + return_type, + schema, + ignore_nulls: proto_def.ignore_nulls, + ordering_req, + is_reversed: self.is_reversed, + name: self.name.as_str(), + is_distinct: proto_def.distinct, + exprs, + }) + } +} + +impl<'a> TryFrom> for FFI_AccumulatorArgs { + type Error = DataFusionError; + + fn try_from(args: AccumulatorArgs) -> std::result::Result { + let return_type = WrappedSchema(FFI_ArrowSchema::try_from(args.return_type)?); + let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); + + let codec = DefaultPhysicalExtensionCodec {}; + let ordering_req = + serialize_physical_sort_exprs(args.ordering_req.to_owned(), &codec)?; + + let expr = serialize_physical_exprs(args.exprs, &codec)?; + + let physical_expr_def = PhysicalAggregateExprNode { + expr, + ordering_req, + distinct: args.is_distinct, + ignore_nulls: args.ignore_nulls, + fun_definition: None, + aggregate_function: None, + }; + let physical_expr_def = physical_expr_def.encode_to_vec().into(); + + Ok(Self { + return_type, + schema, + is_reversed: args.is_reversed, + name: args.name.into(), + physical_expr_def, + }) + } +} diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs new file mode 100644 index 0000000000000..0c5eb9475d782 --- /dev/null +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -0,0 +1,445 @@ +use std::{ + ffi::c_void, + sync::{Arc, Mutex}, +}; + +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::{ + array::{Array, ArrayRef, BooleanArray}, + error::ArrowError, + ffi::{from_ffi, to_ffi, FFI_ArrowArray}, +}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::{EmitTo, GroupsAccumulator}, + scalar::ScalarValue, +}; +use prost::Message; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, +}; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_GroupsAccumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn( + accumulator: &Self, + emit_to: FFI_EmitTo, + ) -> RResult, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: unsafe extern "C" fn( + accumulator: &Self, + emit_to: FFI_EmitTo, + ) -> RResult, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + pub convert_to_state: unsafe extern "C" fn( + accumulator: &Self, + values: RVec, + opt_filter: ROption, + ) + -> RResult, RString>, + + pub supports_convert_to_state: bool, + + /// Used to create a clone on the provider of the accumulator. This should + /// only need to be called by the receiver of the accumulator. + // pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignGroupsAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_GroupsAccumulator {} +unsafe impl Sync for FFI_GroupsAccumulator {} + +pub struct GroupsAccumulatorPrivateData { + pub accumulator: Box, +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + let group_indices: Vec = group_indices.into_iter().collect(); + + let maybe_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()); + let opt_filter = maybe_filter.map(|arr| BooleanArray::from(arr)); + + rresult!(accum_data.accumulator.update_batch( + &values_arrays, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let result = rresult_return!(accum_data.accumulator.evaluate(emit_to.into())); + + rresult!(WrappedArray::try_from(&result)) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + accum_data.accumulator.size() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult, RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let state = rresult_return!(accum_data.accumulator.state(emit_to.into())); + rresult!(state + .into_iter() + .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + let group_indices: Vec = group_indices.into_iter().collect(); + + let maybe_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()); + let opt_filter = maybe_filter.map(|arr| BooleanArray::from(arr)); + + rresult!(accum_data.accumulator.merge_batch( + &values_arrays, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn convert_to_state_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + values: RVec, + opt_filter: ROption, +) -> RResult, RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let values = rresult_return!(values + .into_iter() + .map(|v| ArrayRef::try_from(v).map_err(DataFusionError::from)) + .collect::>>()); + + let opt_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()).map(|arr| BooleanArray::from(arr)); + + let state = rresult_return!(accum_data + .accumulator + .convert_to_state(&values, opt_filter.as_ref())); + + rresult!(state + .iter() + .map(|arr| WrappedArray::try_from(arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData); + drop(private_data); +} + +// unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> FFI_GroupsAccumulator { +// let private_data = accumulator.private_data as *const GroupsAccumulatorPrivateData; +// let accum_data = &(*private_data); + +// Box::new(accum_data.accumulator).into() +// } + +// impl Clone for FFI_GroupsAccumulator { +// fn clone(&self) -> Self { +// unsafe { (self.clone)(self) } +// } +// } + +impl From> for FFI_GroupsAccumulator { + fn from(accumulator: Box) -> Self { + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + convert_to_state: convert_to_state_fn_wrapper, + supports_convert_to_state: accumulator.supports_convert_to_state(), + + release: release_fn_wrapper, + private_data: Box::into_raw(accumulator) as *mut c_void, + } + } +} + +impl Drop for FFI_GroupsAccumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignGroupsAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_GroupsAccumulator. +#[derive(Debug)] +pub struct ForeignGroupsAccumulator { + accumulator: FFI_GroupsAccumulator, +} + +unsafe impl Send for ForeignGroupsAccumulator {} +unsafe impl Sync for ForeignGroupsAccumulator {} + +impl From for ForeignGroupsAccumulator { + fn from(accumulator: FFI_GroupsAccumulator) -> Self { + Self { accumulator } + } +} + +impl GroupsAccumulator for ForeignGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + unsafe { + let return_array = df_result!((self.accumulator.evaluate)( + &self.accumulator, + emit_to.into() + ))?; + + return_array.try_into().map_err(DataFusionError::from) + } + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + unsafe { + let returned_arrays = + df_result!((self.accumulator.state)(&self.accumulator, emit_to.into()))?; + + returned_arrays + .into_iter() + .map(|wrapped_array| { + wrapped_array.try_into().map_err(DataFusionError::from) + }) + .collect::>>() + } + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + let returned_array = df_result!((self.accumulator.convert_to_state)( + &self.accumulator, + values, + opt_filter + ))?; + + returned_array + .into_iter() + .map(|arr| arr.try_into().map_err(DataFusionError::from)) + .collect() + } + } + + fn supports_convert_to_state(&self) -> bool { + self.accumulator.supports_convert_to_state + } +} + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_EmitTo { + All, + First(usize), +} + +impl From for FFI_EmitTo { + fn from(value: EmitTo) -> Self { + match value { + EmitTo::All => Self::All, + EmitTo::First(v) => Self::First(v), + } + } +} + +impl From for EmitTo { + fn from(value: FFI_EmitTo) -> Self { + match value { + FFI_EmitTo::All => Self::All, + FFI_EmitTo::First(v) => Self::First(v), + } + } +} diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 1c8794f96f0e7..75b200b5d1474 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -15,13 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, sync::Arc}; +use std::{ + ffi::c_void, + sync::{Arc, Mutex}, +}; use abi_stable::{ - std_types::{RResult, RString, RVec}, + std_types::{RResult, RStr, RString, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use accumulator::FFI_Accumulator; +use accumulator_args::FFI_AccumulatorArgs; +use arrow::datatypes::{DataType, Field, SchemaRef}; use arrow::ffi::{from_ffi, to_ffi, FFI_ArrowSchema}; use datafusion::{ error::DataFusionError, @@ -30,6 +35,8 @@ use datafusion::{ utils::AggregateOrderSensitivity, Accumulator, GroupsAccumulator, ReversedUDAF, }, + physical_plan::aggregates::order, + prelude::SessionContext, }; use datafusion::{ error::Result, @@ -37,14 +44,31 @@ use datafusion::{ AggregateUDF, AggregateUDFImpl, ColumnarValue, ScalarFunctionArgs, Signature, }, }; +use datafusion_proto::{ + physical_plan::{ + from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, + to_proto::{ + serialize_physical_expr, serialize_physical_exprs, + serialize_physical_sort_exprs, + }, + DefaultPhysicalExtensionCodec, + }, + protobuf::{PhysicalAggregateExprNode, PhysicalSortExprNodeCollection}, +}; +use groups_accumulator::FFI_GroupsAccumulator; use crate::{ arrow_wrappers::{WrappedArray, WrappedSchema}, df_result, rresult, rresult_return, - signature::{self, rvec_wrapped_to_vec_datatype, FFI_Signature}, + signature::{ + self, rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped, FFI_Signature, + }, }; +use prost::Message; mod accumulator; +mod accumulator_args; +mod groups_accumulator; /// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries. #[repr(C)] @@ -65,6 +89,28 @@ pub struct FFI_AggregateUDF { pub is_nullable: bool, + pub groups_accumulator_supported: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool, + + pub accumulator: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + pub state_fields: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_types: RVec, + return_type: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, + ) -> RResult>, RString>, + + pub create_groups_accumulator: + unsafe extern "C" fn( + &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, /// Used to create a clone on the provider of the udaf. This should /// only need to be called by the receiver of the udaf. @@ -125,6 +171,49 @@ unsafe extern "C" fn return_type_fn_wrapper( rresult!(return_type) } +unsafe extern "C" fn accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + let accumulator_args = rresult_return!(args.to_accumulator_args()); + + rresult!(udaf + .accumulator(accumulator_args) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_groups_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + let accumulator_args = rresult_return!(args.to_accumulator_args()); + + rresult!(udaf + .create_groups_accumulator(accumulator_args) + .map(FFI_GroupsAccumulator::from)) +} + +unsafe extern "C" fn groups_accumulator_supported_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> bool { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + args.to_accumulator_args() + .map(|a| udaf.groups_accumulator_supported(a)) + .unwrap_or_else(|e| { + log::warn!("Unable to parse accumulator args. {}", e); + false + }) +} + unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); drop(private_data); @@ -156,6 +245,9 @@ impl From> for FFI_AggregateUDF { signature: signature_fn_wrapper, aliases: aliases_fn_wrapper, return_type: return_type_fn_wrapper, + accumulator: accumulator_fn_wrapper, + create_groups_accumulator: create_groups_accumulator_fn_wrapper, + groups_accumulator_supported: groups_accumulator_supported_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, @@ -221,7 +313,7 @@ impl AggregateUDFImpl for ForeignAggregateUDF { } fn return_type(&self, arg_types: &[DataType]) -> Result { - let arg_types = signature::vec_datatype_to_rvec_wrapped(arg_types)?; + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; let result = unsafe { (self.udaf.return_type)(&self.udaf, arg_types) }; @@ -234,16 +326,69 @@ impl AggregateUDFImpl for ForeignAggregateUDF { self.udaf.is_nullable } - fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> {} + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let args = acc_args.try_into()?; + + unsafe { df_result!((self.udaf.accumulator)(&self.udaf, args)) } + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + unsafe { + let name = RStr::from_str(args.name); + let input_types = vec_datatype_to_rvec_wrapped(args.input_types)?; + let return_type = WrappedSchema(FFI_ArrowSchema::try_from(args.return_type)?); + let ordering_fields = args + .ordering_fields + .iter() + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()? + .into_iter() + .map(|proto_field| proto_field.encode_to_vec().into()) + .collect(); + + let fields = df_result!((self.udaf.state_fields)( + &self.udaf, + &name, + input_types, + return_type, + ordering_fields, + args.is_distinct + ))?; + let fields = fields + .into_iter() + .map(|field_bytes| { + datafusion_proto_common::Field::decode(field_bytes.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string())) + }) + .collect::>>()?; + + datafusion_proto_common::from_proto::parse_proto_fields_to_fields( + fields.iter(), + ) + .map_err(|e| DataFusionError::Execution(e.to_string())) + } + } - fn state_fields(&self, args: StateFieldsArgs) -> Result> {} + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + let args = match FFI_AccumulatorArgs::try_from(args) { + Ok(v) => v, + Err(e) => { + log::warn!("Attempting to convert accumulator arguments: {}", e); + return false; + } + }; - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {} + unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) } + } fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { + let args = FFI_AccumulatorArgs::try_from(args)?; + + unsafe { df_result!((self.udaf.accumulator)(&self.udaf, args)) } } fn aliases(&self) -> &[String] { From eaa4d434d615be5d9b464b5d8b81c6f967521e21 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 19 Feb 2025 08:59:37 -0500 Subject: [PATCH 03/18] MVP for aggregate udf via FFI --- datafusion/ffi/src/udaf/accumulator.rs | 26 +- datafusion/ffi/src/udaf/accumulator_args.rs | 142 ++++++--- datafusion/ffi/src/udaf/groups_accumulator.rs | 36 ++- datafusion/ffi/src/udaf/mod.rs | 283 ++++++++++++------ 4 files changed, 333 insertions(+), 154 deletions(-) diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index 4b709ff1a3c3f..6ebe67a4b1c52 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -1,7 +1,21 @@ -use std::{ - ffi::c_void, - sync::{Arc, Mutex}, -}; +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; use abi_stable::{ std_types::{RResult, RString, RVec}, @@ -46,10 +60,6 @@ pub struct FFI_Accumulator { pub supports_retract_batch: unsafe extern "C" fn(accumulator: &Self) -> bool, - /// Used to create a clone on the provider of the accumulator. This should - /// only need to be called by the receiver of the accumulator. - // pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, - /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(accumulator: &mut Self), diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index 2eeccca56c961..f15bed5aa2c25 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -1,14 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use std::sync::Arc; use abi_stable::{ std_types::{RString, RVec}, StableAbi, }; -use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema}; +use arrow::{ + datatypes::{DataType, Schema}, + ffi::FFI_ArrowSchema, +}; use datafusion::{ - error::{DataFusionError, Result}, - logical_expr::function::AccumulatorArgs, - prelude::SessionContext, + error::DataFusionError, logical_expr::function::AccumulatorArgs, + physical_expr::LexOrdering, physical_plan::PhysicalExpr, prelude::SessionContext, }; use datafusion_proto::{ physical_plan::{ @@ -20,7 +39,7 @@ use datafusion_proto::{ }; use prost::Message; -use crate::{arrow_wrappers::WrappedSchema, rresult_return}; +use crate::arrow_wrappers::WrappedSchema; #[repr(C)] #[derive(Debug, StableAbi)] @@ -33,51 +52,10 @@ pub struct FFI_AccumulatorArgs { physical_expr_def: RVec, } -impl FFI_AccumulatorArgs { - pub fn to_accumulator_args(&self) -> Result { - let proto_def = - PhysicalAggregateExprNode::decode(self.physical_expr_def.as_ref()) - .map_err(|e| DataFusionError::Execution(e.to_string()))?; - - let return_type = &(&self.return_type.0).try_into()?; - let schema = &Arc::new(Schema::try_from(&self.schema.0)?); - - let default_ctx = SessionContext::new(); - let codex = DefaultPhysicalExtensionCodec {}; - - // let proto_ordering_req = - // rresult_return!(PhysicalSortExprNodeCollection::decode(ordering_req.as_ref())); - let ordering_req = &parse_physical_sort_exprs( - &proto_def.ordering_req, - &default_ctx, - &schema, - &codex, - )?; - - let exprs = &rresult_return!(parse_physical_exprs( - &proto_def.expr, - &default_ctx, - &schema, - &codex - )); - - Ok(AccumulatorArgs { - return_type, - schema, - ignore_nulls: proto_def.ignore_nulls, - ordering_req, - is_reversed: self.is_reversed, - name: self.name.as_str(), - is_distinct: proto_def.distinct, - exprs, - }) - } -} - -impl<'a> TryFrom> for FFI_AccumulatorArgs { +impl TryFrom> for FFI_AccumulatorArgs { type Error = DataFusionError; - fn try_from(args: AccumulatorArgs) -> std::result::Result { + fn try_from(args: AccumulatorArgs) -> Result { let return_type = WrappedSchema(FFI_ArrowSchema::try_from(args.return_type)?); let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); @@ -106,3 +84,71 @@ impl<'a> TryFrom> for FFI_AccumulatorArgs { }) } } + +/// This struct mirrors AccumulatorArgs except that it contains owned data. +/// It is necessary to create this struct so that we can parse the protobuf +/// data across the FFI boundary and turn it into owned data that +/// AccumulatorArgs can then reference. +pub struct ForeignAccumulatorArgs { + pub return_type: DataType, + pub schema: Schema, + pub ignore_nulls: bool, + pub ordering_req: LexOrdering, + pub is_reversed: bool, + pub name: String, + pub is_distinct: bool, + pub exprs: Vec>, +} + +impl TryFrom for ForeignAccumulatorArgs { + type Error = DataFusionError; + + fn try_from(value: FFI_AccumulatorArgs) -> Result { + let proto_def = + PhysicalAggregateExprNode::decode(value.physical_expr_def.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + let return_type = (&value.return_type.0).try_into()?; + let schema = Schema::try_from(&value.schema.0)?; + + let default_ctx = SessionContext::new(); + let codex = DefaultPhysicalExtensionCodec {}; + + // let proto_ordering_req = + // rresult_return!(PhysicalSortExprNodeCollection::decode(ordering_req.as_ref())); + let ordering_req = parse_physical_sort_exprs( + &proto_def.ordering_req, + &default_ctx, + &schema, + &codex, + )?; + + let exprs = parse_physical_exprs(&proto_def.expr, &default_ctx, &schema, &codex)?; + + Ok(Self { + return_type, + schema, + ignore_nulls: proto_def.ignore_nulls, + ordering_req, + is_reversed: value.is_reversed, + name: value.name.to_string(), + is_distinct: proto_def.distinct, + exprs, + }) + } +} + +impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { + fn from(value: &'a ForeignAccumulatorArgs) -> Self { + Self { + return_type: &value.return_type, + schema: &value.schema, + ignore_nulls: value.ignore_nulls, + ordering_req: &value.ordering_req, + is_reversed: value.is_reversed, + name: value.name.as_str(), + is_distinct: value.is_distinct, + exprs: &value.exprs, + } + } +} diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index 0c5eb9475d782..eaf4b991477aa 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -1,7 +1,21 @@ -use std::{ - ffi::c_void, - sync::{Arc, Mutex}, -}; +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; use abi_stable::{ std_types::{ROption, RResult, RString, RVec}, @@ -10,14 +24,12 @@ use abi_stable::{ use arrow::{ array::{Array, ArrayRef, BooleanArray}, error::ArrowError, - ffi::{from_ffi, to_ffi, FFI_ArrowArray}, + ffi::to_ffi, }; use datafusion::{ error::{DataFusionError, Result}, logical_expr::{EmitTo, GroupsAccumulator}, - scalar::ScalarValue, }; -use prost::Message; use crate::{ arrow_wrappers::{WrappedArray, WrappedSchema}, @@ -66,10 +78,6 @@ pub struct FFI_GroupsAccumulator { pub supports_convert_to_state: bool, - /// Used to create a clone on the provider of the accumulator. This should - /// only need to be called by the receiver of the accumulator. - // pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, - /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(accumulator: &mut Self), @@ -112,7 +120,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( } } }).map(|arr| arr.into_data()); - let opt_filter = maybe_filter.map(|arr| BooleanArray::from(arr)); + let opt_filter = maybe_filter.map(BooleanArray::from); rresult!(accum_data.accumulator.update_batch( &values_arrays, @@ -181,7 +189,7 @@ unsafe extern "C" fn merge_batch_fn_wrapper( } } }).map(|arr| arr.into_data()); - let opt_filter = maybe_filter.map(|arr| BooleanArray::from(arr)); + let opt_filter = maybe_filter.map(BooleanArray::from); rresult!(accum_data.accumulator.merge_batch( &values_arrays, @@ -212,7 +220,7 @@ unsafe extern "C" fn convert_to_state_fn_wrapper( None } } - }).map(|arr| arr.into_data()).map(|arr| BooleanArray::from(arr)); + }).map(|arr| arr.into_data()).map(BooleanArray::from); let state = rresult_return!(accum_data .accumulator diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 75b200b5d1474..d5def52cfc313 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -15,56 +15,39 @@ // specific language governing permissions and limitations // under the License. -use std::{ - ffi::c_void, - sync::{Arc, Mutex}, -}; +use std::{ffi::c_void, sync::Arc}; use abi_stable::{ - std_types::{RResult, RStr, RString, RVec}, + std_types::{ROption, RResult, RStr, RString, RVec}, StableAbi, }; -use accumulator::FFI_Accumulator; -use accumulator_args::FFI_AccumulatorArgs; -use arrow::datatypes::{DataType, Field, SchemaRef}; -use arrow::ffi::{from_ffi, to_ffi, FFI_ArrowSchema}; +use accumulator::{FFI_Accumulator, ForeignAccumulator}; +use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; +use arrow::datatypes::{DataType, Field}; +use arrow::ffi::FFI_ArrowSchema; use datafusion::{ error::DataFusionError, logical_expr::{ function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, utils::AggregateOrderSensitivity, - Accumulator, GroupsAccumulator, ReversedUDAF, + Accumulator, GroupsAccumulator, }, - physical_plan::aggregates::order, - prelude::SessionContext, }; use datafusion::{ error::Result, - logical_expr::{ - AggregateUDF, AggregateUDFImpl, ColumnarValue, ScalarFunctionArgs, Signature, - }, -}; -use datafusion_proto::{ - physical_plan::{ - from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, - to_proto::{ - serialize_physical_expr, serialize_physical_exprs, - serialize_physical_sort_exprs, - }, - DefaultPhysicalExtensionCodec, - }, - protobuf::{PhysicalAggregateExprNode, PhysicalSortExprNodeCollection}, + logical_expr::{AggregateUDF, AggregateUDFImpl, Signature}, }; -use groups_accumulator::FFI_GroupsAccumulator; +use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; +use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; use crate::{ - arrow_wrappers::{WrappedArray, WrappedSchema}, + arrow_wrappers::WrappedSchema, df_result, rresult, rresult_return, signature::{ - self, rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped, FFI_Signature, + rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped, FFI_Signature, }, }; -use prost::Message; +use prost::{DecodeError, Message}; mod accumulator; mod accumulator_args; @@ -97,6 +80,13 @@ pub struct FFI_AggregateUDF { args: FFI_AccumulatorArgs, ) -> RResult, + pub create_sliding_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + #[allow(clippy::type_complexity)] pub state_fields: unsafe extern "C" fn( udaf: &FFI_AggregateUDF, name: &RStr, @@ -108,9 +98,18 @@ pub struct FFI_AggregateUDF { pub create_groups_accumulator: unsafe extern "C" fn( - &FFI_AggregateUDF, + udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs, - ) -> RResult, + ) -> RResult, + + pub with_beneficial_ordering: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, + ) -> RResult, RString>, + + pub order_sensitivity: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity, /// Used to create a clone on the provider of the udaf. This should /// only need to be called by the receiver of the udaf. @@ -131,25 +130,23 @@ pub struct AggregateUDFPrivateData { pub udaf: Arc, } -unsafe extern "C" fn name_fn_wrapper(udaf: &FFI_AggregateUDF) -> RString { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; - - udaf.name().into() +impl FFI_AggregateUDF { + unsafe fn inner(&self) -> &Arc { + let private_data = self.private_data as *const AggregateUDFPrivateData; + &(*private_data).udaf + } } unsafe extern "C" fn signature_fn_wrapper( udaf: &FFI_AggregateUDF, ) -> RResult { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); rresult!(udaf.signature().try_into()) } unsafe extern "C" fn aliases_fn_wrapper(udaf: &FFI_AggregateUDF) -> RVec { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); udaf.aliases().iter().map(|s| s.to_owned().into()).collect() } @@ -158,8 +155,7 @@ unsafe extern "C" fn return_type_fn_wrapper( udaf: &FFI_AggregateUDF, arg_types: RVec, ) -> RResult { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); @@ -175,27 +171,38 @@ unsafe extern "C" fn accumulator_fn_wrapper( udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs, ) -> RResult { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); - let accumulator_args = rresult_return!(args.to_accumulator_args()); + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); rresult!(udaf - .accumulator(accumulator_args) + .accumulator(accumulator_args.into()) .map(FFI_Accumulator::from)) } -unsafe extern "C" fn create_groups_accumulator_fn_wrapper( +unsafe extern "C" fn create_sliding_accumulator_fn_wrapper( udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs, ) -> RResult { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); - let accumulator_args = rresult_return!(args.to_accumulator_args()); + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); rresult!(udaf - .create_groups_accumulator(accumulator_args) + .create_sliding_accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_groups_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_groups_accumulator(accumulator_args.into()) .map(FFI_GroupsAccumulator::from)) } @@ -203,27 +210,86 @@ unsafe extern "C" fn groups_accumulator_supported_fn_wrapper( udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs, ) -> bool { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); - args.to_accumulator_args() - .map(|a| udaf.groups_accumulator_supported(a)) + ForeignAccumulatorArgs::try_from(args) + .map(|a| udaf.groups_accumulator_supported((&a).into())) .unwrap_or_else(|e| { log::warn!("Unable to parse accumulator args. {}", e); false }) } +unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, +) -> RResult, RString> { + let udaf = udaf.inner().as_ref().clone(); + + let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering)); + let result = rresult_return!(result + .map(|func| func.with_beneficial_ordering(beneficial_ordering)) + .transpose()) + .flatten() + .map(|func| FFI_AggregateUDF::from(Arc::new(func))); + + RResult::ROk(result.into()) +} + +unsafe extern "C" fn state_fields_fn_wrapper( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_types: RVec, + return_type: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, +) -> RResult>, RString> { + let udaf = udaf.inner(); + + let input_types = &rresult_return!(rvec_wrapped_to_vec_datatype(&input_types)); + let return_type = &rresult_return!(DataType::try_from(&return_type.0)); + + let ordering_fields = &rresult_return!(ordering_fields + .into_iter() + .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref())) + .collect::, DecodeError>>()); + + let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields)); + + let args = StateFieldsArgs { + name: name.as_str(), + input_types, + return_type, + ordering_fields, + is_distinct, + }; + + let state_fields = rresult_return!(udaf.state_fields(args)); + let state_fields = rresult_return!(state_fields + .iter() + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()) + .into_iter() + .map(|field| field.encode_to_vec().into()) + .collect(); + + RResult::ROk(state_fields) +} + +unsafe extern "C" fn order_sensitivity_fn_wrapper( + udaf: &FFI_AggregateUDF, +) -> FFI_AggregateOrderSensitivity { + udaf.inner().order_sensitivity().into() +} + unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); drop(private_data); } unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf_data = &(*private_data); - - Arc::clone(&udaf_data.udaf).into() + Arc::clone(udaf.inner()).into() } impl Clone for FFI_AggregateUDF { @@ -246,8 +312,12 @@ impl From> for FFI_AggregateUDF { aliases: aliases_fn_wrapper, return_type: return_type_fn_wrapper, accumulator: accumulator_fn_wrapper, + create_sliding_accumulator: create_sliding_accumulator_fn_wrapper, create_groups_accumulator: create_groups_accumulator_fn_wrapper, groups_accumulator_supported: groups_accumulator_supported_fn_wrapper, + with_beneficial_ordering: with_beneficial_ordering_fn_wrapper, + state_fields: state_fields_fn_wrapper, + order_sensitivity: order_sensitivity_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, @@ -328,8 +398,11 @@ impl AggregateUDFImpl for ForeignAggregateUDF { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { let args = acc_args.try_into()?; - - unsafe { df_result!((self.udaf.accumulator)(&self.udaf, args)) } + unsafe { + df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| { + Box::new(ForeignAccumulator::from(accum)) as Box + }) + } } fn state_fields(&self, args: StateFieldsArgs) -> Result> { @@ -363,10 +436,8 @@ impl AggregateUDFImpl for ForeignAggregateUDF { }) .collect::>>()?; - datafusion_proto_common::from_proto::parse_proto_fields_to_fields( - fields.iter(), - ) - .map_err(|e| DataFusionError::Execution(e.to_string())) + parse_proto_fields_to_fields(fields.iter()) + .map_err(|e| DataFusionError::Execution(e.to_string())) } } @@ -388,7 +459,14 @@ impl AggregateUDFImpl for ForeignAggregateUDF { ) -> Result> { let args = FFI_AccumulatorArgs::try_from(args)?; - unsafe { df_result!((self.udaf.accumulator)(&self.udaf, args)) } + unsafe { + df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map( + |accum| { + Box::new(ForeignGroupsAccumulator::from(accum)) + as Box + }, + ) + } } fn aliases(&self) -> &[String] { @@ -399,32 +477,40 @@ impl AggregateUDFImpl for ForeignAggregateUDF { &self, args: AccumulatorArgs, ) -> Result> { + let args = args.try_into()?; + unsafe { + df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map( + |accum| Box::new(ForeignAccumulator::from(accum)) as Box, + ) + } } fn with_beneficial_ordering( self: Arc, - _beneficial_ordering: bool, + beneficial_ordering: bool, ) -> Result>> { - } - - fn order_sensitivity(&self) -> AggregateOrderSensitivity {} - - fn simplify(&self) -> Option {} - - fn reverse_expr(&self) -> ReversedUDAF {} - - fn coerce_types(&self, _arg_types: &[DataType]) -> Result> {} - - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {} + unsafe { + let result = df_result!((self.udaf.with_beneficial_ordering)( + &self.udaf, + beneficial_ordering + ))? + .into_option(); - fn is_descending(&self) -> Option {} + let result = result + .map(|func| ForeignAggregateUDF::try_from(&func)) + .transpose()?; - fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option { + Ok(result.map(|func| Arc::new(func) as Arc)) + } } - fn default_value(&self, data_type: &DataType) -> Result {} + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() } + } - fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {} + fn simplify(&self) -> Option { + None + } } #[cfg(test)] @@ -433,7 +519,7 @@ mod tests { #[test] fn test_round_trip_udaf() -> Result<()> { - let original_udaf = datafusion::functions::math::abs::AbsFunc::new(); + let original_udaf = datafusion::functions_aggregate::sum::Sum::new(); let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); @@ -445,3 +531,32 @@ mod tests { Ok(()) } } + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_AggregateOrderSensitivity { + Insensitive, + HardRequirement, + Beneficial, +} + +impl From for AggregateOrderSensitivity { + fn from(value: FFI_AggregateOrderSensitivity) -> Self { + match value { + FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive, + FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +impl From for FFI_AggregateOrderSensitivity { + fn from(value: AggregateOrderSensitivity) -> Self { + match value { + AggregateOrderSensitivity::Insensitive => Self::Insensitive, + AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} From 62ffbe9da3857939e263d4580ee83e10ba8db0c6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 20 Feb 2025 17:57:54 -0500 Subject: [PATCH 04/18] Clean up after rebase --- datafusion/ffi/src/udaf/mod.rs | 54 ++++++++++++---------------------- 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index d5def52cfc313..4107f7bd0ff44 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -43,9 +43,8 @@ use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; use crate::{ arrow_wrappers::WrappedSchema, df_result, rresult, rresult_return, - signature::{ - rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped, FFI_Signature, - }, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, }; use prost::{DecodeError, Message}; @@ -58,12 +57,14 @@ mod groups_accumulator; #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] pub struct FFI_AggregateUDF { - /// Return the udaf name. + /// FFI equivalent to the `name` of a [`AggregateUDF`] pub name: RString, - pub signature: unsafe extern "C" fn(udaf: &Self) -> RResult, + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub aliases: RVec, - pub aliases: unsafe extern "C" fn(udaf: &Self) -> RVec, + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub volatility: FFI_Volatility, pub return_type: unsafe extern "C" fn( udaf: &Self, @@ -137,20 +138,6 @@ impl FFI_AggregateUDF { } } -unsafe extern "C" fn signature_fn_wrapper( - udaf: &FFI_AggregateUDF, -) -> RResult { - let udaf = udaf.inner(); - - rresult!(udaf.signature().try_into()) -} - -unsafe extern "C" fn aliases_fn_wrapper(udaf: &FFI_AggregateUDF) -> RVec { - let udaf = udaf.inner(); - - udaf.aliases().iter().map(|s| s.to_owned().into()).collect() -} - unsafe extern "C" fn return_type_fn_wrapper( udaf: &FFI_AggregateUDF, arg_types: RVec, @@ -301,15 +288,17 @@ impl Clone for FFI_AggregateUDF { impl From> for FFI_AggregateUDF { fn from(udaf: Arc) -> Self { let name = udaf.name().into(); + let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect(); let is_nullable = udaf.is_nullable(); + let volatility = udaf.signature().volatility.into(); let private_data = Box::new(AggregateUDFPrivateData { udaf }); Self { name, is_nullable, - signature: signature_fn_wrapper, - aliases: aliases_fn_wrapper, + volatility, + aliases, return_type: return_type_fn_wrapper, accumulator: accumulator_fn_wrapper, create_sliding_accumulator: create_sliding_accumulator_fn_wrapper, @@ -351,21 +340,14 @@ impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { type Error = DataFusionError; fn try_from(udaf: &FFI_AggregateUDF) -> Result { - unsafe { - let ffi_signature = df_result!((udaf.signature)(udaf))?; - let signature = (&ffi_signature).try_into()?; - - let aliases = (udaf.aliases)(udaf) - .into_iter() - .map(|s| s.to_string()) - .collect(); + let signature = Signature::user_defined((&udaf.volatility).into()); + let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect(); - Ok(Self { - udaf: udaf.clone(), - signature, - aliases, - }) - } + Ok(Self { + udaf: udaf.clone(), + signature, + aliases, + }) } } From 08601d804f64a76de30de420ed1a6830172fa47e Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 21 Feb 2025 19:57:24 -0500 Subject: [PATCH 05/18] Add unit test for FFI Accumulator Args --- datafusion/ffi/src/udaf/accumulator_args.rs | 37 +++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index f15bed5aa2c25..b3933c2670256 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -152,3 +152,40 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { } } } + +#[cfg(test)] +mod tests { + use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; + use arrow::datatypes::{DataType, Schema}; + use datafusion::{ + error::Result, logical_expr::function::AccumulatorArgs, + physical_expr::LexOrdering, + }; + + #[test] + fn test_round_trip_accumulator_args() -> Result<()> { + let orig_args = AccumulatorArgs { + return_type: &DataType::Float64, + schema: &Schema::empty(), + ignore_nulls: false, + ordering_req: &LexOrdering::new(vec![]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[], + }; + let orig_str = format!("{:?}", orig_args); + + let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?; + let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?; + let round_trip_args: AccumulatorArgs = (&foreign_args).into(); + + let round_trip_str = format!("{:?}", round_trip_args); + + // Since AccumulatorArgs doesn't implement Eq, simply compare + // the debug strings. + assert_eq!(orig_str, round_trip_str); + + Ok(()) + } +} From 61c54b73c1ed254e99214805092d71bade8c791d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 23 Feb 2025 03:54:37 -0500 Subject: [PATCH 06/18] Adding unit tests and fixing memory errors in aggregate ffi udf --- .../expr-common/src/groups_accumulator.rs | 2 +- datafusion/ffi/src/tests/mod.rs | 8 +- datafusion/ffi/src/tests/udf_udaf_udwf.rs | 14 +++- datafusion/ffi/src/udaf/accumulator.rs | 83 +++++++++++++------ datafusion/ffi/src/udaf/groups_accumulator.rs | 77 ++++++++++++++++- datafusion/ffi/src/udaf/mod.rs | 33 ++++++++ datafusion/ffi/tests/ffi_integration.rs | 43 +++++++++- 7 files changed, 228 insertions(+), 32 deletions(-) diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 5ff1c1d072164..9bcc1edff8824 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -21,7 +21,7 @@ use arrow::array::{ArrayRef, BooleanArray}; use datafusion_common::{not_impl_err, Result}; /// Describes how many rows should be emitted during grouping. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum EmitTo { /// Emit all groups All, diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index 5a471cb8fe434..8a4b376c558fd 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -26,6 +26,8 @@ use abi_stable::{ StableAbi, }; +use crate::udaf::FFI_AggregateUDF; + use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; use arrow::array::RecordBatch; use async_provider::create_async_table_provider; @@ -34,7 +36,7 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::create_ffi_abs_func; +use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_avg_func}; mod async_provider; mod sync_provider; @@ -53,6 +55,9 @@ pub struct ForeignLibraryModule { /// Create a scalar UDF pub create_scalar_udf: extern "C" fn() -> FFI_ScalarUDF, + /// Create an aggregate UDF + pub create_udaf: extern "C" fn() -> FFI_AggregateUDF, + pub version: extern "C" fn() -> u64, } @@ -97,6 +102,7 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { ForeignLibraryModule { create_table: construct_table_provider, create_scalar_udf: create_ffi_abs_func, + create_udaf: create_ffi_avg_func, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index e8a13aac13081..5717bb56c87f7 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -15,8 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::udf::FFI_ScalarUDF; -use datafusion::{functions::math::abs::AbsFunc, logical_expr::ScalarUDF}; +use crate::{udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF}; +use datafusion::{ + functions::math::abs::AbsFunc, + functions_aggregate::sum::Sum, + logical_expr::{AggregateUDF, ScalarUDF}, +}; use std::sync::Arc; @@ -25,3 +29,9 @@ pub(crate) extern "C" fn create_ffi_abs_func() -> FFI_ScalarUDF { udf.into() } + +pub(crate) extern "C" fn create_ffi_avg_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Sum::new().into()); + + udaf.into() +} diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index 6ebe67a4b1c52..c4f8edfeafc53 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -58,7 +58,7 @@ pub struct FFI_Accumulator { values: RVec, ) -> RResult<(), RString>, - pub supports_retract_batch: unsafe extern "C" fn(accumulator: &Self) -> bool, + pub supports_retract_batch: bool, /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(accumulator: &mut Self), @@ -161,35 +161,17 @@ unsafe extern "C" fn retract_batch_fn_wrapper( rresult!(accum_data.accumulator.retract_batch(&values)) } -unsafe extern "C" fn supports_retract_batch_fn_wrapper( - accumulator: &FFI_Accumulator, -) -> bool { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); - accum_data.accumulator.supports_retract_batch() -} - unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { let private_data = Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); drop(private_data); } -// unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_Accumulator) -> FFI_Accumulator { -// let private_data = accumulator.private_data as *const AccumulatorPrivateData; -// let accum_data = &(*private_data); - -// Box::new(accum_data.accumulator).into() -// } - -// impl Clone for FFI_Accumulator { -// fn clone(&self) -> Self { -// unsafe { (self.clone)(self) } -// } -// } - impl From> for FFI_Accumulator { fn from(accumulator: Box) -> Self { + let supports_retract_batch = accumulator.supports_retract_batch(); + let private_data = AccumulatorPrivateData { accumulator }; + Self { update_batch: update_batch_fn_wrapper, evaluate: evaluate_fn_wrapper, @@ -197,11 +179,11 @@ impl From> for FFI_Accumulator { state: state_fn_wrapper, merge_batch: merge_batch_fn_wrapper, retract_batch: retract_batch_fn_wrapper, - supports_retract_batch: supports_retract_batch_fn_wrapper, + supports_retract_batch, // clone: clone_fn_wrapper, release: release_fn_wrapper, - private_data: Box::into_raw(accumulator) as *mut c_void, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, } } } @@ -308,6 +290,57 @@ impl Accumulator for ForeignAccumulator { } fn supports_retract_batch(&self) -> bool { - unsafe { (self.accumulator.supports_retract_batch)(&self.accumulator) } + self.accumulator.supports_retract_batch + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array}; + use datafusion::{ + common::create_array, error::Result, + functions_aggregate::average::AvgAccumulator, logical_expr::Accumulator, + scalar::ScalarValue, + }; + + use super::{FFI_Accumulator, ForeignAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let boxed_accum: Box = Box::new(AvgAccumulator::default()); + let ffi_accum: FFI_Accumulator = boxed_accum.into(); + let mut foreign_accum: ForeignAccumulator = ffi_accum.into(); + + // Send in an array to average. There are 5 values and it should average to 30.0 + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + foreign_accum.update_batch(&[values])?; + + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + let state = foreign_accum.state()?; + assert_eq!(state.len(), 2); + assert_eq!(state[0], ScalarValue::UInt64(Some(5))); + assert_eq!(state[1], ScalarValue::Float64(Some(150.0))); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = vec![ + make_array(create_array!(UInt64, vec![1]).to_data()), + make_array(create_array!(Float64, vec![0.0]).to_data()), + ]; + + foreign_accum.merge_batch(&second_states)?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(25.0))); + + // If we remove a batch that is equivalent to the state we added + // we should go back to our original value of 30.0 + let values = create_array!(Float64, vec![0.0]); + foreign_accum.retract_batch(&[values])?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + Ok(()) } } diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index eaf4b991477aa..d5559e24b5f6c 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -253,6 +253,9 @@ unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) impl From> for FFI_GroupsAccumulator { fn from(accumulator: Box) -> Self { + let supports_convert_to_state = accumulator.supports_convert_to_state(); + let private_data = GroupsAccumulatorPrivateData { accumulator }; + Self { update_batch: update_batch_fn_wrapper, evaluate: evaluate_fn_wrapper, @@ -260,10 +263,10 @@ impl From> for FFI_GroupsAccumulator { state: state_fn_wrapper, merge_batch: merge_batch_fn_wrapper, convert_to_state: convert_to_state_fn_wrapper, - supports_convert_to_state: accumulator.supports_convert_to_state(), + supports_convert_to_state, release: release_fn_wrapper, - private_data: Box::into_raw(accumulator) as *mut c_void, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, } } } @@ -451,3 +454,73 @@ impl From for EmitTo { } } } + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array, Float64Array}; + use datafusion::{ + common::create_array, + error::Result, + functions_aggregate::stddev::StddevGroupsAccumulator, + logical_expr::{EmitTo, GroupsAccumulator}, + physical_plan::expressions::StatsType, + }; + + use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let boxed_accum: Box = + Box::new(StddevGroupsAccumulator::new(StatsType::Population)); + let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); + let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); + + // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. + let values = create_array!(Float64, vec![26., 26., 34., 34.]); + foreign_accum.update_batch(&[values], &[0; 4], None, 1)?; + + let groups_avg = foreign_accum.evaluate(EmitTo::All)?; + let groups_avg = groups_avg.as_any().downcast_ref::().unwrap(); + let expected = 4.0; + assert_eq!(groups_avg.len(), 1); + assert!((groups_avg.value(0) - expected).abs() < 0.0001); + + let state = foreign_accum.state(EmitTo::All)?; + assert_eq!(state.len(), 3); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = vec![ + make_array(create_array!(UInt64, vec![1]).to_data()), + make_array(create_array!(Float64, vec![30.0]).to_data()), + make_array(create_array!(Float64, vec![64.0]).to_data()), + ]; + + foreign_accum.merge_batch(&second_states, &[0], None, 1)?; + let avg = foreign_accum.evaluate(EmitTo::All)?; + assert_eq!(avg.len(), 1); + assert_eq!( + avg.as_ref(), + make_array(create_array!(Float64, vec![8.0]).to_data()).as_ref() + ); + + Ok(()) + } + + fn test_emit_to_round_trip(value: EmitTo) -> Result<()> { + let ffi_value: FFI_EmitTo = value.into(); + let round_trip_value: EmitTo = ffi_value.into(); + + assert_eq!(value, round_trip_value); + Ok(()) + } + + /// This test ensures all enum values are properly translated + #[test] + fn test_all_emit_to_round_trip() -> Result<()> { + test_emit_to_round_trip(EmitTo::All)?; + test_emit_to_round_trip(EmitTo::First(10))?; + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 4107f7bd0ff44..018b2b07c2ebe 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -29,6 +29,7 @@ use datafusion::{ error::DataFusionError, logical_expr::{ function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + type_coercion::functions::data_types_with_aggregate_udf, utils::AggregateOrderSensitivity, Accumulator, GroupsAccumulator, }, @@ -112,6 +113,15 @@ pub struct FFI_AggregateUDF { pub order_sensitivity: unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity, + /// Performs type coersion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`ScalarUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, RString>, + /// Used to create a clone on the provider of the udaf. This should /// only need to be called by the receiver of the udaf. pub clone: unsafe extern "C" fn(udaf: &Self) -> Self, @@ -270,6 +280,19 @@ unsafe extern "C" fn order_sensitivity_fn_wrapper( udaf.inner().order_sensitivity().into() } +unsafe extern "C" fn coerce_types_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec, +) -> RResult, RString> { + let udaf = udaf.inner(); + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_types = rresult_return!(data_types_with_aggregate_udf(&arg_types, udaf)); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); drop(private_data); @@ -307,6 +330,7 @@ impl From> for FFI_AggregateUDF { with_beneficial_ordering: with_beneficial_ordering_fn_wrapper, state_fields: state_fields_fn_wrapper, order_sensitivity: order_sensitivity_fn_wrapper, + coerce_types: coerce_types_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, @@ -493,6 +517,15 @@ impl AggregateUDFImpl for ForeignAggregateUDF { fn simplify(&self) -> Option { None } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = + df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } } #[cfg(test)] diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index 84e120df4299c..227ec24c13dfc 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -23,10 +23,11 @@ mod tests { use abi_stable::library::RootModule; use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; - use datafusion::logical_expr::ScalarUDF; + use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; use datafusion::prelude::{col, SessionContext}; use datafusion_ffi::table_provider::ForeignTableProvider; use datafusion_ffi::tests::{create_record_batch, ForeignLibraryModuleRef}; + use datafusion_ffi::udaf::ForeignAggregateUDF; use datafusion_ffi::udf::ForeignScalarUDF; use std::path::Path; use std::sync::Arc; @@ -179,4 +180,44 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_ffi_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_avg_func = + module.create_udaf().ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_avg_func: ForeignAggregateUDF = (&ffi_avg_func).try_into()?; + + let udf: AggregateUDF = foreign_avg_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ("b", Float64, vec![1.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0]) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udf.call(vec![col("b")]).alias("sum_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + + let expected = record_batch!( + ("a", Int32, vec![1, 2, 4]), + ("sum_b", Float64, vec![1.0, 4.0, 16.0]) + )?; + + assert_eq!(result[0], expected); + + Ok(()) + } } From 866dbf5dea40af99433f1bd6aad2e0019cc01e09 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 23 Feb 2025 11:57:58 +0100 Subject: [PATCH 07/18] Working through additional unit and integration tests for UDAF ffi --- datafusion/ffi/src/tests/mod.rs | 12 ++-- datafusion/ffi/src/tests/udf_udaf_udwf.rs | 8 ++- datafusion/ffi/src/udaf/groups_accumulator.rs | 8 ++- datafusion/ffi/tests/ffi_integration.rs | 70 +++++++++++++++++-- 4 files changed, 85 insertions(+), 13 deletions(-) diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index 8a4b376c558fd..de9f89e18ad17 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -36,7 +36,7 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_avg_func}; +use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_avg_func, create_ffi_stddev_func}; mod async_provider; mod sync_provider; @@ -55,8 +55,11 @@ pub struct ForeignLibraryModule { /// Create a scalar UDF pub create_scalar_udf: extern "C" fn() -> FFI_ScalarUDF, - /// Create an aggregate UDF - pub create_udaf: extern "C" fn() -> FFI_AggregateUDF, + /// Create an aggregate UDAF using sum + pub create_sum_udaf: extern "C" fn() -> FFI_AggregateUDF, + + /// Createa grouping UDAF using stddev + pub create_stddev_udaf: extern "C" fn() -> FFI_AggregateUDF, pub version: extern "C" fn() -> u64, } @@ -102,7 +105,8 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { ForeignLibraryModule { create_table: construct_table_provider, create_scalar_udf: create_ffi_abs_func, - create_udaf: create_ffi_avg_func, + create_sum_udaf: create_ffi_avg_func, + create_stddev_udaf: create_ffi_stddev_func, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index 5717bb56c87f7..24247d521ed25 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -18,7 +18,7 @@ use crate::{udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF}; use datafusion::{ functions::math::abs::AbsFunc, - functions_aggregate::sum::Sum, + functions_aggregate::{stddev::Stddev, sum::Sum}, logical_expr::{AggregateUDF, ScalarUDF}, }; @@ -35,3 +35,9 @@ pub(crate) extern "C" fn create_ffi_avg_func() -> FFI_AggregateUDF { udaf.into() } + +pub(crate) extern "C" fn create_ffi_stddev_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Stddev::new().into()); + + udaf.into() +} diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index d5559e24b5f6c..ccdf4000bdc70 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -476,8 +476,9 @@ mod tests { let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. - let values = create_array!(Float64, vec![26., 26., 34., 34.]); - foreign_accum.update_batch(&[values], &[0; 4], None, 1)?; + let values = create_array!(Float64, vec![26., 26., 34., 34., 0.0]); + let opt_filter = create_array!(Boolean, vec![true, true, true, true, false]); + foreign_accum.update_batch(&[values], &[0; 5], Some(opt_filter.as_ref()), 1)?; let groups_avg = foreign_accum.evaluate(EmitTo::All)?; let groups_avg = groups_avg.as_any().downcast_ref::().unwrap(); @@ -496,7 +497,8 @@ mod tests { make_array(create_array!(Float64, vec![64.0]).to_data()), ]; - foreign_accum.merge_batch(&second_states, &[0], None, 1)?; + let opt_filter = create_array!(Boolean, vec![true]); + foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?; let avg = foreign_accum.evaluate(EmitTo::All)?; assert_eq!(avg.len(), 1); assert_eq!( diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index 227ec24c13dfc..c036763e546b8 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -21,6 +21,7 @@ mod tests { use abi_stable::library::RootModule; + use arrow::array::Float64Array; use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; @@ -186,12 +187,14 @@ mod tests { let module = get_module()?; let ffi_avg_func = - module.create_udaf().ok_or(DataFusionError::NotImplemented( - "External table provider failed to implement create_udaf".to_string(), - ))?(); + module + .create_sum_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); let foreign_avg_func: ForeignAggregateUDF = (&ffi_avg_func).try_into()?; - let udf: AggregateUDF = foreign_avg_func.into(); + let udaf: AggregateUDF = foreign_avg_func.into(); let ctx = SessionContext::default(); let record_batch = record_batch!( @@ -205,7 +208,7 @@ mod tests { let df = df .aggregate( vec![col("a")], - vec![udf.call(vec![col("b")]).alias("sum_b")], + vec![udaf.call(vec![col("b")]).alias("sum_b")], )? .sort_by(vec![col("a")])?; @@ -220,4 +223,61 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_ffi_grouping_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_stddev_func = + module + .create_stddev_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_stddev_func: ForeignAggregateUDF = (&ffi_stddev_func).try_into()?; + + let udaf: AggregateUDF = foreign_stddev_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ( + "b", + Float64, + vec![ + 1.0, + 2.0, + 2.0 + 2.0_f64.sqrt(), + 4.0, + 4.0, + 4.0 + 3.0_f64.sqrt(), + 4.0 + 3.0_f64.sqrt() + ] + ) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("stddev_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + let result = result[0].column_by_name("stddev_b").unwrap(); + let result = result + .as_any() + .downcast_ref::() + .unwrap() + .values(); + + assert!(result.first().unwrap().is_nan()); + assert!(result.get(1).unwrap() - 1.0 < 0.00001); + assert!(result.get(2).unwrap() - 1.0 < 0.00001); + + Ok(()) + } } From 7edcc14f414b10e1a843a0e76b3786211f7a8ec0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 23 Feb 2025 14:43:05 +0100 Subject: [PATCH 08/18] Switch to a accumulator that supports convert to state to get a little better coverage --- Cargo.lock | 1 + datafusion/ffi/Cargo.toml | 1 + datafusion/ffi/src/udaf/groups_accumulator.rs | 58 ++++++++++++------- 3 files changed, 38 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6477c003ea037..0394842caee46 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2026,6 +2026,7 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-functions-aggregate-common", "datafusion-proto", "datafusion-proto-common", "doc-comment", diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 532ccd9ff83d7..80a16074b6126 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -45,6 +45,7 @@ async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } datafusion-proto = { workspace = true } datafusion-proto-common = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } futures = { workspace = true } log = { workspace = true } prost = { workspace = true } diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index ccdf4000bdc70..3f6b5def4f9b5 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -457,53 +457,67 @@ impl From for EmitTo { #[cfg(test)] mod tests { - use arrow::array::{make_array, Array, Float64Array}; + use arrow::array::{make_array, Array, BooleanArray}; use datafusion::{ common::create_array, error::Result, - functions_aggregate::stddev::StddevGroupsAccumulator, logical_expr::{EmitTo, GroupsAccumulator}, - physical_plan::expressions::StatsType, }; + use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator}; #[test] fn test_foreign_avg_accumulator() -> Result<()> { let boxed_accum: Box = - Box::new(StddevGroupsAccumulator::new(StatsType::Population)); + Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true)); let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. - let values = create_array!(Float64, vec![26., 26., 34., 34., 0.0]); - let opt_filter = create_array!(Boolean, vec![true, true, true, true, false]); - foreign_accum.update_batch(&[values], &[0; 5], Some(opt_filter.as_ref()), 1)?; + let values = create_array!(Boolean, vec![true, true, true, false, true, true]); + let opt_filter = + create_array!(Boolean, vec![true, true, true, true, false, false]); + foreign_accum.update_batch( + &[values], + &[0, 0, 1, 1, 2, 2], + Some(opt_filter.as_ref()), + 3, + )?; + + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + let groups_bool = groups_bool.as_any().downcast_ref::().unwrap(); - let groups_avg = foreign_accum.evaluate(EmitTo::All)?; - let groups_avg = groups_avg.as_any().downcast_ref::().unwrap(); - let expected = 4.0; - assert_eq!(groups_avg.len(), 1); - assert!((groups_avg.value(0) - expected).abs() < 0.0001); + assert_eq!( + groups_bool, + create_array!(Boolean, vec![Some(true), Some(false), None]).as_ref() + ); let state = foreign_accum.state(EmitTo::All)?; - assert_eq!(state.len(), 3); + assert_eq!(state.len(), 1); // To verify merging batches works, create a second state to add in // This should cause our average to go down to 25.0 - let second_states = vec![ - make_array(create_array!(UInt64, vec![1]).to_data()), - make_array(create_array!(Float64, vec![30.0]).to_data()), - make_array(create_array!(Float64, vec![64.0]).to_data()), - ]; + let second_states = + vec![make_array(create_array!(Boolean, vec![false]).to_data())]; let opt_filter = create_array!(Boolean, vec![true]); foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?; - let avg = foreign_accum.evaluate(EmitTo::All)?; - assert_eq!(avg.len(), 1); + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + assert_eq!(groups_bool.len(), 1); + assert_eq!( + groups_bool.as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() + ); + + let values = create_array!(Boolean, vec![false]); + let opt_filter = create_array!(Boolean, vec![true]); + let groups_bool = + foreign_accum.convert_to_state(&[values], Some(opt_filter.as_ref()))?; + assert_eq!( - avg.as_ref(), - make_array(create_array!(Float64, vec![8.0]).to_data()).as_ref() + groups_bool[0].as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() ); Ok(()) From 930cecff4626c9259335740028ea416eb4d2f7ed Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 23 Feb 2025 19:58:33 +0100 Subject: [PATCH 09/18] Set feature so we do not get an error warning in stable rustc --- Cargo.toml | 2 +- datafusion/ffi/Cargo.toml | 1 + datafusion/ffi/src/arrow_wrappers.rs | 46 ++++++++++++++++------------ 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index adb3ee23d947c..b557e1c39c481 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -189,5 +189,5 @@ incremental = false large_futures = "warn" [workspace.lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] } +unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)", "cfg(tarpaulin_include)"] } unused_qualifications = "deny" diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 80a16074b6126..db8a92f696cb4 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -57,3 +57,4 @@ doc-comment = { workspace = true } [features] integration-tests = [] +tarpaulin_include = [] # Exists only to prevent warnings on stable and still have accurate coverage diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index d87d47a9b1848..45049f696a4fa 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -21,6 +21,7 @@ use abi_stable::StableAbi; use arrow::{ array::{make_array, ArrayRef}, datatypes::{Schema, SchemaRef}, + error::ArrowError, ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; use log::error; @@ -31,30 +32,37 @@ use log::error; #[derive(Debug, StableAbi)] pub struct WrappedSchema(#[sabi(unsafe_opaque_field)] pub FFI_ArrowSchema); +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_ffi_schema_error(e: ArrowError) -> FFI_ArrowSchema { + error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); + FFI_ArrowSchema::empty() +} + impl From for WrappedSchema { fn from(value: SchemaRef) -> Self { - let ffi_schema = match FFI_ArrowSchema::try_from(value.as_ref()) { - Ok(s) => s, - Err(e) => { - error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); - FFI_ArrowSchema::empty() - } - }; - - WrappedSchema(ffi_schema) + WrappedSchema( + FFI_ArrowSchema::try_from(value.as_ref()) + .unwrap_or_else(catch_ffi_schema_error), + ) } } +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_df_schema_error(e: ArrowError) -> Schema { + error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); + Schema::empty() +} impl From for SchemaRef { fn from(value: WrappedSchema) -> Self { - let schema = match Schema::try_from(&value.0) { - Ok(s) => s, - Err(e) => { - error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); - Schema::empty() - } - }; - Arc::new(schema) + Arc::new(Schema::try_from(&value.0).unwrap_or_else(catch_df_schema_error)) } } @@ -71,7 +79,7 @@ pub struct WrappedArray { } impl TryFrom for ArrayRef { - type Error = arrow::error::ArrowError; + type Error = ArrowError; fn try_from(value: WrappedArray) -> Result { let data = unsafe { from_ffi(value.array, &value.schema.0)? }; @@ -81,7 +89,7 @@ impl TryFrom for ArrayRef { } impl TryFrom<&ArrayRef> for WrappedArray { - type Error = arrow::error::ArrowError; + type Error = ArrowError; fn try_from(array: &ArrayRef) -> Result { let (array, schema) = to_ffi(&array.to_data())?; From 47fffb4126c785ff24dbb7245fab17a228a6384f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 24 Feb 2025 08:21:33 +0100 Subject: [PATCH 10/18] Add more options to test --- datafusion/ffi/src/plan_properties.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 3592c16b8fab0..1b2b37708b546 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -300,7 +300,10 @@ impl From for EmissionType { #[cfg(test)] mod tests { - use datafusion::physical_plan::Partitioning; + use datafusion::{ + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::Partitioning, + }; use super::*; @@ -311,8 +314,13 @@ mod tests { Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); let original_props = PlanProperties::new( - EquivalenceProperties::new(schema), - Partitioning::UnknownPartitioning(3), + EquivalenceProperties::new(Arc::clone(&schema)).with_reorder( + LexOrdering::new(vec![PhysicalSortExpr { + expr: datafusion::physical_plan::expressions::col("a", &schema)?, + options: Default::default(), + }]), + ), + Partitioning::RoundRobinBatch(3), EmissionType::Incremental, Boundedness::Bounded, ); From 98a0f52d9e8bc014d9e84776ef435281b61b4edf Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 25 Feb 2025 06:54:03 +0100 Subject: [PATCH 11/18] Add unit test for FFI RecordBatchStream --- datafusion/ffi/src/record_batch_stream.rs | 45 +++++++++++++++++++++++ datafusion/ffi/src/util.rs | 2 +- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 939c4050028cb..5663fb12f0e97 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -196,3 +196,48 @@ impl Stream for FFI_RecordBatchStream { } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::{ + common::record_batch, error::Result, execution::SendableRecordBatchStream, + test_util::bounded_stream, + }; + + use super::FFI_RecordBatchStream; + use futures::StreamExt; + + #[tokio::test] + async fn test_round_trip_record_batch_stream() -> Result<()> { + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 3]), + ("b", Float64, vec![Some(4.0), None, Some(5.0)]) + )?; + let original_rbs = bounded_stream(record_batch.clone(), 1); + + let ffi_rbs: FFI_RecordBatchStream = original_rbs.into(); + let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs); + + let schema = ffi_rbs.schema(); + assert_eq!( + schema, + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true) + ])) + ); + + let batch = ffi_rbs.next().await; + assert!(batch.is_some()); + assert!(batch.unwrap().is_ok()); + + // There should only be one batch + let no_batch = ffi_rbs.next().await; + assert!(no_batch.is_none()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index 9d5f2aefe324b..97f6490509ae9 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -20,7 +20,7 @@ use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; use crate::arrow_wrappers::WrappedSchema; -/// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a +/// This macro is a helpful conversion utility to convert from an abi_stable::RResult to a /// DataFusion result. #[macro_export] macro_rules! df_result { From f50ff528d85cf58894c139dd555153295e70f565 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 25 Feb 2025 08:57:16 +0100 Subject: [PATCH 12/18] Add a few more args to ffi accumulator test fn --- datafusion/ffi/src/udaf/accumulator_args.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index b3933c2670256..3a25d09c4a550 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -156,23 +156,29 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { #[cfg(test)] mod tests { use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; - use arrow::datatypes::{DataType, Schema}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{ - error::Result, logical_expr::function::AccumulatorArgs, - physical_expr::LexOrdering, + error::Result, + logical_expr::function::AccumulatorArgs, + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::expressions::col, }; #[test] fn test_round_trip_accumulator_args() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); let orig_args = AccumulatorArgs { return_type: &DataType::Float64, - schema: &Schema::empty(), + schema: &schema, ignore_nulls: false, - ordering_req: &LexOrdering::new(vec![]), + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), is_reversed: false, name: "round_trip", is_distinct: true, - exprs: &[], + exprs: &[col("a", &schema)?], }; let orig_str = format!("{:?}", orig_args); @@ -185,6 +191,7 @@ mod tests { // Since AccumulatorArgs doesn't implement Eq, simply compare // the debug strings. assert_eq!(orig_str, round_trip_str); + println!("{}", round_trip_str); Ok(()) } From 77fd002ed1bb48a8137b48664d55edc8bfe89b76 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 27 Feb 2025 07:59:48 +0100 Subject: [PATCH 13/18] Adding more unit tests on ffi aggregate udaf --- datafusion/ffi/src/udaf/accumulator.rs | 12 +- datafusion/ffi/src/udaf/mod.rs | 168 ++++++++++++++++++++++--- 2 files changed, 160 insertions(+), 20 deletions(-) diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index c4f8edfeafc53..a6c007dce8f76 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -307,7 +307,11 @@ mod tests { #[test] fn test_foreign_avg_accumulator() -> Result<()> { - let boxed_accum: Box = Box::new(AvgAccumulator::default()); + let original_accum = AvgAccumulator::default(); + let original_size = original_accum.size(); + let original_supports_retract = original_accum.supports_retract_batch(); + + let boxed_accum: Box = Box::new(original_accum); let ffi_accum: FFI_Accumulator = boxed_accum.into(); let mut foreign_accum: ForeignAccumulator = ffi_accum.into(); @@ -341,6 +345,12 @@ mod tests { let avg = foreign_accum.evaluate()?; assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + assert_eq!(original_size, foreign_accum.size()); + assert_eq!( + original_supports_retract, + foreign_accum.supports_retract_batch() + ); + Ok(()) } } diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 018b2b07c2ebe..cec7cf08a3e6f 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -528,25 +528,6 @@ impl AggregateUDFImpl for ForeignAggregateUDF { } } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_round_trip_udaf() -> Result<()> { - let original_udaf = datafusion::functions_aggregate::sum::Sum::new(); - let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); - - let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); - - let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; - - assert!(original_udaf.name() == foreign_udaf.name()); - - Ok(()) - } -} - #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] @@ -575,3 +556,152 @@ impl From for FFI_AggregateOrderSensitivity { } } } + +#[cfg(test)] +mod tests { + use arrow::datatypes::Schema; + use datafusion::{ + common::create_array, + functions_aggregate::sum::Sum, + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::expressions::col, + scalar::ScalarValue, + }; + + use super::*; + + fn create_test_foreign_udaf( + original_udaf: impl AggregateUDFImpl + 'static, + ) -> Result { + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + Ok(foreign_udaf.into()) + } + + #[test] + fn test_round_trip_udaf() -> Result<()> { + let original_udaf = Sum::new(); + let original_name = original_udaf.name().to_owned(); + + let foreign_udaf = create_test_foreign_udaf(original_udaf)?; + // let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + // let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + // let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + // let foreign_udaf: AggregateUDF = foreign_udaf.into(); + + assert_eq!(original_name, foreign_udaf.name()); + Ok(()) + } + + #[test] + fn test_foreign_udaf_aliases() -> Result<()> { + let foreign_udaf = + create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]); + + let return_type = foreign_udaf.return_type(&[DataType::Float64])?; + assert_eq!(return_type, DataType::Float64); + Ok(()) + } + + #[test] + fn test_foreign_udaf_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let acc_args = AccumulatorArgs { + return_type: &DataType::Float64, + schema: &schema, + ignore_nulls: true, + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + let mut accumulator = foreign_udaf.accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + #[test] + fn test_beneficial_ordering() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf( + datafusion::functions_aggregate::first_last::FirstValue::new(), + )?; + + let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap(); + + assert_eq!( + foreign_udaf.order_sensitivity(), + AggregateOrderSensitivity::Beneficial + ); + + let a_field = Field::new("a", DataType::Float64, true); + let state_fields = foreign_udaf.state_fields(StateFieldsArgs { + name: "a", + input_types: &[DataType::Float64], + return_type: &DataType::Float64, + ordering_fields: &[a_field.clone()], + is_distinct: false, + })?; + + println!("{:#?}", state_fields); + assert_eq!(state_fields.len(), 3); + assert_eq!(state_fields[1], a_field); + Ok(()) + } + + #[test] + fn test_sliding_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let acc_args = AccumulatorArgs { + return_type: &DataType::Float64, + schema: &schema, + ignore_nulls: true, + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + + let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) { + let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into(); + let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into(); + + assert_eq!(sensitivity, round_trip_sensitivity); + } + + #[test] + fn test_round_trip_all_order_sensitivities() { + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial); + } +} From 207c6bad13bf4ad20e28ef7faf8685f67f74246c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 27 Feb 2025 08:14:37 +0100 Subject: [PATCH 14/18] taplo format --- datafusion/ffi/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index db8a92f696cb4..5e6d271e639d6 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -43,9 +43,9 @@ arrow = { workspace = true, features = ["ffi"] } async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } +datafusion-functions-aggregate-common = { workspace = true } datafusion-proto = { workspace = true } datafusion-proto-common = { workspace = true } -datafusion-functions-aggregate-common = { workspace = true } futures = { workspace = true } log = { workspace = true } prost = { workspace = true } From c5d4632e3ec0af54888f9377d7c79b99c7b765be Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 27 Feb 2025 08:15:49 +0100 Subject: [PATCH 15/18] Update code comment --- datafusion/ffi/src/udaf/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index cec7cf08a3e6f..ac59cef8fdf41 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -116,7 +116,7 @@ pub struct FFI_AggregateUDF { /// Performs type coersion. To simply this interface, all UDFs are treated as having /// user defined signatures, which will in turn call coerce_types to be called. This /// call should be transparent to most users as the internal function performs the - /// appropriate calls on the underlying [`ScalarUDF`] + /// appropriate calls on the underlying [`AggregateUDF`] pub coerce_types: unsafe extern "C" fn( udf: &Self, arg_types: RVec, From 4080ca20a3dd97ed2c556f9ddeb76060c33636e0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 27 Feb 2025 21:50:39 +0100 Subject: [PATCH 16/18] Correct function name --- datafusion/ffi/src/tests/mod.rs | 4 ++-- datafusion/ffi/src/tests/udf_udaf_udwf.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index de9f89e18ad17..3a09c35e4c26a 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -36,7 +36,7 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_avg_func, create_ffi_stddev_func}; +use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_stddev_func, create_ffi_sum_func}; mod async_provider; mod sync_provider; @@ -105,7 +105,7 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { ForeignLibraryModule { create_table: construct_table_provider, create_scalar_udf: create_ffi_abs_func, - create_sum_udaf: create_ffi_avg_func, + create_sum_udaf: create_ffi_sum_func, create_stddev_udaf: create_ffi_stddev_func, version: super::version, } diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index 24247d521ed25..20be262803305 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -30,7 +30,7 @@ pub(crate) extern "C" fn create_ffi_abs_func() -> FFI_ScalarUDF { udf.into() } -pub(crate) extern "C" fn create_ffi_avg_func() -> FFI_AggregateUDF { +pub(crate) extern "C" fn create_ffi_sum_func() -> FFI_AggregateUDF { let udaf: Arc = Arc::new(Sum::new().into()); udaf.into() From e5c948cb586eada6f9203dfbebb0f89fa30371be Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Tue, 1 Apr 2025 10:49:58 -0400 Subject: [PATCH 17/18] Temp fix record batch test dependencies --- .../src/datasource/file_format/parquet.rs | 140 +++++++++++++----- datafusion/core/src/lib.rs | 1 + datafusion/core/src/test_util/mod.rs | 47 ++++++ 3 files changed, 152 insertions(+), 36 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 27a7e7ae3c061..d75a8ea7ba93d 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -104,10 +104,8 @@ pub(crate) mod test_util { mod tests { use std::fmt::{self, Display, Formatter}; - use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use std::task::{Context, Poll}; use std::time::Duration; use crate::datasource::file_format::parquet::test_util::store_parquet; @@ -117,7 +115,7 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use arrow::array::RecordBatch; - use arrow_schema::{Schema, SchemaRef}; + use arrow_schema::Schema; use datafusion_catalog::Session; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array, @@ -137,7 +135,7 @@ mod tests { }; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; - use datafusion_execution::{RecordBatchStream, TaskContext}; + use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{collect, ExecutionPlan}; @@ -150,7 +148,7 @@ mod tests { use async_trait::async_trait; use datafusion_datasource::file_groups::FileGroup; use futures::stream::BoxStream; - use futures::{Stream, StreamExt}; + use futures::StreamExt; use insta::assert_snapshot; use log::error; use object_store::local::LocalFileSystem; @@ -166,6 +164,8 @@ mod tests { use parquet::format::FileMetaData; use tokio::fs::File; + use crate::test_util::bounded_stream; + enum ForceViews { Yes, No, @@ -1646,42 +1646,110 @@ mod tests { Ok(()) } - /// Creates an bounded stream for testing purposes. - fn bounded_stream( - batch: RecordBatch, - limit: usize, - ) -> datafusion_execution::SendableRecordBatchStream { - Box::pin(BoundedStream { - count: 0, - limit, - batch, - }) - } + #[tokio::test] + async fn test_memory_reservation_column_parallel() -> Result<()> { + async fn test_memory_reservation(global: ParquetOptions) -> Result<()> { + let field_a = Field::new("a", DataType::Utf8, false); + let field_b = Field::new("b", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let object_store_url = ObjectStoreUrl::local_filesystem(); - struct BoundedStream { - limit: usize, - count: usize, - batch: RecordBatch, - } + let file_sink_config = FileSinkConfig { + original_url: String::default(), + object_store_url: object_store_url.clone(), + file_group: FileGroup::new(vec![PartitionedFile::new( + "/tmp".to_string(), + 1, + )]), + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![], + insert_op: InsertOp::Overwrite, + keep_partition_by_columns: false, + file_extension: "parquet".into(), + }; + let parquet_sink = Arc::new(ParquetSink::new( + file_sink_config, + TableParquetOptions { + key_value_metadata: std::collections::HashMap::from([ + ("my-data".to_string(), Some("stuff".to_string())), + ("my-data-bool-key".to_string(), None), + ]), + global, + ..Default::default() + }, + )); + + // create data + let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); + let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); + let batch = + RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); - impl Stream for BoundedStream { - type Item = Result; + // create task context + let task_context = build_ctx(object_store_url.as_ref()); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no bytes are reserved yet" + ); - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - if self.count >= self.limit { - return Poll::Ready(None); + let mut write_task = FileSink::write_all( + parquet_sink.as_ref(), + Box::pin(RecordBatchStreamAdapter::new( + schema, + bounded_stream(batch, 1000), + )), + &task_context, + ); + + // incrementally poll and check for memory reservation + let mut reserved_bytes = 0; + while futures::poll!(&mut write_task).is_pending() { + reserved_bytes += task_context.memory_pool().reserved(); + tokio::time::sleep(Duration::from_micros(1)).await; } - self.count += 1; - Poll::Ready(Some(Ok(self.batch.clone()))) - } - } + assert!( + reserved_bytes > 0, + "should have bytes reserved during write" + ); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no leaking byte reservation" + ); - impl RecordBatchStream for BoundedStream { - fn schema(&self) -> SchemaRef { - self.batch.schema() + Ok(()) } + + let write_opts = ParquetOptions { + allow_single_file_parallelism: false, + ..Default::default() + }; + test_memory_reservation(write_opts) + .await + .expect("should track for non-parallel writes"); + + let row_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 10, + maximum_buffered_record_batches_per_stream: 1, + ..Default::default() + }; + test_memory_reservation(row_parallel_write_opts) + .await + .expect("should track for row-parallel writes"); + + let col_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 1, + maximum_buffered_record_batches_per_stream: 2, + ..Default::default() + }; + test_memory_reservation(col_parallel_write_opts) + .await + .expect("should track for column-parallel writes"); + + Ok(()) } } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index cc510bc81f1a8..a1afec3d90fa1 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -829,6 +829,7 @@ pub mod test; mod schema_equivalence; pub mod test_util; +// pub use test_util::bounded_stream; #[cfg(doctest)] doc_comment::doctest!("../../../README.md", readme_example_test); diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index d6865ca3d532a..2f8e66a2bbfbb 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -22,12 +22,14 @@ pub mod parquet; pub mod csv; +use futures::Stream; use std::any::Any; use std::collections::HashMap; use std::fs::File; use std::io::Write; use std::path::Path; use std::sync::Arc; +use std::task::{Context, Poll}; use crate::catalog::{TableProvider, TableProviderFactory}; use crate::dataframe::DataFrame; @@ -38,11 +40,13 @@ use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; +use crate::execution::SendableRecordBatchStream; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_catalog::Session; use datafusion_common::TableReference; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use std::pin::Pin; use async_trait::async_trait; @@ -52,6 +56,8 @@ use tempfile::TempDir; pub use datafusion_common::test_util::parquet_test_data; pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; +use crate::execution::RecordBatchStream; + /// Scan an empty data source, mainly used in tests pub fn scan_empty( name: Option<&str>, @@ -234,3 +240,44 @@ pub fn register_unbounded_file_with_ordering( ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } + +/// Creates a bounded stream that emits the same record batch a specified number of times. +/// This is useful for testing purposes. +pub fn bounded_stream( + record_batch: RecordBatch, + limit: usize, +) -> SendableRecordBatchStream { + Box::pin(BoundedStream { + record_batch, + count: 0, + limit, + }) +} + +struct BoundedStream { + record_batch: RecordBatch, + count: usize, + limit: usize, +} + +impl Stream for BoundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.count >= self.limit { + Poll::Ready(None) + } else { + self.count += 1; + Poll::Ready(Some(Ok(self.record_batch.clone()))) + } + } +} + +impl RecordBatchStream for BoundedStream { + fn schema(&self) -> SchemaRef { + self.record_batch.schema() + } +} From f69e656aca1db87ded3ef08912c9d864d675baee Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Tue, 1 Apr 2025 20:36:57 -0400 Subject: [PATCH 18/18] Address some comments --- datafusion/ffi/src/record_batch_stream.rs | 3 +- datafusion/ffi/src/udaf/accumulator.rs | 37 +++++++++++++---------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 5663fb12f0e97..78d65a816fcc2 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -232,7 +232,8 @@ mod tests { let batch = ffi_rbs.next().await; assert!(batch.is_some()); - assert!(batch.unwrap().is_ok()); + assert!(batch.as_ref().unwrap().is_ok()); + assert_eq!(batch.unwrap().unwrap(), record_batch); // There should only be one batch let no_batch = ffi_rbs.next().await; diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index a6c007dce8f76..897cd9f49cc3a 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -31,6 +31,7 @@ use prost::Message; use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; +/// A stable struct for sharing [`Accumulator`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] @@ -75,12 +76,19 @@ pub struct AccumulatorPrivateData { pub accumulator: Box, } +impl FFI_Accumulator { + #[inline] + unsafe fn inner(&self) -> &mut AccumulatorPrivateData { + let private_data = self.private_data as *mut AccumulatorPrivateData; + &mut (*private_data) + } +} + unsafe extern "C" fn update_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, values: RVec, ) -> RResult<(), RString> { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accum_data = accumulator.inner(); let values_arrays = values .into_iter() @@ -94,8 +102,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( unsafe extern "C" fn evaluate_fn_wrapper( accumulator: &FFI_Accumulator, ) -> RResult, RString> { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accum_data = accumulator.inner(); let scalar_result = rresult_return!(accum_data.accumulator.evaluate()); let proto_result: datafusion_proto::protobuf::ScalarValue = @@ -105,10 +112,9 @@ unsafe extern "C" fn evaluate_fn_wrapper( } unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); - - accum_data.accumulator.size() + // let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + // let accum_data = &mut (*private_data); + accumulator.inner().accumulator.size() } unsafe extern "C" fn state_fn_wrapper( @@ -135,8 +141,7 @@ unsafe extern "C" fn merge_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, states: RVec, ) -> RResult<(), RString> { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accum_data = accumulator.inner(); let states = rresult_return!(states .into_iter() @@ -150,15 +155,15 @@ unsafe extern "C" fn retract_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, values: RVec, ) -> RResult<(), RString> { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accum_data = accumulator.inner(); - let values = rresult_return!(values + let values_arrays = values .into_iter() - .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) - .collect::>>()); + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); - rresult!(accum_data.accumulator.retract_batch(&values)) + rresult!(accum_data.accumulator.retract_batch(&values_arrays)) } unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) {