diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 5f690e9a6734..c0e1dac0bdb6 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -458,6 +458,9 @@ enum Codec { List(RowConverter), /// A row converter for the values array of a run-end encoded array RunEndEncoded(RowConverter), + /// Row converters for each union field (indexed by type_id) + /// and the encoding of null rows for each field + Union(Vec, Vec), } impl Codec { @@ -524,6 +527,35 @@ impl Codec { Ok(Self::Struct(converter, owned)) } + DataType::Union(fields, _mode) => { + // similar to dictionaries and lists, we set descending to false and negate nulls_first + // since the encoded contents will be inverted if descending is set + let options = SortOptions { + descending: false, + nulls_first: sort_field.options.nulls_first != sort_field.options.descending, + }; + + let mut converters = Vec::with_capacity(fields.len()); + let mut null_rows = Vec::with_capacity(fields.len()); + + for (_type_id, field) in fields.iter() { + let sort_field = + SortField::new_with_options(field.data_type().clone(), options); + let converter = RowConverter::new(vec![sort_field])?; + + let null_array = new_null_array(field.data_type(), 1); + let nulls = converter.convert_columns(&[null_array])?; + let owned = OwnedRow { + data: nulls.buffer.into(), + config: nulls.config, + }; + + converters.push(converter); + null_rows.push(owned); + } + + Ok(Self::Union(converters, null_rows)) + } _ => Err(ArrowError::NotYetImplemented(format!( "not yet implemented: {:?}", sort_field.data_type @@ -592,6 +624,28 @@ impl Codec { let rows = converter.convert_columns(std::slice::from_ref(values))?; Ok(Encoder::RunEndEncoded(rows)) } + Codec::Union(converters, _) => { + let union_array = array + .as_any() + .downcast_ref::() + .expect("expected Union array"); + + let type_ids = union_array.type_ids().clone(); + let offsets = union_array.offsets().cloned(); + + let mut child_rows = Vec::with_capacity(converters.len()); + for (type_id, converter) in converters.iter().enumerate() { + let child_array = union_array.child(type_id as i8); + let rows = converter.convert_columns(std::slice::from_ref(child_array))?; + child_rows.push(rows); + } + + Ok(Encoder::Union { + child_rows, + type_ids, + offsets, + }) + } } } @@ -602,6 +656,10 @@ impl Codec { Codec::Struct(converter, nulls) => converter.size() + nulls.data.len(), Codec::List(converter) => converter.size(), Codec::RunEndEncoded(converter) => converter.size(), + Codec::Union(converters, null_rows) => { + converters.iter().map(|c| c.size()).sum::() + + null_rows.iter().map(|n| n.data.len()).sum::() + } } } } @@ -622,6 +680,12 @@ enum Encoder<'a> { List(Rows), /// The row encoding of the values array RunEndEncoded(Rows), + /// The row encoding of each union field's child array, type_ids buffer, offsets buffer (for Dense), and mode + Union { + child_rows: Vec, + type_ids: ScalarBuffer, + offsets: Option>, + }, } /// Configure the data type and sort order for a given column @@ -681,6 +745,9 @@ impl RowConverter { } DataType::Struct(f) => f.iter().all(|x| Self::supports_datatype(x.data_type())), DataType::RunEndEncoded(_, values) => Self::supports_datatype(values.data_type()), + DataType::Union(fs, _mode) => fs + .iter() + .all(|(_, f)| Self::supports_datatype(f.data_type())), _ => false, } } @@ -1523,6 +1590,27 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> LengthTracker { }, _ => unreachable!(), }, + Encoder::Union { + child_rows, + type_ids, + offsets, + } => { + let union_array = array + .as_any() + .downcast_ref::() + .expect("expected UnionArray"); + + let lengths = (0..union_array.len()).map(|i| { + let type_id = type_ids[i]; + let child_row_i = offsets.as_ref().map(|o| o[i] as usize).unwrap_or(i); + let child_row = child_rows[type_id as usize].row(child_row_i); + + // length: 1 byte type_id + child row bytes + 1 + child_row.as_ref().len() + }); + + tracker.push_variable(lengths); + } } } @@ -1637,6 +1725,36 @@ fn encode_column( }, _ => unreachable!(), }, + Encoder::Union { + child_rows, + type_ids, + offsets: offsets_buf, + } => { + offsets + .iter_mut() + .skip(1) + .enumerate() + .for_each(|(i, offset)| { + let type_id = type_ids[i]; + + let child_row_idx = offsets_buf.as_ref().map(|o| o[i] as usize).unwrap_or(i); + let child_row = child_rows[type_id as usize].row(child_row_idx); + let child_bytes = child_row.as_ref(); + + let type_id_byte = if opts.descending { + !(type_id as u8) + } else { + type_id as u8 + }; + data[*offset] = type_id_byte; + + let child_start = *offset + 1; + let child_end = child_start + child_bytes.len(); + data[child_start..child_end].copy_from_slice(child_bytes); + + *offset = child_end; + }); + } } } @@ -1762,6 +1880,111 @@ unsafe fn decode_column( }, _ => unreachable!(), }, + Codec::Union(converters, null_rows) => { + let len = rows.len(); + + let DataType::Union(union_fields, mode) = &field.data_type else { + unreachable!() + }; + + let mut type_ids = Vec::with_capacity(len); + let mut rows_by_field: Vec> = vec![Vec::new(); converters.len()]; + + for (idx, row) in rows.iter_mut().enumerate() { + let mut cursor = 0; + + let type_id_byte = { + let id = row[cursor]; + cursor += 1; + + if options.descending { !id } else { id } + }; + + let type_id = type_id_byte as i8; + type_ids.push(type_id); + + let field_idx = type_id as usize; + + let child_row = &row[cursor..]; + rows_by_field[field_idx].push((idx, child_row)); + + *row = &row[row.len()..]; + } + + let mut child_arrays: Vec = Vec::with_capacity(converters.len()); + + let mut offsets = (*mode == UnionMode::Dense).then(|| Vec::with_capacity(len)); + + for (field_idx, converter) in converters.iter().enumerate() { + let field_rows = &rows_by_field[field_idx]; + + match &mode { + UnionMode::Dense => { + if field_rows.is_empty() { + let (_, field) = union_fields.iter().nth(field_idx).unwrap(); + child_arrays.push(arrow_array::new_empty_array(field.data_type())); + continue; + } + + let mut child_data = field_rows + .iter() + .map(|(_, bytes)| *bytes) + .collect::>(); + + let child_array = + unsafe { converter.convert_raw(&mut child_data, validate_utf8) }?; + + child_arrays.push(child_array.into_iter().next().unwrap()); + } + UnionMode::Sparse => { + let mut sparse_data: Vec<&[u8]> = Vec::with_capacity(len); + let mut field_row_iter = field_rows.iter().peekable(); + let null_row_bytes: &[u8] = &null_rows[field_idx].data; + + for idx in 0..len { + if let Some((next_idx, bytes)) = field_row_iter.peek() { + if *next_idx == idx { + sparse_data.push(*bytes); + + field_row_iter.next(); + continue; + } + } + sparse_data.push(null_row_bytes); + } + + let child_array = + unsafe { converter.convert_raw(&mut sparse_data, validate_utf8) }?; + child_arrays.push(child_array.into_iter().next().unwrap()); + } + } + } + + // build offsets for dense unions + if let Some(ref mut offsets_vec) = offsets { + let mut count = vec![0i32; converters.len()]; + for type_id in &type_ids { + let field_idx = *type_id as usize; + offsets_vec.push(count[field_idx]); + + count[field_idx] += 1; + } + } + + let type_ids_buffer = ScalarBuffer::from(type_ids); + let offsets_buffer = offsets.map(ScalarBuffer::from); + + let union_array = UnionArray::try_new( + union_fields.clone(), + type_ids_buffer, + offsets_buffer, + child_arrays, + )?; + + // note: union arrays don't support physical null buffers + // nulls are represented logically though child arrays + Arc::new(union_array) + } }; Ok(array) } @@ -3598,4 +3821,237 @@ mod tests { assert_eq!(unchecked_values_len, 13); assert!(checked_values_len > unchecked_values_len); } + + #[test] + fn test_sparse_union() { + // create a sparse union with Int32 (type_id = 0) and Utf8 (type_id = 1) + let int_array = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let str_array = StringArray::from(vec![None, Some("b"), None, Some("d"), None]); + + // [1, "b", 3, "d", 5] + let type_ids = vec![0, 1, 0, 1, 0].into(); + + let union_fields = [ + (0, Arc::new(Field::new("int", DataType::Int32, false))), + (1, Arc::new(Field::new("str", DataType::Utf8, false))), + ] + .into_iter() + .collect(); + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + None, + vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)], + ) + .unwrap(); + + let union_type = union_array.data_type().clone(); + let converter = RowConverter::new(vec![SortField::new(union_type)]).unwrap(); + + let rows = converter + .convert_columns(&[Arc::new(union_array.clone())]) + .unwrap(); + + // round trip + let back = converter.convert_rows(&rows).unwrap(); + let back_union = back[0].as_any().downcast_ref::().unwrap(); + + assert_eq!(union_array.len(), back_union.len()); + for i in 0..union_array.len() { + assert_eq!(union_array.type_id(i), back_union.type_id(i)); + } + } + + #[test] + fn test_sparse_union_with_nulls() { + // create a sparse union with Int32 (type_id = 0) and Utf8 (type_id = 1) + let int_array = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let str_array = StringArray::from(vec![None::<&str>; 5]); + + // [1, null (both children null), 3, null (both children null), 5] + let type_ids = vec![0, 1, 0, 1, 0].into(); + + let union_fields = [ + (0, Arc::new(Field::new("int", DataType::Int32, true))), + (1, Arc::new(Field::new("str", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + None, + vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)], + ) + .unwrap(); + + let union_type = union_array.data_type().clone(); + let converter = RowConverter::new(vec![SortField::new(union_type)]).unwrap(); + + let rows = converter + .convert_columns(&[Arc::new(union_array.clone())]) + .unwrap(); + + // round trip + let back = converter.convert_rows(&rows).unwrap(); + let back_union = back[0].as_any().downcast_ref::().unwrap(); + + assert_eq!(union_array.len(), back_union.len()); + for i in 0..union_array.len() { + let expected_null = union_array.is_null(i); + let actual_null = back_union.is_null(i); + assert_eq!(expected_null, actual_null, "Null mismatch at index {i}"); + if !expected_null { + assert_eq!(union_array.type_id(i), back_union.type_id(i)); + } + } + } + + #[test] + fn test_dense_union() { + // create a dense union with Int32 (type_id = 0) and use Utf8 (type_id = 1) + let int_array = Int32Array::from(vec![1, 3, 5]); + let str_array = StringArray::from(vec!["a", "b"]); + + let type_ids = vec![0, 1, 0, 1, 0].into(); + + // [1, "a", 3, "b", 5] + let offsets = vec![0, 0, 1, 1, 2].into(); + + let union_fields = [ + (0, Arc::new(Field::new("int", DataType::Int32, false))), + (1, Arc::new(Field::new("str", DataType::Utf8, false))), + ] + .into_iter() + .collect(); + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + Some(offsets), // Dense mode + vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)], + ) + .unwrap(); + + let union_type = union_array.data_type().clone(); + let converter = RowConverter::new(vec![SortField::new(union_type)]).unwrap(); + + let rows = converter + .convert_columns(&[Arc::new(union_array.clone())]) + .unwrap(); + + // round trip + let back = converter.convert_rows(&rows).unwrap(); + let back_union = back[0].as_any().downcast_ref::().unwrap(); + + assert_eq!(union_array.len(), back_union.len()); + for i in 0..union_array.len() { + assert_eq!(union_array.type_id(i), back_union.type_id(i)); + } + } + + #[test] + fn test_dense_union_with_nulls() { + // create a dense union with Int32 (type_id = 0) and Utf8 (type_id = 1) + let int_array = Int32Array::from(vec![Some(1), None, Some(5)]); + let str_array = StringArray::from(vec![Some("a"), None]); + + // [1, "a", 5, null (str null), null (int null)] + let type_ids = vec![0, 1, 0, 1, 0].into(); + let offsets = vec![0, 0, 1, 1, 2].into(); + + let union_fields = [ + (0, Arc::new(Field::new("int", DataType::Int32, true))), + (1, Arc::new(Field::new("str", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + Some(offsets), + vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)], + ) + .unwrap(); + + let union_type = union_array.data_type().clone(); + let converter = RowConverter::new(vec![SortField::new(union_type)]).unwrap(); + + let rows = converter + .convert_columns(&[Arc::new(union_array.clone())]) + .unwrap(); + + // round trip + let back = converter.convert_rows(&rows).unwrap(); + let back_union = back[0].as_any().downcast_ref::().unwrap(); + + assert_eq!(union_array.len(), back_union.len()); + for i in 0..union_array.len() { + let expected_null = union_array.is_null(i); + let actual_null = back_union.is_null(i); + assert_eq!(expected_null, actual_null, "Null mismatch at index {i}"); + if !expected_null { + assert_eq!(union_array.type_id(i), back_union.type_id(i)); + } + } + } + + #[test] + fn test_union_ordering() { + let int_array = Int32Array::from(vec![100, 5, 20]); + let str_array = StringArray::from(vec!["z", "a"]); + + // [100, "z", 5, "a", 20] + let type_ids = vec![0, 1, 0, 1, 0].into(); + let offsets = vec![0, 0, 1, 1, 2].into(); + + let union_fields = [ + (0, Arc::new(Field::new("int", DataType::Int32, false))), + (1, Arc::new(Field::new("str", DataType::Utf8, false))), + ] + .into_iter() + .collect(); + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + Some(offsets), + vec![Arc::new(int_array) as ArrayRef, Arc::new(str_array)], + ) + .unwrap(); + + let union_type = union_array.data_type().clone(); + let converter = RowConverter::new(vec![SortField::new(union_type)]).unwrap(); + + let rows = converter.convert_columns(&[Arc::new(union_array)]).unwrap(); + + /* + expected ordering + + row 2: 5 - type_id 0 + row 4: 20 - type_id 0 + row 0: 100 - type id 0 + row 3: "a" - type id 1 + row 1: "z" - type id 1 + */ + + // 5 < "z" + assert!(rows.row(2) < rows.row(1)); + + // 100 < "a" + assert!(rows.row(0) < rows.row(3)); + + // among ints + // 5 < 20 + assert!(rows.row(2) < rows.row(4)); + // 20 < 100 + assert!(rows.row(4) < rows.row(0)); + + // among strigns + // "a" < "z" + assert!(rows.row(3) < rows.row(1)); + } }