Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Validate/coerce types for comparisons within join_where predicates #21049

Merged
merged 10 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,19 @@ impl Operator {
| Self::LtEq
| Self::Gt
| Self::GtEq
| Self::And
| Self::Or
| Self::Xor
| Self::EqValidity
| Self::NotEqValidity
)
}

pub fn is_bitwise(&self) -> bool {
matches!(self, Self::And | Self::Or | Self::Xor)
}

pub fn is_comparison_or_bitwise(&self) -> bool {
self.is_comparison() || self.is_bitwise()
}

pub fn swap_operands(self) -> Self {
match self {
Operator::Eq => Operator::Eq,
Expand All @@ -465,6 +470,6 @@ impl Operator {
}

pub fn is_arithmetic(&self) -> bool {
!(self.is_comparison())
!(self.is_comparison_or_bitwise())
}
}
212 changes: 175 additions & 37 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,6 @@ pub fn resolve_join(
let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);

// Not a closure to avoid borrow issues because we mutate expr_arena as well.
macro_rules! get_dtype {
($expr:expr, $schema:expr) => {
ctxt.expr_arena
.get($expr.node())
.get_type($schema, Context::Default, ctxt.expr_arena)
};
}
// # Resolve scalars
//
// Scalars need to be expanded. We translate them to temporary columns added with
Expand Down Expand Up @@ -234,6 +226,15 @@ pub fn resolve_join(
(schema_left, schema_right)
};

// Not a closure to avoid borrow issues because we mutate expr_arena as well.
macro_rules! get_dtype {
($expr:expr, $schema:expr) => {
ctxt.expr_arena
.get($expr.node())
.get_type($schema, Context::Default, ctxt.expr_arena)
};
}

// # Cast lossless
//
// If we do a full join and keys are coalesced, the cast keys must be added up front.
Expand All @@ -249,15 +250,20 @@ pub fn resolve_join(
let rtype = get_dtype!(rnode, &schema_right)?;

if let Some(dtype) = get_numeric_upcast_supertype_lossless(&ltype, &rtype) {
// We use overflowing cast to allow better optimization as we are casting to a known
// lossless supertype.
//
// We have unique references to these nodes (they are created by this function),
// so we can mutate in-place without causing side effects somewhere else.
let casted_l = ctxt.expr_arena.add(AExpr::Cast {
expr: lnode.node(),
dtype: dtype.clone(),
options: CastOptions::Strict,
options: CastOptions::Overflowing,
});
let casted_r = ctxt.expr_arena.add(AExpr::Cast {
expr: rnode.node(),
dtype,
options: CastOptions::Strict,
options: CastOptions::Overflowing,
});

if key_cols_coalesced {
Expand Down Expand Up @@ -400,37 +406,12 @@ fn resolve_join_where(
let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt)
.map_err(|e| e.context(failed_here!(join left)))?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt
let schema_left = ctxt
.lp_arena
.get(input_right)
.get(input_left)
.schema(ctxt.lp_arena)
.into_owned();

for expr in &predicates {
fn all_in_schema(
schema: &Schema,
other: Option<&Schema>,
left: &Expr,
right: &Expr,
) -> bool {
let mut iter =
expr_to_leaf_column_names_iter(left).chain(expr_to_leaf_column_names_iter(right));
iter.all(|name| {
schema.contains(name.as_str()) && other.is_none_or(|s| !s.contains(name.as_str()))
})
}

let valid = expr.into_iter().all(|e| match e {
Expr::BinaryExpr { left, op, right } if op.is_comparison() => {
!(all_in_schema(&schema_left, None, left, right)
|| all_in_schema(&schema_right, Some(&schema_left), left, right))
},
_ => true,
});
polars_ensure!( valid, InvalidOperation: "'join_where' predicate only refers to columns from a single table")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check refactored to process_join_where_predicate() below

}

let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::Cross;

Expand All @@ -444,9 +425,31 @@ fn resolve_join_where(
ctxt,
)?;

let mut ae_nodes_stack = Vec::new();

let schema_merged = ctxt
.lp_arena
.get(last_node)
.schema(ctxt.lp_arena)
.into_owned();
let schema_merged = schema_merged.as_ref();

for e in predicates {
let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?;

debug_assert!(ae_nodes_stack.is_empty());
ae_nodes_stack.clear();
ae_nodes_stack.push(predicate.node());

process_join_where_predicate(
&mut ae_nodes_stack,
0,
schema_left.as_ref(),
schema_merged,
ctxt.expr_arena,
&mut ExprOrigin::None,
)?;

ctxt.conversion_optimizer
.push_scratch(predicate.node(), ctxt.expr_arena);

Expand All @@ -464,3 +467,138 @@ fn resolve_join_where(

Ok((last_node, join_node))
}

/// Performs validation and type-coercion on join_where predicates.
///
/// Validates for all comparison expressions / subexpressions, that:
/// 1. They reference columns from both sides.
/// 2. The dtypes of the LHS and RHS are match, or can be casted to a lossless
/// supertype (and inserts the necessary casting).
///
/// We perform (1) by recursing whenever we encounter a comparison expression.
fn process_join_where_predicate(
stack: &mut Vec<Node>,
prev_comparison_expr_stack_offset: usize,
schema_left: &Schema,
schema_merged: &Schema,
expr_arena: &mut Arena<AExpr>,
column_origins: &mut ExprOrigin,
) -> PolarsResult<()> {
while stack.len() > prev_comparison_expr_stack_offset {
let ae_node = stack.pop().unwrap();
let ae = expr_arena.get(ae_node).clone();

match ae {
AExpr::Column(ref name) => {
let origin = if schema_left.contains(name) {
ExprOrigin::Left
} else if schema_merged.contains(name) {
ExprOrigin::Right
} else {
polars_bail!(ColumnNotFound: "{}", name);
};

*column_origins |= origin;
},
// This is not actually Origin::Both, but we set this because the test suite expects
// this predicate to pass:
// * `pl.col("flag_right") == 1`
// Observe that it only has a column from one side because it is comparing to a literal.
AExpr::Literal(_) => *column_origins = ExprOrigin::Both,
AExpr::BinaryExpr {
left: left_node,
op,
right: right_node,
} if op.is_comparison_or_bitwise() => {
{
let new_stack_offset = stack.len();
stack.extend([right_node, left_node]);

// Reset `column_origins` to a `None` state. We will only have 2 possible return states from
// this point:
// * Ok(()), with column_origins @ ExprOrigin::Both
// * Err(_), in which case the value of column_origins doesn't matter.
*column_origins = ExprOrigin::None;

process_join_where_predicate(
stack,
new_stack_offset,
schema_left,
schema_merged,
expr_arena,
column_origins,
)?;

if *column_origins != ExprOrigin::Both {
polars_bail!(
InvalidOperation:
"'join_where' predicate only refers to columns from a single table: {}",
node_to_expr(ae_node, expr_arena),
)
}
}

// Fetch them again in case they were rewritten.
let left = expr_arena.get(left_node).clone();
let right = expr_arena.get(right_node).clone();

let resolve_dtype = |ae: &AExpr, node: Node| -> PolarsResult<DataType> {
ae.to_dtype(schema_merged, Context::Default, expr_arena)
.map_err(|e| {
e.context(
format!(
"could not resolve dtype of join_where predicate (expr: {})",
node_to_expr(node, expr_arena),
)
.into(),
)
})
};

let dtype_left = resolve_dtype(&left, left_node)?;
let dtype_right = resolve_dtype(&right, right_node)?;

// Note: We only upcast the sides if the expr output dtype is Boolean (i.e. `op` is
// a comparison), otherwise the output may change.

if let Some(dtype) =
get_numeric_upcast_supertype_lossless(&dtype_left, &dtype_right)
.filter(|_| op.is_comparison())
{
// We have unique references to these nodes (they are created by this function),
// so we can mutate in-place without causing side effects somewhere else.
let expr = expr_arena.add(expr_arena.get(left_node).clone());
expr_arena.replace(
left_node,
AExpr::Cast {
expr,
dtype: dtype.clone(),
options: CastOptions::Overflowing,
},
);

let expr = expr_arena.add(expr_arena.get(right_node).clone());
expr_arena.replace(
right_node,
AExpr::Cast {
expr,
dtype,
options: CastOptions::Overflowing,
},
);
} else {
polars_ensure!(
dtype_left == dtype_right,
SchemaMismatch:
"datatypes of join_where comparison don't match - {} on left does not match {} on right \
(expr: {})",
dtype_left, dtype_right, node_to_expr(ae_node, expr_arena),
)
}
},
ae => ae.inputs_rev(stack),
}
}

Ok(())
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ macro_rules! unpack {
fn compares_cat_to_string(type_left: &DataType, type_right: &DataType, op: Operator) -> bool {
#[cfg(feature = "dtype-categorical")]
{
op.is_comparison()
op.is_comparison_or_bitwise()
&& matches_any_order!(
type_left,
type_right,
Expand Down Expand Up @@ -167,40 +167,40 @@ pub(super) fn process_binary(
match (&type_left, &type_right, op) {
#[cfg(not(feature = "dtype-categorical"))]
(DataType::String, dt, op) | (dt, DataType::String, op)
if op.is_comparison() && dt.is_primitive_numeric() =>
if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() =>
{
return Ok(None)
},
#[cfg(feature = "dtype-categorical")]
(String | Unknown(UnknownKind::Str) | Categorical(_, _), dt, op)
| (dt, Unknown(UnknownKind::Str) | String | Categorical(_, _), op)
if op.is_comparison() && dt.is_primitive_numeric() =>
if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() =>
{
return Ok(None)
},
#[cfg(feature = "dtype-categorical")]
(Unknown(UnknownKind::Str) | String | Enum(_, _), dt, op)
| (dt, Unknown(UnknownKind::Str) | String | Enum(_, _), op)
if op.is_comparison() && dt.is_primitive_numeric() =>
if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() =>
{
return Ok(None)
},
#[cfg(feature = "dtype-date")]
(Date, String | Unknown(UnknownKind::Str), op)
| (String | Unknown(UnknownKind::Str), Date, op)
if op.is_comparison() =>
if op.is_comparison_or_bitwise() =>
{
err_date_str_compare()?
},
#[cfg(feature = "dtype-datetime")]
(Datetime(_, _), String | Unknown(UnknownKind::Str), op)
| (String | Unknown(UnknownKind::Str), Datetime(_, _), op)
if op.is_comparison() =>
if op.is_comparison_or_bitwise() =>
{
err_date_str_compare()?
},
#[cfg(feature = "dtype-time")]
(Time | Unknown(UnknownKind::Str), String, op) if op.is_comparison() => {
(Time | Unknown(UnknownKind::Str), String, op) if op.is_comparison_or_bitwise() => {
err_date_str_compare()?
},
// structs can be arbitrarily nested, leave the complexity to the caller for now.
Expand Down
Loading
Loading