|
16 | 16 | // under the License. |
17 | 17 |
|
18 | 18 | use arrow::array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait}; |
19 | | -use arrow::datatypes::{DataType, Field, Schema}; |
| 19 | +use arrow::datatypes::{DataType, Schema}; |
20 | 20 | use arrow::{ |
21 | 21 | array::{as_primitive_array, Capacities, MutableArrayData}, |
22 | 22 | buffer::{NullBuffer, OffsetBuffer}, |
23 | | - datatypes::ArrowNativeType, |
24 | 23 | record_batch::RecordBatch, |
25 | 24 | }; |
26 | 25 | use datafusion::common::{ |
@@ -198,114 +197,131 @@ fn array_insert<O: OffsetSizeTrait>( |
198 | 197 | pos_array: &ArrayRef, |
199 | 198 | legacy_mode: bool, |
200 | 199 | ) -> DataFusionResult<ColumnarValue> { |
201 | | - // The code is based on the implementation of the array_append from the Apache DataFusion |
202 | | - // https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513 |
203 | | - // |
204 | | - // This code is also based on the implementation of the array_insert from the Apache Spark |
205 | | - // https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713 |
| 200 | + // Implementation aligned with Arrow's half-open offset ranges and Spark semantics. |
206 | 201 |
|
207 | 202 | let values = list_array.values(); |
208 | 203 | let offsets = list_array.offsets(); |
209 | 204 | let values_data = values.to_data(); |
210 | 205 | let item_data = items_array.to_data(); |
| 206 | + |
| 207 | + // Estimate capacity (original values + inserted items upper bound) |
211 | 208 | let new_capacity = Capacities::Array(values_data.len() + item_data.len()); |
212 | 209 |
|
213 | 210 | let mut mutable_values = |
214 | 211 | MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity); |
215 | 212 |
|
216 | | - let mut new_offsets = vec![O::usize_as(0)]; |
217 | | - let mut new_nulls = Vec::<bool>::with_capacity(list_array.len()); |
| 213 | + // New offsets and top-level list validity bitmap |
| 214 | + let mut new_offsets = Vec::with_capacity(list_array.len() + 1); |
| 215 | + new_offsets.push(O::usize_as(0)); |
| 216 | + let mut list_valid = Vec::<bool>::with_capacity(list_array.len()); |
218 | 217 |
|
219 | | - let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions |
| 218 | + // Spark supports only Int32 position indices |
| 219 | + let pos_data: &Int32Array = as_primitive_array(&pos_array); |
220 | 220 |
|
221 | | - for (row_index, offset_window) in offsets.windows(2).enumerate() { |
222 | | - let pos = pos_data.values()[row_index]; |
223 | | - let start = offset_window[0].as_usize(); |
224 | | - let end = offset_window[1].as_usize(); |
225 | | - let is_item_null = items_array.is_null(row_index); |
| 221 | + for (row_index, window) in offsets.windows(2).enumerate() { |
| 222 | + let start = window[0].as_usize(); |
| 223 | + let end = window[1].as_usize(); |
| 224 | + let len = end - start; |
| 225 | + |
| 226 | + // Return null for the entire row when pos is null (consistent with Spark's behavior) |
| 227 | + if pos_data.is_null(row_index) { |
| 228 | + new_offsets.push(new_offsets[row_index]); |
| 229 | + list_valid.push(false); |
| 230 | + continue; |
| 231 | + } |
| 232 | + let pos = pos_data.value(row_index); |
226 | 233 |
|
227 | 234 | if list_array.is_null(row_index) { |
228 | | - // In Spark if value of the array is NULL than nothing happens |
229 | | - mutable_values.extend_nulls(1); |
230 | | - new_offsets.push(new_offsets[row_index] + O::one()); |
231 | | - new_nulls.push(false); |
| 235 | + // Top-level list row is NULL: do not write any child values and do not advance offset |
| 236 | + new_offsets.push(new_offsets[row_index]); |
| 237 | + list_valid.push(false); |
232 | 238 | continue; |
233 | 239 | } |
234 | 240 |
|
235 | 241 | if pos == 0 { |
236 | 242 | return Err(DataFusionError::Internal( |
237 | | - "Position for array_insert should be greter or less than zero".to_string(), |
| 243 | + "Position for array_insert should be greater or less than zero".to_string(), |
238 | 244 | )); |
239 | 245 | } |
240 | 246 |
|
241 | | - if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) { |
242 | | - let corrected_pos = if pos > 0 { |
243 | | - (pos - 1).as_usize() |
244 | | - } else { |
245 | | - end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 } |
246 | | - }; |
247 | | - let new_array_len = std::cmp::max(end - start + 1, corrected_pos); |
248 | | - if new_array_len > MAX_ROUNDED_ARRAY_LENGTH { |
249 | | - return Err(DataFusionError::Internal(format!( |
250 | | - "Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}" |
251 | | - ))); |
252 | | - } |
| 247 | + let final_len: usize; |
253 | 248 |
|
254 | | - if (start + corrected_pos) <= end { |
255 | | - mutable_values.extend(0, start, start + corrected_pos); |
| 249 | + if pos > 0 { |
| 250 | + // Positive index (1-based) |
| 251 | + let pos1 = pos as usize; |
| 252 | + if pos1 <= len + 1 { |
| 253 | + // In-range insertion (including appending to end) |
| 254 | + let corrected = pos1 - 1; // 0-based insertion point |
| 255 | + mutable_values.extend(0, start, start + corrected); |
256 | 256 | mutable_values.extend(1, row_index, row_index + 1); |
257 | | - mutable_values.extend(0, start + corrected_pos, end); |
258 | | - new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len)); |
| 257 | + mutable_values.extend(0, start + corrected, end); |
| 258 | + final_len = len + 1; |
259 | 259 | } else { |
| 260 | + // Beyond end: pad with nulls then insert |
| 261 | + let corrected = pos1 - 1; |
| 262 | + let padding = corrected - len; |
260 | 263 | mutable_values.extend(0, start, end); |
261 | | - mutable_values.extend_nulls(new_array_len - (end - start)); |
| 264 | + mutable_values.extend_nulls(padding); |
262 | 265 | mutable_values.extend(1, row_index, row_index + 1); |
263 | | - // In that case spark actualy makes array longer than expected; |
264 | | - // For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5 |
265 | | - new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one()); |
| 266 | + final_len = corrected + 1; // equals pos1 |
266 | 267 | } |
267 | 268 | } else { |
268 | | - // This comment is takes from the Apache Spark source code as is: |
269 | | - // special case- if the new position is negative but larger than the current array size |
270 | | - // place the new item at start of array, place the current array contents at the end |
271 | | - // and fill the newly created array elements inbetween with a null |
272 | | - let base_offset = if legacy_mode { 1 } else { 0 }; |
273 | | - let new_array_len = (-pos + base_offset).as_usize(); |
274 | | - if new_array_len > MAX_ROUNDED_ARRAY_LENGTH { |
275 | | - return Err(DataFusionError::Internal(format!( |
276 | | - "Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}" |
277 | | - ))); |
278 | | - } |
279 | | - mutable_values.extend(1, row_index, row_index + 1); |
280 | | - mutable_values.extend_nulls(new_array_len - (end - start + 1)); |
281 | | - mutable_values.extend(0, start, end); |
282 | | - new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len)); |
283 | | - } |
284 | | - if is_item_null { |
285 | | - if (start == end) || (values.is_null(row_index)) { |
286 | | - new_nulls.push(false) |
| 269 | + // Negative index (1-based from the end) |
| 270 | + let k = (-pos) as usize; |
| 271 | + |
| 272 | + if k <= len { |
| 273 | + // In-range negative insertion |
| 274 | + // Non-legacy: -1 behaves like append to end (corrected = len - k + 1) |
| 275 | + // Legacy: -1 behaves like insert before the last element (corrected = len - k) |
| 276 | + let base_offset = if legacy_mode { 0 } else { 1 }; |
| 277 | + let corrected = len - k + base_offset; |
| 278 | + mutable_values.extend(0, start, start + corrected); |
| 279 | + mutable_values.extend(1, row_index, row_index + 1); |
| 280 | + mutable_values.extend(0, start + corrected, end); |
| 281 | + final_len = len + 1; |
287 | 282 | } else { |
288 | | - new_nulls.push(true) |
| 283 | + // Negative index beyond the start (Spark-specific behavior): |
| 284 | + // Place item first, then pad with nulls, then append the original array. |
| 285 | + // Final length = k + base_offset, where base_offset = 1 in legacy mode, otherwise 0. |
| 286 | + let base_offset = if legacy_mode { 1 } else { 0 }; |
| 287 | + let target_len = k + base_offset; |
| 288 | + let padding = target_len.saturating_sub(len + 1); |
| 289 | + mutable_values.extend(1, row_index, row_index + 1); // insert item first |
| 290 | + mutable_values.extend_nulls(padding); // pad nulls |
| 291 | + mutable_values.extend(0, start, end); // append original values |
| 292 | + final_len = target_len; |
289 | 293 | } |
290 | | - } else { |
291 | | - new_nulls.push(true) |
292 | 294 | } |
| 295 | + |
| 296 | + if final_len > MAX_ROUNDED_ARRAY_LENGTH { |
| 297 | + return Err(DataFusionError::Internal(format!( |
| 298 | + "Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH}, but got {final_len}" |
| 299 | + ))); |
| 300 | + } |
| 301 | + |
| 302 | + let prev = new_offsets[row_index].as_usize(); |
| 303 | + new_offsets.push(O::usize_as(prev + final_len)); |
| 304 | + list_valid.push(true); |
293 | 305 | } |
294 | 306 |
|
295 | | - let data = make_array(mutable_values.freeze()); |
296 | | - let data_type = match list_array.data_type() { |
297 | | - DataType::List(field) => field.data_type(), |
298 | | - DataType::LargeList(field) => field.data_type(), |
| 307 | + let child = make_array(mutable_values.freeze()); |
| 308 | + |
| 309 | + // Reuse the original list element field (name/type/nullability) |
| 310 | + let elem_field = match list_array.data_type() { |
| 311 | + DataType::List(field) => Arc::clone(field), |
| 312 | + DataType::LargeList(field) => Arc::clone(field), |
299 | 313 | _ => unreachable!(), |
300 | 314 | }; |
301 | | - let new_array = GenericListArray::<O>::try_new( |
302 | | - Arc::new(Field::new("item", data_type.clone(), true)), |
| 315 | + |
| 316 | + // Build the resulting list array |
| 317 | + let new_list = GenericListArray::<O>::try_new( |
| 318 | + elem_field, |
303 | 319 | OffsetBuffer::new(new_offsets.into()), |
304 | | - data, |
305 | | - Some(NullBuffer::new(new_nulls.into())), |
| 320 | + child, |
| 321 | + Some(NullBuffer::new(list_valid.into())), |
306 | 322 | )?; |
307 | 323 |
|
308 | | - Ok(ColumnarValue::Array(Arc::new(new_array))) |
| 324 | + Ok(ColumnarValue::Array(Arc::new(new_list))) |
309 | 325 | } |
310 | 326 |
|
311 | 327 | impl Display for ArrayInsert { |
@@ -442,4 +458,37 @@ mod test { |
442 | 458 |
|
443 | 459 | Ok(()) |
444 | 460 | } |
| 461 | + |
| 462 | + #[test] |
| 463 | + fn test_array_insert_bug_repro_null_item_pos1_fixed() -> Result<()> { |
| 464 | + use arrow::array::{Array, ArrayRef, Int32Array, ListArray}; |
| 465 | + use arrow::datatypes::Int32Type; |
| 466 | + |
| 467 | + // row0 = [0, null, 0] |
| 468 | + // row1 = [1, null, 1] |
| 469 | + let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![ |
| 470 | + Some(vec![Some(0), None, Some(0)]), |
| 471 | + Some(vec![Some(1), None, Some(1)]), |
| 472 | + ]); |
| 473 | + |
| 474 | + let positions = Int32Array::from(vec![1, 1]); |
| 475 | + let items = Int32Array::from(vec![None, None]); |
| 476 | + |
| 477 | + let ColumnarValue::Array(result) = array_insert( |
| 478 | + &list, |
| 479 | + &(Arc::new(items) as ArrayRef), |
| 480 | + &(Arc::new(positions) as ArrayRef), |
| 481 | + false, // legacy_mode = false |
| 482 | + )? |
| 483 | + else { |
| 484 | + unreachable!() |
| 485 | + }; |
| 486 | + |
| 487 | + let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![ |
| 488 | + Some(vec![None, Some(0), None, Some(0)]), |
| 489 | + Some(vec![None, Some(1), None, Some(1)]), |
| 490 | + ]); |
| 491 | + assert_eq!(&result.to_data(), &expected.to_data()); |
| 492 | + Ok(()) |
| 493 | + } |
445 | 494 | } |
0 commit comments