From 39925fa23dad103c0d95a701291824550e7a5a6c Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Fri, 17 Oct 2025 12:11:53 -0700 Subject: [PATCH 01/18] sum_int_native_support --- native/spark-expr/src/agg_funcs/mod.rs | 1 + native/spark-expr/src/agg_funcs/sum_int.rs | 500 +++++++++++++++++++++ 2 files changed, 501 insertions(+) create mode 100644 native/spark-expr/src/agg_funcs/sum_int.rs diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 252da78890..da6b616e24 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -21,6 +21,7 @@ mod correlation; mod covariance; mod stddev; mod sum_decimal; +mod sum_int; mod variance; pub use avg::Avg; diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs new file mode 100644 index 0000000000..b6b2435575 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -0,0 +1,500 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::{build_bool_state, is_valid_decimal_precision}; +use arrow::array::{ + cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array, + Int64Array, +}; +use arrow::datatypes::{DataType, Field, FieldRef, Int64Type}; +use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer}; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::Volatility::Immutable; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, +}; +use std::{any::Any, ops::BitAnd, sync::Arc}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SumInteger { + /// Aggregate function signature + signature: Signature, + /// The data type of the SUM result. This will always be a decimal type + /// with the same precision and scale as specified in this struct + result_type: DataType, +} + +impl SumInteger { + pub fn try_new(data_type: DataType) -> DFResult { + // The `data_type` is the SUM result type passed from Spark side + if (!data_type.is_integer()) { + return Err(DataFusionError::Internal( + "Invalid data type for SumInteger".into(), + )); + } + + Ok(Self { + signature: Signature::user_defined(Immutable), + result_type: data_type, + }) + } +} + +impl AggregateUDFImpl for SumInteger { + fn as_any(&self) -> &dyn Any { + self + } + + fn accumulator(&self, _args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(SumIntegerAccumulator::new())) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { + let fields = vec![ + Arc::new(Field::new( + self.name(), + self.result_type.clone(), + self.is_nullable(), + )), + Arc::new(Field::new("is_empty", DataType::Boolean, false)), + ]; + Ok(fields) + } + + fn name(&self) -> &str { + "sum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(self.result_type.clone()) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { + Ok(Box::new(SumIntegerGroupsAccumulator::new(DataType::Null))) + } + + fn default_value(&self, _data_type: &DataType) -> DFResult { + ScalarValue::new_primitive::(None, &DataType::Int64) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn is_nullable(&self) -> bool { + // SumDecimal is always nullable because overflows can cause null values + true + } +} + +#[derive(Debug)] +struct SumIntegerAccumulator { + sum: i128, + is_empty: bool, + is_not_null: bool, +} + +impl SumIntegerAccumulator { + fn new() -> Self { + Self { + sum: 0, + is_empty: true, + is_not_null: true, + } + } + + fn update_single(&mut self, values: &Int64Array, idx: usize) { + let v = unsafe { values.value_unchecked(idx) }; + let (new_sum, is_overflow) = self.sum.overflowing_add(v as i128); + + if is_overflow { + // Overflow: set buffer accumulator to null + self.is_not_null = false; + return; + } + + self.sum = new_sum; + self.is_not_null = true; + } +} + +impl Accumulator for SumIntegerAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { + assert_eq!( + values.len(), + 1, + "Expect only one element in 'values' but found {}", + values.len() + ); + + if !self.is_empty && !self.is_not_null { + // This means there's a overflow in decimal, so we will just skip the rest + // of the computation + return Ok(()); + } + + let values = &values[0]; + let data = values.as_primitive::(); + + self.is_empty = self.is_empty && values.len() == values.null_count(); + + if values.null_count() == 0 { + for i in 0..data.len() { + self.update_single(data, i); + } + } else { + for i in 0..data.len() { + if data.is_null(i) { + continue; + } + self.update_single(data, i); + } + } + + Ok(()) + } + + fn evaluate(&mut self) -> DFResult { + // For each group: + // 1. if `is_empty` is true, it means either there is no value or all values for the group + // are null, in this case we'll return null + // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In + // non-ANSI mode Spark returns null. + if self.is_empty || !self.is_not_null { + ScalarValue::new_primitive::(None, &DataType::Int64) + } else { + ScalarValue::try_new_decimal128(self.sum) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + let sum = if self.is_not_null { + ScalarValue::try_new_null(self.sum)? + } else { + ScalarValue::new_primitive::(None, &DataType::Int64)? + }; + Ok(vec![sum, ScalarValue::from(self.is_empty)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + assert_eq!( + states.len(), + 2, + "Expect two element in 'states' but found {}", + states.len() + ); + assert_eq!(states[0].len(), 1); + assert_eq!(states[1].len(), 1); + + let that_sum = states[0].as_primitive::(); + let that_is_empty = states[1].as_any().downcast_ref::().unwrap(); + + let this_overflow = !self.is_empty && !self.is_not_null; + let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0); + + self.is_not_null = !this_overflow && !that_overflow; + self.is_empty = self.is_empty && that_is_empty.value(0); + + if self.is_not_null { + self.sum += that_sum.value(0); + } + + Ok(()) + } +} + +struct SumIntegerGroupsAccumulator { + // Whether aggregate buffer for a particular group is null. True indicates it is not null. + is_not_null: BooleanBufferBuilder, + is_empty: BooleanBufferBuilder, + sum: Vec, + result_type: DataType, +} + +impl SumIntegerGroupsAccumulator { + fn new(result_type: DataType) -> Self { + Self { + is_not_null: BooleanBufferBuilder::new(0), + is_empty: BooleanBufferBuilder::new(0), + sum: Vec::new(), + result_type, + } + } + + fn is_overflow(&self, index: usize) -> bool { + !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index) + } + + #[inline] + fn update_single(&mut self, group_index: usize, value: i128) { + self.is_empty.set_bit(group_index, false); + let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value); + self.sum[group_index] = new_sum; + } +} + +fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) { + if builder.len() < capacity { + let additional = capacity - builder.len(); + builder.append_n(additional, true); + } +} + +impl GroupsAccumulator for SumIntegerGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + assert_eq!(values.len(), 1); + let values = values[0].as_primitive::(); + let data = values.values(); + + // Update size for the accumulate states + self.sum.resize(total_num_groups, 0); + ensure_bit_capacity(&mut self.is_empty, total_num_groups); + ensure_bit_capacity(&mut self.is_not_null, total_num_groups); + + let iter = group_indices.iter().zip(data.iter()); + if values.null_count() == 0 { + for (&group_index, &value) in iter { + self.update_single(group_index, value); + } + } else { + for (idx, (&group_index, &value)) in iter.enumerate() { + if values.is_null(idx) { + continue; + } + self.update_single(group_index, value); + } + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { + // For each group: + // 1. if `is_empty` is true, it means either there is no value or all values for the group + // are null, in this case we'll return null + // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In + // non-ANSI mode Spark returns null. + let result = emit_to.take_needed(&mut self.sum); + result.iter().enumerate().for_each(|(i, &v)| {}); + + let nulls = build_bool_state(&mut self.is_not_null, &emit_to); + let is_empty = build_bool_state(&mut self.is_empty, &emit_to); + let x = (!&is_empty).bitand(&nulls); + + let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x))) + .with_data_type(self.result_type.clone()); + + Ok(Arc::new(result)) + } + + fn state(&mut self, emit_to: EmitTo) -> DFResult> { + let nulls = build_bool_state(&mut self.is_not_null, &emit_to); + let nulls = Some(NullBuffer::new(nulls)); + + let sum = emit_to.take_needed(&mut self.sum); + let sum = Decimal128Array::new(sum.into(), nulls.clone()) + .with_data_type(self.result_type.clone()); + + let is_empty = build_bool_state(&mut self.is_empty, &emit_to); + let is_empty = BooleanArray::new(is_empty, None); + + Ok(vec![ + Arc::new(sum) as ArrayRef, + Arc::new(is_empty) as ArrayRef, + ]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + assert_eq!( + values.len(), + 2, + "Expected two arrays: 'sum' and 'is_empty', but found {}", + values.len() + ); + assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + + // Make sure we have enough capacity for the additional groups + self.sum.resize(total_num_groups, 0); + ensure_bit_capacity(&mut self.is_empty, total_num_groups); + ensure_bit_capacity(&mut self.is_not_null, total_num_groups); + + let that_sum = &values[0]; + let that_sum = that_sum.as_primitive::(); + let that_is_empty = &values[1]; + let that_is_empty = that_is_empty + .as_any() + .downcast_ref::() + .unwrap(); + + group_indices + .iter() + .enumerate() + .for_each(|(idx, &group_index)| unsafe { + let this_overflow = self.is_overflow(group_index); + let that_is_empty = that_is_empty.value_unchecked(idx); + let that_overflow = !that_is_empty && that_sum.is_null(idx); + let is_overflow = this_overflow || that_overflow; + + // This part follows the logic in Spark: + // `org.apache.spark.sql.catalyst.expressions.aggregate.Sum` + self.is_not_null.set_bit(group_index, !is_overflow); + self.is_empty.set_bit( + group_index, + self.is_empty.get_bit(group_index) && that_is_empty, + ); + if !is_overflow { + // .. otherwise, the sum value for this particular index must not be null, + // and thus we merge both values and update this sum. + self.sum[group_index] += that_sum.value_unchecked(idx); + } + }); + + Ok(()) + } + + fn size(&self) -> usize { + self.sum.capacity() * std::mem::size_of::() + + self.is_empty.capacity() / 8 + + self.is_not_null.capacity() / 8 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::builder::{Decimal128Builder, StringBuilder}; + use arrow::array::RecordBatch; + use arrow::datatypes::*; + use datafusion::common::Result; + use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::datasource::source::DataSourceExec; + use datafusion::execution::TaskContext; + use datafusion::logical_expr::AggregateUDF; + use datafusion::physical_expr::aggregate::AggregateExprBuilder; + use datafusion::physical_expr::expressions::Column; + use datafusion::physical_expr::PhysicalExpr; + use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; + use datafusion::physical_plan::ExecutionPlan; + use futures::StreamExt; + + #[test] + fn invalid_data_type() { + assert!(SumInteger::try_new(DataType::Int32).is_err()); + } + + #[tokio::test] + async fn sum_no_overflow() -> Result<()> { + let num_rows = 8192; + let batch = create_record_batch(num_rows); + let mut batches = Vec::new(); + for _ in 0..10 { + batches.push(batch.clone()); + } + let partitions = &[batches]; + let c0: Arc = Arc::new(Column::new("c0", 0)); + let c1: Arc = Arc::new(Column::new("c1", 1)); + + let data_type = DataType::Decimal128(8, 2); + let schema = Arc::clone(&partitions[0][0].schema()); + let scan: Arc = Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(), + ))); + + let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumInteger::try_new( + data_type.clone(), + )?)); + + let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) + .schema(Arc::clone(&schema)) + .alias("sum") + .with_ignore_nulls(false) + .with_distinct(false) + .build()?; + + let aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]), + vec![aggr_expr.into()], + vec![None], // no filter expressions + scan, + Arc::clone(&schema), + )?); + + let mut stream = aggregate + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + while let Some(batch) = stream.next().await { + let _batch = batch?; + } + + Ok(()) + } + + fn create_record_batch(num_rows: usize) -> RecordBatch { + let mut decimal_builder = Decimal128Builder::with_capacity(num_rows); + let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); + for i in 0..num_rows { + decimal_builder.append_value(i as i128); + string_builder.append_value(format!("this is string #{}", i % 1024)); + } + let decimal_array = Arc::new(decimal_builder.finish()); + let string_array = Arc::new(string_builder.finish()); + + let mut fields = vec![]; + let mut columns: Vec = vec![]; + + // string column + fields.push(Field::new("c0", DataType::Utf8, false)); + columns.push(string_array); + + // decimal column + fields.push(Field::new("c1", DataType::Decimal128(38, 10), false)); + columns.push(decimal_array); + + let schema = Schema::new(fields); + RecordBatch::try_new(Arc::new(schema), columns).unwrap() + } +} From 77dd36f61df850d3a1fba3937dad96bcabcbb89c Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 21 Oct 2025 22:26:41 -0700 Subject: [PATCH 02/18] wip_sum_tests --- native/.DS_Store | Bin 0 -> 6148 bytes native/core/src/execution/planner.rs | 13 +- native/spark-expr/src/agg_funcs/mod.rs | 1 + native/spark-expr/src/agg_funcs/sum_int.rs | 488 ++++++++---------- .../org/apache/comet/serde/aggregates.scala | 14 - .../apache/comet/CometExpressionSuite.scala | 22 + 6 files changed, 261 insertions(+), 277 deletions(-) create mode 100644 native/.DS_Store diff --git a/native/.DS_Store b/native/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..0a67b21dd36a455d40c819eb79aa5509ed690c03 GIT binary patch literal 6148 zcmeHKJ5EDE3>=db5ouDU+zU`}gH;qxkP8qAL;(sY2t@fR&XuDv{s<*{p`f5aW67Rf zuV4Ifrdwwh2Zp3eJQl*4+W zq7;w { + // let eval_mode = let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let eval_mode = if expr.fail_on_error { + EvalMode::Ansi + } else { + EvalMode::Legacy + }; + let func = + AggregateUDF::new_from_impl(SumInteger::try_new(datatype, eval_mode)?); + AggregateExprBuilder::new(Arc::new(func), vec![child]) + } _ => { // cast to the result data type of SUM if necessary, we should not expect // a cast failure since it should have already been checked at Spark side diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index da6b616e24..b1027153e8 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -30,4 +30,5 @@ pub use correlation::Correlation; pub use covariance::Covariance; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; +pub use sum_int::SumInteger; pub use variance::Variance; diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index b6b2435575..fa8df6b90d 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -15,43 +15,41 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::{build_bool_state, is_valid_decimal_precision}; +use crate::EvalMode; use arrow::array::{ - cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array, - Int64Array, + cast::AsArray, Array, ArrayBuilder, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + BooleanArray, Int64Array, PrimitiveArray, }; -use arrow::datatypes::{DataType, Field, FieldRef, Int64Type}; -use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer}; +use arrow::datatypes::{ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, Int8Type}; use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; -use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::function::AccumulatorArgs; use datafusion::logical_expr::Volatility::Immutable; use datafusion::logical_expr::{ - Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, }; -use std::{any::Any, ops::BitAnd, sync::Arc}; +use std::{any::Any, sync::Arc}; #[derive(Debug, PartialEq, Eq, Hash)] pub struct SumInteger { /// Aggregate function signature signature: Signature, - /// The data type of the SUM result. This will always be a decimal type - /// with the same precision and scale as specified in this struct - result_type: DataType, + /// eval mode : ANSI, Legacy, Try + eval_mode: EvalMode, } impl SumInteger { - pub fn try_new(data_type: DataType) -> DFResult { + pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { // The `data_type` is the SUM result type passed from Spark side - if (!data_type.is_integer()) { - return Err(DataFusionError::Internal( + println!("data type: {:?}", data_type); + match data_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self { + signature: Signature::user_defined(Immutable), + eval_mode, + }), + _ => Err(DataFusionError::Internal( "Invalid data type for SumInteger".into(), - )); + )), } - - Ok(Self { - signature: Signature::user_defined(Immutable), - result_type: data_type, - }) } } @@ -60,22 +58,6 @@ impl AggregateUDFImpl for SumInteger { self } - fn accumulator(&self, _args: AccumulatorArgs) -> DFResult> { - Ok(Box::new(SumIntegerAccumulator::new())) - } - - fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { - let fields = vec![ - Arc::new(Field::new( - self.name(), - self.result_type.clone(), - self.is_nullable(), - )), - Arc::new(Field::new("is_empty", DataType::Boolean, false)), - ]; - Ok(fields) - } - fn name(&self) -> &str { "sum" } @@ -84,193 +66,208 @@ impl AggregateUDFImpl for SumInteger { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> DFResult { - Ok(self.result_type.clone()) + fn return_type(&self, arg_types: &[DataType]) -> DFResult { + Ok(DataType::Int64) } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } + fn accumulator(&self, acc_args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(SumIntegerAccumulator::new())) + } + fn create_groups_accumulator( &self, _args: AccumulatorArgs, ) -> DFResult> { - Ok(Box::new(SumIntegerGroupsAccumulator::new(DataType::Null))) - } - - fn default_value(&self, _data_type: &DataType) -> DFResult { - ScalarValue::new_primitive::(None, &DataType::Int64) - } - - fn reverse_expr(&self) -> ReversedUDAF { - ReversedUDAF::Identical - } - - fn is_nullable(&self) -> bool { - // SumDecimal is always nullable because overflows can cause null values - true + Ok(Box::new(SumDecimalGroupsAccumulator::new(self.eval_mode))) } } #[derive(Debug)] struct SumIntegerAccumulator { - sum: i128, - is_empty: bool, - is_not_null: bool, + sum: i64, + eval_mode: EvalMode, + input_data_type: DataType, } impl SumIntegerAccumulator { fn new() -> Self { Self { sum: 0, - is_empty: true, - is_not_null: true, - } - } - - fn update_single(&mut self, values: &Int64Array, idx: usize) { - let v = unsafe { values.value_unchecked(idx) }; - let (new_sum, is_overflow) = self.sum.overflowing_add(v as i128); - - if is_overflow { - // Overflow: set buffer accumulator to null - self.is_not_null = false; - return; + eval_mode: EvalMode::Legacy, + input_data_type: DataType::Int64, } - - self.sum = new_sum; - self.is_not_null = true; } } impl Accumulator for SumIntegerAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { - assert_eq!( - values.len(), - 1, - "Expect only one element in 'values' but found {}", - values.len() - ); + fn update_sum_internal( + int_array: &PrimitiveArray, + eval_mode: EvalMode, + mut sum: i64, + ) -> Result + where + T: ArrowPrimitiveType, + { + println!("match internal function data type: {:?}", sum); + let len = int_array.len(); + for i in 0..int_array.len() { + if !int_array.is_null(i) { + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal("Failed to convert value to i64".to_string()) + })?; + match eval_mode { + EvalMode::Legacy | EvalMode::Try => { + sum = v.add_wrapping(sum); + } + EvalMode::Ansi => { + match v.add_checked(sum) { + Ok(v) => sum = v, + Err(e) => { + return Err(DataFusionError::Internal("error".to_string())) + } + }; + } + } + } + } + println!("match internal (AFTER) function data type: {:?}", sum); - if !self.is_empty && !self.is_not_null { - // This means there's a overflow in decimal, so we will just skip the rest - // of the computation - return Ok(()); + Ok(sum) } let values = &values[0]; - let data = values.as_primitive::(); + println!("accumulator data type: {:?}", self.input_data_type); - self.is_empty = self.is_empty && values.len() == values.null_count(); + println!( + "DEBUG: values[0] actual Rust type: {:?}, Arrow dtype: {:?}, len={}", + values.as_any().type_id(), + values.data_type(), + values.len() + ); - if values.null_count() == 0 { - for i in 0..data.len() { - self.update_single(data, i); - } + if values.len() == values.null_count() { + println!( + "ALL NULL in values accumulator data type: {:?}", + self.input_data_type + ); + Ok(()) } else { - for i in 0..data.len() { - if data.is_null(i) { - continue; + match values.data_type() { + DataType::Int64 => { + println!("match data type: {:?}", self.input_data_type); + update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?; } - self.update_single(data, i); - } - } + DataType::Int32 => { + println!("match data type: {:?}", self.input_data_type); + update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?; + } + DataType::Int16 => { + println!("match data type: {:?}", self.input_data_type); + update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?; + } + DataType::Int8 => { + println!("match data type: {:?}", self.input_data_type); + update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?; + } + _ => { + println!("unsupported input data type: {:?}", self.input_data_type); + panic!("Unsupported data type") + } + }; + println!( + "sum updated accumulator data type: {:?}", + self.input_data_type + ); - Ok(()) + Ok(()) + } } fn evaluate(&mut self) -> DFResult { - // For each group: - // 1. if `is_empty` is true, it means either there is no value or all values for the group - // are null, in this case we'll return null - // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In - // non-ANSI mode Spark returns null. - if self.is_empty || !self.is_not_null { - ScalarValue::new_primitive::(None, &DataType::Int64) - } else { - ScalarValue::try_new_decimal128(self.sum) - } + println!( + "evaluate :: accumulator data type: {:?}", + self.input_data_type + ); + Ok(ScalarValue::Int64(Some(self.sum))) } fn size(&self) -> usize { - std::mem::size_of_val(self) + println!("size :: accumulator data type: {:?}", self.input_data_type); + size_of_val(self) } fn state(&mut self) -> DFResult> { - let sum = if self.is_not_null { - ScalarValue::try_new_null(self.sum)? - } else { - ScalarValue::new_primitive::(None, &DataType::Int64)? - }; - Ok(vec![sum, ScalarValue::from(self.is_empty)]) + println!("state :: accumulator data type: {:?}", self.input_data_type); + Ok(vec![ScalarValue::Int64(Some(self.sum))]) } fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { - assert_eq!( - states.len(), - 2, - "Expect two element in 'states' but found {}", - states.len() + println!( + "merge batch :: accumulator data type: {:?}", + self.input_data_type ); - assert_eq!(states[0].len(), 1); - assert_eq!(states[1].len(), 1); - - let that_sum = states[0].as_primitive::(); - let that_is_empty = states[1].as_any().downcast_ref::().unwrap(); - - let this_overflow = !self.is_empty && !self.is_not_null; - let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0); - - self.is_not_null = !this_overflow && !that_overflow; - self.is_empty = self.is_empty && that_is_empty.value(0); - - if self.is_not_null { - self.sum += that_sum.value(0); + let that_sum = states[0].as_primitive::(); + match self.eval_mode { + EvalMode::Legacy | EvalMode::Try => { + self.sum.add_wrapping(that_sum.value(0)); + } + EvalMode::Ansi => match self.sum.add_checked(that_sum.value(0)) { + Ok(v) => self.sum = v, + Err(e) => return Err(DataFusionError::Internal("error".to_string())), + }, } - Ok(()) } } -struct SumIntegerGroupsAccumulator { - // Whether aggregate buffer for a particular group is null. True indicates it is not null. - is_not_null: BooleanBufferBuilder, - is_empty: BooleanBufferBuilder, - sum: Vec, - result_type: DataType, +struct SumDecimalGroupsAccumulator { + sums: Vec, + eval_mode: EvalMode, } -impl SumIntegerGroupsAccumulator { - fn new(result_type: DataType) -> Self { +impl SumDecimalGroupsAccumulator { + fn new(eval_mode: EvalMode) -> Self { Self { - is_not_null: BooleanBufferBuilder::new(0), - is_empty: BooleanBufferBuilder::new(0), - sum: Vec::new(), - result_type, + sums: Vec::new(), + eval_mode, } } - - fn is_overflow(&self, index: usize) -> bool { - !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index) - } - - #[inline] - fn update_single(&mut self, group_index: usize, value: i128) { - self.is_empty.set_bit(group_index, false); - let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value); - self.sum[group_index] = new_sum; - } } -fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) { - if builder.len() < capacity { - let additional = capacity - builder.len(); - builder.append_n(additional, true); - } -} - -impl GroupsAccumulator for SumIntegerGroupsAccumulator { +impl GroupsAccumulator for SumDecimalGroupsAccumulator { fn update_batch( &mut self, values: &[ArrayRef], @@ -279,66 +276,54 @@ impl GroupsAccumulator for SumIntegerGroupsAccumulator { total_num_groups: usize, ) -> DFResult<()> { assert!(opt_filter.is_none(), "opt_filter is not supported yet"); - assert_eq!(values.len(), 1); - let values = values[0].as_primitive::(); + let values = values[0].as_primitive::(); let data = values.values(); - - // Update size for the accumulate states - self.sum.resize(total_num_groups, 0); - ensure_bit_capacity(&mut self.is_empty, total_num_groups); - ensure_bit_capacity(&mut self.is_not_null, total_num_groups); + self.sums.resize(total_num_groups, 0); let iter = group_indices.iter().zip(data.iter()); - if values.null_count() == 0 { - for (&group_index, &value) in iter { - self.update_single(group_index, value); - } - } else { - for (idx, (&group_index, &value)) in iter.enumerate() { - if values.is_null(idx) { - continue; + + for (&group_index, &value) in iter { + match self.eval_mode { + EvalMode::Legacy | EvalMode::Try => { + self.sums[group_index].add_wrapping(value); + } + EvalMode::Ansi => { + match self.sums[group_index].add_checked(value) { + Ok(v) => v, + Err(e) => { + return Err(DataFusionError::Internal("integer overflow".to_string())) + } + }; } - self.update_single(group_index, value); } } - Ok(()) } fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { - // For each group: - // 1. if `is_empty` is true, it means either there is no value or all values for the group - // are null, in this case we'll return null - // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In - // non-ANSI mode Spark returns null. - let result = emit_to.take_needed(&mut self.sum); - result.iter().enumerate().for_each(|(i, &v)| {}); - - let nulls = build_bool_state(&mut self.is_not_null, &emit_to); - let is_empty = build_bool_state(&mut self.is_empty, &emit_to); - let x = (!&is_empty).bitand(&nulls); - - let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x))) - .with_data_type(self.result_type.clone()); - - Ok(Arc::new(result)) + match emit_to { + // When emitting all groups, return all calculated sums and reset the internal state. + EmitTo::All => { + // Create an Arrow array from the accumulated sums. + let result = Arc::new(Int64Array::from(self.sums.clone())) as ArrayRef; + // Reset the accumulator state for the next use. + self.sums.clear(); + Ok(result) + } + // When emitting the first `n` groups, return the first `n` sums + // and retain the state for the remaining groups. + EmitTo::First(n) => { + // Take the first `n` sums. + let emitted_sums: Vec = self.sums.drain(..n).collect(); + let result = Arc::new(Int64Array::from(emitted_sums)) as ArrayRef; + Ok(result) + } + } } fn state(&mut self, emit_to: EmitTo) -> DFResult> { - let nulls = build_bool_state(&mut self.is_not_null, &emit_to); - let nulls = Some(NullBuffer::new(nulls)); - - let sum = emit_to.take_needed(&mut self.sum); - let sum = Decimal128Array::new(sum.into(), nulls.clone()) - .with_data_type(self.result_type.clone()); - - let is_empty = build_bool_state(&mut self.is_empty, &emit_to); - let is_empty = BooleanArray::new(is_empty, None); - - Ok(vec![ - Arc::new(sum) as ArrayRef, - Arc::new(is_empty) as ArrayRef, - ]) + let state_array = Arc::new(Int64Array::from(self.sums.clone())); + Ok(vec![state_array]) } fn merge_batch( @@ -348,65 +333,43 @@ impl GroupsAccumulator for SumIntegerGroupsAccumulator { opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> DFResult<()> { - assert_eq!( - values.len(), - 2, - "Expected two arrays: 'sum' and 'is_empty', but found {}", - values.len() - ); assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + println!("merge batch : {:?}", values[0]); + let values = values[0].as_primitive::(); + let data = values.values(); + self.sums.resize(total_num_groups, 0); - // Make sure we have enough capacity for the additional groups - self.sum.resize(total_num_groups, 0); - ensure_bit_capacity(&mut self.is_empty, total_num_groups); - ensure_bit_capacity(&mut self.is_not_null, total_num_groups); - - let that_sum = &values[0]; - let that_sum = that_sum.as_primitive::(); - let that_is_empty = &values[1]; - let that_is_empty = that_is_empty - .as_any() - .downcast_ref::() - .unwrap(); + let iter = group_indices.iter().zip(data.iter()); - group_indices - .iter() - .enumerate() - .for_each(|(idx, &group_index)| unsafe { - let this_overflow = self.is_overflow(group_index); - let that_is_empty = that_is_empty.value_unchecked(idx); - let that_overflow = !that_is_empty && that_sum.is_null(idx); - let is_overflow = this_overflow || that_overflow; - - // This part follows the logic in Spark: - // `org.apache.spark.sql.catalyst.expressions.aggregate.Sum` - self.is_not_null.set_bit(group_index, !is_overflow); - self.is_empty.set_bit( - group_index, - self.is_empty.get_bit(group_index) && that_is_empty, - ); - if !is_overflow { - // .. otherwise, the sum value for this particular index must not be null, - // and thus we merge both values and update this sum. - self.sum[group_index] += that_sum.value_unchecked(idx); + for (&group_index, &value) in iter { + match self.eval_mode { + EvalMode::Legacy | EvalMode::Try => { + self.sums[group_index].add_wrapping(value); } - }); - + EvalMode::Ansi => { + match self.sums[group_index].add_checked(value) { + Ok(v) => v, + Err(e) => { + return Err(DataFusionError::Internal("integer overflow".to_string())) + } + }; + } + } + } Ok(()) } fn size(&self) -> usize { - self.sum.capacity() * std::mem::size_of::() - + self.is_empty.capacity() / 8 - + self.is_not_null.capacity() / 8 + size_of_val(self) } } #[cfg(test)] mod tests { use super::*; - use arrow::array::builder::{Decimal128Builder, StringBuilder}; - use arrow::array::RecordBatch; + use arrow::array::builder::StringBuilder; + use arrow::array::{Int64Builder, RecordBatch}; + use arrow::datatypes::DataType::Int64; use arrow::datatypes::*; use datafusion::common::Result; use datafusion::datasource::memory::MemorySourceConfig; @@ -422,7 +385,7 @@ mod tests { #[test] fn invalid_data_type() { - assert!(SumInteger::try_new(DataType::Int32).is_err()); + assert!(SumInteger::try_new(DataType::Date32, EvalMode::Legacy).is_err()); } #[tokio::test] @@ -437,14 +400,15 @@ mod tests { let c0: Arc = Arc::new(Column::new("c0", 0)); let c1: Arc = Arc::new(Column::new("c1", 1)); - let data_type = DataType::Decimal128(8, 2); + let data_type = Int64; let schema = Arc::clone(&partitions[0][0].schema()); let scan: Arc = Arc::new(DataSourceExec::new(Arc::new( - MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(), + MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None)?, ))); let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumInteger::try_new( data_type.clone(), + EvalMode::Legacy, )?)); let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) @@ -474,13 +438,13 @@ mod tests { } fn create_record_batch(num_rows: usize) -> RecordBatch { - let mut decimal_builder = Decimal128Builder::with_capacity(num_rows); + let mut int_builder = Int64Builder::with_capacity(num_rows); let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); for i in 0..num_rows { - decimal_builder.append_value(i as i128); + int_builder.append_value(i as i64); string_builder.append_value(format!("this is string #{}", i % 1024)); } - let decimal_array = Arc::new(decimal_builder.finish()); + let int_array = Arc::new(int_builder.finish()); let string_array = Arc::new(string_builder.finish()); let mut fields = vec![]; @@ -491,8 +455,8 @@ mod tests { columns.push(string_array); // decimal column - fields.push(Field::new("c1", DataType::Decimal128(38, 10), false)); - columns.push(decimal_array); + fields.push(Field::new("c1", DataType::Int64, false)); + columns.push(int_array); let schema = Schema::new(fields); RecordBatch::try_new(Arc::new(schema), columns).unwrap() diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 4b8a74c15a..a7228c77c5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -201,20 +201,6 @@ object CometSum extends CometAggregateExpressionSerde[Sum] { return None } - sum.evalMode match { - case EvalMode.ANSI if !CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get() => - withInfo( - aggExpr, - "ANSI mode is not supported. Set " + - s"${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true to allow it anyway") - return None - case EvalMode.TRY => - withInfo(aggExpr, "TRY mode is not supported") - return None - case _ => - // supported - } - val childExpr = exprToProto(sum.child, inputs, binding) val dataType = serializeDataType(sum.dataType) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index fc45d2cb3a..5dda16ccb2 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3003,6 +3003,28 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for SUM function") { + val data = Seq((Int.MaxValue, 10), (1, 1)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + withParquetTable(data, "tbl") { + val res = spark.sql(""" + |SELECT + | SUM(_1) + | from tbl + | """.stripMargin) + + res.show(10, false) +// checkSparkMaybeThrows(res) match { +// case (Some(sparkExc), Some(cometExc)) => +// assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG)) +// assert(sparkExc.getMessage.contains("overflow")) +// case _ => fail("Exception should be thrown") +// } + } + } + + } + test("test integral divide overflow for decimal") { if (isSpark40Plus) { Seq(true, false) From 4e994c6ed33cad582f3b4562c75dcb8106c99880 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Fri, 24 Oct 2025 12:48:48 -0700 Subject: [PATCH 03/18] conf_bug_fix --- native/spark-expr/src/agg_funcs/sum_int.rs | 115 ++++++++---------- .../apache/comet/CometExpressionSuite.scala | 42 ++++--- 2 files changed, 76 insertions(+), 81 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index fa8df6b90d..d2fc59a0d0 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::EvalMode; +use crate::{arithmetic_overflow_error, EvalMode}; use arrow::array::{ cast::AsArray, Array, ArrayBuilder, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, Int64Array, PrimitiveArray, @@ -39,8 +39,8 @@ pub struct SumInteger { impl SumInteger { pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { - // The `data_type` is the SUM result type passed from Spark side - println!("data type: {:?}", data_type); + // The `data_type` is the SUM result type passed from Spark side which should i64 + println!("data type: {:?} eval_mode {:?}", data_type, eval_mode); match data_type { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self { signature: Signature::user_defined(Immutable), @@ -75,14 +75,14 @@ impl AggregateUDFImpl for SumInteger { } fn accumulator(&self, acc_args: AccumulatorArgs) -> DFResult> { - Ok(Box::new(SumIntegerAccumulator::new())) + Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode))) } fn create_groups_accumulator( &self, _args: AccumulatorArgs, ) -> DFResult> { - Ok(Box::new(SumDecimalGroupsAccumulator::new(self.eval_mode))) + Ok(Box::new(SumIntGroupsAccumulator::new(self.eval_mode))) } } @@ -94,10 +94,10 @@ struct SumIntegerAccumulator { } impl SumIntegerAccumulator { - fn new() -> Self { + fn new(eval_mode: EvalMode) -> Self { Self { sum: 0, - eval_mode: EvalMode::Legacy, + eval_mode, input_data_type: DataType::Int64, } } @@ -113,13 +113,13 @@ impl Accumulator for SumIntegerAccumulator { where T: ArrowPrimitiveType, { - println!("match internal function data type: {:?}", sum); let len = int_array.len(); for i in 0..int_array.len() { if !int_array.is_null(i) { let v = int_array.value(i).to_i64().ok_or_else(|| { DataFusionError::Internal("Failed to convert value to i64".to_string()) })?; + println!("sum : {:?}, v : {:?}", sum, v); match eval_mode { EvalMode::Legacy | EvalMode::Try => { sum = v.add_wrapping(sum); @@ -128,7 +128,7 @@ impl Accumulator for SumIntegerAccumulator { match v.add_checked(sum) { Ok(v) => sum = v, Err(e) => { - return Err(DataFusionError::Internal("error".to_string())) + return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) } }; } @@ -157,53 +157,40 @@ impl Accumulator for SumIntegerAccumulator { ); Ok(()) } else { - match values.data_type() { - DataType::Int64 => { - println!("match data type: {:?}", self.input_data_type); - update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - )?; - } - DataType::Int32 => { - println!("match data type: {:?}", self.input_data_type); - update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - )?; - } - DataType::Int16 => { - println!("match data type: {:?}", self.input_data_type); - update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - )?; - } - DataType::Int8 => { - println!("match data type: {:?}", self.input_data_type); - update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - )?; - } + self.sum = match values.data_type() { + DataType::Int64 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?, + DataType::Int32 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?, + DataType::Int16 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?, + DataType::Int8 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?, _ => { - println!("unsupported input data type: {:?}", self.input_data_type); panic!("Unsupported data type") } }; @@ -246,19 +233,19 @@ impl Accumulator for SumIntegerAccumulator { } EvalMode::Ansi => match self.sum.add_checked(that_sum.value(0)) { Ok(v) => self.sum = v, - Err(e) => return Err(DataFusionError::Internal("error".to_string())), + Err(e) => return Err(DataFusionError::from(arithmetic_overflow_error("integer"))), }, } Ok(()) } } -struct SumDecimalGroupsAccumulator { +struct SumIntGroupsAccumulator { sums: Vec, eval_mode: EvalMode, } -impl SumDecimalGroupsAccumulator { +impl SumIntGroupsAccumulator { fn new(eval_mode: EvalMode) -> Self { Self { sums: Vec::new(), @@ -267,7 +254,7 @@ impl SumDecimalGroupsAccumulator { } } -impl GroupsAccumulator for SumDecimalGroupsAccumulator { +impl GroupsAccumulator for SumIntGroupsAccumulator { fn update_batch( &mut self, values: &[ArrayRef], @@ -285,13 +272,13 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { for (&group_index, &value) in iter { match self.eval_mode { EvalMode::Legacy | EvalMode::Try => { - self.sums[group_index].add_wrapping(value); + self.sums[group_index] = self.sums[group_index].add_wrapping(value); } EvalMode::Ansi => { match self.sums[group_index].add_checked(value) { - Ok(v) => v, + Ok(v) => self.sums[group_index] = v, Err(e) => { - return Err(DataFusionError::Internal("integer overflow".to_string())) + return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) } }; } @@ -344,11 +331,11 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { for (&group_index, &value) in iter { match self.eval_mode { EvalMode::Legacy | EvalMode::Try => { - self.sums[group_index].add_wrapping(value); + self.sums[group_index] = self.sums[group_index].add_wrapping(value); } EvalMode::Ansi => { match self.sums[group_index].add_checked(value) { - Ok(v) => v, + Ok(v) => self.sums[group_index] = v, Err(e) => { return Err(DataFusionError::Internal("integer overflow".to_string())) } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 5dda16ccb2..40e9a6c9e7 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3004,24 +3004,32 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("ANSI support for SUM function") { - val data = Seq((Int.MaxValue, 10), (1, 1)) - withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { - withParquetTable(data, "tbl") { - val res = spark.sql(""" - |SELECT - | SUM(_1) - | from tbl - | """.stripMargin) - - res.show(10, false) -// checkSparkMaybeThrows(res) match { -// case (Some(sparkExc), Some(cometExc)) => -// assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG)) -// assert(sparkExc.getMessage.contains("overflow")) -// case _ => fail("Exception should be thrown") -// } + val batchSize = 10 + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test_sum.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, batchSize) + withParquetTable(path.toString, "tbl") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + spark.table("tbl").printSchema() +// val res = spark.sql( +// """ +// |SELECT +// | SUM(_1) +// | from tbl +// | """.stripMargin) +// checkSparkAnswerAndOperator(res) + } + // res.show(10, false) +// checkSparkMaybeThrows(res) match { +// case (Some(sparkExc), Some(cometExc)) => +// assert(cometExc.getMessage.contains("error")) +// assert(sparkExc.getMessage.contains("overflow")) +// case _ => fail("Exception should be thrown") +// } + } + } } - } } From 79ef864872ba15a64aaf87eae222b52b08d7ed10 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Fri, 24 Oct 2025 13:27:21 -0700 Subject: [PATCH 04/18] impl_try_mode --- native/core/src/execution/planner.rs | 7 +-- native/proto/src/proto/expr.proto | 2 +- native/spark-expr/src/agg_funcs/sum_int.rs | 44 ++++++++++++++----- .../org/apache/comet/serde/aggregates.scala | 7 ++- .../apache/comet/CometExpressionSuite.scala | 4 +- 5 files changed, 42 insertions(+), 22 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 23f500b5e5..d043baf919 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1834,12 +1834,7 @@ impl PhysicalPlanner { AggregateExprBuilder::new(Arc::new(func), vec![child]) } DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - // let eval_mode = let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - let eval_mode = if expr.fail_on_error { - EvalMode::Ansi - } else { - EvalMode::Legacy - }; + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; let func = AggregateUDF::new_from_impl(SumInteger::try_new(datatype, eval_mode)?); AggregateExprBuilder::new(Arc::new(func), vec![child]) diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index c9037dcd69..a7736f561a 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -120,7 +120,7 @@ message Count { message Sum { Expr child = 1; DataType datatype = 2; - bool fail_on_error = 3; + EvalMode eval_mode = 3; } message Min { diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index d2fc59a0d0..4cf83fab9c 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -41,6 +41,7 @@ impl SumInteger { pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { // The `data_type` is the SUM result type passed from Spark side which should i64 println!("data type: {:?} eval_mode {:?}", data_type, eval_mode); + match data_type { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self { signature: Signature::user_defined(Immutable), @@ -121,14 +122,19 @@ impl Accumulator for SumIntegerAccumulator { })?; println!("sum : {:?}, v : {:?}", sum, v); match eval_mode { - EvalMode::Legacy | EvalMode::Try => { + EvalMode::Legacy => { sum = v.add_wrapping(sum); } - EvalMode::Ansi => { + EvalMode::Ansi | EvalMode::Try => { match v.add_checked(sum) { Ok(v) => sum = v, Err(e) => { - return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) + if (eval_mode == EvalMode::Ansi){ + return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) + } + else { + sum = None.unwrap(); + } } }; } @@ -228,12 +234,18 @@ impl Accumulator for SumIntegerAccumulator { ); let that_sum = states[0].as_primitive::(); match self.eval_mode { - EvalMode::Legacy | EvalMode::Try => { + EvalMode::Legacy => { self.sum.add_wrapping(that_sum.value(0)); } - EvalMode::Ansi => match self.sum.add_checked(that_sum.value(0)) { + EvalMode::Ansi | EvalMode::Try => match self.sum.add_checked(that_sum.value(0)) { Ok(v) => self.sum = v, - Err(e) => return Err(DataFusionError::from(arithmetic_overflow_error("integer"))), + Err(e) => + if (self.eval_mode == EvalMode::Ansi){ + return Err(DataFusionError::from(arithmetic_overflow_error("integer"))), + } + else{ + self.sum = None.unwrap(); + } }, } Ok(()) @@ -271,14 +283,19 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { for (&group_index, &value) in iter { match self.eval_mode { - EvalMode::Legacy | EvalMode::Try => { + EvalMode::Legacy => { self.sums[group_index] = self.sums[group_index].add_wrapping(value); } - EvalMode::Ansi => { + EvalMode::Ansi | EvalMode::Try => { match self.sums[group_index].add_checked(value) { Ok(v) => self.sums[group_index] = v, Err(e) => { - return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) + if (self.eval_mode == EvalMode::Ansi){ + return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) + } + else{ + self.sums[group_index] = None.unwrap(); + } } }; } @@ -333,11 +350,16 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { EvalMode::Legacy | EvalMode::Try => { self.sums[group_index] = self.sums[group_index].add_wrapping(value); } - EvalMode::Ansi => { + EvalMode::Ansi | EvalMode::Try => { match self.sums[group_index].add_checked(value) { Ok(v) => self.sums[group_index] = v, Err(e) => { - return Err(DataFusionError::Internal("integer overflow".to_string())) + if (self.eval_mode == EvalMode::Ansi){ + return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) + } + else{ + self.sums[group_index] = None.unwrap(); + } } }; } diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index a7228c77c5..3af5c231b8 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType, LongType, import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} +import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType} +import org.apache.comet.shims.CometEvalModeUtil object CometMin extends CometAggregateExpressionSerde[Min] { @@ -201,6 +202,8 @@ object CometSum extends CometAggregateExpressionSerde[Sum] { return None } + val evalMode = sum.evalMode + val childExpr = exprToProto(sum.child, inputs, binding) val dataType = serializeDataType(sum.dataType) @@ -208,7 +211,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] { val builder = ExprOuterClass.Sum.newBuilder() builder.setChild(childExpr.get) builder.setDatatype(dataType.get) - builder.setFailOnError(sum.evalMode == EvalMode.ANSI) + builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(evalMode))) Some( ExprOuterClass.AggExpr diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 40e9a6c9e7..f8d23e71ee 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3020,16 +3020,16 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { // | """.stripMargin) // checkSparkAnswerAndOperator(res) } - // res.show(10, false) + // res.show(10, false) // checkSparkMaybeThrows(res) match { // case (Some(sparkExc), Some(cometExc)) => // assert(cometExc.getMessage.contains("error")) // assert(sparkExc.getMessage.contains("overflow")) // case _ => fail("Exception should be thrown") // } - } } } + } } From ef6dcae29dd7009a90b77f91f960141e04e36da4 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Fri, 24 Oct 2025 13:33:00 -0700 Subject: [PATCH 05/18] impl_try_mode --- native/spark-expr/src/agg_funcs/sum_int.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index 4cf83fab9c..845d3846fe 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -241,7 +241,7 @@ impl Accumulator for SumIntegerAccumulator { Ok(v) => self.sum = v, Err(e) => if (self.eval_mode == EvalMode::Ansi){ - return Err(DataFusionError::from(arithmetic_overflow_error("integer"))), + return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) } else{ self.sum = None.unwrap(); From 6f166c6ea4dc00d2de952f2799123fea929f7073 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 28 Oct 2025 10:40:33 -0700 Subject: [PATCH 06/18] squash_commits --- native/spark-expr/src/agg_funcs/sum_int.rs | 343 ++++++++++-------- .../apache/comet/CometExpressionSuite.scala | 285 +++++++++++++-- 2 files changed, 453 insertions(+), 175 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index 845d3846fe..962cbdcd8a 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -17,8 +17,8 @@ use crate::{arithmetic_overflow_error, EvalMode}; use arrow::array::{ - cast::AsArray, Array, ArrayBuilder, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, - BooleanArray, Int64Array, PrimitiveArray, + cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, + Int64Array, PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, Int8Type}; use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; @@ -39,9 +39,6 @@ pub struct SumInteger { impl SumInteger { pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { - // The `data_type` is the SUM result type passed from Spark side which should i64 - println!("data type: {:?} eval_mode {:?}", data_type, eval_mode); - match data_type { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self { signature: Signature::user_defined(Immutable), @@ -67,7 +64,7 @@ impl AggregateUDFImpl for SumInteger { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> DFResult { + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { Ok(DataType::Int64) } @@ -75,7 +72,7 @@ impl AggregateUDFImpl for SumInteger { true } - fn accumulator(&self, acc_args: AccumulatorArgs) -> DFResult> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode))) } @@ -89,17 +86,15 @@ impl AggregateUDFImpl for SumInteger { #[derive(Debug)] struct SumIntegerAccumulator { - sum: i64, + sum: Option, eval_mode: EvalMode, - input_data_type: DataType, } impl SumIntegerAccumulator { fn new(eval_mode: EvalMode) -> Self { Self { - sum: 0, + sum: Some(0), eval_mode, - input_data_type: DataType::Int64, } } } @@ -109,151 +104,132 @@ impl Accumulator for SumIntegerAccumulator { fn update_sum_internal( int_array: &PrimitiveArray, eval_mode: EvalMode, - mut sum: i64, - ) -> Result + sum: Option, + ) -> Result, DataFusionError> where T: ArrowPrimitiveType, { - let len = int_array.len(); + let mut curr_sum = sum.unwrap(); for i in 0..int_array.len() { if !int_array.is_null(i) { let v = int_array.value(i).to_i64().ok_or_else(|| { DataFusionError::Internal("Failed to convert value to i64".to_string()) })?; - println!("sum : {:?}, v : {:?}", sum, v); match eval_mode { EvalMode::Legacy => { - sum = v.add_wrapping(sum); + curr_sum = v.add_wrapping(curr_sum); } EvalMode::Ansi | EvalMode::Try => { - match v.add_checked(sum) { - Ok(v) => sum = v, - Err(e) => { - if (eval_mode == EvalMode::Ansi){ - return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) - } - else { - sum = None.unwrap(); - } + match v.add_checked(curr_sum) { + Ok(v) => curr_sum = v, + Err(_e) => { + return if eval_mode == EvalMode::Ansi { + Err(DataFusionError::from(arithmetic_overflow_error( + "integer", + ))) + } else { + Ok(None) + }; } }; } } } } - println!("match internal (AFTER) function data type: {:?}", sum); - - Ok(sum) + Ok(Some(curr_sum)) } - let values = &values[0]; - println!("accumulator data type: {:?}", self.input_data_type); - - println!( - "DEBUG: values[0] actual Rust type: {:?}, Arrow dtype: {:?}, len={}", - values.as_any().type_id(), - values.data_type(), - values.len() - ); - - if values.len() == values.null_count() { - println!( - "ALL NULL in values accumulator data type: {:?}", - self.input_data_type - ); + if self.sum.is_none() { Ok(()) } else { - self.sum = match values.data_type() { - DataType::Int64 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - )?, - DataType::Int32 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - )?, - DataType::Int16 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - )?, - DataType::Int8 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - )?, - _ => { - panic!("Unsupported data type") - } - }; - println!( - "sum updated accumulator data type: {:?}", - self.input_data_type - ); - - Ok(()) + let values = &values[0]; + if values.len() == values.null_count() { + Ok(()) + } else { + self.sum = match values.data_type() { + DataType::Int64 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?, + DataType::Int32 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?, + DataType::Int16 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?, + DataType::Int8 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + self.sum, + )?, + _ => { + panic!("Unsupported data type") + } + }; + Ok(()) + } } } fn evaluate(&mut self) -> DFResult { - println!( - "evaluate :: accumulator data type: {:?}", - self.input_data_type - ); - Ok(ScalarValue::Int64(Some(self.sum))) + Ok(ScalarValue::Int64(self.sum)) } fn size(&self) -> usize { - println!("size :: accumulator data type: {:?}", self.input_data_type); size_of_val(self) } fn state(&mut self) -> DFResult> { - println!("state :: accumulator data type: {:?}", self.input_data_type); - Ok(vec![ScalarValue::Int64(Some(self.sum))]) + Ok(vec![ScalarValue::Int64(self.sum)]) } fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { - println!( - "merge batch :: accumulator data type: {:?}", - self.input_data_type - ); + if self.sum.is_none() { + return Ok(()); + } let that_sum = states[0].as_primitive::(); match self.eval_mode { EvalMode::Legacy => { - self.sum.add_wrapping(that_sum.value(0)); + self.sum = Some(self.sum.unwrap().add_wrapping(that_sum.value(0))); } - EvalMode::Ansi | EvalMode::Try => match self.sum.add_checked(that_sum.value(0)) { - Ok(v) => self.sum = v, - Err(e) => - if (self.eval_mode == EvalMode::Ansi){ - return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) - } - else{ - self.sum = None.unwrap(); + EvalMode::Ansi | EvalMode::Try => { + match self.sum.unwrap().add_checked(that_sum.value(0)) { + Ok(v) => self.sum = Some(v), + Err(_e) => { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error( + "integer", + ))); + } else { + self.sum = None + } + } } - }, + } } Ok(()) } } struct SumIntGroupsAccumulator { - sums: Vec, + sums: Vec>, eval_mode: EvalMode, } @@ -274,33 +250,93 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> DFResult<()> { - assert!(opt_filter.is_none(), "opt_filter is not supported yet"); - let values = values[0].as_primitive::(); - let data = values.values(); - self.sums.resize(total_num_groups, 0); - - let iter = group_indices.iter().zip(data.iter()); - - for (&group_index, &value) in iter { - match self.eval_mode { - EvalMode::Legacy => { - self.sums[group_index] = self.sums[group_index].add_wrapping(value); - } - EvalMode::Ansi | EvalMode::Try => { - match self.sums[group_index].add_checked(value) { - Ok(v) => self.sums[group_index] = v, - Err(e) => { - if (self.eval_mode == EvalMode::Ansi){ - return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) - } - else{ - self.sums[group_index] = None.unwrap(); - } + fn update_groups_sum_internal( + int_array: &PrimitiveArray, + group_indices: &[usize], + sums: &mut [Option], + eval_mode: EvalMode, + ) -> DFResult<()> + where + T: ArrowPrimitiveType, + T::Native: ArrowNativeType, + { + for (i, &group_index) in group_indices.iter().enumerate() { + if sums[group_index].is_some() && !int_array.is_null(i) { + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal("Failed to convert value to i64".to_string()) + })?; + match eval_mode { + EvalMode::Legacy => { + sums[group_index] = Some(sums[group_index].unwrap().add_wrapping(v)); + } + EvalMode::Ansi | EvalMode::Try => { + match sums[group_index].unwrap().add_checked(v) { + Ok(new_sum) => sums[group_index] = Some(new_sum), + Err(_) => { + if eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from( + arithmetic_overflow_error("integer"), + )); + } else { + sums[group_index] = None + } + } + }; } - }; + } } } + Ok(()) } + + assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + let values = &values[0]; + self.sums.resize(total_num_groups, Some(0)); + + match values.data_type() { + DataType::Int64 => update_groups_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + group_indices, + &mut self.sums, + self.eval_mode, + )?, + DataType::Int32 => update_groups_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + group_indices, + &mut self.sums, + self.eval_mode, + )?, + DataType::Int16 => update_groups_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + group_indices, + &mut self.sums, + self.eval_mode, + )?, + DataType::Int8 => update_groups_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + group_indices, + &mut self.sums, + self.eval_mode, + )?, + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported data type for SumIntGroupsAccumulator: {:?}", + values.data_type() + ))) + } + }; Ok(()) } @@ -318,14 +354,14 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { // and retain the state for the remaining groups. EmitTo::First(n) => { // Take the first `n` sums. - let emitted_sums: Vec = self.sums.drain(..n).collect(); - let result = Arc::new(Int64Array::from(emitted_sums)) as ArrayRef; + let result = Arc::new(Int64Array::from(self.sums.drain(..n).collect::>())) + as ArrayRef; Ok(result) } } } - fn state(&mut self, emit_to: EmitTo) -> DFResult> { + fn state(&mut self, _emit_to: EmitTo) -> DFResult> { let state_array = Arc::new(Int64Array::from(self.sums.clone())); Ok(vec![state_array]) } @@ -338,30 +374,33 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { total_num_groups: usize, ) -> DFResult<()> { assert!(opt_filter.is_none(), "opt_filter is not supported yet"); - println!("merge batch : {:?}", values[0]); let values = values[0].as_primitive::(); let data = values.values(); - self.sums.resize(total_num_groups, 0); + self.sums.resize(total_num_groups, Some(0)); let iter = group_indices.iter().zip(data.iter()); for (&group_index, &value) in iter { - match self.eval_mode { - EvalMode::Legacy | EvalMode::Try => { - self.sums[group_index] = self.sums[group_index].add_wrapping(value); - } - EvalMode::Ansi | EvalMode::Try => { - match self.sums[group_index].add_checked(value) { - Ok(v) => self.sums[group_index] = v, - Err(e) => { - if (self.eval_mode == EvalMode::Ansi){ - return Err(DataFusionError::from(arithmetic_overflow_error("integer"))) - } - else{ - self.sums[group_index] = None.unwrap(); + if self.sums[group_index].is_some() { + match self.eval_mode { + EvalMode::Legacy => { + self.sums[group_index] = + Some(self.sums[group_index].unwrap().add_wrapping(value)); + } + EvalMode::Ansi | EvalMode::Try => { + match self.sums[group_index].unwrap().add_checked(value) { + Ok(v) => self.sums[group_index] = Some(v), + Err(_e) => { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error( + "integer", + ))); + } else { + self.sums[group_index] = None + } } - } - }; + }; + } } } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index f8d23e71ee..4a914c25e2 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3003,34 +3003,273 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("ANSI support for SUM function") { - val batchSize = 10 - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test_sum.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, batchSize) - withParquetTable(path.toString, "tbl") { - withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { - spark.table("tbl").printSchema() -// val res = spark.sql( -// """ -// |SELECT -// | SUM(_1) -// | from tbl -// | """.stripMargin) -// checkSparkAnswerAndOperator(res) + test("ANSI support - SUM function") { +// Test long overflow + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + // Test long overflow + withParquetTable(Seq((Long.MaxValue, 1L), (100L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => fail("Exception should be thrown for Long overflow in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test long underflow + withParquetTable(Seq((Long.MinValue, 1L), (-100L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => fail("Exception should be thrown for Long underflow in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test Int SUM (should not overflow) + withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 1)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + // Test Int SUM (should not overflow) + withParquetTable( + Seq((Short.MaxValue, 1.toShort), (Short.MaxValue, 1.toShort), (100.toShort, 1.toShort)), + "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + + // Test Byte SUM (should not overflow) + withParquetTable( + Seq((Byte.MaxValue, 1.toByte), (Byte.MaxValue, 1.toByte), (10.toByte, 1.toByte)), + "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + + // Test Long overflow with NULL values + withParquetTable(Seq((Long.MaxValue, 1L), (null, 1L), (100L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for Long overflow with NULLs in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) } - // res.show(10, false) -// checkSparkMaybeThrows(res) match { -// case (Some(sparkExc), Some(cometExc)) => -// assert(cometExc.getMessage.contains("error")) -// assert(sparkExc.getMessage.contains("overflow")) -// case _ => fail("Exception should be thrown") -// } + } + + // Test Long underflow with NULL values + withParquetTable(Seq((Long.MinValue, 1L), (null, 1L), (-100L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for Long underflow with NULLs in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + + // Test only NULL inputs + withParquetTable(Seq((null, 1L), (null, 1L), (null, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + + // Test with mix of valid numbers and NULLs (no overflow) + withParquetTable(Seq((100L, 1L), (null, 1L), (200L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + + // Test with mix of valid numbers and NULLs (no overflow) - overflow + withParquetTable(Seq((Long.MaxValue, 1L), (null, 1L), (200L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) } } } + } + test("SUM overflow - GROUP BY") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + // Test Long overflow with GROUP BY to test GroupAccumulator + withParquetTable( + Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (200L, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2) + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for Long overflow with GROUP BY in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + + withParquetTable( + Seq((Long.MinValue, 1), (-100L, 1), (Long.MinValue, 2), (-200L, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for Long underflow with GROUP BY in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test Int with GROUP BY + withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 2), (200, 2)), "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + // Test Short with GROUP BY + withParquetTable( + Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + + // Test Byte with GROUP BY + withParquetTable( + Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + + // Test Long overflow with GROUP BY and NULL values + withParquetTable( + Seq((Long.MaxValue, 1), (null, 1), (100L, 1), (200L, 2), (null, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail( + "Exception should be thrown for Long overflow with GROUP BY and NULLs in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + + // Test Long underflow with GROUP BY and NULL values + withParquetTable( + Seq((Long.MinValue, 1), (null, 1), (-100L, 1), (-200L, 2), (null, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + if (ansiEnabled) { + checkSparkMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail( + "Exception should be thrown for Long underflow with GROUP BY and NULLs in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + + // Test GROUP BY with only NULL values + withParquetTable(Seq((null, 1), (null, 1), (null, 2), (null, 2)), "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + + // Test GROUP BY with mix of valid values and NULLs (no overflow) + withParquetTable(Seq((100L, 1), (null, 1), (200L, 1), (300L, 2), (null, 2)), "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + } + } + } + + test("try_sum overflow - with GROUP BY") { + // Test Long overflow with GROUP BY - some groups overflow, some don't + withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (200L, 2), (300L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // Group 1 should return NULL (overflow), Group 2 should return 500 + checkSparkAnswerAndOperator(res) + } + + // Test Long underflow with GROUP BY + withParquetTable(Seq((Long.MinValue, 1), (-100L, 1), (-200L, 2), (-300L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // Group 1 should return NULL (underflow), Group 2 should return -500 + checkSparkAnswerAndOperator(res) + } + + // Test all groups overflow + withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (100L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // Both groups should return NULL + checkSparkAnswerAndOperator(res) + } + + // Test with NULL values mixed with overflow + withParquetTable(Seq((Long.MaxValue, 1), (null, 1), (100L, 2), (null, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // Group 1 should return NULL (overflow), Group 2 should return 100 + checkSparkAnswerAndOperator(res) + } + + // Test Int with GROUP BY (should NOT overflow since accumulator is i64) + withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 2), (200, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } + + // Test Short with GROUP BY (should NOT overflow) + withParquetTable( + Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } + + // Test Byte with GROUP BY (should NOT overflow) + withParquetTable( + Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } } test("test integral divide overflow for decimal") { From bf33672808670c38da422d8a19c5412d007a8aca Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 29 Oct 2025 16:54:02 -0700 Subject: [PATCH 07/18] sum_ansi_mode_checks --- native/spark-expr/src/agg_funcs/sum_int.rs | 177 ++++++++++++------ .../apache/comet/CometExpressionSuite.scala | 137 ++------------ 2 files changed, 135 insertions(+), 179 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index 962cbdcd8a..add96eb67f 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -20,9 +20,11 @@ use arrow::array::{ cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, Int64Array, PrimitiveArray, }; -use arrow::datatypes::{ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, Int8Type}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, +}; use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; -use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::Volatility::Immutable; use datafusion::logical_expr::{ Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, @@ -68,14 +70,25 @@ impl AggregateUDFImpl for SumInteger { Ok(DataType::Int64) } - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { - true - } - fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult> { Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode))) } + fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { + if self.eval_mode == EvalMode::Try { + Ok(vec![ + Arc::new(Field::new("sum", DataType::Int64, true)), + Arc::new(Field::new("is_null", DataType::Boolean, false)), + ]) + } else { + Ok(vec![Arc::new(Field::new("sum", DataType::Int64, true))]) + } + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + fn create_groups_accumulator( &self, _args: AccumulatorArgs, @@ -86,49 +99,52 @@ impl AggregateUDFImpl for SumInteger { #[derive(Debug)] struct SumIntegerAccumulator { - sum: Option, + sum: i64, eval_mode: EvalMode, + is_null: bool, } impl SumIntegerAccumulator { fn new(eval_mode: EvalMode) -> Self { Self { - sum: Some(0), + sum: 0, eval_mode, + is_null: false, } } } impl Accumulator for SumIntegerAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { + // accumulator internal to add sum and return is_null: true if there is an overflow in Try Eval mode fn update_sum_internal( int_array: &PrimitiveArray, eval_mode: EvalMode, - sum: Option, - ) -> Result, DataFusionError> + mut sum: i64, + is_null: bool, + ) -> Result<(i64, bool), DataFusionError> where T: ArrowPrimitiveType, { - let mut curr_sum = sum.unwrap(); for i in 0..int_array.len() { - if !int_array.is_null(i) { + if !is_null && !int_array.is_null(i) { let v = int_array.value(i).to_i64().ok_or_else(|| { DataFusionError::Internal("Failed to convert value to i64".to_string()) })?; match eval_mode { EvalMode::Legacy => { - curr_sum = v.add_wrapping(curr_sum); + sum = v.add_wrapping(sum); } EvalMode::Ansi | EvalMode::Try => { - match v.add_checked(curr_sum) { - Ok(v) => curr_sum = v, + match v.add_checked(sum) { + Ok(v) => sum = v, Err(_e) => { return if eval_mode == EvalMode::Ansi { Err(DataFusionError::from(arithmetic_overflow_error( "integer", ))) } else { - Ok(None) + return Ok((sum, true)); }; } }; @@ -136,17 +152,17 @@ impl Accumulator for SumIntegerAccumulator { } } } - Ok(Some(curr_sum)) + Ok((sum, false)) } - if self.sum.is_none() { + if self.is_null { Ok(()) } else { let values = &values[0]; if values.len() == values.null_count() { Ok(()) } else { - self.sum = match values.data_type() { + let (sum, is_overflow) = match values.data_type() { DataType::Int64 => update_sum_internal( values .as_any() @@ -154,6 +170,7 @@ impl Accumulator for SumIntegerAccumulator { .unwrap(), self.eval_mode, self.sum, + self.is_null, )?, DataType::Int32 => update_sum_internal( values @@ -162,6 +179,7 @@ impl Accumulator for SumIntegerAccumulator { .unwrap(), self.eval_mode, self.sum, + self.is_null, )?, DataType::Int16 => update_sum_internal( values @@ -170,6 +188,7 @@ impl Accumulator for SumIntegerAccumulator { .unwrap(), self.eval_mode, self.sum, + self.is_null, )?, DataType::Int8 => update_sum_internal( values @@ -178,18 +197,26 @@ impl Accumulator for SumIntegerAccumulator { .unwrap(), self.eval_mode, self.sum, + self.is_null, )?, _ => { panic!("Unsupported data type") } }; + + self.sum = sum; + self.is_null = is_overflow; Ok(()) } } } fn evaluate(&mut self) -> DFResult { - Ok(ScalarValue::Int64(self.sum)) + if self.is_null { + Ok(ScalarValue::Int64(None)) + } else { + Ok(ScalarValue::Int64(Some(self.sum))) + } } fn size(&self) -> usize { @@ -197,39 +224,48 @@ impl Accumulator for SumIntegerAccumulator { } fn state(&mut self) -> DFResult> { - Ok(vec![ScalarValue::Int64(self.sum)]) + if self.eval_mode == EvalMode::Try { + Ok(vec![ + ScalarValue::Int64(Some(self.sum)), + ScalarValue::Boolean(Some(self.is_null)), + ]) + } else { + Ok(vec![ScalarValue::Int64(Some(self.sum))]) + } } fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { - if self.sum.is_none() { + if self.is_null { return Ok(()); } let that_sum = states[0].as_primitive::(); + + if self.eval_mode == EvalMode::Try && states[1].as_boolean().value(0) { + return Ok(()); + } + match self.eval_mode { EvalMode::Legacy => { - self.sum = Some(self.sum.unwrap().add_wrapping(that_sum.value(0))); + self.sum = self.sum.add_wrapping(that_sum.value(0)); } - EvalMode::Ansi | EvalMode::Try => { - match self.sum.unwrap().add_checked(that_sum.value(0)) { - Ok(v) => self.sum = Some(v), - Err(_e) => { - if self.eval_mode == EvalMode::Ansi { - return Err(DataFusionError::from(arithmetic_overflow_error( - "integer", - ))); - } else { - self.sum = None - } + EvalMode::Ansi | EvalMode::Try => match self.sum.add_checked(that_sum.value(0)) { + Ok(v) => self.sum = v, + Err(_e) => { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("integer"))); + } else { + self.is_null = true; } } - } + }, } Ok(()) } } struct SumIntGroupsAccumulator { - sums: Vec>, + sums: Vec, + has_nulls: Vec, eval_mode: EvalMode, } @@ -238,6 +274,7 @@ impl SumIntGroupsAccumulator { Self { sums: Vec::new(), eval_mode, + has_nulls: Vec::new(), } } } @@ -253,7 +290,8 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { fn update_groups_sum_internal( int_array: &PrimitiveArray, group_indices: &[usize], - sums: &mut [Option], + sums: &mut [i64], + has_nulls: &mut [bool], eval_mode: EvalMode, ) -> DFResult<()> where @@ -261,24 +299,24 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { T::Native: ArrowNativeType, { for (i, &group_index) in group_indices.iter().enumerate() { - if sums[group_index].is_some() && !int_array.is_null(i) { + if !has_nulls[group_index] && !int_array.is_null(i) { let v = int_array.value(i).to_i64().ok_or_else(|| { DataFusionError::Internal("Failed to convert value to i64".to_string()) })?; match eval_mode { EvalMode::Legacy => { - sums[group_index] = Some(sums[group_index].unwrap().add_wrapping(v)); + sums[group_index] = sums[group_index].add_wrapping(v); } EvalMode::Ansi | EvalMode::Try => { - match sums[group_index].unwrap().add_checked(v) { - Ok(new_sum) => sums[group_index] = Some(new_sum), + match sums[group_index].add_checked(v) { + Ok(new_sum) => sums[group_index] = new_sum, Err(_) => { if eval_mode == EvalMode::Ansi { return Err(DataFusionError::from( arithmetic_overflow_error("integer"), )); } else { - sums[group_index] = None + has_nulls[group_index] = true } } }; @@ -291,7 +329,8 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { assert!(opt_filter.is_none(), "opt_filter is not supported yet"); let values = &values[0]; - self.sums.resize(total_num_groups, Some(0)); + self.sums.resize(total_num_groups, 0); + self.has_nulls.resize(total_num_groups, false); match values.data_type() { DataType::Int64 => update_groups_sum_internal( @@ -301,6 +340,7 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { .unwrap(), group_indices, &mut self.sums, + &mut self.has_nulls, self.eval_mode, )?, DataType::Int32 => update_groups_sum_internal( @@ -310,6 +350,7 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { .unwrap(), group_indices, &mut self.sums, + &mut self.has_nulls, self.eval_mode, )?, DataType::Int16 => update_groups_sum_internal( @@ -319,6 +360,7 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { .unwrap(), group_indices, &mut self.sums, + &mut self.has_nulls, self.eval_mode, )?, DataType::Int8 => update_groups_sum_internal( @@ -328,6 +370,7 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { .unwrap(), group_indices, &mut self.sums, + &mut self.has_nulls, self.eval_mode, )?, _ => { @@ -342,28 +385,40 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { match emit_to { - // When emitting all groups, return all calculated sums and reset the internal state. EmitTo::All => { - // Create an Arrow array from the accumulated sums. - let result = Arc::new(Int64Array::from(self.sums.clone())) as ArrayRef; - // Reset the accumulator state for the next use. + // Create an Int64Array with nullability from has_nulls + let result = Arc::new(Int64Array::from_iter( + self.sums + .iter() + .zip(self.has_nulls.iter()) + .map(|(&sum, &is_null)| if is_null { None } else { Some(sum) }), + )) as ArrayRef; + self.sums.clear(); + self.has_nulls.clear(); Ok(result) } - // When emitting the first `n` groups, return the first `n` sums - // and retain the state for the remaining groups. EmitTo::First(n) => { - // Take the first `n` sums. - let result = Arc::new(Int64Array::from(self.sums.drain(..n).collect::>())) - as ArrayRef; + let result = Arc::new(Int64Array::from_iter( + self.sums + .drain(..n) + .zip(self.has_nulls.drain(..n)) + .map(|(sum, is_null)| if is_null { None } else { Some(sum) }), + )) as ArrayRef; Ok(result) } } } fn state(&mut self, _emit_to: EmitTo) -> DFResult> { - let state_array = Arc::new(Int64Array::from(self.sums.clone())); - Ok(vec![state_array]) + if self.eval_mode == EvalMode::Try { + Ok(vec![ + Arc::new(Int64Array::from(self.sums.clone())), + Arc::new(BooleanArray::from(self.has_nulls.clone())), + ]) + } else { + Ok(vec![Arc::new(Int64Array::from(self.sums.clone()))]) + } } fn merge_batch( @@ -376,27 +431,27 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { assert!(opt_filter.is_none(), "opt_filter is not supported yet"); let values = values[0].as_primitive::(); let data = values.values(); - self.sums.resize(total_num_groups, Some(0)); + self.sums.resize(total_num_groups, 0); + self.has_nulls.resize(total_num_groups, false); let iter = group_indices.iter().zip(data.iter()); for (&group_index, &value) in iter { - if self.sums[group_index].is_some() { + if !self.has_nulls[group_index] { match self.eval_mode { EvalMode::Legacy => { - self.sums[group_index] = - Some(self.sums[group_index].unwrap().add_wrapping(value)); + self.sums[group_index] = self.sums[group_index].add_wrapping(value); } EvalMode::Ansi | EvalMode::Try => { - match self.sums[group_index].unwrap().add_checked(value) { - Ok(v) => self.sums[group_index] = Some(v), + match self.sums[group_index].add_checked(value) { + Ok(v) => self.sums[group_index] = v, Err(_e) => { if self.eval_mode == EvalMode::Ansi { return Err(DataFusionError::from(arithmetic_overflow_error( "integer", ))); } else { - self.sums[group_index] = None + self.has_nulls[group_index] = true } } }; diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 4a914c25e2..e9f9255f1d 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3003,8 +3003,21 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for sum - null test group by") { + withParquetTable(Seq((null.asInstanceOf[Long], "a"), (null.asInstanceOf[Long], "b")), "tbl") { + val res = sql("SELECT sum(_1) FROM tbl group by _2") + checkSparkAnswerAndOperator(res) + } + } + + test("ANSI support for sum - null test") { + withParquetTable(Seq((null.asInstanceOf[Long], "a"), (null.asInstanceOf[Long], "b")), "tbl") { + val res = sql("SELECT sum(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + } + test("ANSI support - SUM function") { -// Test long overflow Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { // Test long overflow @@ -3040,7 +3053,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val res = sql("SELECT SUM(_1) FROM tbl") checkSparkAnswerAndOperator(res) } - // Test Int SUM (should not overflow) + // Test Short SUM (should not overflow) withParquetTable( Seq((Short.MaxValue, 1.toShort), (Short.MaxValue, 1.toShort), (100.toShort, 1.toShort)), "tbl") { @@ -3056,63 +3069,14 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator(res) } - // Test Long overflow with NULL values - withParquetTable(Seq((Long.MaxValue, 1L), (null, 1L), (100L, 1L)), "tbl") { - val res = sql("SELECT SUM(_1) FROM tbl") - if (ansiEnabled) { - checkSparkMaybeThrows(res) match { - case (Some(sparkExc), Some(cometExc)) => - assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - case _ => - fail("Exception should be thrown for Long overflow with NULLs in ANSI mode") - } - } else { - checkSparkAnswerAndOperator(res) - } - } - - // Test Long underflow with NULL values - withParquetTable(Seq((Long.MinValue, 1L), (null, 1L), (-100L, 1L)), "tbl") { - val res = sql("SELECT SUM(_1) FROM tbl") - if (ansiEnabled) { - checkSparkMaybeThrows(res) match { - case (Some(sparkExc), Some(cometExc)) => - assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - case _ => - fail("Exception should be thrown for Long underflow with NULLs in ANSI mode") - } - } else { - checkSparkAnswerAndOperator(res) - } - } - - // Test only NULL inputs - withParquetTable(Seq((null, 1L), (null, 1L), (null, 1L)), "tbl") { - val res = sql("SELECT SUM(_1) FROM tbl") - checkSparkAnswerAndOperator(res) - } - - // Test with mix of valid numbers and NULLs (no overflow) - withParquetTable(Seq((100L, 1L), (null, 1L), (200L, 1L)), "tbl") { - val res = sql("SELECT SUM(_1) FROM tbl") - checkSparkAnswerAndOperator(res) - } - - // Test with mix of valid numbers and NULLs (no overflow) - overflow - withParquetTable(Seq((Long.MaxValue, 1L), (null, 1L), (200L, 1L)), "tbl") { - val res = sql("SELECT SUM(_1) FROM tbl") - checkSparkAnswerAndOperator(res) - } } } } - test("SUM overflow - GROUP BY") { + test("ANSI support for SUM - GROUP BY") { + // Test Long overflow with GROUP BY to test GroupAccumulator with ANSI support Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - // Test Long overflow with GROUP BY to test GroupAccumulator withParquetTable( Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (200L, 2)), "tbl") { @@ -3166,56 +3130,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") checkSparkAnswerAndOperator(res) } - - // Test Long overflow with GROUP BY and NULL values - withParquetTable( - Seq((Long.MaxValue, 1), (null, 1), (100L, 1), (200L, 2), (null, 2)), - "tbl") { - val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") - if (ansiEnabled) { - checkSparkMaybeThrows(res) match { - case (Some(sparkExc), Some(cometExc)) => - assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - case _ => - fail( - "Exception should be thrown for Long overflow with GROUP BY and NULLs in ANSI mode") - } - } else { - checkSparkAnswerAndOperator(res) - } - } - - // Test Long underflow with GROUP BY and NULL values - withParquetTable( - Seq((Long.MinValue, 1), (null, 1), (-100L, 1), (-200L, 2), (null, 2)), - "tbl") { - val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") - if (ansiEnabled) { - checkSparkMaybeThrows(res) match { - case (Some(sparkExc), Some(cometExc)) => - assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - case _ => - fail( - "Exception should be thrown for Long underflow with GROUP BY and NULLs in ANSI mode") - } - } else { - checkSparkAnswerAndOperator(res) - } - } - - // Test GROUP BY with only NULL values - withParquetTable(Seq((null, 1), (null, 1), (null, 2), (null, 2)), "tbl") { - val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") - checkSparkAnswerAndOperator(res) - } - - // Test GROUP BY with mix of valid values and NULLs (no overflow) - withParquetTable(Seq((100L, 1), (null, 1), (200L, 1), (300L, 2), (null, 2)), "tbl") { - val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") - checkSparkAnswerAndOperator(res) - } } } } @@ -3224,14 +3138,14 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Test Long overflow with GROUP BY - some groups overflow, some don't withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (200L, 2), (300L, 2)), "tbl") { val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) - // Group 1 should return NULL (overflow), Group 2 should return 500 + // first group should return NULL (overflow) and group 2 should return 500 checkSparkAnswerAndOperator(res) } // Test Long underflow with GROUP BY withParquetTable(Seq((Long.MinValue, 1), (-100L, 1), (-200L, 2), (-300L, 2)), "tbl") { val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) - // Group 1 should return NULL (underflow), Group 2 should return -500 + // first group should return NULL (underflow), second group should return neg 500 checkSparkAnswerAndOperator(res) } @@ -3242,19 +3156,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator(res) } - // Test with NULL values mixed with overflow - withParquetTable(Seq((Long.MaxValue, 1), (null, 1), (100L, 2), (null, 2)), "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) - // Group 1 should return NULL (overflow), Group 2 should return 100 - checkSparkAnswerAndOperator(res) - } - - // Test Int with GROUP BY (should NOT overflow since accumulator is i64) - withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 2), (200, 2)), "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) - checkSparkAnswerAndOperator(res) - } - // Test Short with GROUP BY (should NOT overflow) withParquetTable( Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)), From d27dbe1e205839fa291a6951fad702918f9d4d3a Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 4 Nov 2025 18:46:13 -0800 Subject: [PATCH 08/18] sum_ansi_mode_checks_fix_tests --- native/spark-expr/src/agg_funcs/sum_int.rs | 363 ++++++++++++------ .../apache/comet/CometExpressionSuite.scala | 75 +++- 2 files changed, 298 insertions(+), 140 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index add96eb67f..e007d69abe 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -17,8 +17,8 @@ use crate::{arithmetic_overflow_error, EvalMode}; use arrow::array::{ - cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, - Int64Array, PrimitiveArray, + cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + BooleanArray, Int64Array, PrimitiveArray, }; use arrow::datatypes::{ ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, @@ -78,7 +78,7 @@ impl AggregateUDFImpl for SumInteger { if self.eval_mode == EvalMode::Try { Ok(vec![ Arc::new(Field::new("sum", DataType::Int64, true)), - Arc::new(Field::new("is_null", DataType::Boolean, false)), + Arc::new(Field::new("has_all_nulls", DataType::Boolean, false)), ]) } else { Ok(vec![Arc::new(Field::new("sum", DataType::Int64, true))]) @@ -99,17 +99,26 @@ impl AggregateUDFImpl for SumInteger { #[derive(Debug)] struct SumIntegerAccumulator { - sum: i64, + sum: Option, eval_mode: EvalMode, - is_null: bool, + has_all_nulls: bool, } impl SumIntegerAccumulator { fn new(eval_mode: EvalMode) -> Self { - Self { - sum: 0, - eval_mode, - is_null: false, + if eval_mode == EvalMode::Try { + Self { + // Try mode starts with 0 (because if this is init to None we cant say if it is none due to all nulls or due to an overflow + sum: Some(0), + has_all_nulls: true, // true = no non-null values yet + eval_mode, + } + } else { + Self { + sum: None, // Legacy/ANSI start with None + has_all_nulls: false, // not used for Legacy/ANSI + eval_mode, + } } } } @@ -121,13 +130,12 @@ impl Accumulator for SumIntegerAccumulator { int_array: &PrimitiveArray, eval_mode: EvalMode, mut sum: i64, - is_null: bool, - ) -> Result<(i64, bool), DataFusionError> + ) -> Result, DataFusionError> where T: ArrowPrimitiveType, { for i in 0..int_array.len() { - if !is_null && !int_array.is_null(i) { + if !int_array.is_null(i) { let v = int_array.value(i).to_i64().ok_or_else(|| { DataFusionError::Internal("Failed to convert value to i64".to_string()) })?; @@ -144,7 +152,7 @@ impl Accumulator for SumIntegerAccumulator { "integer", ))) } else { - return Ok((sum, true)); + return Ok(None); }; } }; @@ -152,70 +160,67 @@ impl Accumulator for SumIntegerAccumulator { } } } - Ok((sum, false)) + Ok(Some(sum)) } - if self.is_null { + if self.eval_mode == EvalMode::Try && !self.has_all_nulls && self.sum.is_none() { + // we saw an overflow earlier (Try eval mode). Skip processing + return Ok(()); + } + let values = &values[0]; + if values.len() == values.null_count() { Ok(()) } else { - let values = &values[0]; - if values.len() == values.null_count() { - Ok(()) - } else { - let (sum, is_overflow) = match values.data_type() { - DataType::Int64 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - self.is_null, - )?, - DataType::Int32 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - self.is_null, - )?, - DataType::Int16 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - self.is_null, - )?, - DataType::Int8 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), - self.eval_mode, - self.sum, - self.is_null, - )?, - _ => { - panic!("Unsupported data type") - } - }; - - self.sum = sum; - self.is_null = is_overflow; - Ok(()) - } + // No nulls so there should be a non-null sum. (null incase overflow in Try eval) + let running_sum = self.sum.unwrap_or(0); + let sum = match values.data_type() { + DataType::Int64 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + running_sum, + )?, + DataType::Int32 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + running_sum, + )?, + DataType::Int16 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + running_sum, + )?, + DataType::Int8 => update_sum_internal( + values + .as_any() + .downcast_ref::>() + .unwrap(), + self.eval_mode, + running_sum, + )?, + _ => { + panic!("Unsupported data type") + } + }; + self.sum = sum; + self.has_all_nulls = false; + Ok(()) } } fn evaluate(&mut self) -> DFResult { - if self.is_null { + if self.has_all_nulls { Ok(ScalarValue::Int64(None)) } else { - Ok(ScalarValue::Int64(Some(self.sum))) + Ok(ScalarValue::Int64(self.sum)) } } @@ -226,35 +231,65 @@ impl Accumulator for SumIntegerAccumulator { fn state(&mut self) -> DFResult> { if self.eval_mode == EvalMode::Try { Ok(vec![ - ScalarValue::Int64(Some(self.sum)), - ScalarValue::Boolean(Some(self.is_null)), + ScalarValue::Int64(self.sum), + ScalarValue::Boolean(Some(self.has_all_nulls)), ]) } else { - Ok(vec![ScalarValue::Int64(Some(self.sum))]) + Ok(vec![ScalarValue::Int64(self.sum)]) } } fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { - if self.is_null { - return Ok(()); - } - let that_sum = states[0].as_primitive::(); + let that_sum_array = states[0].as_primitive::(); + let that_sum = if that_sum_array.is_null(0) { + None + } else { + Some(that_sum_array.value(0)) + }; - if self.eval_mode == EvalMode::Try && states[1].as_boolean().value(0) { - return Ok(()); + // Check for overflow for early termination + if self.eval_mode == EvalMode::Try { + let that_has_all_nulls = states[1].as_boolean().value(0); + let that_overflowed = !that_has_all_nulls && that_sum.is_none(); + let this_overflowed = !self.has_all_nulls && self.sum.is_none(); + if that_overflowed || this_overflowed { + self.sum = None; + self.has_all_nulls = false; + return Ok(()); + } + self.has_all_nulls = self.has_all_nulls && that_has_all_nulls; + if that_has_all_nulls { + return Ok(()); + } + if self.has_all_nulls { + self.sum = that_sum; + return Ok(()); + } + } else { + if that_sum.is_none() { + return Ok(()); + } + if self.sum.is_none() { + self.sum = that_sum; + return Ok(()); + } } + let left = self.sum.unwrap(); + let right = that_sum.unwrap(); + match self.eval_mode { EvalMode::Legacy => { - self.sum = self.sum.add_wrapping(that_sum.value(0)); + self.sum = Some(left.add_wrapping(right)); } - EvalMode::Ansi | EvalMode::Try => match self.sum.add_checked(that_sum.value(0)) { - Ok(v) => self.sum = v, - Err(_e) => { + EvalMode::Ansi | EvalMode::Try => match left.add_checked(right) { + Ok(v) => self.sum = Some(v), + Err(_) => { if self.eval_mode == EvalMode::Ansi { return Err(DataFusionError::from(arithmetic_overflow_error("integer"))); } else { - self.is_null = true; + self.sum = None; + self.has_all_nulls = false; } } }, @@ -264,8 +299,8 @@ impl Accumulator for SumIntegerAccumulator { } struct SumIntGroupsAccumulator { - sums: Vec, - has_nulls: Vec, + sums: Vec>, + has_all_nulls: Vec, eval_mode: EvalMode, } @@ -274,7 +309,7 @@ impl SumIntGroupsAccumulator { Self { sums: Vec::new(), eval_mode, - has_nulls: Vec::new(), + has_all_nulls: Vec::new(), } } } @@ -290,8 +325,8 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { fn update_groups_sum_internal( int_array: &PrimitiveArray, group_indices: &[usize], - sums: &mut [i64], - has_nulls: &mut [bool], + sums: &mut [Option], + has_all_nulls: &mut [bool], eval_mode: EvalMode, ) -> DFResult<()> where @@ -299,29 +334,40 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { T::Native: ArrowNativeType, { for (i, &group_index) in group_indices.iter().enumerate() { - if !has_nulls[group_index] && !int_array.is_null(i) { + if !int_array.is_null(i) { + // there is an overflow in prev group in try eval . Skip processing + if eval_mode == EvalMode::Try + && !has_all_nulls[group_index] + && sums[group_index].is_none() + { + continue; + } let v = int_array.value(i).to_i64().ok_or_else(|| { DataFusionError::Internal("Failed to convert value to i64".to_string()) })?; match eval_mode { EvalMode::Legacy => { - sums[group_index] = sums[group_index].add_wrapping(v); + sums[group_index] = + Some(sums[group_index].unwrap_or(0).add_wrapping(v)); } EvalMode::Ansi | EvalMode::Try => { - match sums[group_index].add_checked(v) { - Ok(new_sum) => sums[group_index] = new_sum, + match sums[group_index].unwrap_or(0).add_checked(v) { + Ok(new_sum) => { + sums[group_index] = Some(new_sum); + } Err(_) => { if eval_mode == EvalMode::Ansi { return Err(DataFusionError::from( arithmetic_overflow_error("integer"), )); } else { - has_nulls[group_index] = true + sums[group_index] = None; } } }; } } + has_all_nulls[group_index] = false } } Ok(()) @@ -329,8 +375,13 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { assert!(opt_filter.is_none(), "opt_filter is not supported yet"); let values = &values[0]; - self.sums.resize(total_num_groups, 0); - self.has_nulls.resize(total_num_groups, false); + if self.eval_mode == EvalMode::Try { + self.sums.resize(total_num_groups, Some(0)); + self.has_all_nulls.resize(total_num_groups, true); + } else { + self.sums.resize(total_num_groups, None); + self.has_all_nulls.resize(total_num_groups, false); + } match values.data_type() { DataType::Int64 => update_groups_sum_internal( @@ -340,7 +391,7 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { .unwrap(), group_indices, &mut self.sums, - &mut self.has_nulls, + &mut self.has_all_nulls, self.eval_mode, )?, DataType::Int32 => update_groups_sum_internal( @@ -350,7 +401,7 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { .unwrap(), group_indices, &mut self.sums, - &mut self.has_nulls, + &mut self.has_all_nulls, self.eval_mode, )?, DataType::Int16 => update_groups_sum_internal( @@ -360,7 +411,7 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { .unwrap(), group_indices, &mut self.sums, - &mut self.has_nulls, + &mut self.has_all_nulls, self.eval_mode, )?, DataType::Int8 => update_groups_sum_internal( @@ -370,7 +421,7 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { .unwrap(), group_indices, &mut self.sums, - &mut self.has_nulls, + &mut self.has_all_nulls, self.eval_mode, )?, _ => { @@ -390,20 +441,20 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { let result = Arc::new(Int64Array::from_iter( self.sums .iter() - .zip(self.has_nulls.iter()) - .map(|(&sum, &is_null)| if is_null { None } else { Some(sum) }), + .zip(self.has_all_nulls.iter()) + .map(|(&sum, &is_null)| if is_null { None } else { sum }), )) as ArrayRef; self.sums.clear(); - self.has_nulls.clear(); + self.has_all_nulls.clear(); Ok(result) } EmitTo::First(n) => { let result = Arc::new(Int64Array::from_iter( self.sums .drain(..n) - .zip(self.has_nulls.drain(..n)) - .map(|(sum, is_null)| if is_null { None } else { Some(sum) }), + .zip(self.has_all_nulls.drain(..n)) + .map(|(sum, is_null)| if is_null { None } else { sum }), )) as ArrayRef; Ok(result) } @@ -414,7 +465,7 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { if self.eval_mode == EvalMode::Try { Ok(vec![ Arc::new(Int64Array::from(self.sums.clone())), - Arc::new(BooleanArray::from(self.has_nulls.clone())), + Arc::new(BooleanArray::from(self.has_all_nulls.clone())), ]) } else { Ok(vec![Arc::new(Int64Array::from(self.sums.clone()))]) @@ -429,32 +480,88 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { total_num_groups: usize, ) -> DFResult<()> { assert!(opt_filter.is_none(), "opt_filter is not supported yet"); - let values = values[0].as_primitive::(); - let data = values.values(); - self.sums.resize(total_num_groups, 0); - self.has_nulls.resize(total_num_groups, false); - - let iter = group_indices.iter().zip(data.iter()); - - for (&group_index, &value) in iter { - if !self.has_nulls[group_index] { - match self.eval_mode { - EvalMode::Legacy => { - self.sums[group_index] = self.sums[group_index].add_wrapping(value); - } - EvalMode::Ansi | EvalMode::Try => { - match self.sums[group_index].add_checked(value) { - Ok(v) => self.sums[group_index] = v, - Err(_e) => { - if self.eval_mode == EvalMode::Ansi { - return Err(DataFusionError::from(arithmetic_overflow_error( - "integer", - ))); - } else { - self.has_nulls[group_index] = true - } + + // Extract incoming sums array + let that_sums = values[0].as_primitive::(); + + if self.eval_mode == EvalMode::Try { + self.sums.resize(total_num_groups, Some(0)); + self.has_all_nulls.resize(total_num_groups, true); + } else { + self.sums.resize(total_num_groups, None); + self.has_all_nulls.resize(total_num_groups, false); + } + + let that_sums_is_all_nulls = if self.eval_mode == EvalMode::Try { + Some(values[1].as_boolean()) + } else { + None + }; + + for (idx, &group_index) in group_indices.iter().enumerate() { + // Extract incoming sum value (handle nulls) + let that_sum = if that_sums.is_null(idx) { + None + } else { + Some(that_sums.value(idx)) + }; + + if self.eval_mode == EvalMode::Try { + let that_has_all_nulls = that_sums_is_all_nulls.unwrap().value(idx); + + let that_overflowed = !that_has_all_nulls && that_sum.is_none(); + let this_overflowed = + !self.has_all_nulls[group_index] && self.sums[group_index].is_none(); + + if that_overflowed || this_overflowed { + self.sums[group_index] = None; + self.has_all_nulls[group_index] = false; + continue; + } + + self.has_all_nulls[group_index] = + self.has_all_nulls[group_index] && that_has_all_nulls; + + if that_has_all_nulls { + continue; + } + + if self.has_all_nulls[group_index] { + self.sums[group_index] = that_sum; + continue; + } + } else { + if that_sum.is_none() { + continue; + } + if self.sums[group_index].is_none() { + self.sums[group_index] = that_sum; + continue; + } + } + + // Both sides have non-null. Update sums now + let left = self.sums[group_index].unwrap(); + let right = that_sum.unwrap(); + + match self.eval_mode { + EvalMode::Legacy => { + self.sums[group_index] = Some(left.add_wrapping(right)); + } + EvalMode::Ansi | EvalMode::Try => { + match left.add_checked(right) { + Ok(v) => self.sums[group_index] = Some(v), + Err(_) => { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error( + "integer", + ))); + } else { + // overflow . update flag accordingly + self.sums[group_index] = None; + self.has_all_nulls[group_index] = false; } - }; + } } } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index e9f9255f1d..2b2b33bbff 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3003,17 +3003,69 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("ANSI support for sum - null test group by") { - withParquetTable(Seq((null.asInstanceOf[Long], "a"), (null.asInstanceOf[Long], "b")), "tbl") { - val res = sql("SELECT sum(_1) FROM tbl group by _2") - checkSparkAnswerAndOperator(res) + test("ANSI support for sum - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), + "null_tbl") { + val res = sql("SELECT sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } } } - test("ANSI support for sum - null test") { - withParquetTable(Seq((null.asInstanceOf[Long], "a"), (null.asInstanceOf[Long], "b")), "tbl") { - val res = sql("SELECT sum(_1) FROM tbl") - checkSparkAnswerAndOperator(res) + test("ANSI support for try_sum - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), + "null_tbl") { + val res = sql("SELECT try_sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } + } + } + + test("ANSI support for sum - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b")), + "tbl") { + val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row("a", null), Row("b", null))) + } + } + } + } + + test("ANSI support for try_sum - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b")), + "tbl") { + val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row("a", null), Row("b", null))) + } + } } } @@ -3034,7 +3086,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator(res) } } - // Test long underflow + // Test long underflow withParquetTable(Seq((Long.MinValue, 1L), (-100L, 1L)), "tbl") { val res = sql("SELECT SUM(_1) FROM tbl") if (ansiEnabled) { @@ -3068,7 +3120,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val res = sql("SELECT SUM(_1) FROM tbl") checkSparkAnswerAndOperator(res) } - } } } @@ -3135,7 +3186,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("try_sum overflow - with GROUP BY") { - // Test Long overflow with GROUP BY - some groups overflow, some don't + // Test Long overflow with GROUP BY - some groups overflow while some don't withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (200L, 2), (300L, 2)), "tbl") { val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) // first group should return NULL (overflow) and group 2 should return 500 @@ -3164,7 +3215,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator(res) } - // Test Byte with GROUP BY (should NOT overflow) + // Test Byte with GROUP BY (no overflow) withParquetTable( Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)), "tbl") { From 534c111fff78951d9c919dd260c8542a35ef9c02 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 4 Nov 2025 18:49:26 -0800 Subject: [PATCH 09/18] sum_ansi_mode_checks_fix_tests --- native/spark-expr/src/agg_funcs/sum_int.rs | 99 ---------------------- 1 file changed, 99 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index e007d69abe..2c16901b31 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -573,102 +573,3 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { size_of_val(self) } } - -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::builder::StringBuilder; - use arrow::array::{Int64Builder, RecordBatch}; - use arrow::datatypes::DataType::Int64; - use arrow::datatypes::*; - use datafusion::common::Result; - use datafusion::datasource::memory::MemorySourceConfig; - use datafusion::datasource::source::DataSourceExec; - use datafusion::execution::TaskContext; - use datafusion::logical_expr::AggregateUDF; - use datafusion::physical_expr::aggregate::AggregateExprBuilder; - use datafusion::physical_expr::expressions::Column; - use datafusion::physical_expr::PhysicalExpr; - use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; - use datafusion::physical_plan::ExecutionPlan; - use futures::StreamExt; - - #[test] - fn invalid_data_type() { - assert!(SumInteger::try_new(DataType::Date32, EvalMode::Legacy).is_err()); - } - - #[tokio::test] - async fn sum_no_overflow() -> Result<()> { - let num_rows = 8192; - let batch = create_record_batch(num_rows); - let mut batches = Vec::new(); - for _ in 0..10 { - batches.push(batch.clone()); - } - let partitions = &[batches]; - let c0: Arc = Arc::new(Column::new("c0", 0)); - let c1: Arc = Arc::new(Column::new("c1", 1)); - - let data_type = Int64; - let schema = Arc::clone(&partitions[0][0].schema()); - let scan: Arc = Arc::new(DataSourceExec::new(Arc::new( - MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None)?, - ))); - - let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumInteger::try_new( - data_type.clone(), - EvalMode::Legacy, - )?)); - - let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) - .schema(Arc::clone(&schema)) - .alias("sum") - .with_ignore_nulls(false) - .with_distinct(false) - .build()?; - - let aggregate = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]), - vec![aggr_expr.into()], - vec![None], // no filter expressions - scan, - Arc::clone(&schema), - )?); - - let mut stream = aggregate - .execute(0, Arc::new(TaskContext::default())) - .unwrap(); - while let Some(batch) = stream.next().await { - let _batch = batch?; - } - - Ok(()) - } - - fn create_record_batch(num_rows: usize) -> RecordBatch { - let mut int_builder = Int64Builder::with_capacity(num_rows); - let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); - for i in 0..num_rows { - int_builder.append_value(i as i64); - string_builder.append_value(format!("this is string #{}", i % 1024)); - } - let int_array = Arc::new(int_builder.finish()); - let string_array = Arc::new(string_builder.finish()); - - let mut fields = vec![]; - let mut columns: Vec = vec![]; - - // string column - fields.push(Field::new("c0", DataType::Utf8, false)); - columns.push(string_array); - - // decimal column - fields.push(Field::new("c1", DataType::Int64, false)); - columns.push(int_array); - - let schema = Schema::new(fields); - RecordBatch::try_new(Arc::new(schema), columns).unwrap() - } -} From 910a13fdda0e6cf5482016cb231e8876ea107b85 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 5 Nov 2025 10:55:17 -0800 Subject: [PATCH 10/18] sum_ansi_mode_checks_fix_tests --- native/spark-expr/src/agg_funcs/sum_int.rs | 4 ++-- .../test/scala/org/apache/comet/CometExpressionSuite.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index 2c16901b31..8136acb60d 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -17,8 +17,8 @@ use crate::{arithmetic_overflow_error, EvalMode}; use arrow::array::{ - cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, - BooleanArray, Int64Array, PrimitiveArray, + cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, + Int64Array, PrimitiveArray, }; use arrow::datatypes::{ ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 2b2b33bbff..9e54b7d093 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3044,7 +3044,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { "tbl") { val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") checkSparkAnswerAndOperator(res) - assert(res.collect() === Array(Row("a", null), Row("b", null))) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) } } } @@ -3063,7 +3063,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { "tbl") { val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") checkSparkAnswerAndOperator(res) - assert(res.collect() === Array(Row("a", null), Row("b", null))) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) } } } From 8fc3fbdb057fd42c6fe2f14f8c29dcc2f312a38b Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 5 Nov 2025 11:44:47 -0800 Subject: [PATCH 11/18] sum_ansi_mode_checks_fix_tests --- native/spark-expr/src/agg_funcs/sum_int.rs | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index 8136acb60d..5e031156fa 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -33,9 +33,7 @@ use std::{any::Any, sync::Arc}; #[derive(Debug, PartialEq, Eq, Hash)] pub struct SumInteger { - /// Aggregate function signature signature: Signature, - /// eval mode : ANSI, Legacy, Try eval_mode: EvalMode, } @@ -110,13 +108,13 @@ impl SumIntegerAccumulator { Self { // Try mode starts with 0 (because if this is init to None we cant say if it is none due to all nulls or due to an overflow sum: Some(0), - has_all_nulls: true, // true = no non-null values yet + has_all_nulls: true, eval_mode, } } else { Self { - sum: None, // Legacy/ANSI start with None - has_all_nulls: false, // not used for Legacy/ANSI + sum: None, + has_all_nulls: false, eval_mode, } } @@ -125,7 +123,7 @@ impl SumIntegerAccumulator { impl Accumulator for SumIntegerAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { - // accumulator internal to add sum and return is_null: true if there is an overflow in Try Eval mode + // accumulator internal to add sum and return null sum (and has_nulls false) if there is an overflow in Try Eval mode fn update_sum_internal( int_array: &PrimitiveArray, eval_mode: EvalMode, @@ -171,7 +169,7 @@ impl Accumulator for SumIntegerAccumulator { if values.len() == values.null_count() { Ok(()) } else { - // No nulls so there should be a non-null sum. (null incase overflow in Try eval) + // No nulls so there should be a non-null sum / null incase overflow in Try eval let running_sum = self.sum.unwrap_or(0); let sum = match values.data_type() { DataType::Int64 => update_sum_internal( @@ -207,7 +205,7 @@ impl Accumulator for SumIntegerAccumulator { running_sum, )?, _ => { - panic!("Unsupported data type") + panic!("Unsupported data type {}", values.data_type()) } }; self.sum = sum; @@ -335,7 +333,7 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { { for (i, &group_index) in group_indices.iter().enumerate() { if !int_array.is_null(i) { - // there is an overflow in prev group in try eval . Skip processing + // there is an overflow in prev group in try eval. Skip processing if eval_mode == EvalMode::Try && !has_all_nulls[group_index] && sums[group_index].is_none() @@ -437,7 +435,6 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { match emit_to { EmitTo::All => { - // Create an Int64Array with nullability from has_nulls let result = Arc::new(Int64Array::from_iter( self.sums .iter() @@ -481,7 +478,6 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { ) -> DFResult<()> { assert!(opt_filter.is_none(), "opt_filter is not supported yet"); - // Extract incoming sums array let that_sums = values[0].as_primitive::(); if self.eval_mode == EvalMode::Try { @@ -499,7 +495,6 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { }; for (idx, &group_index) in group_indices.iter().enumerate() { - // Extract incoming sum value (handle nulls) let that_sum = if that_sums.is_null(idx) { None } else { @@ -557,7 +552,7 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { "integer", ))); } else { - // overflow . update flag accordingly + // overflow. update flag accordingly self.sums[group_index] = None; self.has_all_nulls[group_index] = false; } From 0914d5d077e23ae374de56298d3ac606db364426 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 6 Nov 2025 17:15:08 -0800 Subject: [PATCH 12/18] sum_ansi_mode_checks_fix_tests --- native/spark-expr/src/agg_funcs/sum_int.rs | 6 +++++- .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index 5e031156fa..e159fc59cf 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -27,7 +27,7 @@ use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::Volatility::Immutable; use datafusion::logical_expr::{ - Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, }; use std::{any::Any, sync::Arc}; @@ -93,6 +93,10 @@ impl AggregateUDFImpl for SumInteger { ) -> DFResult> { Ok(Box::new(SumIntGroupsAccumulator::new(self.eval_mode))) } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } } #[derive(Debug)] diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index edae4453a7..cc0efd85d5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -55,6 +55,8 @@ import org.apache.comet.shims.CometExprShim */ object QueryPlanSerde extends Logging with CometExprShim { + val integerTypes = Seq(ByteType, ShortType, IntegerType, LongType) + /** * Mapping of Spark operator class to Comet operator handler. */ @@ -419,7 +421,7 @@ object QueryPlanSerde extends Logging with CometExprShim { } case s: Sum => if (AggSerde.sumDataTypeSupported(s.dataType) && !s.dataType - .isInstanceOf[DecimalType]) { + .isInstanceOf[DecimalType] && !integerTypes.contains(s.dataType)) { Some(agg) } else { withInfo(windowExpr, s"datatype ${s.dataType} is not supported", expr) From aa01b8429e7f052e0c8362af71824adb38f810b1 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 6 Nov 2025 19:08:28 -0800 Subject: [PATCH 13/18] sum_ansi_mode_checks_fix_tests --- native/spark-expr/src/agg_funcs/sum_int.rs | 43 +++++++++---------- .../apache/comet/serde/QueryPlanSerde.scala | 2 +- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index e159fc59cf..92156b629e 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -138,9 +138,7 @@ impl Accumulator for SumIntegerAccumulator { { for i in 0..int_array.len() { if !int_array.is_null(i) { - let v = int_array.value(i).to_i64().ok_or_else(|| { - DataFusionError::Internal("Failed to convert value to i64".to_string()) - })?; + let v = int_array.value(i).to_i64().unwrap(); match eval_mode { EvalMode::Legacy => { sum = v.add_wrapping(sum); @@ -209,7 +207,10 @@ impl Accumulator for SumIntegerAccumulator { running_sum, )?, _ => { - panic!("Unsupported data type {}", values.data_type()) + return Err(DataFusionError::Internal(format!( + "unsupported data type: {:?}", + values.data_type() + ))); } }; self.sum = sum; @@ -227,7 +228,7 @@ impl Accumulator for SumIntegerAccumulator { } fn size(&self) -> usize { - size_of_val(self) + std::mem::size_of_val(self) } fn state(&mut self) -> DFResult> { @@ -314,6 +315,16 @@ impl SumIntGroupsAccumulator { has_all_nulls: Vec::new(), } } + + fn resize_helper(&mut self, total_num_groups: usize) { + if self.eval_mode == EvalMode::Try { + self.sums.resize(total_num_groups, Some(0)); + self.has_all_nulls.resize(total_num_groups, true); + } else { + self.sums.resize(total_num_groups, None); + self.has_all_nulls.resize(total_num_groups, false); + } + } } impl GroupsAccumulator for SumIntGroupsAccumulator { @@ -375,15 +386,9 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { Ok(()) } - assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); let values = &values[0]; - if self.eval_mode == EvalMode::Try { - self.sums.resize(total_num_groups, Some(0)); - self.has_all_nulls.resize(total_num_groups, true); - } else { - self.sums.resize(total_num_groups, None); - self.has_all_nulls.resize(total_num_groups, false); - } + self.resize_helper(total_num_groups); match values.data_type() { DataType::Int64 => update_groups_sum_internal( @@ -480,17 +485,11 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> DFResult<()> { - assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet"); let that_sums = values[0].as_primitive::(); - if self.eval_mode == EvalMode::Try { - self.sums.resize(total_num_groups, Some(0)); - self.has_all_nulls.resize(total_num_groups, true); - } else { - self.sums.resize(total_num_groups, None); - self.has_all_nulls.resize(total_num_groups, false); - } + self.resize_helper(total_num_groups); let that_sums_is_all_nulls = if self.eval_mode == EvalMode::Try { Some(values[1].as_boolean()) @@ -569,6 +568,6 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { } fn size(&self) -> usize { - size_of_val(self) + std::mem::size_of_val(self) } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index cc0efd85d5..59a38dc46c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -55,7 +55,7 @@ import org.apache.comet.shims.CometExprShim */ object QueryPlanSerde extends Logging with CometExprShim { - val integerTypes = Seq(ByteType, ShortType, IntegerType, LongType) + private val integerTypes = Seq(ByteType, ShortType, IntegerType, LongType) /** * Mapping of Spark operator class to Comet operator handler. From 008d85d5202314dabaacc8b05aa2ba3ec9024b0b Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 6 Nov 2025 19:24:50 -0800 Subject: [PATCH 14/18] sum_ansi_mode_checks_fix_tests --- native/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 native/.DS_Store diff --git a/native/.DS_Store b/native/.DS_Store deleted file mode 100644 index 0a67b21dd36a455d40c819eb79aa5509ed690c03..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKJ5EDE3>=db5ouDU+zU`}gH;qxkP8qAL;(sY2t@fR&XuDv{s<*{p`f5aW67Rf zuV4Ifrdwwh2Zp3eJQl*4+W zq7;w Date: Tue, 11 Nov 2025 14:00:56 -0800 Subject: [PATCH 15/18] sum_ansi_mode_checks_fix_tests_rebase_main --- native/spark-expr/src/agg_funcs/sum_int.rs | 4 ++-- .../scala/org/apache/comet/CometExpressionSuite.scala | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index 92156b629e..b1c0741fa7 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -110,7 +110,7 @@ impl SumIntegerAccumulator { fn new(eval_mode: EvalMode) -> Self { if eval_mode == EvalMode::Try { Self { - // Try mode starts with 0 (because if this is init to None we cant say if it is none due to all nulls or due to an overflow + // Try mode starts with 0 (because if this is init to None we cant say if it is none due to all nulls or due to an overflow) sum: Some(0), has_all_nulls: true, eval_mode, @@ -152,7 +152,7 @@ impl Accumulator for SumIntegerAccumulator { "integer", ))) } else { - return Ok(None); + Ok(None) }; } }; diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 9e54b7d093..42b567e185 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3061,7 +3061,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { (null.asInstanceOf[java.lang.Long], "b"), (null.asInstanceOf[java.lang.Long], "b")), "tbl") { - val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") + val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1") checkSparkAnswerAndOperator(res) assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) } @@ -3076,7 +3076,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { withParquetTable(Seq((Long.MaxValue, 1L), (100L, 1L)), "tbl") { val res = sql("SELECT SUM(_1) FROM tbl") if (ansiEnabled) { - checkSparkMaybeThrows(res) match { + checkSparkAnswerMaybeThrows(res) match { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) @@ -3090,7 +3090,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { withParquetTable(Seq((Long.MinValue, 1L), (-100L, 1L)), "tbl") { val res = sql("SELECT SUM(_1) FROM tbl") if (ansiEnabled) { - checkSparkMaybeThrows(res) match { + checkSparkAnswerMaybeThrows(res) match { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) @@ -3133,7 +3133,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { "tbl") { val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2) if (ansiEnabled) { - checkSparkMaybeThrows(res) match { + checkSparkAnswerMaybeThrows(res) match { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) @@ -3150,7 +3150,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { "tbl") { val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") if (ansiEnabled) { - checkSparkMaybeThrows(res) match { + checkSparkAnswerMaybeThrows(res) match { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) From 607c9deec35f7ce6d9e7fc4e5b804ba5d59eeed2 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 12 Nov 2025 15:58:56 -0800 Subject: [PATCH 16/18] sum_ansi_mode_checks_fix_tests_rebase_main --- native/spark-expr/src/agg_funcs/sum_int.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index b1c0741fa7..9bb90d8dc9 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -467,14 +467,17 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { } } - fn state(&mut self, _emit_to: EmitTo) -> DFResult> { + fn state(&mut self, emit_to: EmitTo) -> DFResult> { + let sums = emit_to.take_needed(&mut self.sums); + if self.eval_mode == EvalMode::Try { + let has_all_nulls = emit_to.take_needed(&mut self.has_all_nulls); Ok(vec![ - Arc::new(Int64Array::from(self.sums.clone())), - Arc::new(BooleanArray::from(self.has_all_nulls.clone())), + Arc::new(Int64Array::from(sums)), + Arc::new(BooleanArray::from(has_all_nulls)), ]) } else { - Ok(vec![Arc::new(Int64Array::from(self.sums.clone()))]) + Ok(vec![Arc::new(Int64Array::from(sums))]) } } From 5f3464b61be10846f818f87d77087fc25f3c68cb Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 20 Nov 2025 14:39:51 -0800 Subject: [PATCH 17/18] sum_ansi_mode_checks_fix_tests_rebase_main --- native/spark-expr/src/agg_funcs/sum_int.rs | 64 +++-- .../apache/comet/serde/QueryPlanSerde.scala | 2 - .../org/apache/comet/serde/aggregates.scala | 11 - .../apache/comet/CometExpressionSuite.scala | 221 ----------------- .../comet/exec/CometAggregateSuite.scala | 223 +++++++++++++++++- .../sql/comet/CometPlanStabilitySuite.scala | 1 - 6 files changed, 249 insertions(+), 273 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index 9bb90d8dc9..7d36b012a4 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -17,8 +17,8 @@ use crate::{arithmetic_overflow_error, EvalMode}; use arrow::array::{ - cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, BooleanArray, - Int64Array, PrimitiveArray, + as_primitive_array, cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + BooleanArray, Int64Array, PrimitiveArray, }; use arrow::datatypes::{ ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, @@ -138,7 +138,12 @@ impl Accumulator for SumIntegerAccumulator { { for i in 0..int_array.len() { if !int_array.is_null(i) { - let v = int_array.value(i).to_i64().unwrap(); + let v = int_array.value(i).to_i64().ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to convert value {:?} to i64", + int_array.value(i) + )) + })?; match eval_mode { EvalMode::Legacy => { sum = v.add_wrapping(sum); @@ -175,34 +180,22 @@ impl Accumulator for SumIntegerAccumulator { let running_sum = self.sum.unwrap_or(0); let sum = match values.data_type() { DataType::Int64 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), + as_primitive_array::(values), self.eval_mode, running_sum, )?, DataType::Int32 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), + as_primitive_array::(values), self.eval_mode, running_sum, )?, DataType::Int16 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), + as_primitive_array::(values), self.eval_mode, running_sum, )?, DataType::Int8 => update_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), + as_primitive_array::(values), self.eval_mode, running_sum, )?, @@ -278,8 +271,17 @@ impl Accumulator for SumIntegerAccumulator { } } - let left = self.sum.unwrap(); - let right = that_sum.unwrap(); + // safe to unwrap (since we checked nulls above) but handling error just in case state is corrupt + let left = self.sum.ok_or_else(|| { + DataFusionError::Internal( + "Invalid state in merging batch. Current batch's is None".to_string(), + ) + })?; + let right = that_sum.ok_or_else(|| { + DataFusionError::Internal( + "Invalid state in merging batch. Incoming sum to is None".to_string(), + ) + })?; match self.eval_mode { EvalMode::Legacy => { @@ -392,40 +394,28 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { match values.data_type() { DataType::Int64 => update_groups_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), + as_primitive_array::(values), group_indices, &mut self.sums, &mut self.has_all_nulls, self.eval_mode, )?, DataType::Int32 => update_groups_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), + as_primitive_array::(values), group_indices, &mut self.sums, &mut self.has_all_nulls, self.eval_mode, )?, DataType::Int16 => update_groups_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), + as_primitive_array::(values), group_indices, &mut self.sums, &mut self.has_all_nulls, self.eval_mode, )?, DataType::Int8 => update_groups_sum_internal( - values - .as_any() - .downcast_ref::>() - .unwrap(), + as_primitive_array::(values), group_indices, &mut self.sums, &mut self.has_all_nulls, diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index deca8121cf..54df2f1688 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -44,8 +44,6 @@ import org.apache.comet.shims.CometExprShim */ object QueryPlanSerde extends Logging with CometExprShim { - private val integerTypes = Seq(ByteType, ShortType, IntegerType, LongType) - private val arrayExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[ArrayAppend] -> CometArrayAppend, classOf[ArrayCompact] -> CometArrayCompact, diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 2e0e8ee34b..a05efaebbc 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -213,17 +213,6 @@ object CometAverage extends CometAggregateExpressionSerde[Average] { object CometSum extends CometAggregateExpressionSerde[Sum] { - override def getSupportLevel(sum: Sum): SupportLevel = { - sum.evalMode match { - case EvalMode.ANSI => - Incompatible(Some("ANSI mode is not supported")) - case EvalMode.TRY => - Incompatible(Some("TRY mode is not supported")) - case _ => - Compatible() - } - } - override def convert( aggExpr: AggregateExpression, sum: Sum, diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index c5f892e7e8..b0c718a2b6 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2998,227 +2998,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("ANSI support for sum - null test") { - Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - withParquetTable( - Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), - "null_tbl") { - val res = sql("SELECT sum(_1) FROM null_tbl") - checkSparkAnswerAndOperator(res) - assert(res.collect() === Array(Row(null))) - } - } - } - } - - test("ANSI support for try_sum - null test") { - Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - withParquetTable( - Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), - "null_tbl") { - val res = sql("SELECT try_sum(_1) FROM null_tbl") - checkSparkAnswerAndOperator(res) - assert(res.collect() === Array(Row(null))) - } - } - } - } - - test("ANSI support for sum - null test (group by)") { - Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - withParquetTable( - Seq( - (null.asInstanceOf[java.lang.Long], "a"), - (null.asInstanceOf[java.lang.Long], "a"), - (null.asInstanceOf[java.lang.Long], "b"), - (null.asInstanceOf[java.lang.Long], "b"), - (null.asInstanceOf[java.lang.Long], "b")), - "tbl") { - val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") - checkSparkAnswerAndOperator(res) - assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) - } - } - } - } - - test("ANSI support for try_sum - null test (group by)") { - Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - withParquetTable( - Seq( - (null.asInstanceOf[java.lang.Long], "a"), - (null.asInstanceOf[java.lang.Long], "a"), - (null.asInstanceOf[java.lang.Long], "b"), - (null.asInstanceOf[java.lang.Long], "b"), - (null.asInstanceOf[java.lang.Long], "b")), - "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1") - checkSparkAnswerAndOperator(res) - assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) - } - } - } - } - - test("ANSI support - SUM function") { - Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - // Test long overflow - withParquetTable(Seq((Long.MaxValue, 1L), (100L, 1L)), "tbl") { - val res = sql("SELECT SUM(_1) FROM tbl") - if (ansiEnabled) { - checkSparkAnswerMaybeThrows(res) match { - case (Some(sparkExc), Some(cometExc)) => - assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - case _ => fail("Exception should be thrown for Long overflow in ANSI mode") - } - } else { - checkSparkAnswerAndOperator(res) - } - } - // Test long underflow - withParquetTable(Seq((Long.MinValue, 1L), (-100L, 1L)), "tbl") { - val res = sql("SELECT SUM(_1) FROM tbl") - if (ansiEnabled) { - checkSparkAnswerMaybeThrows(res) match { - case (Some(sparkExc), Some(cometExc)) => - assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - case _ => fail("Exception should be thrown for Long underflow in ANSI mode") - } - } else { - checkSparkAnswerAndOperator(res) - } - } - // Test Int SUM (should not overflow) - withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 1)), "tbl") { - val res = sql("SELECT SUM(_1) FROM tbl") - checkSparkAnswerAndOperator(res) - } - // Test Short SUM (should not overflow) - withParquetTable( - Seq((Short.MaxValue, 1.toShort), (Short.MaxValue, 1.toShort), (100.toShort, 1.toShort)), - "tbl") { - val res = sql("SELECT SUM(_1) FROM tbl") - checkSparkAnswerAndOperator(res) - } - - // Test Byte SUM (should not overflow) - withParquetTable( - Seq((Byte.MaxValue, 1.toByte), (Byte.MaxValue, 1.toByte), (10.toByte, 1.toByte)), - "tbl") { - val res = sql("SELECT SUM(_1) FROM tbl") - checkSparkAnswerAndOperator(res) - } - } - } - } - - test("ANSI support for SUM - GROUP BY") { - // Test Long overflow with GROUP BY to test GroupAccumulator with ANSI support - Seq(true, false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { - withParquetTable( - Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (200L, 2)), - "tbl") { - val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2) - if (ansiEnabled) { - checkSparkAnswerMaybeThrows(res) match { - case (Some(sparkExc), Some(cometExc)) => - assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - case _ => - fail("Exception should be thrown for Long overflow with GROUP BY in ANSI mode") - } - } else { - checkSparkAnswerAndOperator(res) - } - } - - withParquetTable( - Seq((Long.MinValue, 1), (-100L, 1), (Long.MinValue, 2), (-200L, 2)), - "tbl") { - val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") - if (ansiEnabled) { - checkSparkAnswerMaybeThrows(res) match { - case (Some(sparkExc), Some(cometExc)) => - assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - case _ => - fail("Exception should be thrown for Long underflow with GROUP BY in ANSI mode") - } - } else { - checkSparkAnswerAndOperator(res) - } - } - // Test Int with GROUP BY - withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 2), (200, 2)), "tbl") { - val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") - checkSparkAnswerAndOperator(res) - } - // Test Short with GROUP BY - withParquetTable( - Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)), - "tbl") { - val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") - checkSparkAnswerAndOperator(res) - } - - // Test Byte with GROUP BY - withParquetTable( - Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)), - "tbl") { - val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") - checkSparkAnswerAndOperator(res) - } - } - } - } - - test("try_sum overflow - with GROUP BY") { - // Test Long overflow with GROUP BY - some groups overflow while some don't - withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (200L, 2), (300L, 2)), "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) - // first group should return NULL (overflow) and group 2 should return 500 - checkSparkAnswerAndOperator(res) - } - - // Test Long underflow with GROUP BY - withParquetTable(Seq((Long.MinValue, 1), (-100L, 1), (-200L, 2), (-300L, 2)), "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) - // first group should return NULL (underflow), second group should return neg 500 - checkSparkAnswerAndOperator(res) - } - - // Test all groups overflow - withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (100L, 2)), "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) - // Both groups should return NULL - checkSparkAnswerAndOperator(res) - } - - // Test Short with GROUP BY (should NOT overflow) - withParquetTable( - Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)), - "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) - checkSparkAnswerAndOperator(res) - } - - // Test Byte with GROUP BY (no overflow) - withParquetTable( - Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)), - "tbl") { - val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) - checkSparkAnswerAndOperator(res) - } - } - test("test integral divide overflow for decimal") { if (isSpark40Plus) { Seq(true, false) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 7e577c5fda..06f0c09703 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.optimizer.EliminateSorts import org.apache.spark.sql.comet.CometHashAggregateExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.{avg, count_distinct, sum} +import org.apache.spark.sql.functions.{avg, col, count_distinct, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataTypes, StructField, StructType} @@ -1471,6 +1471,227 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for sum - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), + "null_tbl") { + val res = sql("SELECT sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } + } + } + + test("ANSI support for try_sum - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq((null.asInstanceOf[java.lang.Long], "a"), (null.asInstanceOf[java.lang.Long], "b")), + "null_tbl") { + val res = sql("SELECT try_sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } + } + } + + test("ANSI support for sum - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b")), + "tbl") { + val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) + } + } + } + } + + test("ANSI support for try_sum - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq( + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "a"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b"), + (null.asInstanceOf[java.lang.Long], "b")), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) + } + } + } + } + + test("ANSI support - SUM function") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + // Test long overflow + withParquetTable(Seq((Long.MaxValue, 1L), (100L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => fail("Exception should be thrown for Long overflow in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test long underflow + withParquetTable(Seq((Long.MinValue, 1L), (-100L, 1L)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => fail("Exception should be thrown for Long underflow in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test Int SUM (should not overflow) + withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 1)), "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + // Test Short SUM (should not overflow) + withParquetTable( + Seq((Short.MaxValue, 1.toShort), (Short.MaxValue, 1.toShort), (100.toShort, 1.toShort)), + "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + + // Test Byte SUM (should not overflow) + withParquetTable( + Seq((Byte.MaxValue, 1.toByte), (Byte.MaxValue, 1.toByte), (10.toByte, 1.toByte)), + "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + } + } + } + + test("ANSI support for SUM - GROUP BY") { + // Test Long overflow with GROUP BY to test GroupAccumulator with ANSI support + Seq(true, false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { + withParquetTable( + Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (200L, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2) + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for Long overflow with GROUP BY in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + + withParquetTable( + Seq((Long.MinValue, 1), (-100L, 1), (Long.MinValue, 2), (-200L, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for Long underflow with GROUP BY in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + // Test Int with GROUP BY + withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 2), (200, 2)), "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + // Test Short with GROUP BY + withParquetTable( + Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + + // Test Byte with GROUP BY + withParquetTable( + Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)), + "tbl") { + val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + } + } + } + + test("try_sum overflow - with GROUP BY") { + // Test Long overflow with GROUP BY - some groups overflow while some don't + withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (200L, 2), (300L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // first group should return NULL (overflow) and group 2 should return 500 + checkSparkAnswerAndOperator(res) + } + + // Test Long underflow with GROUP BY + withParquetTable(Seq((Long.MinValue, 1), (-100L, 1), (-200L, 2), (-300L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // first group should return NULL (underflow), second group should return neg 500 + checkSparkAnswerAndOperator(res) + } + + // Test all groups overflow + withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (100L, 2)), "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + // Both groups should return NULL + checkSparkAnswerAndOperator(res) + } + + // Test Short with GROUP BY (should NOT overflow) + withParquetTable( + Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), (200.toShort, 2)), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } + + // Test Byte with GROUP BY (no overflow) + withParquetTable( + Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 2)), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index 8f260e2ca8..d852d5f8b6 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -226,7 +226,6 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true", // Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true", - CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true", // as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64 CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") { From 3f9aff7d0ebbfe03f209e377a99a3d54e92a2c9a Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 20 Nov 2025 18:08:47 -0800 Subject: [PATCH 18/18] sum_ansi_mode_checks_fix_tests_rebase_main --- native/spark-expr/src/agg_funcs/sum_int.rs | 10 ++++------ .../spark/sql/comet/CometPlanStabilitySuite.scala | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs b/native/spark-expr/src/agg_funcs/sum_int.rs index 7d36b012a4..af56c55fdd 100644 --- a/native/spark-expr/src/agg_funcs/sum_int.rs +++ b/native/spark-expr/src/agg_funcs/sum_int.rs @@ -253,12 +253,12 @@ impl Accumulator for SumIntegerAccumulator { self.has_all_nulls = false; return Ok(()); } - self.has_all_nulls = self.has_all_nulls && that_has_all_nulls; if that_has_all_nulls { return Ok(()); } if self.has_all_nulls { self.sum = that_sum; + self.has_all_nulls = false; return Ok(()); } } else { @@ -274,12 +274,12 @@ impl Accumulator for SumIntegerAccumulator { // safe to unwrap (since we checked nulls above) but handling error just in case state is corrupt let left = self.sum.ok_or_else(|| { DataFusionError::Internal( - "Invalid state in merging batch. Current batch's is None".to_string(), + "Invalid state in merging batch. Current batch's sum is None".to_string(), ) })?; let right = that_sum.ok_or_else(|| { DataFusionError::Internal( - "Invalid state in merging batch. Incoming sum to is None".to_string(), + "Invalid state in merging batch. Incoming sum is None".to_string(), ) })?; @@ -510,15 +510,13 @@ impl GroupsAccumulator for SumIntGroupsAccumulator { continue; } - self.has_all_nulls[group_index] = - self.has_all_nulls[group_index] && that_has_all_nulls; - if that_has_all_nulls { continue; } if self.has_all_nulls[group_index] { self.sums[group_index] = that_sum; + self.has_all_nulls[group_index] = false; continue; } } else { diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index d852d5f8b6..1728ce5b27 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkContext import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.sql.TPCDSBase import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast} -import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.Average import org.apache.spark.sql.catalyst.util.resourceToString import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite