Skip to content

Commit e89f553

Browse files
authored
Fix: Fix null handling in CometVector implementations (#2643)
1 parent 937cacd commit e89f553

File tree

6 files changed

+152
-71
lines changed

6 files changed

+152
-71
lines changed

common/src/main/java/org/apache/comet/vector/CometListVector.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public CometListVector(
4545

4646
@Override
4747
public ColumnarArray getArray(int i) {
48+
if (isNullAt(i)) return null;
4849
int start = listVector.getOffsetBuffer().getInt(i * ListVector.OFFSET_WIDTH);
4950
int end = listVector.getOffsetBuffer().getInt((i + 1) * ListVector.OFFSET_WIDTH);
5051

common/src/main/java/org/apache/comet/vector/CometMapVector.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ public CometMapVector(
6565

6666
@Override
6767
public ColumnarMap getMap(int i) {
68+
if (isNullAt(i)) return null;
6869
int start = mapVector.getOffsetBuffer().getInt(i * MapVector.OFFSET_WIDTH);
6970
int end = mapVector.getOffsetBuffer().getInt((i + 1) * MapVector.OFFSET_WIDTH);
7071

common/src/main/java/org/apache/comet/vector/CometPlainVector.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ public double getDouble(int rowId) {
123123

124124
@Override
125125
public UTF8String getUTF8String(int rowId) {
126+
if (isNullAt(rowId)) return null;
126127
if (!isBaseFixedWidthVector) {
127128
BaseVariableWidthVector varWidthVector = (BaseVariableWidthVector) valueVector;
128129
long offsetBufferAddress = varWidthVector.getOffsetBuffer().memoryAddress();
@@ -147,6 +148,7 @@ public UTF8String getUTF8String(int rowId) {
147148

148149
@Override
149150
public byte[] getBinary(int rowId) {
151+
if (isNullAt(rowId)) return null;
150152
int offset;
151153
int length;
152154
if (valueVector instanceof BaseVariableWidthVector) {

common/src/main/java/org/apache/comet/vector/CometVector.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ public boolean isFixedLength() {
8585

8686
@Override
8787
public Decimal getDecimal(int i, int precision, int scale) {
88+
if (isNullAt(i)) return null;
8889
if (!useDecimal128 && precision <= Decimal.MAX_INT_DIGITS() && type instanceof IntegerType) {
8990
return createDecimal(getInt(i), precision, scale);
9091
} else if (precision <= Decimal.MAX_LONG_DIGITS()) {

native/spark-expr/src/array_funcs/array_insert.rs

Lines changed: 120 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
// under the License.
1717

1818
use arrow::array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait};
19-
use arrow::datatypes::{DataType, Field, Schema};
19+
use arrow::datatypes::{DataType, Schema};
2020
use arrow::{
2121
array::{as_primitive_array, Capacities, MutableArrayData},
2222
buffer::{NullBuffer, OffsetBuffer},
23-
datatypes::ArrowNativeType,
2423
record_batch::RecordBatch,
2524
};
2625
use datafusion::common::{
@@ -198,114 +197,131 @@ fn array_insert<O: OffsetSizeTrait>(
198197
pos_array: &ArrayRef,
199198
legacy_mode: bool,
200199
) -> DataFusionResult<ColumnarValue> {
201-
// The code is based on the implementation of the array_append from the Apache DataFusion
202-
// https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513
203-
//
204-
// This code is also based on the implementation of the array_insert from the Apache Spark
205-
// https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713
200+
// Implementation aligned with Arrow's half-open offset ranges and Spark semantics.
206201

207202
let values = list_array.values();
208203
let offsets = list_array.offsets();
209204
let values_data = values.to_data();
210205
let item_data = items_array.to_data();
206+
207+
// Estimate capacity (original values + inserted items upper bound)
211208
let new_capacity = Capacities::Array(values_data.len() + item_data.len());
212209

213210
let mut mutable_values =
214211
MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity);
215212

216-
let mut new_offsets = vec![O::usize_as(0)];
217-
let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());
213+
// New offsets and top-level list validity bitmap
214+
let mut new_offsets = Vec::with_capacity(list_array.len() + 1);
215+
new_offsets.push(O::usize_as(0));
216+
let mut list_valid = Vec::<bool>::with_capacity(list_array.len());
218217

219-
let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions
218+
// Spark supports only Int32 position indices
219+
let pos_data: &Int32Array = as_primitive_array(&pos_array);
220220

221-
for (row_index, offset_window) in offsets.windows(2).enumerate() {
222-
let pos = pos_data.values()[row_index];
223-
let start = offset_window[0].as_usize();
224-
let end = offset_window[1].as_usize();
225-
let is_item_null = items_array.is_null(row_index);
221+
for (row_index, window) in offsets.windows(2).enumerate() {
222+
let start = window[0].as_usize();
223+
let end = window[1].as_usize();
224+
let len = end - start;
225+
226+
// Return null for the entire row when pos is null (consistent with Spark's behavior)
227+
if pos_data.is_null(row_index) {
228+
new_offsets.push(new_offsets[row_index]);
229+
list_valid.push(false);
230+
continue;
231+
}
232+
let pos = pos_data.value(row_index);
226233

227234
if list_array.is_null(row_index) {
228-
// In Spark if value of the array is NULL than nothing happens
229-
mutable_values.extend_nulls(1);
230-
new_offsets.push(new_offsets[row_index] + O::one());
231-
new_nulls.push(false);
235+
// Top-level list row is NULL: do not write any child values and do not advance offset
236+
new_offsets.push(new_offsets[row_index]);
237+
list_valid.push(false);
232238
continue;
233239
}
234240

235241
if pos == 0 {
236242
return Err(DataFusionError::Internal(
237-
"Position for array_insert should be greter or less than zero".to_string(),
243+
"Position for array_insert should be greater or less than zero".to_string(),
238244
));
239245
}
240246

241-
if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) {
242-
let corrected_pos = if pos > 0 {
243-
(pos - 1).as_usize()
244-
} else {
245-
end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 }
246-
};
247-
let new_array_len = std::cmp::max(end - start + 1, corrected_pos);
248-
if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
249-
return Err(DataFusionError::Internal(format!(
250-
"Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}"
251-
)));
252-
}
247+
let final_len: usize;
253248

254-
if (start + corrected_pos) <= end {
255-
mutable_values.extend(0, start, start + corrected_pos);
249+
if pos > 0 {
250+
// Positive index (1-based)
251+
let pos1 = pos as usize;
252+
if pos1 <= len + 1 {
253+
// In-range insertion (including appending to end)
254+
let corrected = pos1 - 1; // 0-based insertion point
255+
mutable_values.extend(0, start, start + corrected);
256256
mutable_values.extend(1, row_index, row_index + 1);
257-
mutable_values.extend(0, start + corrected_pos, end);
258-
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
257+
mutable_values.extend(0, start + corrected, end);
258+
final_len = len + 1;
259259
} else {
260+
// Beyond end: pad with nulls then insert
261+
let corrected = pos1 - 1;
262+
let padding = corrected - len;
260263
mutable_values.extend(0, start, end);
261-
mutable_values.extend_nulls(new_array_len - (end - start));
264+
mutable_values.extend_nulls(padding);
262265
mutable_values.extend(1, row_index, row_index + 1);
263-
// In that case spark actualy makes array longer than expected;
264-
// For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5
265-
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one());
266+
final_len = corrected + 1; // equals pos1
266267
}
267268
} else {
268-
// This comment is takes from the Apache Spark source code as is:
269-
// special case- if the new position is negative but larger than the current array size
270-
// place the new item at start of array, place the current array contents at the end
271-
// and fill the newly created array elements inbetween with a null
272-
let base_offset = if legacy_mode { 1 } else { 0 };
273-
let new_array_len = (-pos + base_offset).as_usize();
274-
if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
275-
return Err(DataFusionError::Internal(format!(
276-
"Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}"
277-
)));
278-
}
279-
mutable_values.extend(1, row_index, row_index + 1);
280-
mutable_values.extend_nulls(new_array_len - (end - start + 1));
281-
mutable_values.extend(0, start, end);
282-
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
283-
}
284-
if is_item_null {
285-
if (start == end) || (values.is_null(row_index)) {
286-
new_nulls.push(false)
269+
// Negative index (1-based from the end)
270+
let k = (-pos) as usize;
271+
272+
if k <= len {
273+
// In-range negative insertion
274+
// Non-legacy: -1 behaves like append to end (corrected = len - k + 1)
275+
// Legacy: -1 behaves like insert before the last element (corrected = len - k)
276+
let base_offset = if legacy_mode { 0 } else { 1 };
277+
let corrected = len - k + base_offset;
278+
mutable_values.extend(0, start, start + corrected);
279+
mutable_values.extend(1, row_index, row_index + 1);
280+
mutable_values.extend(0, start + corrected, end);
281+
final_len = len + 1;
287282
} else {
288-
new_nulls.push(true)
283+
// Negative index beyond the start (Spark-specific behavior):
284+
// Place item first, then pad with nulls, then append the original array.
285+
// Final length = k + base_offset, where base_offset = 1 in legacy mode, otherwise 0.
286+
let base_offset = if legacy_mode { 1 } else { 0 };
287+
let target_len = k + base_offset;
288+
let padding = target_len.saturating_sub(len + 1);
289+
mutable_values.extend(1, row_index, row_index + 1); // insert item first
290+
mutable_values.extend_nulls(padding); // pad nulls
291+
mutable_values.extend(0, start, end); // append original values
292+
final_len = target_len;
289293
}
290-
} else {
291-
new_nulls.push(true)
292294
}
295+
296+
if final_len > MAX_ROUNDED_ARRAY_LENGTH {
297+
return Err(DataFusionError::Internal(format!(
298+
"Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH}, but got {final_len}"
299+
)));
300+
}
301+
302+
let prev = new_offsets[row_index].as_usize();
303+
new_offsets.push(O::usize_as(prev + final_len));
304+
list_valid.push(true);
293305
}
294306

295-
let data = make_array(mutable_values.freeze());
296-
let data_type = match list_array.data_type() {
297-
DataType::List(field) => field.data_type(),
298-
DataType::LargeList(field) => field.data_type(),
307+
let child = make_array(mutable_values.freeze());
308+
309+
// Reuse the original list element field (name/type/nullability)
310+
let elem_field = match list_array.data_type() {
311+
DataType::List(field) => Arc::clone(field),
312+
DataType::LargeList(field) => Arc::clone(field),
299313
_ => unreachable!(),
300314
};
301-
let new_array = GenericListArray::<O>::try_new(
302-
Arc::new(Field::new("item", data_type.clone(), true)),
315+
316+
// Build the resulting list array
317+
let new_list = GenericListArray::<O>::try_new(
318+
elem_field,
303319
OffsetBuffer::new(new_offsets.into()),
304-
data,
305-
Some(NullBuffer::new(new_nulls.into())),
320+
child,
321+
Some(NullBuffer::new(list_valid.into())),
306322
)?;
307323

308-
Ok(ColumnarValue::Array(Arc::new(new_array)))
324+
Ok(ColumnarValue::Array(Arc::new(new_list)))
309325
}
310326

311327
impl Display for ArrayInsert {
@@ -442,4 +458,37 @@ mod test {
442458

443459
Ok(())
444460
}
461+
462+
#[test]
463+
fn test_array_insert_bug_repro_null_item_pos1_fixed() -> Result<()> {
464+
use arrow::array::{Array, ArrayRef, Int32Array, ListArray};
465+
use arrow::datatypes::Int32Type;
466+
467+
// row0 = [0, null, 0]
468+
// row1 = [1, null, 1]
469+
let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
470+
Some(vec![Some(0), None, Some(0)]),
471+
Some(vec![Some(1), None, Some(1)]),
472+
]);
473+
474+
let positions = Int32Array::from(vec![1, 1]);
475+
let items = Int32Array::from(vec![None, None]);
476+
477+
let ColumnarValue::Array(result) = array_insert(
478+
&list,
479+
&(Arc::new(items) as ArrayRef),
480+
&(Arc::new(positions) as ArrayRef),
481+
false, // legacy_mode = false
482+
)?
483+
else {
484+
unreachable!()
485+
};
486+
487+
let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
488+
Some(vec![None, Some(0), None, Some(0)]),
489+
Some(vec![None, Some(1), None, Some(1)]),
490+
]);
491+
assert_eq!(&result.to_data(), &expected.to_data());
492+
Ok(())
493+
}
445494
}

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.CometTestBase
2626
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayRepeat, ArraysOverlap, ArrayUnion}
2727
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
2828
import org.apache.spark.sql.functions._
29+
import org.apache.spark.sql.types.ArrayType
2930

3031
import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus}
3132
import org.apache.comet.DataTypeSupport.isComplexType
@@ -210,11 +211,13 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
210211
.withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)"))
211212
.withColumn("arrInsertNegativeIndexResult", expr("array_insert(arr, -1, 1)"))
212213
.withColumn("arrPosGreaterThanSize", expr("array_insert(arr, 8, 1)"))
214+
.withColumn("arrPosIsNull", expr("array_insert(arr, cast(null as int), 1)"))
213215
.withColumn("arrNegPosGreaterThanSize", expr("array_insert(arr, -8, 1)"))
214216
.withColumn("arrInsertNone", expr("array_insert(arr, 1, null)"))
215217
checkSparkAnswerAndOperator(df.select("arrInsertResult"))
216218
checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult"))
217219
checkSparkAnswerAndOperator(df.select("arrPosGreaterThanSize"))
220+
checkSparkAnswerAndOperator(df.select("arrPosIsNull"))
218221
checkSparkAnswerAndOperator(df.select("arrNegPosGreaterThanSize"))
219222
checkSparkAnswerAndOperator(df.select("arrInsertNone"))
220223
})
@@ -802,4 +805,28 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
802805
fallbackReason)
803806
}
804807
}
808+
809+
test("array_reverse 2") {
810+
// This test validates data correctness for array<binary> columns with nullable elements.
811+
// See https://github.com/apache/datafusion-comet/issues/2612
812+
withTempDir { dir =>
813+
val path = new Path(dir.toURI.toString, "test.parquet")
814+
val filename = path.toString
815+
val random = new Random(42)
816+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
817+
val schemaOptions =
818+
SchemaGenOptions(generateArray = true, generateStruct = false, generateMap = false)
819+
val dataOptions = DataGenOptions(allowNull = true, generateNegativeZero = false)
820+
ParquetGenerator.makeParquetFile(random, spark, filename, 100, schemaOptions, dataOptions)
821+
}
822+
withTempView("t1") {
823+
val table = spark.read.parquet(filename)
824+
table.createOrReplaceTempView("t1")
825+
for (field <- table.schema.fields.filter(_.dataType.isInstanceOf[ArrayType])) {
826+
val sql = s"SELECT ${field.name}, reverse(${field.name}) FROM t1 ORDER BY ${field.name}"
827+
checkSparkAnswer(sql)
828+
}
829+
}
830+
}
831+
}
805832
}

0 commit comments

Comments
 (0)