diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs new file mode 100644 index 000000000000..4c47e904d78d --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{ArrayRef, AsArray}; +use datafusion_common::{internal_err, ScalarValue}; + +#[derive(Clone, Debug)] +pub(super) struct BooleanIndexMap { + true_index: i32, + false_index: i32, + else_index: i32, +} + +impl WhenLiteralIndexMap for BooleanIndexMap { + fn try_new( + unique_non_null_literals: Vec, + else_index: i32, + ) -> datafusion_common::Result + where + Self: Sized, + { + let mut true_index: Option = None; + let mut false_index: Option = None; + + for (index, literal) in unique_non_null_literals.into_iter().enumerate() { + match literal { + ScalarValue::Boolean(Some(true)) => { + if true_index.is_some() { + return internal_err!( + "Duplicate true literal found in literals for BooleanIndexMap" + ); + } + true_index = Some(index as i32); + } + ScalarValue::Boolean(Some(false)) => { + if false_index.is_some() { + return internal_err!( + "Duplicate false literal found in literals for BooleanIndexMap" + ); + } + false_index = Some(index as i32); + } + ScalarValue::Boolean(None) => { + return internal_err!( + "Null literal found in non-null literals for BooleanIndexMap" + ) + } + _ => { + return internal_err!( + "Non-boolean literal found in literals for BooleanIndexMap" + ) + } + } + } + + Ok(Self { + true_index: true_index.unwrap_or(else_index), + false_index: false_index.unwrap_or(else_index), + else_index, + }) + } + + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + Ok(array + .as_boolean() + .into_iter() + .map(|value| match value { + Some(true) => self.true_index, + Some(false) => self.false_index, + None => self.else_index, + }) + .collect::>()) + } +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs new file mode 100644 index 000000000000..10cf34ca7d95 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs @@ -0,0 +1,256 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{ + ArrayIter, ArrayRef, AsArray, FixedSizeBinaryArray, FixedSizeBinaryIter, + GenericByteArray, GenericByteViewArray, TypedDictionaryArray, +}; +use arrow::datatypes::{ArrowDictionaryKeyType, ByteArrayType, ByteViewType}; +use datafusion_common::{exec_datafusion_err, internal_err, HashMap, ScalarValue}; +use std::fmt::Debug; +use std::iter::Map; +use std::marker::PhantomData; + +/// Helper trait to convert various byte-like array types to iterator over byte slices +pub(super) trait BytesMapHelperWrapperTrait: Send + Sync { + /// Iterator over byte slices that will return + type IntoIter<'a>: Iterator> + 'a; + + /// Convert the array to an iterator over byte slices + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result>; +} + +#[derive(Debug, Clone, Default)] +pub(super) struct GenericBytesHelper(PhantomData); + +impl BytesMapHelperWrapperTrait for GenericBytesHelper { + type IntoIter<'a> = Map< + ArrayIter<&'a GenericByteArray>, + fn(Option<&'a ::Native>) -> Option<&[u8]>, + >; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array.as_bytes::().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } +} + +#[derive(Debug, Clone, Default)] +pub(super) struct FixedBinaryHelper; + +impl BytesMapHelperWrapperTrait for FixedBinaryHelper { + type IntoIter<'a> = FixedSizeBinaryIter<'a>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array.as_fixed_size_binary().into_iter()) + } +} + +#[derive(Debug, Clone, Default)] +pub(super) struct GenericBytesViewHelper(PhantomData); +impl BytesMapHelperWrapperTrait for GenericBytesViewHelper { + type IntoIter<'a> = Map< + ArrayIter<&'a GenericByteViewArray>, + fn(Option<&'a ::Native>) -> Option<&[u8]>, + >; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + Ok(array.as_byte_view::().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } +} + +#[derive(Debug, Clone, Default)] +pub(super) struct BytesDictionaryHelper( + PhantomData<(Key, Value)>, +); + +impl BytesMapHelperWrapperTrait for BytesDictionaryHelper +where + Key: ArrowDictionaryKeyType + Send + Sync, + Value: ByteArrayType, + for<'a> TypedDictionaryArray<'a, Key, GenericByteArray>: + IntoIterator>, +{ + type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array + .as_dictionary::() + .downcast_dict::>() + .ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast dictionary array {} to expected dictionary value {}", + array.data_type(), + Value::DATA_TYPE + ) + })?; + + Ok(dict_array.into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } +} + +#[derive(Debug, Clone, Default)] +pub(super) struct FixedBytesDictionaryHelper( + PhantomData, +); + +impl BytesMapHelperWrapperTrait for FixedBytesDictionaryHelper +where + Key: ArrowDictionaryKeyType + Send + Sync, + for<'a> TypedDictionaryArray<'a, Key, FixedSizeBinaryArray>: + IntoIterator>, +{ + type IntoIter<'a> = + as IntoIterator>::IntoIter; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array + .as_dictionary::() + .downcast_dict::() + .ok_or_else(|| exec_datafusion_err!( + "Failed to downcast dictionary array {} to expected dictionary fixed size binary values", + array.data_type() + ))?; + + Ok(dict_array.into_iter()) + } +} + +#[derive(Debug, Clone, Default)] +pub(super) struct BytesViewDictionaryHelper< + Key: ArrowDictionaryKeyType, + Value: ByteViewType, +>(PhantomData<(Key, Value)>); + +impl BytesMapHelperWrapperTrait for BytesViewDictionaryHelper +where + Key: ArrowDictionaryKeyType + Send + Sync, + Value: ByteViewType, + for<'a> TypedDictionaryArray<'a, Key, GenericByteViewArray>: + IntoIterator>, +{ + type IntoIter<'a> = Map<> as IntoIterator>::IntoIter, fn(Option<&'a ::Native>) -> Option<&[u8]>>; + + fn array_to_iter(array: &ArrayRef) -> datafusion_common::Result> { + let dict_array = array + .as_dictionary::() + .downcast_dict::>() + .ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast dictionary array {} to expected dictionary value {}", + array.data_type(), + Value::DATA_TYPE + ) + })?; + + Ok(dict_array.into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } +} + +/// Map from byte-like literal values to their first occurrence index +/// +/// This is a wrapper for handling different kinds of literal maps +#[derive(Clone)] +pub(super) struct BytesLikeIndexMap { + /// Map from non-null literal value the first occurrence index in the literals + map: HashMap, i32>, + + /// The index to return when no match is found + else_index: i32, + + _phantom_data: PhantomData, +} + +impl Debug for BytesLikeIndexMap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BytesMapHelper") + .field("map", &self.map) + .field("else_index", &self.else_index) + .finish() + } +} + +impl WhenLiteralIndexMap + for BytesLikeIndexMap +{ + fn try_new( + unique_non_null_literals: Vec, + else_index: i32, + ) -> datafusion_common::Result + where + Self: Sized, + { + let input = ScalarValue::iter_to_array(unique_non_null_literals)?; + + // Literals are guaranteed to not contain nulls + if input.null_count() > 0 { + return internal_err!("Literal values for WHEN clauses cannot contain nulls"); + } + + let bytes_iter = Helper::array_to_iter(&input)?; + + let map: HashMap, i32> = bytes_iter + // Flattening Option<&[u8]> to &[u8] as literals cannot contain nulls + .flatten() + .enumerate() + .map(|(map_index, value): (usize, &[u8])| (value.to_vec(), map_index as i32)) + // Because literals are unique we can collect directly, and we can avoid only inserting the first occurrence + .collect(); + + Ok(Self { + map, + else_index, + _phantom_data: Default::default(), + }) + } + + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + let bytes_iter = Helper::array_to_iter(array)?; + let indices = bytes_iter + .map(|value| match value { + Some(value) => self.map.get(value).copied().unwrap_or(self.else_index), + None => self.else_index, + }) + .collect::>(); + + Ok(indices) + } +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs new file mode 100644 index 000000000000..b7e3d63954bf --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -0,0 +1,440 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod boolean_lookup_table; +mod bytes_like_lookup_table; +mod primitive_lookup_table; + +use crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap; +use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::{ + BytesDictionaryHelper, BytesLikeIndexMap, BytesViewDictionaryHelper, + FixedBinaryHelper, FixedBytesDictionaryHelper, GenericBytesHelper, + GenericBytesViewHelper, +}; +use crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveArrayMapHolder; +use crate::expressions::case::WhenThen; +use crate::expressions::Literal; +use arrow::array::{downcast_integer, downcast_primitive, ArrayRef, Int32Array}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, BinaryViewType, DataType, GenericBinaryType, + GenericStringType, StringViewType, +}; +use datafusion_common::DataFusionError; +use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use indexmap::IndexMap; +use std::fmt::Debug; +use std::sync::Arc; + +/// Optimization for CASE expressions with literal WHEN and THEN clauses +/// +/// for this form: +/// ```sql +/// CASE +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// ELSE +/// END +/// ``` +/// +/// # Improvement idea +/// TODO - we should think of unwrapping the `IN` expressions into multiple equality comparisons +/// so it will use this optimization as well, e.g. +/// ```sql +/// -- Before +/// CASE +/// WHEN ( = ) THEN +/// WHEN ( in (, ) THEN +/// WHEN ( = ) THEN +/// ELSE +/// +/// -- After +/// CASE +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// ELSE +/// END +/// ``` +/// +#[derive(Debug)] +pub(in super::super) struct LiteralLookupTable { + /// The lookup table to use for evaluating the CASE expression + lookup: Arc, + + /// [`ArrayRef`] where `array[i] = then_literals[i]` + /// the last value in the array is the else_expr + values_to_take_from: ArrayRef, +} + +impl LiteralLookupTable { + pub(in super::super) fn maybe_new( + when_then_expr: &Vec, + else_expr: &Option>, + ) -> Option { + // We can't use the optimization if we don't have any when then pairs + if when_then_expr.is_empty() { + return None; + } + + // If we only have 1 than this optimization is not useful + if when_then_expr.len() == 1 { + return None; + } + + // Try to downcast all the WHEN/THEN expressions to literals + let when_then_exprs_maybe_literals = when_then_expr + .iter() + .map(|(when, then)| { + let when_maybe_literal = when.as_any().downcast_ref::(); + let then_maybe_literal = then.as_any().downcast_ref::(); + + when_maybe_literal.zip(then_maybe_literal) + }) + .collect::>(); + + // If not all the WHEN/THEN expressions are literals we cannot use this optimization + if when_then_exprs_maybe_literals.contains(&None) { + return None; + } + + let when_then_exprs_scalars = when_then_exprs_maybe_literals + .into_iter() + // Unwrap the options as we have already checked there is no None + .flatten() + .map(|(when_lit, then_lit)| { + (when_lit.value().clone(), then_lit.value().clone()) + }) + // Only keep non-null WHEN literals + // as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE + .filter(|(when_lit, _)| !when_lit.is_null()) + .collect::>(); + + if when_then_exprs_scalars.is_empty() { + // All WHEN literals were nulls, so cannot use optimization + // + // instead, another optimization would be to go straight to the ELSE clause + return None; + } + + // Keep only the first occurrence of each when literal (as the first match is used) + // and remove nulls (as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE) + let (when_literals, then_literals): (Vec, Vec) = { + let mut map = IndexMap::with_capacity(when_then_expr.len()); + + for (when, then) in when_then_exprs_scalars.into_iter() { + // Don't overwrite existing entries as we want to keep the first occurrence + if !map.contains_key(&when) { + map.insert(when, then); + } + } + + map.into_iter().unzip() + }; + + let else_expr: ScalarValue = if let Some(else_expr) = else_expr { + let literal = else_expr.as_any().downcast_ref::()?; + + literal.value().clone() + } else { + let Ok(null_scalar) = + ScalarValue::try_new_null(&then_literals[0].data_type()) + else { + return None; + }; + + null_scalar + }; + + { + let data_type = when_literals[0].data_type(); + + // If not all the WHEN literals are the same data type we cannot use this optimization + if when_literals.iter().any(|l| l.data_type() != data_type) { + return None; + } + } + + { + let data_type = then_literals[0].data_type(); + + // If not all the then and the else literals are the same data type we cannot use this optimization + if then_literals.iter().any(|l| l.data_type() != data_type) { + return None; + } + + if else_expr.data_type() != data_type { + return None; + } + } + + let output_array = ScalarValue::iter_to_array( + then_literals + .iter() + // The else is in the end + .chain(std::iter::once(&else_expr)) + .cloned(), + ) + .ok()?; + + let lookup = try_creating_lookup_table( + when_literals, + // The else expression is in the end + output_array.len() as i32 - 1, + ) + .ok()?; + + Some(Self { + lookup, + values_to_take_from: output_array, + }) + } + + pub(in super::super) fn create_output( + &self, + expr_array: &ArrayRef, + ) -> datafusion_common::Result { + let take_indices = self.lookup.match_values(expr_array)?; + + // Zero-copy conversion + let take_indices = Int32Array::from(take_indices); + + // An optimize version would depend on the type of the values_to_take_from + // For example, if the type is view we can just keep pointing to the same value (similar to dictionary) + // if the type is dictionary we can just use the indices as is (or cast them to the key type) and create a new dictionary array + let output = arrow::compute::take(&self.values_to_take_from, &take_indices, None) + .map_err(|e| arrow_datafusion_err!(e))?; + + Ok(output) + } +} + +/// Lookup table for mapping literal values to their corresponding indices in the THEN clauses +/// +/// The else index is used when a value is not found in the lookup table +pub(super) trait WhenLiteralIndexMap: Debug + Send + Sync { + /// Try creating a new lookup table from the given literals and else index + /// + /// `literals` are guaranteed to be unique and non-nullable + fn try_new( + unique_non_null_literals: Vec, + else_index: i32, + ) -> datafusion_common::Result + where + Self: Sized; + + /// Return indices to take from the literals based on the values in the given array + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result>; +} + +pub(crate) fn try_creating_lookup_table( + unique_non_null_literals: Vec, + else_index: i32, +) -> datafusion_common::Result> { + assert_ne!( + unique_non_null_literals.len(), + 0, + "Must have at least one literal" + ); + match unique_non_null_literals[0].data_type() { + DataType::Boolean => { + let lookup_table = + BooleanIndexMap::try_new(unique_non_null_literals, else_index)?; + Ok(Arc::new(lookup_table)) + } + + data_type if data_type.is_primitive() => { + macro_rules! create_matching_map { + ($t:ty) => {{ + let lookup_table = PrimitiveArrayMapHolder::<$t>::try_new( + unique_non_null_literals, + else_index, + )?; + Ok(Arc::new(lookup_table)) + }}; + } + + downcast_primitive! { + data_type => (create_matching_map), + _ => Err(plan_datafusion_err!( + "Unsupported field type for primitive: {:?}", + data_type + )), + } + } + + DataType::Utf8 => { + let lookup_table = BytesLikeIndexMap::< + GenericBytesHelper>, + >::try_new( + unique_non_null_literals, else_index + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeUtf8 => { + let lookup_table = BytesLikeIndexMap::< + GenericBytesHelper>, + >::try_new( + unique_non_null_literals, else_index + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::Binary => { + let lookup_table = BytesLikeIndexMap::< + GenericBytesHelper>, + >::try_new( + unique_non_null_literals, else_index + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeBinary => { + let lookup_table = BytesLikeIndexMap::< + GenericBytesHelper>, + >::try_new( + unique_non_null_literals, else_index + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::FixedSizeBinary(_) => { + let lookup_table = BytesLikeIndexMap::::try_new( + unique_non_null_literals, + else_index, + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::Utf8View => { + let lookup_table = + BytesLikeIndexMap::>::try_new( + unique_non_null_literals, + else_index, + )?; + Ok(Arc::new(lookup_table)) + } + DataType::BinaryView => { + let lookup_table = + BytesLikeIndexMap::>::try_new( + unique_non_null_literals, + else_index, + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::Dictionary(key, value) => { + macro_rules! downcast_dictionary_array_helper { + ($t:ty) => {{ + create_lookup_table_for_dictionary_input::<$t>( + value.as_ref(), + unique_non_null_literals, + else_index, + ) + }}; + } + + downcast_integer! { + key.as_ref() => (downcast_dictionary_array_helper), + k => unreachable!("unsupported dictionary key type: {}", k) + } + } + _ => Err(plan_datafusion_err!( + "Unsupported data type for lookup table: {}", + unique_non_null_literals[0].data_type() + )), + } +} + +fn create_lookup_table_for_dictionary_input( + value: &DataType, + unique_non_null_literals: Vec, + else_index: i32, +) -> datafusion_common::Result> { + // TODO - optimize dictionary to use different wrapper that takes advantage of it being a dictionary + match value { + DataType::Utf8 => { + let lookup_table = BytesLikeIndexMap::< + BytesDictionaryHelper>, + >::try_new( + unique_non_null_literals, else_index + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeUtf8 => { + let lookup_table = BytesLikeIndexMap::< + BytesDictionaryHelper>, + >::try_new( + unique_non_null_literals, else_index + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::Binary => { + let lookup_table = BytesLikeIndexMap::< + BytesDictionaryHelper>, + >::try_new( + unique_non_null_literals, else_index + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::LargeBinary => { + let lookup_table = BytesLikeIndexMap::< + BytesDictionaryHelper>, + >::try_new( + unique_non_null_literals, else_index + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::FixedSizeBinary(_) => { + let lookup_table = + BytesLikeIndexMap::>::try_new( + unique_non_null_literals, + else_index, + )?; + Ok(Arc::new(lookup_table)) + } + + DataType::Utf8View => { + let lookup_table = BytesLikeIndexMap::< + BytesViewDictionaryHelper, + >::try_new( + unique_non_null_literals, else_index + )?; + Ok(Arc::new(lookup_table)) + } + DataType::BinaryView => { + let lookup_table = BytesLikeIndexMap::< + BytesViewDictionaryHelper, + >::try_new( + unique_non_null_literals, else_index + )?; + Ok(Arc::new(lookup_table)) + } + _ => Err(plan_datafusion_err!( + "Unsupported dictionary value type for lookup table: {}", + value + )), + } +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs new file mode 100644 index 000000000000..bc466b1a4dbe --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray}; +use arrow::datatypes::{i256, IntervalDayTime, IntervalMonthDayNano}; +use datafusion_common::{internal_err, HashMap, ScalarValue}; +use half::f16; +use std::fmt::Debug; +use std::hash::Hash; + +#[derive(Clone)] +pub(super) struct PrimitiveArrayMapHolder +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + /// Literal value to map index + /// + /// If searching this map becomes a bottleneck consider using linear map implementations for small hashmaps + map: HashMap<::HashableKey, i32>, + else_index: i32, +} + +impl Debug for PrimitiveArrayMapHolder +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrimitiveArrayMapHolder") + .field("map", &self.map) + .field("else_index", &self.else_index) + .finish() + } +} + +impl WhenLiteralIndexMap for PrimitiveArrayMapHolder +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + fn try_new( + unique_non_null_literals: Vec, + else_index: i32, + ) -> datafusion_common::Result + where + Self: Sized, + { + let input = ScalarValue::iter_to_array(unique_non_null_literals)?; + + // Literals are guaranteed to not contain nulls + if input.null_count() > 0 { + return internal_err!("Literal values for WHEN clauses cannot contain nulls"); + } + + let map = input + .as_primitive::() + .values() + .iter() + .enumerate() + // Because literals are unique we can collect directly, and we can avoid only inserting the first occurrence + .map(|(map_index, value)| (value.into_hashable_key(), map_index as i32)) + .collect(); + + Ok(Self { map, else_index }) + } + + fn match_values(&self, array: &ArrayRef) -> datafusion_common::Result> { + let indices = array + .as_primitive::() + .into_iter() + .map(|value| match value { + Some(value) => self + .map + .get(&value.into_hashable_key()) + .copied() + .unwrap_or(self.else_index), + + None => self.else_index, + }) + .collect::>(); + + Ok(indices) + } +} + +// TODO - We need to port it to arrow so that it can be reused in other places + +/// Trait that help convert a value to a key that is hashable and equatable +/// This is needed as some types like f16/f32/f64 do not implement Hash/Eq directly +pub(super) trait ToHashableKey: ArrowNativeTypeOp { + /// The type that is hashable and equatable + /// It must be an Arrow native type but it NOT GUARANTEED to be the same as Self + /// this is just a helper trait so you can reuse the same code for all arrow native types + type HashableKey: Hash + Eq + Debug + Clone + Copy + Send + Sync; + + /// Converts self to a hashable key + /// the result of this value can be used as the key in hash maps/sets + fn into_hashable_key(self) -> Self::HashableKey; +} + +macro_rules! impl_to_hashable_key { + (@single_already_hashable | $t:ty) => { + impl ToHashableKey for $t { + type HashableKey = $t; + + #[inline] + fn into_hashable_key(self) -> Self::HashableKey { + self + } + } + }; + (@already_hashable | $($t:ty),+ $(,)?) => { + $( + impl_to_hashable_key!(@single_already_hashable | $t); + )+ + }; + (@float | $t:ty => $hashable:ty) => { + impl ToHashableKey for $t { + type HashableKey = $hashable; + + #[inline] + fn into_hashable_key(self) -> Self::HashableKey { + self.to_bits() + } + } + }; +} + +impl_to_hashable_key!(@already_hashable | i8, i16, i32, i64, i128, i256, u8, u16, u32, u64, IntervalDayTime, IntervalMonthDayNano); +impl_to_hashable_key!(@float | f16 => u16); +impl_to_hashable_key!(@float | f32 => u32); +impl_to_hashable_key!(@float | f64 => u64); + +#[cfg(test)] +mod tests { + use super::ToHashableKey; + use arrow::array::downcast_primitive; + + // This test ensure that all arrow primitive types implement ToHashableKey + // otherwise the code will not compile + #[test] + fn should_implement_to_hashable_key_for_all_primitives() { + #[derive(Debug, Default)] + struct ExampleSet + where + T: arrow::datatypes::ArrowPrimitiveType, + T::Native: ToHashableKey, + { + _map: std::collections::HashSet<::HashableKey>, + } + + macro_rules! create_matching_set { + ($t:ty) => {{ + let _lookup_table = ExampleSet::<$t> { + _map: Default::default(), + }; + + return; + }}; + } + + let data_type = arrow::datatypes::DataType::Float16; + + downcast_primitive! { + data_type => (create_matching_set), + _ => panic!("not implemented for {data_type}"), + } + } +} diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case/mod.rs similarity index 69% rename from datafusion/physical-expr/src/expressions/case.rs rename to datafusion/physical-expr/src/expressions/case/mod.rs index 0b4c3af1d9c5..037f9eb5034e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case/mod.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use super::{Column, Literal}; -use crate::expressions::case::ResultState::{Complete, Empty, Partial}; -use crate::expressions::try_cast; +mod literal_lookup_table; + +use crate::expressions::{try_cast, Column, Literal}; use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::zip::zip; @@ -31,14 +31,16 @@ use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr_common::datum::compare_with_eq; -use itertools::Itertools; use std::borrow::Cow; use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::{any::Any, sync::Arc}; -type WhenThen = (Arc, Arc); +use crate::expressions::case::literal_lookup_table::LiteralLookupTable; +use datafusion_physical_expr_common::datum::compare_with_eq; +use itertools::Itertools; + +pub(super) type WhenThen = (Arc, Arc); #[derive(Debug, Hash, PartialEq, Eq)] enum EvalMethod { @@ -69,8 +71,27 @@ enum EvalMethod { /// /// CASE WHEN condition THEN expression ELSE expression END ExpressionOrExpression, + + /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals + /// + /// See [`LiteralLookupTable`] for more details + WithExprScalarLookupTable(LiteralLookupTable), +} + +// Implement empty hash as the data is derived from PhysicalExprs which are already hashed +impl Hash for LiteralLookupTable { + fn hash(&self, _state: &mut H) {} } +// Implement always equal as the data is derived from PhysicalExprs which are already compared +impl PartialEq for LiteralLookupTable { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for LiteralLookupTable {} + /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. @@ -378,7 +399,7 @@ impl ResultBuilder { Self { data_type: data_type.clone(), row_count, - state: Empty, + state: ResultState::Empty, } } @@ -457,21 +478,21 @@ impl ResultBuilder { } match &mut self.state { - Empty => { + ResultState::Empty => { let array_index = PartialResultIndex::zero(); let mut indices = vec![PartialResultIndex::none(); self.row_count]; for row_ix in row_indices.as_primitive::().values().iter() { indices[*row_ix as usize] = array_index; } - self.state = Partial { + self.state = ResultState::Partial { arrays: vec![row_values], indices, }; Ok(()) } - Partial { arrays, indices } => { + ResultState::Partial { arrays, indices } => { let array_index = PartialResultIndex::try_new(arrays.len())?; arrays.push(row_values); @@ -489,7 +510,7 @@ impl ResultBuilder { } Ok(()) } - Complete(_) => internal_err!( + ResultState::Complete(_) => internal_err!( "Cannot add a partial result when complete result is already set" ), } @@ -502,23 +523,23 @@ impl ResultBuilder { /// without any merging overhead. fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> { match &self.state { - Empty => { - self.state = Complete(value); + ResultState::Empty => { + self.state = ResultState::Complete(value); Ok(()) } - Partial { .. } => { + ResultState::Partial { .. } => { internal_err!( "Cannot set a complete result when there are already partial results" ) } - Complete(_) => internal_err!("Complete result already set"), + ResultState::Complete(_) => internal_err!("Complete result already set"), } } /// Finishes building the result and returns the final array. fn finish(self) -> Result { match self.state { - Empty => { + ResultState::Empty => { // No complete result and no partial results. // This can happen for case expressions with no else branch where no rows // matched. @@ -526,11 +547,11 @@ impl ResultBuilder { &self.data_type, )?)) } - Partial { arrays, indices } => { + ResultState::Partial { arrays, indices } => { // Merge partial results into a single array. Ok(ColumnarValue::Array(merge(&arrays, &indices)?)) } - Complete(v) => { + ResultState::Complete(v) => { // If we have a complete result, we can just return it. Ok(v) } @@ -558,24 +579,8 @@ impl CaseExpr { if when_then_expr.is_empty() { exec_err!("There must be at least one WHEN clause") } else { - let eval_method = if expr.is_some() { - EvalMethod::WithExpression - } else if when_then_expr.len() == 1 - && is_cheap_and_infallible(&(when_then_expr[0].1)) - && else_expr.is_none() - { - EvalMethod::InfallibleExprOrNull - } else if when_then_expr.len() == 1 - && when_then_expr[0].1.as_any().is::() - && else_expr.is_some() - && else_expr.as_ref().unwrap().as_any().is::() - { - EvalMethod::ScalarOrScalar - } else if when_then_expr.len() == 1 && else_expr.is_some() { - EvalMethod::ExpressionOrExpression - } else { - EvalMethod::NoExpression - }; + let eval_method = + Self::find_best_eval_method(&expr, &when_then_expr, &else_expr); Ok(Self { expr, @@ -586,6 +591,39 @@ impl CaseExpr { } } + fn find_best_eval_method( + expr: &Option>, + when_then_expr: &Vec, + else_expr: &Option>, + ) -> EvalMethod { + if expr.is_some() { + if let Some(mapping) = + LiteralLookupTable::maybe_new(when_then_expr, else_expr) + { + return EvalMethod::WithExprScalarLookupTable(mapping); + } + + return EvalMethod::WithExpression; + } + + if when_then_expr.len() == 1 + && is_cheap_and_infallible(&(when_then_expr[0].1)) + && else_expr.is_none() + { + EvalMethod::InfallibleExprOrNull + } else if when_then_expr.len() == 1 + && when_then_expr[0].1.as_any().is::() + && else_expr.is_some() + && else_expr.as_ref().unwrap().as_any().is::() + { + EvalMethod::ScalarOrScalar + } else if when_then_expr.len() == 1 && else_expr.is_some() { + EvalMethod::ExpressionOrExpression + } else { + EvalMethod::NoExpression + } + } + /// Optional base expression that can be compared to literal values in the "when" expressions pub fn expr(&self) -> Option<&Arc> { self.expr.as_ref() @@ -967,6 +1005,28 @@ impl CaseExpr { Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) } + + fn with_lookup_table( + &self, + batch: &RecordBatch, + scalars_or_null_lookup: &LiteralLookupTable, + ) -> Result { + let expr = self.expr.as_ref().unwrap(); + let evaluated_expression = expr.evaluate(batch)?; + + let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_)); + let evaluated_expression = evaluated_expression.to_array(1)?; + + let output = scalars_or_null_lookup.create_output(&evaluated_expression)?; + + let result = if is_scalar { + ColumnarValue::Scalar(ScalarValue::try_from_array(output.as_ref(), 0)?) + } else { + ColumnarValue::Array(output) + }; + + Ok(result) + } } impl PhysicalExpr for CaseExpr { @@ -1031,6 +1091,9 @@ impl PhysicalExpr for CaseExpr { } EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch), + EvalMethod::WithExprScalarLookupTable(ref e) => { + self.with_lookup_table(batch, e) + } } } @@ -1123,13 +1186,14 @@ mod tests { use crate::expressions::{binary, cast, col, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; - use arrow::datatypes::Field; + use arrow::datatypes::{Field, Int32Type}; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::Operator; use datafusion_physical_expr_common::physical_expr::fmt_sql; + use half::f16; #[test] fn case_with_expr() -> Result<()> { @@ -1954,4 +2018,805 @@ mod tests { Ok(()) } + + // Test Lookup evaluation + + enum AssertLookupEvaluation { + Used, + NotUsed, + } + + fn test_case_when_literal_lookup( + values: ArrayRef, + lookup_map: &[(ScalarValue, ScalarValue)], + else_value: Option, + expected: ArrayRef, + assert_lookup_evaluation: AssertLookupEvaluation, + ) { + // Create lookup + // CASE + // WHEN THEN + // WHEN THEN + // [ ELSE ] + + let schema = Schema::new(vec![Field::new( + "a", + values.data_type().clone(), + values.is_nullable(), + )]); + let schema = Arc::new(schema); + + let batch = RecordBatch::try_new(schema, vec![values]) + .expect("failed to create RecordBatch"); + + let schema = batch.schema_ref(); + let case = col("a", schema).expect("failed to create col"); + + let when_then = lookup_map + .iter() + .map(|(when, then)| { + ( + Arc::new(Literal::new(when.clone())) as _, + Arc::new(Literal::new(then.clone())) as _, + ) + }) + .collect::>(); + + let else_expr = else_value.map(|else_value| { + Arc::new(Literal::new(else_value)) as Arc + }); + let expr = CaseExpr::try_new(Some(case), when_then, else_expr) + .expect("failed to create case"); + + // Assert that we are testing what we intend to assert + match assert_lookup_evaluation { + AssertLookupEvaluation::Used => { + assert!( + matches!(expr.eval_method, EvalMethod::WithExprScalarLookupTable(_)), + "we should use the expected eval method" + ); + } + AssertLookupEvaluation::NotUsed => { + assert!( + !matches!(expr.eval_method, EvalMethod::WithExprScalarLookupTable(_)), + "we should not use lookup evaluation method" + ); + } + } + + let actual = expr + .evaluate(&batch) + .expect("failed to evaluate case") + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + + assert_eq!( + actual.data_type(), + expected.data_type(), + "Data type mismatch" + ); + + assert_eq!( + actual.as_ref(), + expected.as_ref(), + "actual (left) does not match expected (right)" + ); + } + + fn create_lookup( + when_then_pairs: impl IntoIterator, + ) -> Vec<(ScalarValue, ScalarValue)> + where + ScalarValue: From, + ScalarValue: From, + { + when_then_pairs + .into_iter() + .map(|(when, then)| (ScalarValue::from(when), ScalarValue::from(then))) + .collect() + } + + fn create_input_and_expected( + input_and_expected_pairs: impl IntoIterator, + ) -> (Input, Expected) + where + Input: Array + From>, + Expected: Array + From>, + { + let (input_items, expected_items): (Vec, Vec) = + input_and_expected_pairs.into_iter().unzip(); + + (Input::from(input_items), Expected::from(expected_items)) + } + + fn test_lookup_eval_with_and_without_else( + lookup_map: &[(ScalarValue, ScalarValue)], + input_values: ArrayRef, + expected: StringArray, + ) { + // Testing without ELSE should fallback to None + test_case_when_literal_lookup( + Arc::clone(&input_values), + lookup_map, + None, + Arc::new(expected.clone()), + AssertLookupEvaluation::Used, + ); + + // Testing with Else + let else_value = "___fallback___"; + + // Changing each expected None to be fallback + let expected_with_else = expected + .iter() + .map(|item| item.unwrap_or(else_value)) + .map(Some) + .collect::(); + + // Test case + test_case_when_literal_lookup( + input_values, + lookup_map, + Some(ScalarValue::Utf8(Some(else_value.to_string()))), + Arc::new(expected_with_else), + AssertLookupEvaluation::Used, + ); + } + + #[test] + fn test_case_when_literal_lookup_int32_to_string() { + let lookup_map = create_lookup([ + (Some(4), Some("four")), + (Some(2), Some("two")), + (Some(3), Some("three")), + (Some(1), Some("one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, Some("one")), + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_none_case_should_never_match() { + let lookup_map = create_lookup([ + (Some(4), Some("four")), + (None, Some("none")), + (Some(2), Some("two")), + (Some(1), Some("one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (Some(1), Some("one")), + (Some(5), None), // No match in WHEN + (None, None), // None cases are never match in CASE WHEN syntax + (Some(2), Some("two")), + (None, None), // None cases are never match in CASE WHEN syntax + (None, None), // None cases are never match in CASE WHEN syntax + (Some(2), Some("two")), + (Some(5), None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_int32_to_string_with_duplicate_cases() { + let lookup_map = create_lookup([ + (Some(4), Some("four")), + (Some(4), Some("no 4")), + (Some(2), Some("two")), + (Some(2), Some("no 2")), + (Some(3), Some("three")), + (Some(3), Some("no 3")), + (Some(2), Some("no 2")), + (Some(4), Some("no 4")), + (Some(2), Some("no 2")), + (Some(3), Some("no 3")), + (Some(4), Some("no 4")), + (Some(2), Some("no 2")), + (Some(3), Some("no 3")), + (Some(3), Some("no 3")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, None), // No match in WHEN + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f32_to_string_with_special_values_and_duplicate_cases( + ) { + let lookup_map = create_lookup([ + (Some(4.0), Some("four point zero")), + (Some(f32::NAN), Some("NaN")), + (Some(3.2), Some("three point two")), + // Duplicate case to make sure it is not used + (Some(f32::NAN), Some("should not use this NaN branch")), + (Some(f32::INFINITY), Some("Infinity")), + (Some(0.0), Some("zero")), + // Duplicate case to make sure it is not used + ( + Some(f32::INFINITY), + Some("should not use this Infinity branch"), + ), + (Some(1.1), Some("one point one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1.1, Some("one point one")), + (f32::NAN, Some("NaN")), + (3.2, Some("three point two")), + (3.2, Some("three point two")), + (0.0, Some("zero")), + (f32::INFINITY, Some("Infinity")), + (3.2, Some("three point two")), + (f32::NEG_INFINITY, None), // No match in WHEN + (f32::NEG_INFINITY, None), // No match in WHEN + (3.2, Some("three point two")), + (-0.0, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f16_to_string_with_special_values() { + let lookup_map = create_lookup([ + ( + ScalarValue::Float16(Some(f16::from_f32(3.2))), + Some("3 dot 2"), + ), + (ScalarValue::Float16(Some(f16::NAN)), Some("NaN")), + ( + ScalarValue::Float16(Some(f16::from_f32(17.4))), + Some("17 dot 4"), + ), + (ScalarValue::Float16(Some(f16::INFINITY)), Some("Infinity")), + (ScalarValue::Float16(Some(f16::ZERO)), Some("zero")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (f16::from_f32(3.2), Some("3 dot 2")), + (f16::NAN, Some("NaN")), + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::INFINITY, Some("Infinity")), + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::NEG_INFINITY, None), // No match in WHEN + (f16::NEG_INFINITY, None), // No match in WHEN + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::NEG_ZERO, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f32_to_string_with_special_values() { + let lookup_map = create_lookup([ + (3.2, Some("3 dot 2")), + (f32::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (f32::INFINITY, Some("Infinity")), + (f32::ZERO, Some("zero")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (3.2, Some("3 dot 2")), + (f32::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (17.4, Some("17 dot 4")), + (f32::INFINITY, Some("Infinity")), + (17.4, Some("17 dot 4")), + (f32::NEG_INFINITY, None), // No match in WHEN + (f32::NEG_INFINITY, None), // No match in WHEN + (17.4, Some("17 dot 4")), + (-0.0, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f64_to_string_with_special_values() { + let lookup_map = create_lookup([ + (3.2, Some("3 dot 2")), + (f64::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (f64::INFINITY, Some("Infinity")), + (f64::ZERO, Some("zero")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (3.2, Some("3 dot 2")), + (f64::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (17.4, Some("17 dot 4")), + (f64::INFINITY, Some("Infinity")), + (17.4, Some("17 dot 4")), + (f64::NEG_INFINITY, None), // No match in WHEN + (f64::NEG_INFINITY, None), // No match in WHEN + (17.4, Some("17 dot 4")), + (-0.0, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + // Test that we don't lose the decimal precision and scale info + #[test] + fn test_decimal_with_non_default_precision_and_scale() { + let lookup_map = create_lookup([ + (ScalarValue::Decimal32(Some(4), 3, 2), Some("four")), + (ScalarValue::Decimal32(Some(2), 3, 2), Some("two")), + (ScalarValue::Decimal32(Some(3), 3, 2), Some("three")), + (ScalarValue::Decimal32(Some(1), 3, 2), Some("one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, Some("one")), + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + let input_values = input_values + .with_precision_and_scale(3, 2) + .expect("must be able to set precision and scale"); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + // Test that we don't lose the timezone info + #[test] + fn test_timestamp_with_non_default_timezone() { + let timezone: Option> = Some("-10:00".into()); + let lookup_map = create_lookup([ + ( + ScalarValue::TimestampMillisecond(Some(4), timezone.clone()), + Some("four"), + ), + ( + ScalarValue::TimestampMillisecond(Some(2), timezone.clone()), + Some("two"), + ), + ( + ScalarValue::TimestampMillisecond(Some(3), timezone.clone()), + Some("three"), + ), + ( + ScalarValue::TimestampMillisecond(Some(1), timezone.clone()), + Some("one"), + ), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, Some("one")), + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + let input_values = input_values.with_timezone_opt(timezone); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_with_strings_to_int32() { + let lookup_map = create_lookup([ + (Some("why"), Some(42)), + (Some("what"), Some(22)), + (Some("when"), Some(17)), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (Some("why"), Some(42)), + (Some("5"), None), // No match in WHEN + (None, None), // None cases are never match in CASE WHEN syntax + (Some("what"), Some(22)), + (None, None), // None cases are never match in CASE WHEN syntax + (None, None), // None cases are never match in CASE WHEN syntax + (Some("what"), Some(22)), + (Some("5"), None), // No match in WHEN + ]); + + let input_values = Arc::new(input_values) as ArrayRef; + + // Testing without ELSE should fallback to None + test_case_when_literal_lookup( + Arc::clone(&input_values), + &lookup_map, + None, + Arc::new(expected.clone()), + AssertLookupEvaluation::Used, + ); + + // Testing with Else + let else_value = 101; + + // Changing each expected None to be fallback + let expected_with_else = expected + .iter() + .map(|item| item.unwrap_or(else_value)) + .map(Some) + .collect::(); + + // Test case + test_case_when_literal_lookup( + input_values, + &lookup_map, + Some(ScalarValue::Int32(Some(else_value))), + Arc::new(expected_with_else), + AssertLookupEvaluation::Used, + ); + } + + #[test] + fn test_with_bytes_to_string() { + test_string_casted_to_string(DataType::Binary); + } + + #[test] + fn test_with_large_bytes_to_string() { + test_string_casted_to_string(DataType::LargeBinary); + } + + #[test] + fn test_with_fixed_size_bytes_to_string() { + test_fixed_binary_casted_to_string(DataType::FixedSizeBinary(3)); + } + + #[test] + fn test_with_string_view_to_string() { + test_string_casted_to_string(DataType::Utf8View); + } + + #[test] + fn test_with_binary_view_to_string() { + test_string_casted_to_string(DataType::BinaryView); + } + + fn test_string_casted_to_string(input_data_type: DataType) { + let mut lookup_map = create_lookup([ + (Some("why".to_string()), Some("one")), + (Some("what".to_string()), Some("two")), + (Some("when".to_string()), Some("three")), + ]); + + // Cast all when to the input data type + lookup_map.iter_mut().for_each(|(when, _)| { + *when = when + .cast_to(&input_data_type) + .expect("should be able to cast"); + }); + + let (input_values, expected) = + create_input_and_expected::([ + (Some("why"), Some("one")), + (Some("5"), None), // No match in WHEN + (None, None), // None cases are never match in CASE WHEN syntax + (Some("what"), Some("two")), + (None, None), // None cases are never match in CASE WHEN syntax + (None, None), // None cases are never match in CASE WHEN syntax + (Some("what"), Some("two")), + (Some("5"), None), // No match in WHEN + ]); + + let input_values = arrow::compute::cast(&input_values, &input_data_type) + .expect("should be able to cast"); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + fn test_fixed_binary_casted_to_string(input_data_type: DataType) { + let mut lookup_map = create_lookup([ + ( + ScalarValue::FixedSizeBinary(3, Some("why".as_bytes().to_vec())), + Some("one"), + ), + ( + ScalarValue::FixedSizeBinary(3, Some("dad".as_bytes().to_vec())), + Some("two"), + ), + ( + ScalarValue::FixedSizeBinary(3, Some("mom".as_bytes().to_vec())), + Some("three"), + ), + ]); + + // Cast all when to the input data type + lookup_map.iter_mut().for_each(|(when, _)| { + *when = when + .cast_to(&input_data_type) + .expect("should be able to cast"); + }); + + let (input_values, expected) = + create_input_and_expected::([ + (Some(b"why" as &[u8]), Some("one")), + (Some(b"555"), None), // No match in WHEN + (None, None), // None cases are never match in CASE WHEN syntax + (Some(b"dad"), Some("two")), + (None, None), // None cases are never match in CASE WHEN syntax + (None, None), // None cases are never match in CASE WHEN syntax + (Some(b"dad"), Some("two")), + (Some(b"555"), None), // No match in WHEN + ]); + + let input_values = arrow::compute::cast(&input_values, &input_data_type) + .expect("should be able to cast"); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_large_string_to_string() { + test_string_casted_to_string(DataType::LargeUtf8); + } + + #[test] + fn test_different_dictionary_keys_with_string_values_to_string() { + for key_type in [ + DataType::Int8, + DataType::UInt8, + DataType::Int16, + DataType::UInt16, + DataType::Int32, + DataType::UInt32, + DataType::Int64, + DataType::UInt64, + ] { + test_string_casted_to_string(DataType::Dictionary( + Box::new(key_type), + Box::new(DataType::Utf8), + )); + } + } + + #[test] + fn test_string_like_dictionary_to_string() { + for data_type in [DataType::Utf8, DataType::LargeUtf8] { + test_string_casted_to_string( + // test int + DataType::Dictionary(Box::new(DataType::Int32), Box::new(data_type)), + ); + } + } + + #[test] + fn test_binary_like_dictionary_to_string() { + for data_type in [ + DataType::Binary, + DataType::LargeBinary, + DataType::FixedSizeBinary(3), + ] { + test_fixed_binary_casted_to_string( + // test int + DataType::Dictionary(Box::new(DataType::Int32), Box::new(data_type)), + ); + } + } + + fn test_dictionary_view_value_to_string(input_data_type: DataType) { + /// Because casting From String to Dictionary or to Dictionary + /// we need to do manual casting + fn cast_to_dictionary_of_view( + array_to_cast: &dyn Array, + input_data_type: &DataType, + ) -> ArrayRef { + let string_dictionary = arrow::compute::cast( + array_to_cast, + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + ) + .expect("should be able to cast"); + + let string_dictionary = string_dictionary.as_dictionary::(); + let (keys, values) = string_dictionary.clone().into_parts(); + let dictionary_values_casted = arrow::compute::cast(&values, input_data_type) + .expect("should be able to cast"); + + let final_dictionary_array = + DictionaryArray::new(keys, dictionary_values_casted); + + Arc::new(final_dictionary_array) + } + + let mut lookup_map = create_lookup([ + (Some("why".to_string()), Some("one")), + (Some("what".to_string()), Some("two")), + (Some("when".to_string()), Some("three")), + ]); + + // Cast all when to the input data type + lookup_map.iter_mut().for_each(|(when, _)| { + // First cast to dictionary of string + let when_array = when + .to_array_of_size(1) + .expect("should be able to convert scalar to array"); + let casted_array = cast_to_dictionary_of_view(&when_array, &input_data_type); + *when = ScalarValue::try_from_array(&casted_array, 0) + .expect("should be able to convert array to scalar"); + }); + + let (input_values, expected) = + create_input_and_expected::([ + (Some("why"), Some("one")), + (Some("5"), None), // No match in WHEN + (None, None), // None cases are never match in CASE WHEN syntax + (Some("what"), Some("two")), + (None, None), // None cases are never match in CASE WHEN syntax + (None, None), // None cases are never match in CASE WHEN syntax + (Some("what"), Some("two")), + (Some("5"), None), // No match in WHEN + ]); + + let input_values = cast_to_dictionary_of_view(&input_values, &input_data_type); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_string_and_binary_view_dictionary_to_string() { + for data_type in [DataType::Utf8View, DataType::BinaryView] { + test_dictionary_view_value_to_string(data_type); + } + } + + #[test] + fn test_boolean_to_string_lookup_table() { + let lookup_map = create_lookup([ + (Some(true), Some("one")), + (Some(false), Some("two")), + (Some(true), Some("three")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (Some(true), Some("one")), + (None, None), + (Some(false), Some("two")), + (Some(true), Some("one")), + (Some(false), Some("two")), + (Some(false), Some("two")), + (None, None), + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_with_strings_to_int32_with_different_else_type() { + let lookup_map = create_lookup([ + (Some("why"), Some(42)), + (Some("what"), Some(22)), + (Some("when"), Some(17)), + ]); + + let else_value = 101; + + let (input_values, expected) = + create_input_and_expected::([ + (Some("why"), Some(42)), + (Some("5"), Some(else_value)), // No match in WHEN + (None, Some(else_value)), // None cases are never match in CASE WHEN syntax + (Some("what"), Some(22)), + (None, Some(else_value)), // None cases are never match in CASE WHEN syntax + (None, Some(else_value)), // None cases are never match in CASE WHEN syntax + (Some("what"), Some(22)), + (Some("5"), Some(else_value)), // No match in WHEN + ]); + + let input_values = Arc::new(input_values) as ArrayRef; + + // Test case + test_case_when_literal_lookup( + input_values, + &lookup_map, + Some(ScalarValue::Int8(Some(else_value as i8))), + Arc::new(expected), + // Assert not used as the else type is different than the then data types + AssertLookupEvaluation::NotUsed, + ); + } }