Skip to content

Commit

Permalink
Consolidate example: simplify_udaf_expression.rs into advanced_udaf.rs (
Browse files Browse the repository at this point in the history
  • Loading branch information
takaebato authored Dec 27, 2024
1 parent 9665e09 commit 933fec8
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 229 deletions.
185 changes: 132 additions & 53 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
expr::AggregateFunction,
function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
simplify::SimplifyInfo,
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

Expand Down Expand Up @@ -197,40 +199,6 @@ impl Accumulator for GeometricMean {
}
}

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));

// define data in two partitions
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
],
)?;
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![64.0])),
Arc::new(Float32Array::from(vec![2.0])),
],
)?;

// declare a new context. In spark API, this corresponds to a new spark SQLsession
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
Ok(ctx)
}

// Define a `GroupsAccumulator` for GeometricMean
/// which handles accumulator state for multiple groups at once.
/// This API is significantly more complicated than `Accumulator`, which manages
Expand Down Expand Up @@ -399,35 +367,146 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
}
}

/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
/// defined aggregate function with a different expression which is defined in the `simplify` method.
#[derive(Debug, Clone)]
struct SimplifiedGeoMeanUdaf {
signature: Signature,
}

impl SimplifiedGeoMeanUdaf {
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for SimplifiedGeoMeanUdaf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"simplified_geo_mean"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
unimplemented!("should not be invoked")
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
unimplemented!("should not be invoked")
}

fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("should not get here");
}

/// Optionally replaces a UDAF with another expression during query optimization.
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| {
// Replaces the UDAF with `GeoMeanUdaf` as a placeholder example to demonstrate the `simplify` method.
// In real-world scenarios, you might create UDFs from built-in expressions.
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(AggregateUDF::from(GeoMeanUdaf::new())),
aggregate_function.args,
aggregate_function.distinct,
aggregate_function.filter,
aggregate_function.order_by,
aggregate_function.null_treatment,
)))
};
Some(Box::new(simplify))
}
}

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));

// define data in two partitions
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
],
)?;
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![64.0])),
Arc::new(Float32Array::from(vec![2.0])),
],
)?;

// declare a new context. In spark API, this corresponds to a new spark SQLsession
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
Ok(ctx)
}

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;

// create the AggregateUDF
let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new());
ctx.register_udaf(geometric_mean.clone());
let geo_mean_udf = AggregateUDF::from(GeoMeanUdaf::new());
let simplified_geo_mean_udf = AggregateUDF::from(SimplifiedGeoMeanUdaf::new());

for (udf, udf_name) in [
(geo_mean_udf, "geo_mean"),
(simplified_geo_mean_udf, "simplified_geo_mean"),
] {
ctx.register_udaf(udf.clone());

let sql_df = ctx.sql("SELECT geo_mean(a) FROM t group by b").await?;
sql_df.show().await?;
let sql_df = ctx
.sql(&format!("SELECT {}(a) FROM t GROUP BY b", udf_name))
.await?;
sql_df.show().await?;

// get a DataFrame from the context
// this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
let df = ctx.table("t").await?;
// get a DataFrame from the context
// this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
let df = ctx.table("t").await?;

// perform the aggregation
let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?;
// perform the aggregation
let df = df.aggregate(vec![], vec![udf.call(vec![col("a")])])?;

// note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature.
// note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature.

// execute the query
let results = df.collect().await?;
// execute the query
let results = df.collect().await?;

// downcast the array to the expected type
let result = as_float64_array(results[0].column(0))?;
// downcast the array to the expected type
let result = as_float64_array(results[0].column(0))?;

// verify that the calculation is correct
assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
// verify that the calculation is correct
assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
}

Ok(())
}
176 changes: 0 additions & 176 deletions datafusion-examples/examples/simplify_udaf_expression.rs

This file was deleted.

0 comments on commit 933fec8

Please sign in to comment.