Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1893,19 +1893,24 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;

let builder = match datatype {
DataType::Decimal128(_, _) => {
let func =
AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype));
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
_ => {
// cast to the result data type of AVG if the result data type is different
// from the input type, e.g. AVG(Int32). We should not expect a cast
// failure since it should have already been checked at Spark side.
// For all other numeric types (Int8/16/32/64, Float32/64):
// Cast to Float64 for accumulation
let child: Arc<dyn PhysicalExpr> =
Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None));
let func = AggregateUDF::new_from_impl(Avg::new("avg", datatype));
Arc::new(CastExpr::new(Arc::clone(&child), DataType::Float64, None));
let func = AggregateUDF::new_from_impl(Avg::new(
"avg",
DataType::Float64,
eval_mode,
));
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
};
Expand Down
2 changes: 1 addition & 1 deletion native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ message Avg {
Expr child = 1;
DataType datatype = 2;
DataType sum_datatype = 3;
bool fail_on_error = 4; // currently unused (useful for deciding Ansi vs Legacy mode)
EvalMode eval_mode = 4;
}

message First {
Expand Down
52 changes: 32 additions & 20 deletions native/spark-expr/src/agg_funcs/avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use crate::EvalMode;
use arrow::array::{
builder::PrimitiveBuilder,
cast::AsArray,
types::{Float64Type, Int64Type},
Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray,
Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, Int64Array, PrimitiveArray,
};
use arrow::compute::sum;
use arrow::datatypes::{DataType, Field, FieldRef};
Expand All @@ -31,45 +32,43 @@ use datafusion::logical_expr::{
use datafusion::physical_expr::expressions::format_state_name;
use std::{any::Any, sync::Arc};

use arrow::array::ArrowNativeTypeOp;
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::Volatility::Immutable;
use DataType::*;

/// AVG aggregate expression
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Avg {
name: String,
signature: Signature,
// expr: Arc<dyn PhysicalExpr>,
input_data_type: DataType,
result_data_type: DataType,
eval_mode: EvalMode,
}

impl Avg {
/// Create a new AVG aggregate function
pub fn new(name: impl Into<String>, data_type: DataType) -> Self {
pub fn new(name: impl Into<String>, data_type: DataType, eval_mode: EvalMode) -> Self {
let result_data_type = avg_return_type("avg", &data_type).unwrap();

Self {
name: name.into(),
signature: Signature::user_defined(Immutable),
input_data_type: data_type,
result_data_type,
eval_mode,
}
}
}

impl AggregateUDFImpl for Avg {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
// instantiate specialized accumulator based for the type
// All numeric types use Float64 accumulation after casting
match (&self.input_data_type, &self.result_data_type) {
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
(Float64, Float64) => Ok(Box::new(AvgAccumulator::new(self.eval_mode))),
_ => not_impl_err!(
"AvgAccumulator for ({} --> {})",
self.input_data_type,
Expand Down Expand Up @@ -109,10 +108,10 @@ impl AggregateUDFImpl for Avg {
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
// instantiate specialized accumulator based for the type
match (&self.input_data_type, &self.result_data_type) {
(Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
&self.input_data_type,
self.eval_mode,
|sum: f64, count: i64| Ok(sum / count as f64),
))),

Expand All @@ -137,11 +136,22 @@ impl AggregateUDFImpl for Avg {
}
}

/// An accumulator to compute the average
#[derive(Debug, Default)]
#[derive(Debug)]
pub struct AvgAccumulator {
sum: Option<f64>,
count: i64,
#[allow(dead_code)]
eval_mode: EvalMode,
}

impl AvgAccumulator {
pub fn new(eval_mode: EvalMode) -> Self {
Self {
sum: None,
count: 0,
eval_mode,
}
}
}

impl Accumulator for AvgAccumulator {
Expand All @@ -166,7 +176,7 @@ impl Accumulator for AvgAccumulator {
// counts are summed
self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();

// sums are summed
// sums are summed - no overflow checking
if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
let v = self.sum.get_or_insert(0.);
*v += x;
Expand All @@ -176,8 +186,6 @@ impl Accumulator for AvgAccumulator {

fn evaluate(&mut self) -> Result<ScalarValue> {
if self.count == 0 {
// If all input are nulls, count will be 0 and we will get null after the division.
// This is consistent with Spark Average implementation.
Ok(ScalarValue::Float64(None))
} else {
Ok(ScalarValue::Float64(
Expand All @@ -192,7 +200,7 @@ impl Accumulator for AvgAccumulator {
}

/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
/// Stores values as native types, and does overflow checking
/// Stores values as native types.
///
/// F: Function that calculates the average value from a sum of
/// T::Native and a total count
Expand All @@ -211,6 +219,10 @@ where
/// Sums per group, stored as the native type
sums: Vec<T::Native>,

/// Evaluation mode (stored but not used for Float64)
#[allow(dead_code)]
eval_mode: EvalMode,

/// Function that computes the final average (value / count)
avg_fn: F,
}
Expand All @@ -220,11 +232,12 @@ where
T: ArrowNumericType + Send,
F: Fn(T::Native, i64) -> Result<T::Native> + Send,
{
pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {
pub fn new(return_data_type: &DataType, eval_mode: EvalMode, avg_fn: F) -> Self {
Self {
return_data_type: return_data_type.clone(),
counts: vec![],
sums: vec![],
eval_mode,
avg_fn,
}
}
Expand Down Expand Up @@ -254,6 +267,7 @@ where
if values.null_count() == 0 {
for (&group_index, &value) in iter {
let sum = &mut self.sums[group_index];
// No overflow checking - INFINITY is a valid result
*sum = (*sum).add_wrapping(value);
self.counts[group_index] += 1;
}
Expand All @@ -264,7 +278,6 @@ where
}
let sum = &mut self.sums[group_index];
*sum = (*sum).add_wrapping(value);

self.counts[group_index] += 1;
}
}
Expand All @@ -280,17 +293,17 @@ where
total_num_groups: usize,
) -> Result<()> {
assert_eq!(values.len(), 2, "two arguments to merge_batch");
// first batch is partial sums, second is counts
let partial_sums = values[0].as_primitive::<T>();
let partial_counts = values[1].as_primitive::<Int64Type>();

// update counts with partial counts
self.counts.resize(total_num_groups, 0);
let iter1 = group_indices.iter().zip(partial_counts.values().iter());
for (&group_index, &partial_count) in iter1 {
self.counts[group_index] += partial_count;
}

// update sums
// update sums - no overflow checking
self.sums.resize(total_num_groups, T::default_value());
let iter2 = group_indices.iter().zip(partial_sums.values().iter());
for (&group_index, &new_value) in iter2 {
Expand Down Expand Up @@ -319,7 +332,6 @@ where
Ok(Arc::new(array))
}

// return arrays for sums and counts
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let counts = emit_to.take_needed(&mut self.counts);
let counts = Int64Array::new(counts.into(), None);
Expand Down
16 changes: 3 additions & 13 deletions spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType
import org.apache.comet.CometConf
import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType}
import org.apache.comet.shims.CometEvalModeUtil

object CometMin extends CometAggregateExpressionSerde[Min] {

Expand Down Expand Up @@ -150,17 +151,6 @@ object CometCount extends CometAggregateExpressionSerde[Count] {

object CometAverage extends CometAggregateExpressionSerde[Average] {

override def getSupportLevel(avg: Average): SupportLevel = {
avg.evalMode match {
case EvalMode.ANSI =>
Incompatible(Some("ANSI mode is not supported"))
case EvalMode.TRY =>
Incompatible(Some("TRY mode is not supported"))
case _ =>
Compatible()
}
}

override def convert(
aggExpr: AggregateExpression,
avg: Average,
Expand Down Expand Up @@ -192,7 +182,7 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
val builder = ExprOuterClass.Avg.newBuilder()
builder.setChild(childExpr.get)
builder.setDatatype(dataType.get)
builder.setFailOnError(avg.evalMode == EvalMode.ANSI)
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(avg.evalMode)))
builder.setSumDatatype(sumDataType.get)

Some(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,42 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("AVG and try_avg - basic functionality") {
withParquetTable(
Seq(
(10L, 1),
(20L, 1),
(null.asInstanceOf[Long], 1),
(100L, 2),
(200L, 2),
(null.asInstanceOf[Long], 3)),
"tbl") {

Seq(true, false).foreach({ k =>
// without GROUP BY
withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) {
val res = sql("SELECT avg(_1) FROM tbl")
checkSparkAnswerAndOperator(res)
}

// with GROUP BY
withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) {
val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
checkSparkAnswerAndOperator(res)
}

})

// try_avg without GROUP BY
val resTry = sql("SELECT try_avg(_1) FROM tbl")
checkSparkAnswerAndOperator(resTry)

// try_avg with GROUP BY
val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2")
checkSparkAnswerAndOperator(resTryGroup)
}
}

protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE}
import org.apache.spark.sql.TPCDSBase
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.util.resourceToString
import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
Expand Down Expand Up @@ -225,7 +225,6 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true",
// Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support
CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true",
CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true",
// as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
Expand Down
Loading