diff --git a/arrow-cast/src/cast/dictionary.rs b/arrow-cast/src/cast/dictionary.rs index c213ac266228..64d77236ccd5 100644 --- a/arrow-cast/src/cast/dictionary.rs +++ b/arrow-cast/src/cast/dictionary.rs @@ -28,111 +28,92 @@ pub(crate) fn dictionary_cast( ) -> Result { use DataType::*; - match to_type { - Dictionary(to_index_type, to_value_type) => { - let dict_array = array - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - ArrowError::ComputeError( - "Internal Error: Cannot cast dictionary to DictionaryArray of expected type".to_string(), - ) - })?; + let array = array.as_dictionary::(); + let from_child_type = array.values().data_type(); + match (from_child_type, to_type) { + (_, Dictionary(to_index_type, to_value_type)) => { + dictionary_to_dictionary_cast(array, to_index_type, to_value_type, cast_options) + } + // `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data + // copy of the value buffer. Fast path which avoids copying underlying values buffer. + // TODO: handle LargeUtf8/LargeBinary -> View (need to check offsets can fit) + // TODO: handle cross types (String -> BinaryView, Binary -> StringView) + // (need to validate utf8?) + (Utf8, Utf8View) => view_from_dict_values::( + array.keys(), + array.values().as_string::(), + ), + (Binary, BinaryView) => view_from_dict_values::( + array.keys(), + array.values().as_binary::(), + ), + _ => unpack_dictionary(array, to_type, cast_options), + } +} - let keys_array: ArrayRef = - Arc::new(PrimitiveArray::::from(dict_array.keys().to_data())); - let values_array = dict_array.values(); - let cast_keys = cast_with_options(&keys_array, to_index_type, cast_options)?; - let cast_values = cast_with_options(values_array, to_value_type, cast_options)?; +fn dictionary_to_dictionary_cast( + array: &DictionaryArray, + to_index_type: &DataType, + to_value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + use DataType::*; - // Failure to cast keys (because they don't fit in the - // target type) results in NULL values; - if cast_keys.null_count() > keys_array.null_count() { - return Err(ArrowError::ComputeError(format!( - "Could not convert {} dictionary indexes from {:?} to {:?}", - cast_keys.null_count() - keys_array.null_count(), - keys_array.data_type(), - to_index_type - ))); - } + let keys_array: ArrayRef = Arc::new(PrimitiveArray::::from(array.keys().to_data())); + let values_array = array.values(); + let cast_keys = cast_with_options(&keys_array, to_index_type, cast_options)?; + let cast_values = cast_with_options(values_array, to_value_type, cast_options)?; - let data = cast_keys.into_data(); - let builder = data - .into_builder() - .data_type(to_type.clone()) - .child_data(vec![cast_values.into_data()]); + // Failure to cast keys (because they don't fit in the + // target type) results in NULL values; + if cast_keys.null_count() > keys_array.null_count() { + return Err(ArrowError::ComputeError(format!( + "Could not convert {} dictionary indexes from {:?} to {:?}", + cast_keys.null_count() - keys_array.null_count(), + keys_array.data_type(), + to_index_type + ))); + } - // Safety - // Cast keys are still valid - let data = unsafe { builder.build_unchecked() }; + let data = cast_keys.into_data(); + let builder = data + .into_builder() + .data_type(Dictionary( + Box::new(to_index_type.clone()), + Box::new(to_value_type.clone()), + )) + .child_data(vec![cast_values.into_data()]); - // create the appropriate array type - let new_array: ArrayRef = match **to_index_type { - Int8 => Arc::new(DictionaryArray::::from(data)), - Int16 => Arc::new(DictionaryArray::::from(data)), - Int32 => Arc::new(DictionaryArray::::from(data)), - Int64 => Arc::new(DictionaryArray::::from(data)), - UInt8 => Arc::new(DictionaryArray::::from(data)), - UInt16 => Arc::new(DictionaryArray::::from(data)), - UInt32 => Arc::new(DictionaryArray::::from(data)), - UInt64 => Arc::new(DictionaryArray::::from(data)), - _ => { - return Err(ArrowError::CastError(format!( - "Unsupported type {to_index_type} for dictionary index" - ))); - } - }; + // Safety + // Cast keys are still valid + let data = unsafe { builder.build_unchecked() }; - Ok(new_array) + // create the appropriate array type + let new_array: ArrayRef = match to_index_type { + Int8 => Arc::new(DictionaryArray::::from(data)), + Int16 => Arc::new(DictionaryArray::::from(data)), + Int32 => Arc::new(DictionaryArray::::from(data)), + Int64 => Arc::new(DictionaryArray::::from(data)), + UInt8 => Arc::new(DictionaryArray::::from(data)), + UInt16 => Arc::new(DictionaryArray::::from(data)), + UInt32 => Arc::new(DictionaryArray::::from(data)), + UInt64 => Arc::new(DictionaryArray::::from(data)), + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported type {to_index_type} for dictionary index" + ))); } - Utf8View => { - // `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data copy of the value buffer. - // we handle it here to avoid the copy. - let dict_array = array - .as_dictionary::() - .downcast_dict::() - .ok_or_else(|| { - ArrowError::ComputeError( - "Internal Error: Cannot cast Utf8View to StringArray of expected type" - .to_string(), - ) - })?; + }; - let string_view = view_from_dict_values::>( - dict_array.values(), - dict_array.keys(), - )?; - Ok(Arc::new(string_view)) - } - BinaryView => { - // `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data copy of the value buffer. - // we handle it here to avoid the copy. - let dict_array = array - .as_dictionary::() - .downcast_dict::() - .ok_or_else(|| { - ArrowError::ComputeError( - "Internal Error: Cannot cast BinaryView to BinaryArray of expected type" - .to_string(), - ) - })?; - - let binary_view = view_from_dict_values::( - dict_array.values(), - dict_array.keys(), - )?; - Ok(Arc::new(binary_view)) - } - _ => unpack_dictionary::(array, to_type, cast_options), - } + Ok(new_array) } -fn view_from_dict_values( - array: &GenericByteArray, +fn view_from_dict_values( keys: &PrimitiveArray, -) -> Result, ArrowError> { - let value_buffer = array.values(); - let value_offsets = array.value_offsets(); + values: &GenericByteArray, +) -> Result { + let value_buffer = values.values(); + let value_offsets = values.value_offsets(); let mut builder = GenericByteViewBuilder::::with_capacity(keys.len()); builder.append_block(value_buffer.clone()); for i in keys.iter() { @@ -157,21 +138,17 @@ fn view_from_dict_values into a flattened array of type to_type -pub(crate) fn unpack_dictionary( - array: &dyn Array, +// Unpack a dictionary into a flattened array of type to_type +pub(crate) fn unpack_dictionary( + array: &DictionaryArray, to_type: &DataType, cast_options: &CastOptions, -) -> Result -where - K: ArrowDictionaryKeyType, -{ - let dict_array = array.as_dictionary::(); - let cast_dict_values = cast_with_options(dict_array.values(), to_type, cast_options)?; - take(cast_dict_values.as_ref(), dict_array.keys(), None) +) -> Result { + let cast_dict_values = cast_with_options(array.values(), to_type, cast_options)?; + take(cast_dict_values.as_ref(), array.keys(), None) } /// Pack a data type into a dictionary array passing the values through a primitive array diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 897a9153cb57..cbe86830877b 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -11866,4 +11866,60 @@ mod tests { // Verify the run-ends were cast correctly (run ends at 3, 6, 9) assert_eq!(run_array.run_ends().values(), &[3i64, 6i64, 9i64]); } + + #[test] + fn test_string_dicts_to_binary_view() { + let expected = BinaryViewArray::from_iter(vec![ + VIEW_TEST_DATA[1], + VIEW_TEST_DATA[0], + None, + VIEW_TEST_DATA[3], + None, + VIEW_TEST_DATA[1], + VIEW_TEST_DATA[4], + ]); + + let values_arrays: [ArrayRef; _] = [ + Arc::new(StringArray::from_iter(VIEW_TEST_DATA)), + Arc::new(StringViewArray::from_iter(VIEW_TEST_DATA)), + Arc::new(LargeStringArray::from_iter(VIEW_TEST_DATA)), + ]; + for values in values_arrays { + let keys = + Int8Array::from_iter([Some(1), Some(0), None, Some(3), None, Some(1), Some(4)]); + let string_dict_array = + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + let casted = cast(&string_dict_array, &DataType::BinaryView).unwrap(); + assert_eq!(casted.as_ref(), &expected); + } + } + + #[test] + fn test_binary_dicts_to_string_view() { + let expected = StringViewArray::from_iter(vec![ + VIEW_TEST_DATA[1], + VIEW_TEST_DATA[0], + None, + VIEW_TEST_DATA[3], + None, + VIEW_TEST_DATA[1], + VIEW_TEST_DATA[4], + ]); + + let values_arrays: [ArrayRef; _] = [ + Arc::new(BinaryArray::from_iter(VIEW_TEST_DATA)), + Arc::new(BinaryViewArray::from_iter(VIEW_TEST_DATA)), + Arc::new(LargeBinaryArray::from_iter(VIEW_TEST_DATA)), + ]; + for values in values_arrays { + let keys = + Int8Array::from_iter([Some(1), Some(0), None, Some(3), None, Some(1), Some(4)]); + let string_dict_array = + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + let casted = cast(&string_dict_array, &DataType::Utf8View).unwrap(); + assert_eq!(casted.as_ref(), &expected); + } + } }