Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,12 @@ impl ScalarValue {
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))),
DataType::Float32 => ScalarValue::Float32(Some(0.0)),
DataType::Float64 => ScalarValue::Float64(Some(0.0)),
DataType::Decimal32(precision, scale) => {
ScalarValue::Decimal32(Some(0), *precision, *scale)
}
DataType::Decimal64(precision, scale) => {
ScalarValue::Decimal64(Some(0), *precision, *scale)
}
DataType::Decimal128(precision, scale) => {
ScalarValue::Decimal128(Some(0), *precision, *scale)
}
Expand Down
122 changes: 110 additions & 12 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ impl<'a> BinaryTypeCoercer<'a> {

// TODO Move the rest inside of BinaryTypeCoercer

fn is_decimal(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::Decimal32(..)
| DataType::Decimal64(..)
| DataType::Decimal128(..)
| DataType::Decimal256(..)
)
}

/// Coercion rules for mathematics operators between decimal and non-decimal types.
fn math_decimal_coercion(
lhs_type: &DataType,
Expand Down Expand Up @@ -357,6 +367,15 @@ fn math_decimal_coercion(
| (Decimal256(_, _), Decimal256(_, _)) => {
Some((lhs_type.clone(), rhs_type.clone()))
}
// Cross-variant decimal coercion - choose larger variant with appropriate precision/scale
(lhs, rhs)
if is_decimal(lhs)
&& is_decimal(rhs)
&& std::mem::discriminant(lhs) != std::mem::discriminant(rhs) =>
{
let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?;
Some((coerced_type.clone(), coerced_type))
}
// Unlike with comparison we don't coerce to a decimal in the case of floating point
// numbers, instead falling back to floating point arithmetic instead
(
Expand Down Expand Up @@ -953,21 +972,92 @@ pub fn binary_numeric_coercion(
pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;

// Prefer decimal data type over floating point for comparison operation
match (lhs_type, rhs_type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be cleaner like so:

/// Decimal coercion rules.
pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
    use arrow::datatypes::DataType::*;

    // Prefer decimal data type over floating point for comparison operation
    match (lhs_type, rhs_type) {
        // Same decimal types
        (lhs_type, rhs_type)
            if std::mem::discriminant(lhs_type) == std::mem::discriminant(rhs_type) =>
        {
            get_wider_decimal_type(lhs_type, rhs_type)
        }
        // Mismatched decimal types
        (lhs_type, rhs_type)
            if is_decimal(lhs_type)
                && is_decimal(rhs_type)
                && std::mem::discriminant(lhs_type)
                    != std::mem::discriminant(rhs_type) =>
        {
            get_wider_decimal_type_cross_variant(lhs_type, rhs_type)
        }
        // Decimal + non-decimal types
        (Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), _)
        | (_, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) => {
            get_common_decimal_type(lhs_type, rhs_type)
        }
        (_, _) => None,
    }
}

Following what was done above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops forgot the is_decimal() checks for the first branch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I've added them locally :) should have something soon

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done as part of 4145a04

// Prefer decimal data type over floating point for comparison operation
(Decimal128(_, _), Decimal128(_, _)) => {
// Same decimal types
(lhs_type, rhs_type)
if is_decimal(lhs_type)
&& is_decimal(rhs_type)
&& std::mem::discriminant(lhs_type)
== std::mem::discriminant(rhs_type) =>
{
get_wider_decimal_type(lhs_type, rhs_type)
}
(Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
(_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
(Decimal256(_, _), Decimal256(_, _)) => {
get_wider_decimal_type(lhs_type, rhs_type)
// Mismatched decimal types
(lhs_type, rhs_type)
if is_decimal(lhs_type)
&& is_decimal(rhs_type)
&& std::mem::discriminant(lhs_type)
!= std::mem::discriminant(rhs_type) =>
{
get_wider_decimal_type_cross_variant(lhs_type, rhs_type)
}
// Decimal + non-decimal types
(Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), _) => {
get_common_decimal_type(lhs_type, rhs_type)
}
(_, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) => {
get_common_decimal_type(rhs_type, lhs_type)
}
(Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
(_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
(_, _) => None,
}
}
/// Handle cross-variant decimal widening by choosing the larger variant
fn get_wider_decimal_type_cross_variant(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;

let (p1, s1) = match lhs_type {
Decimal32(p, s) => (*p, *s),
Decimal64(p, s) => (*p, *s),
Decimal128(p, s) => (*p, *s),
Decimal256(p, s) => (*p, *s),
_ => return None,
};

let (p2, s2) = match rhs_type {
Decimal32(p, s) => (*p, *s),
Decimal64(p, s) => (*p, *s),
Decimal128(p, s) => (*p, *s),
Decimal256(p, s) => (*p, *s),
_ => return None,
};

// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
let s = s1.max(s2);
let range = (p1 as i8 - s1).max(p2 as i8 - s2);
let required_precision = (range + s) as u8;
Comment on lines +1029 to +1031
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if we have:

Decimal256 with precision 76 (max) and scale 0, and Decimal128 with precision 38 (max) with scale 1;

So s = 1, range = 76, required_precision = 76 + 1 -> overflow?

Is this a valid case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think an overflow is valid, I'll have to think about it and maybe look into solutions in other systems.
We can also just return None in that case, which should force the user to add an explicit cast to one side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked around a bit, and what I could find is:

  1. DataFusion already has multiple issues regarding cast overflow/precision loss (decimal calculate overflow but not throw error #16406, Datafusion downcasts decimal loosing precision  #13492), which I'm happy to take on but are unrelated here.
  2. Spark (which seems to be the main inspiration for this code) has a configuration to control how it handles these cases (here and here).

I'm not sure what's the desired behavior regarding precision loss (should it be configurable? Is there currently an accepted desired behavior?), I think for this PR it should be fine to just return None if the precision overflows, and take the bigger conversation into an issue where people can weigh in, and I'll be glad to take that forward. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think returning None in cases like this for this PR is fine 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done as part of fd1f043

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cheers; left another minor comment related to the check below. Also would be nice if we had a test for this edge case.


// Choose the larger variant between the two input types, while making sure we don't overflow the precision.
match (lhs_type, rhs_type) {
(Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _))
if required_precision <= DECIMAL64_MAX_PRECISION =>
{
Some(Decimal64(required_precision, s))
}
(Decimal32(_, _), Decimal128(_, _))
| (Decimal128(_, _), Decimal32(_, _))
| (Decimal64(_, _), Decimal128(_, _))
| (Decimal128(_, _), Decimal64(_, _))
if required_precision <= DECIMAL128_MAX_PRECISION =>
{
Some(Decimal128(required_precision, s))
}
(Decimal32(_, _), Decimal256(_, _))
| (Decimal256(_, _), Decimal32(_, _))
| (Decimal64(_, _), Decimal256(_, _))
| (Decimal256(_, _), Decimal64(_, _))
| (Decimal128(_, _), Decimal256(_, _))
| (Decimal256(_, _), Decimal128(_, _))
if required_precision <= DECIMAL256_MAX_PRECISION =>
{
Some(Decimal256(required_precision, s))
}
_ => None,
}
}

/// Coerce `lhs_type` and `rhs_type` to a common type.
fn get_common_decimal_type(
Expand All @@ -976,7 +1066,15 @@ fn get_common_decimal_type(
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match decimal_type {
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) => {
Decimal32(_, _) => {
let other_decimal_type = coerce_numeric_type_to_decimal32(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
Decimal64(_, _) => {
let other_decimal_type = coerce_numeric_type_to_decimal64(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
Decimal128(_, _) => {
let other_decimal_type = coerce_numeric_type_to_decimal128(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
Expand All @@ -988,7 +1086,7 @@ fn get_common_decimal_type(
}
}

/// Returns a `DataType::Decimal128` that can store any value from either
/// Returns a decimal [`DataType`] variant that can store any value from either
/// `lhs_decimal_type` and `rhs_decimal_type`
///
/// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`.
Expand Down Expand Up @@ -1209,14 +1307,14 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTy
}

fn create_decimal32_type(precision: u8, scale: i8) -> DataType {
DataType::Decimal128(
DataType::Decimal32(
DECIMAL32_MAX_PRECISION.min(precision),
DECIMAL32_MAX_SCALE.min(scale),
)
}

fn create_decimal64_type(precision: u8, scale: i8) -> DataType {
DataType::Decimal128(
DataType::Decimal64(
DECIMAL64_MAX_PRECISION.min(precision),
DECIMAL64_MAX_SCALE.min(scale),
)
Expand Down
130 changes: 130 additions & 0 deletions datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,133 @@ fn test_coercion_arithmetic_decimal() -> Result<()> {

Ok(())
}

#[test]
fn test_coercion_arithmetic_decimal_cross_variant() -> Result<()> {
let test_cases = [
(
DataType::Decimal32(5, 2),
DataType::Decimal64(10, 3),
DataType::Decimal64(10, 3),
DataType::Decimal64(10, 3),
),
(
DataType::Decimal32(7, 1),
DataType::Decimal128(15, 4),
DataType::Decimal128(15, 4),
DataType::Decimal128(15, 4),
),
(
DataType::Decimal32(9, 0),
DataType::Decimal256(20, 5),
DataType::Decimal256(20, 5),
DataType::Decimal256(20, 5),
),
(
DataType::Decimal64(12, 3),
DataType::Decimal128(18, 2),
DataType::Decimal128(19, 3),
DataType::Decimal128(19, 3),
),
(
DataType::Decimal64(15, 4),
DataType::Decimal256(25, 6),
DataType::Decimal256(25, 6),
DataType::Decimal256(25, 6),
),
(
DataType::Decimal128(20, 5),
DataType::Decimal256(30, 8),
DataType::Decimal256(30, 8),
DataType::Decimal256(30, 8),
),
// Reverse order cases
(
DataType::Decimal64(10, 3),
DataType::Decimal32(5, 2),
DataType::Decimal64(10, 3),
DataType::Decimal64(10, 3),
),
(
DataType::Decimal128(15, 4),
DataType::Decimal32(7, 1),
DataType::Decimal128(15, 4),
DataType::Decimal128(15, 4),
),
(
DataType::Decimal256(20, 5),
DataType::Decimal32(9, 0),
DataType::Decimal256(20, 5),
DataType::Decimal256(20, 5),
),
(
DataType::Decimal128(18, 2),
DataType::Decimal64(12, 3),
DataType::Decimal128(19, 3),
DataType::Decimal128(19, 3),
),
(
DataType::Decimal256(25, 6),
DataType::Decimal64(15, 4),
DataType::Decimal256(25, 6),
DataType::Decimal256(25, 6),
),
(
DataType::Decimal256(30, 8),
DataType::Decimal128(20, 5),
DataType::Decimal256(30, 8),
DataType::Decimal256(30, 8),
),
];

for (lhs_type, rhs_type, expected_lhs_type, expected_rhs_type) in test_cases {
test_math_decimal_coercion_rule(
lhs_type,
rhs_type,
expected_lhs_type,
expected_rhs_type,
);
}

Ok(())
}

#[test]
fn test_decimal_precision_overflow_cross_variant() -> Result<()> {
// s = max(0, 1) = 1, range = max(76-0, 38-1) = 76, required_precision = 76 + 1 = 77 (overflow)
let result = get_wider_decimal_type_cross_variant(
&DataType::Decimal256(76, 0),
&DataType::Decimal128(38, 1),
);
assert!(result.is_none());

// s = max(0, 10) = 10, range = max(9-0, 18-10) = 9, required_precision = 9 + 10 = 19 (overflow > 18)
let result = get_wider_decimal_type_cross_variant(
&DataType::Decimal32(9, 0),
&DataType::Decimal64(18, 10),
);
assert!(result.is_none());

// s = max(5, 26) = 26, range = max(18-5, 38-26) = 13, required_precision = 13 + 26 = 39 (overflow > 38)
let result = get_wider_decimal_type_cross_variant(
&DataType::Decimal64(18, 5),
&DataType::Decimal128(38, 26),
);
assert!(result.is_none());

// s = max(10, 49) = 49, range = max(38-10, 76-49) = 28, required_precision = 28 + 49 = 77 (overflow > 76)
let result = get_wider_decimal_type_cross_variant(
&DataType::Decimal128(38, 10),
&DataType::Decimal256(76, 49),
);
assert!(result.is_none());

// s = max(2, 3) = 3, range = max(5-2, 10-3) = 7, required_precision = 7 + 3 = 10 (valid <= 18)
let result = get_wider_decimal_type_cross_variant(
&DataType::Decimal32(5, 2),
&DataType::Decimal64(10, 3),
);
assert!(result.is_some());

Ok(())
}
Loading