From 3693a97a7280723af3aa07443d682b21cb39fa85 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 5 Feb 2025 17:46:58 +1100 Subject: [PATCH 1/9] c --- crates/polars-plan/src/dsl/expr.rs | 13 +- .../polars-plan/src/plans/conversion/join.rs | 202 ++++++++++++++---- .../polars-plan/src/plans/conversion/mod.rs | 2 +- .../plans/conversion/type_coercion/binary.rs | 14 +- .../src/plans/optimizer/collapse_joins.rs | 60 +----- .../src/plans/optimizer/join_utils.rs | 73 +++++++ crates/polars-plan/src/plans/optimizer/mod.rs | 51 ++--- .../polars-plan/src/plans/python/pyarrow.rs | 2 +- py-polars/tests/unit/operations/test_join.py | 105 ++++++++- 9 files changed, 389 insertions(+), 133 deletions(-) diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index cf1b1af51b81..cfa3a7aecb0e 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -433,14 +433,19 @@ impl Operator { | Self::LtEq | Self::Gt | Self::GtEq - | Self::And - | Self::Or - | Self::Xor | Self::EqValidity | Self::NotEqValidity ) } + pub fn is_bitwise(&self) -> bool { + matches!(self, Self::And | Self::Or | Self::Xor) + } + + pub fn is_comparison_or_bitwise(&self) -> bool { + self.is_comparison() || self.is_bitwise() + } + pub fn swap_operands(self) -> Self { match self { Operator::Eq => Operator::Eq, @@ -465,6 +470,6 @@ impl Operator { } pub fn is_arithmetic(&self) -> bool { - !(self.is_comparison()) + !(self.is_comparison_or_bitwise()) } } diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 04ebb306f25c..056cbd8d41aa 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -141,14 +141,6 @@ pub fn resolve_join( let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena); - // Not a closure to avoid borrow issues because we mutate expr_arena as well. - macro_rules! get_dtype { - ($expr:expr, $schema:expr) => { - ctxt.expr_arena - .get($expr.node()) - .get_type($schema, Context::Default, ctxt.expr_arena) - }; - } // # Resolve scalars // // Scalars need to be expanded. We translate them to temporary columns added with @@ -234,6 +226,15 @@ pub fn resolve_join( (schema_left, schema_right) }; + // Not a closure to avoid borrow issues because we mutate expr_arena as well. + macro_rules! get_dtype { + ($expr:expr, $schema:expr) => { + ctxt.expr_arena + .get($expr.node()) + .get_type($schema, Context::Default, ctxt.expr_arena) + }; + } + // # Cast lossless // // If we do a full join and keys are coalesced, the cast keys must be added up front. @@ -249,15 +250,20 @@ pub fn resolve_join( let rtype = get_dtype!(rnode, &schema_right)?; if let Some(dtype) = get_numeric_upcast_supertype_lossless(<ype, &rtype) { + // We use overflowing cast to allow better optimization as we are casting to a known + // lossless supertype. + // + // We have unique references to these nodes (they are created by this function), + // so we can mutate in-place without causing side effects somewhere else. let casted_l = ctxt.expr_arena.add(AExpr::Cast { expr: lnode.node(), dtype: dtype.clone(), - options: CastOptions::Strict, + options: CastOptions::Overflowing, }); let casted_r = ctxt.expr_arena.add(AExpr::Cast { expr: rnode.node(), dtype, - options: CastOptions::Strict, + options: CastOptions::Overflowing, }); if key_cols_coalesced { @@ -400,37 +406,12 @@ fn resolve_join_where( let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt) .map_err(|e| e.context(failed_here!(join left)))?; - let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); - let schema_right = ctxt + let schema_left = ctxt .lp_arena - .get(input_right) + .get(input_left) .schema(ctxt.lp_arena) .into_owned(); - for expr in &predicates { - fn all_in_schema( - schema: &Schema, - other: Option<&Schema>, - left: &Expr, - right: &Expr, - ) -> bool { - let mut iter = - expr_to_leaf_column_names_iter(left).chain(expr_to_leaf_column_names_iter(right)); - iter.all(|name| { - schema.contains(name.as_str()) && other.is_none_or(|s| !s.contains(name.as_str())) - }) - } - - let valid = expr.into_iter().all(|e| match e { - Expr::BinaryExpr { left, op, right } if op.is_comparison() => { - !(all_in_schema(&schema_left, None, left, right) - || all_in_schema(&schema_right, Some(&schema_left), left, right)) - }, - _ => true, - }); - polars_ensure!( valid, InvalidOperation: "'join_where' predicate only refers to columns from a single table") - } - let opts = Arc::make_mut(&mut options); opts.args.how = JoinType::Cross; @@ -444,9 +425,31 @@ fn resolve_join_where( ctxt, )?; + let mut ae_nodes_stack = Vec::new(); + + let schema_merged = ctxt + .lp_arena + .get(last_node) + .schema(ctxt.lp_arena) + .into_owned(); + let schema_merged = schema_merged.as_ref(); + for e in predicates { let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?; + debug_assert!(ae_nodes_stack.is_empty()); + ae_nodes_stack.clear(); + ae_nodes_stack.push(predicate.node()); + + process_join_where_predicate( + &mut ae_nodes_stack, + 0, + schema_left.as_ref(), + schema_merged, + ctxt.expr_arena, + &mut ExprOrigin::None, + )?; + ctxt.conversion_optimizer .push_scratch(predicate.node(), ctxt.expr_arena); @@ -464,3 +467,128 @@ fn resolve_join_where( Ok((last_node, join_node)) } + +/// Performs validation and type-coercion on join_where predicates. +/// +/// Validates for all comparison expressions / subexpressions, that: +/// 1. They reference columns from both sides. +/// 2. The dtypes of the LHS and RHS are match, or can be casted to a lossless +/// supertype (and inserts the necessary casting). +fn process_join_where_predicate( + stack: &mut Vec, + binary_expr_stack_offset: usize, + schema_left: &Schema, + schema_merged: &Schema, + expr_arena: &mut Arena, + column_origins: &mut ExprOrigin, +) -> PolarsResult<()> { + while stack.len() > binary_expr_stack_offset { + let ae_node = stack.pop().unwrap(); + let ae = expr_arena.get(ae_node).clone(); + + match ae { + AExpr::Column(ref name) => { + let origin = if schema_left.contains(name) { + ExprOrigin::Left + } else if schema_merged.contains(name) { + ExprOrigin::Right + } else { + polars_bail!(ColumnNotFound: "{}", name); + }; + + *column_origins |= origin; + }, + AExpr::BinaryExpr { + left: left_node, + op, + right: right_node, + } if op.is_comparison_or_bitwise() => { + { + let new_stack_offset = stack.len(); + stack.extend([right_node, left_node]); + + // Reset `column_origins` to a `None` state. We will only have 2 possible return states from + // this point: + // * Ok(()), with column_origins @ ExprOrigin::Both + // * Err(_), in which case the value of column_origins doesn't matter. + *column_origins = ExprOrigin::None; + + process_join_where_predicate( + stack, + new_stack_offset, + schema_left, + schema_merged, + expr_arena, + column_origins, + )?; + + if *column_origins != ExprOrigin::Both { + polars_bail!( + InvalidOperation: + "'join_where' predicate only refers to columns from a single table: {}", + node_to_expr(ae_node, expr_arena), + ) + } + } + + // Fetch them again in case they were rewritten. + let left = expr_arena.get(left_node).clone(); + let right = expr_arena.get(right_node).clone(); + + let resolve_dtype = |ae: &AExpr, node: Node| -> PolarsResult { + ae.to_dtype(schema_merged, Context::Default, expr_arena) + .map_err(|e| { + e.context( + format!( + "could not resolve dtype of join_where predicate (expr: {})", + node_to_expr(node, expr_arena), + ) + .into(), + ) + }) + }; + + let dtype_left = resolve_dtype(&left, left_node)?; + let dtype_right = resolve_dtype(&right, right_node)?; + + if let Some(dtype) = + get_numeric_upcast_supertype_lossless(&dtype_left, &dtype_right) + .filter(|_| op.is_comparison()) + { + // We have unique references to these nodes (they are created by this function), + // so we can mutate in-place without causing side effects somewhere else. + let expr = expr_arena.add(expr_arena.get(left_node).clone()); + expr_arena.replace( + left_node, + AExpr::Cast { + expr, + dtype: dtype.clone(), + options: CastOptions::Overflowing, + }, + ); + + let expr = expr_arena.add(expr_arena.get(right_node).clone()); + expr_arena.replace( + right_node, + AExpr::Cast { + expr, + dtype, + options: CastOptions::Overflowing, + }, + ); + } else { + polars_ensure!( + dtype_left == dtype_right, + SchemaMismatch: + "datatypes of join_where comparison don't match - {} on left does not match {} on right \ + (expr: {})", + dtype_left, dtype_right, node_to_expr(ae_node, expr_arena), + ) + } + }, + ae => ae.inputs_rev(stack), + } + } + + Ok(()) +} diff --git a/crates/polars-plan/src/plans/conversion/mod.rs b/crates/polars-plan/src/plans/conversion/mod.rs index 987dfc89cb37..3d7173ae9d6f 100644 --- a/crates/polars-plan/src/plans/conversion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/mod.rs @@ -10,7 +10,7 @@ mod ir_to_dsl; feature = "json" ))] mod scans; -mod stack_opt; +pub(crate) mod stack_opt; use std::borrow::Cow; use std::sync::{Arc, Mutex}; diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs index 12e9f236cebd..89f167cec953 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs @@ -16,7 +16,7 @@ macro_rules! unpack { fn compares_cat_to_string(type_left: &DataType, type_right: &DataType, op: Operator) -> bool { #[cfg(feature = "dtype-categorical")] { - op.is_comparison() + op.is_comparison_or_bitwise() && matches_any_order!( type_left, type_right, @@ -167,40 +167,40 @@ pub(super) fn process_binary( match (&type_left, &type_right, op) { #[cfg(not(feature = "dtype-categorical"))] (DataType::String, dt, op) | (dt, DataType::String, op) - if op.is_comparison() && dt.is_primitive_numeric() => + if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() => { return Ok(None) }, #[cfg(feature = "dtype-categorical")] (String | Unknown(UnknownKind::Str) | Categorical(_, _), dt, op) | (dt, Unknown(UnknownKind::Str) | String | Categorical(_, _), op) - if op.is_comparison() && dt.is_primitive_numeric() => + if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() => { return Ok(None) }, #[cfg(feature = "dtype-categorical")] (Unknown(UnknownKind::Str) | String | Enum(_, _), dt, op) | (dt, Unknown(UnknownKind::Str) | String | Enum(_, _), op) - if op.is_comparison() && dt.is_primitive_numeric() => + if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() => { return Ok(None) }, #[cfg(feature = "dtype-date")] (Date, String | Unknown(UnknownKind::Str), op) | (String | Unknown(UnknownKind::Str), Date, op) - if op.is_comparison() => + if op.is_comparison_or_bitwise() => { err_date_str_compare()? }, #[cfg(feature = "dtype-datetime")] (Datetime(_, _), String | Unknown(UnknownKind::Str), op) | (String | Unknown(UnknownKind::Str), Datetime(_, _), op) - if op.is_comparison() => + if op.is_comparison_or_bitwise() => { err_date_str_compare()? }, #[cfg(feature = "dtype-time")] - (Time | Unknown(UnknownKind::Str), String, op) if op.is_comparison() => { + (Time | Unknown(UnknownKind::Str), String, op) if op.is_comparison_or_bitwise() => { err_date_str_compare()? }, // structs can be arbitrarily nested, leave the complexity to the caller for now. diff --git a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs index de6418bbad08..5092e96b6ccc 100644 --- a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs +++ b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs @@ -12,60 +12,11 @@ use polars_ops::frame::{JoinCoalesce, JoinType, MaintainOrderJoin}; use polars_utils::arena::{Arena, Node}; use polars_utils::pl_str::PlSmallStr; -use super::{aexpr_to_leaf_names_iter, AExpr, JoinOptions, IR}; +use super::{aexpr_to_leaf_names_iter, AExpr, ExprOrigin, JoinOptions, IR}; use crate::dsl::{JoinTypeOptionsIR, Operator}; use crate::plans::visitor::{AexprNode, RewriteRecursion, RewritingVisitor, TreeWalker}; use crate::plans::{ExprIR, OutputName}; -/// Join origin of an expression -#[derive(Debug, Clone, Copy)] -enum ExprOrigin { - /// Utilizes no columns - None, - /// Utilizes columns from the left side of the join - Left, - /// Utilizes columns from the right side of the join - Right, - /// Utilizes columns from both sides of the join - Both, -} - -fn get_origin( - root: Node, - expr_arena: &Arena, - left_schema: &SchemaRef, - right_schema: &SchemaRef, - suffix: &str, -) -> ExprOrigin { - let mut expr_origin = ExprOrigin::None; - - for name in aexpr_to_leaf_names_iter(root, expr_arena) { - let in_left = left_schema.contains(name.as_str()); - let in_right = right_schema.contains(name.as_str()); - let has_suffix = name.as_str().ends_with(suffix); - let in_right = in_right - | (has_suffix && right_schema.contains(&name.as_str()[..name.len() - suffix.len()])); - - let name_origin = match (in_left, in_right, has_suffix) { - (true, false, _) | (true, true, false) => ExprOrigin::Left, - (false, true, _) | (true, true, true) => ExprOrigin::Right, - (false, false, _) => { - unreachable!("Invalid filter column should have been filtered before") - }, - }; - - use ExprOrigin as O; - expr_origin = match (expr_origin, name_origin) { - (O::None, other) | (other, O::None) => other, - (O::Left, O::Left) => O::Left, - (O::Right, O::Right) => O::Right, - _ => O::Both, - }; - } - - expr_origin -} - fn remove_suffix<'a>( exprs: &mut Vec, expr_arena: &mut Arena, @@ -279,7 +230,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &mut Arena, expr_arena: &mut ArenaIR conversion. + let join_ir = IR::Join { input_left, input_right, diff --git a/crates/polars-plan/src/plans/optimizer/join_utils.rs b/crates/polars-plan/src/plans/optimizer/join_utils.rs index 2fe1fa0c61b5..9a279dd4d6ed 100644 --- a/crates/polars-plan/src/plans/optimizer/join_utils.rs +++ b/crates/polars-plan/src/plans/optimizer/join_utils.rs @@ -1,3 +1,76 @@ +use polars_core::schema::*; +#[cfg(feature = "iejoin")] +use polars_utils::arena::{Arena, Node}; + +use super::{aexpr_to_leaf_names_iter, AExpr}; + +/// Join origin of an expression +#[derive(Debug, Clone, PartialEq, Copy)] +#[repr(u8)] +pub(crate) enum ExprOrigin { + // Note: There is a merge() function implemented on this enum that relies + // on this exact u8 repr layout. + // + /// Utilizes no columns + None = 0b00, + /// Utilizes columns from the left side of the join + Left = 0b10, + /// Utilizes columns from the right side of the join + Right = 0b01, + /// Utilizes columns from both sides of the join + Both = 0b11, +} + +impl ExprOrigin { + pub(crate) fn get_expr_origin( + root: Node, + expr_arena: &Arena, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + suffix: &str, + ) -> ExprOrigin { + let mut expr_origin = ExprOrigin::None; + + for name in aexpr_to_leaf_names_iter(root, expr_arena) { + let in_left = left_schema.contains(name.as_str()); + let in_right = right_schema.contains(name.as_str()); + let has_suffix = name.as_str().ends_with(suffix); + let in_right = in_right + | (has_suffix + && right_schema.contains(&name.as_str()[..name.len() - suffix.len()])); + + let name_origin = match (in_left, in_right, has_suffix) { + (true, false, _) | (true, true, false) => ExprOrigin::Left, + (false, true, _) | (true, true, true) => ExprOrigin::Right, + (false, false, _) => { + unreachable!("Invalid filter column should have been filtered before") + }, + }; + + use ExprOrigin as O; + expr_origin = match (expr_origin, name_origin) { + (O::None, other) | (other, O::None) => other, + (O::Left, O::Left) => O::Left, + (O::Right, O::Right) => O::Right, + _ => O::Both, + }; + } + + expr_origin + } + + /// Logical OR with another [`ExprOrigin`] + fn merge(&mut self, other: Self) { + *self = unsafe { std::mem::transmute::(*self as u8 | other as u8) } + } +} + +impl std::ops::BitOrAssign for ExprOrigin { + fn bitor_assign(&mut self, rhs: Self) { + self.merge(rhs) + } +} + pub(super) fn split_suffix<'a>(name: &'a str, suffix: &str) -> &'a str { let (original, _) = name.split_at(name.len() - suffix.len()); original diff --git a/crates/polars-plan/src/plans/optimizer/mod.rs b/crates/polars-plan/src/plans/optimizer/mod.rs index 3da51ced4406..c8630781e066 100644 --- a/crates/polars-plan/src/plans/optimizer/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/mod.rs @@ -16,6 +16,7 @@ mod flatten_union; #[cfg(feature = "fused")] mod fused; mod join_utils; +pub(crate) use join_utils::ExprOrigin; mod predicate_pushdown; mod projection_pushdown; mod set_order; @@ -61,7 +62,7 @@ pub(crate) fn init_hashmap(max_len: Option) -> PlHashMap { pub fn optimize( logical_plan: DslPlan, - mut opt_state: OptFlags, + mut opt_flags: OptFlags, lp_arena: &mut Arena, expr_arena: &mut Arena, scratch: &mut Vec, @@ -70,7 +71,7 @@ pub fn optimize( #[allow(dead_code)] let verbose = verbose(); - if opt_state.streaming() { + if opt_flags.streaming() { polars_warn!( Deprecation, "\ @@ -91,24 +92,24 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ // This can be turned on again during ir-conversion. #[allow(clippy::eq_op)] #[cfg(feature = "cse")] - if opt_state.contains(OptFlags::EAGER) { - opt_state &= !(OptFlags::COMM_SUBEXPR_ELIM | OptFlags::COMM_SUBEXPR_ELIM); + if opt_flags.contains(OptFlags::EAGER) { + opt_flags &= !(OptFlags::COMM_SUBEXPR_ELIM | OptFlags::COMM_SUBEXPR_ELIM); } - let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena, &mut opt_state)?; + let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena, &mut opt_flags)?; // Don't run optimizations that don't make sense on a single node. // This keeps eager execution more snappy. #[cfg(feature = "cse")] - let comm_subplan_elim = opt_state.contains(OptFlags::COMM_SUBPLAN_ELIM); + let comm_subplan_elim = opt_flags.contains(OptFlags::COMM_SUBPLAN_ELIM); #[cfg(feature = "cse")] - let comm_subexpr_elim = opt_state.contains(OptFlags::COMM_SUBEXPR_ELIM); + let comm_subexpr_elim = opt_flags.contains(OptFlags::COMM_SUBEXPR_ELIM); #[cfg(not(feature = "cse"))] let comm_subexpr_elim = false; #[allow(unused_variables)] let agg_scan_projection = - opt_state.contains(OptFlags::FILE_CACHING) && !opt_state.streaming() && !opt_state.eager(); + opt_flags.contains(OptFlags::FILE_CACHING) && !opt_flags.streaming() && !opt_flags.eager(); // During debug we check if the optimizations have not modified the final schema. #[cfg(debug_assertions)] @@ -116,18 +117,18 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ // Collect members for optimizations that need it. let mut members = MemberCollector::new(); - if !opt_state.eager() && (comm_subexpr_elim || opt_state.projection_pushdown()) { + if !opt_flags.eager() && (comm_subexpr_elim || opt_flags.projection_pushdown()) { members.collect(lp_top, lp_arena, expr_arena) } // Run before slice pushdown - if opt_state.contains(OptFlags::CHECK_ORDER_OBSERVE) + if opt_flags.contains(OptFlags::CHECK_ORDER_OBSERVE) && members.has_group_by | members.has_sort | members.has_distinct { set_order_flags(lp_top, lp_arena, expr_arena, scratch); } - if opt_state.simplify_expr() { + if opt_flags.simplify_expr() { #[cfg(feature = "fused")] rules.push(Box::new(fused::FusedArithmetic {})); } @@ -155,8 +156,8 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ let _cse_plan_changed = false; // Should be run before predicate pushdown. - if opt_state.projection_pushdown() { - let mut projection_pushdown_opt = ProjectionPushDown::new(opt_state.new_streaming()); + if opt_flags.projection_pushdown() { + let mut projection_pushdown_opt = ProjectionPushDown::new(opt_flags.new_streaming()); let alp = lp_arena.take(lp_top); let alp = projection_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; lp_arena.replace(lp_top, alp); @@ -167,36 +168,36 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ } } - if opt_state.predicate_pushdown() { + if opt_flags.predicate_pushdown() { let mut predicate_pushdown_opt = PredicatePushDown::new(expr_eval); let alp = lp_arena.take(lp_top); let alp = predicate_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; lp_arena.replace(lp_top, alp); } - if opt_state.cluster_with_columns() { + if opt_flags.cluster_with_columns() { cluster_with_columns::optimize(lp_top, lp_arena, expr_arena) } // Make sure it is after predicate pushdown - if opt_state.collapse_joins() && members.has_filter_with_join_input { - collapse_joins::optimize(lp_top, lp_arena, expr_arena) + if opt_flags.collapse_joins() && members.has_filter_with_join_input { + collapse_joins::optimize(lp_top, lp_arena, expr_arena); } // Make sure its before slice pushdown. - if opt_state.fast_projection() { + if opt_flags.fast_projection() { rules.push(Box::new(SimpleProjectionAndCollapse::new( - opt_state.eager(), + opt_flags.eager(), ))); } - if !opt_state.eager() { + if !opt_flags.eager() { rules.push(Box::new(DelayRechunk::new())); } - if opt_state.slice_pushdown() { + if opt_flags.slice_pushdown() { let mut slice_pushdown_opt = - SlicePushDown::new(opt_state.streaming(), opt_state.new_streaming()); + SlicePushDown::new(opt_flags.streaming(), opt_flags.new_streaming()); let alp = lp_arena.take(lp_top); let alp = slice_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; @@ -207,11 +208,11 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ } // This optimization removes branches, so we must do it when type coercion // is completed. - if opt_state.simplify_expr() { + if opt_flags.simplify_expr() { rules.push(Box::new(SimplifyBooleanRule {})); } - if !opt_state.eager() { + if !opt_flags.eager() { rules.push(Box::new(FlattenUnionRule {})); } @@ -226,7 +227,7 @@ More information on the new streaming engine: https://github.com/pola-rs/polars/ scratch, expr_eval, verbose, - opt_state.new_streaming(), + opt_flags.new_streaming(), )?; } diff --git a/crates/polars-plan/src/plans/python/pyarrow.rs b/crates/polars-plan/src/plans/python/pyarrow.rs index 78fcc20cc453..b3920b066440 100644 --- a/crates/polars-plan/src/plans/python/pyarrow.rs +++ b/crates/polars-plan/src/plans/python/pyarrow.rs @@ -29,7 +29,7 @@ pub fn predicate_to_pa( ) -> Option { match expr_arena.get(predicate) { AExpr::BinaryExpr { left, right, op } => { - if op.is_comparison() { + if op.is_comparison_or_bitwise() { let left = predicate_to_pa(*left, expr_arena, args)?; let right = predicate_to_pa(*right, expr_arena, args)?; Some(format!("({left} {op} {right})")) diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 8a2e9349d9fc..e2d2b913097a 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -1374,7 +1374,7 @@ def test_join_preserve_order_full() -> None: ], ) # fmt: skip @pytest.mark.parametrize("swap", [True, False]) -def test_join_numeric_type_upcast_15338( +def test_join_numeric_key_upcast_15338( dtypes: tuple[str, str, str], swap: bool ) -> None: supertype, ltype, rtype = (getattr(pl, x) for x in dtypes) @@ -1415,17 +1415,112 @@ def test_join_numeric_type_upcast_15338( pl.select(a=pl.Series([1, 1]).cast(ltype)), ) + # join_where + for no_optimization in [True, False]: + assert_frame_equal( + left.join_where(right, pl.col("a") == pl.col("a_right")).collect( + no_optimization=no_optimization + ), + pl.select( + a=pl.Series([1, 1]).cast(ltype), + a_right=pl.lit(1, dtype=rtype), + b=pl.Series(["A", "A"]), + ), + ) + -def test_join_numeric_type_upcast_forbid_float_int() -> None: +def test_join_numeric_key_upcast_forbid_float_int() -> None: ltype = pl.Float64 - rtype = pl.Int32 + rtype = pl.Int128 - left = pl.LazyFrame(schema={"a": ltype}) - right = pl.LazyFrame(schema={"a": rtype}) + left = pl.LazyFrame({"a": [1.0, 0.0]}, schema={"a": ltype}) + right = pl.LazyFrame({"a": [1, 2]}, schema={"a": rtype}) + + # Establish baseline: In a non-join context, comparisons between ltype and + # rtype succeed even if the upcast is lossy. + assert_frame_equal( + left.with_columns(right.collect()["a"].alias("a_right")) + .select(pl.col("a") == pl.col("a_right")) + .collect(), + pl.DataFrame({"a": [True, False]}), + ) with pytest.raises(SchemaError, match="datatypes of join keys don't match"): left.join(right, on="a", how="left").collect() + for no_optimization in [True, False]: + with pytest.raises( + SchemaError, match="datatypes of join_where comparison don't match" + ): + left.join_where(right, pl.col("a") == pl.col("a_right")).collect( + no_optimization=no_optimization + ) + + with pytest.raises( + SchemaError, match="datatypes of join_where comparison don't match" + ): + left.join_where( + right, pl.col("a") == (pl.col("a") == pl.col("a_right")) + ).collect(no_optimization=no_optimization) + + +def test_join_numeric_key_upcast_order() -> None: + # E.g. when we are joining on this expression: + # * col('a') + 127 + # + # and we want to upcast, ensure that we upcast like this: + # * ( col('a') + 127 ) .cast() + # + # and *not* like this: + # * ( col('a').cast() + lit(127).cast() ) + # + # as otherwise the results would be different. + + left = pl.select(pl.Series("a", [1], dtype=pl.Int8)).lazy() + right = pl.select( + pl.Series("a", [1, 128, -128], dtype=pl.Int64), b=pl.lit("A") + ).lazy() + + # col('a') in `left` is Int8, the result will overflow to become -128 + left_expr = pl.col("a") + 127 + + assert_frame_equal( + left.join(right, left_on=left_expr, right_on="a", how="inner").collect(), + pl.DataFrame( + { + "a": pl.Series([1], dtype=pl.Int8), + "a_right": pl.Series([-128], dtype=pl.Int64), + "b": "A", + } + ), + ) + + assert_frame_equal( + left.join_where(right, left_expr == pl.col("a_right")).collect(), + pl.DataFrame( + { + "a": pl.Series([1], dtype=pl.Int8), + "a_right": pl.Series([-128], dtype=pl.Int64), + "b": "A", + } + ), + ) + + assert_frame_equal( + ( + left.join(right, left_on=left_expr, right_on="a", how="full") + .collect() + .sort(pl.all()) + ), + pl.DataFrame( + { + "a": pl.Series([1, None, None], dtype=pl.Int8), + "a_right": pl.Series([-128, 1, 128], dtype=pl.Int64), + "b": ["A", "A", "A"], + } + ).sort(pl.all()), + ) + def test_no_collapse_join_when_maintain_order_20725() -> None: df1 = pl.LazyFrame({"Fraction_1": [0, 25, 50, 75, 100]}) From f97179d222a5eab7a750039f7662e0a614ab2b8b Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 5 Feb 2025 17:52:25 +1100 Subject: [PATCH 2/9] c --- crates/polars-plan/src/plans/conversion/join.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 056cbd8d41aa..c368ff63ad5e 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -474,6 +474,9 @@ fn resolve_join_where( /// 1. They reference columns from both sides. /// 2. The dtypes of the LHS and RHS are match, or can be casted to a lossless /// supertype (and inserts the necessary casting). +/// +/// This function can be understood as a general iterative type check / coercion +/// pass with a hint of recursion to validate column-references. fn process_join_where_predicate( stack: &mut Vec, binary_expr_stack_offset: usize, From f46a4f08d4462e94a74782e82c53cad91decbb57 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 5 Feb 2025 17:55:17 +1100 Subject: [PATCH 3/9] c --- crates/polars-plan/src/plans/optimizer/join_utils.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/polars-plan/src/plans/optimizer/join_utils.rs b/crates/polars-plan/src/plans/optimizer/join_utils.rs index 9a279dd4d6ed..20c78dfd7a2e 100644 --- a/crates/polars-plan/src/plans/optimizer/join_utils.rs +++ b/crates/polars-plan/src/plans/optimizer/join_utils.rs @@ -1,5 +1,4 @@ use polars_core::schema::*; -#[cfg(feature = "iejoin")] use polars_utils::arena::{Arena, Node}; use super::{aexpr_to_leaf_names_iter, AExpr}; From a66fdd08d185e9ab1890eabb0de3ee37c7c3cdec Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 5 Feb 2025 17:56:08 +1100 Subject: [PATCH 4/9] c --- crates/polars-plan/src/plans/conversion/join.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index c368ff63ad5e..054fae439cd5 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -479,13 +479,13 @@ fn resolve_join_where( /// pass with a hint of recursion to validate column-references. fn process_join_where_predicate( stack: &mut Vec, - binary_expr_stack_offset: usize, + comparison_expr_stack_offset: usize, schema_left: &Schema, schema_merged: &Schema, expr_arena: &mut Arena, column_origins: &mut ExprOrigin, ) -> PolarsResult<()> { - while stack.len() > binary_expr_stack_offset { + while stack.len() > comparison_expr_stack_offset { let ae_node = stack.pop().unwrap(); let ae = expr_arena.get(ae_node).clone(); From b5bac04268e5bf230e9e0e80620c6027cb7f0510 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 5 Feb 2025 17:57:10 +1100 Subject: [PATCH 5/9] c --- crates/polars-plan/src/plans/conversion/join.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 054fae439cd5..afab7428ab12 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -479,13 +479,13 @@ fn resolve_join_where( /// pass with a hint of recursion to validate column-references. fn process_join_where_predicate( stack: &mut Vec, - comparison_expr_stack_offset: usize, + prev_comparison_expr_stack_offset: usize, schema_left: &Schema, schema_merged: &Schema, expr_arena: &mut Arena, column_origins: &mut ExprOrigin, ) -> PolarsResult<()> { - while stack.len() > comparison_expr_stack_offset { + while stack.len() > prev_comparison_expr_stack_offset { let ae_node = stack.pop().unwrap(); let ae = expr_arena.get(ae_node).clone(); From 53fe6e5bf6ac42483ced7c63e559f6d03f2e6c8a Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 5 Feb 2025 18:01:16 +1100 Subject: [PATCH 6/9] c --- crates/polars-plan/src/plans/conversion/join.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index afab7428ab12..ccdd9daa2ad7 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -475,8 +475,7 @@ fn resolve_join_where( /// 2. The dtypes of the LHS and RHS are match, or can be casted to a lossless /// supertype (and inserts the necessary casting). /// -/// This function can be understood as a general iterative type check / coercion -/// pass with a hint of recursion to validate column-references. +/// We perform (1) by recursing whenever we encounter a comparison expression. fn process_join_where_predicate( stack: &mut Vec, prev_comparison_expr_stack_offset: usize, From 0ccf01be02cc08faba9a156f5e0068303b8546db Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 5 Feb 2025 18:23:32 +1100 Subject: [PATCH 7/9] c --- crates/polars-plan/src/plans/conversion/join.rs | 5 +++++ py-polars/tests/unit/operations/test_inequality_join.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index ccdd9daa2ad7..1ee6137ff8a3 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -500,6 +500,11 @@ fn process_join_where_predicate( *column_origins |= origin; }, + // This is not actually Origin::Both, but we set this because the test suite expects + // this predicate to pass: + // * `pl.col("flag_right") == 1` + // Observe that it only has a column from one side because it is comparing to a literal. + AExpr::Literal(_) => *column_origins = ExprOrigin::Both, AExpr::BinaryExpr { left: left_node, op, diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index ce8a2289817f..43ee653ae3d0 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -662,7 +662,7 @@ def test_join_where_literal_20061() -> None: assert df_left.join_where( df_right, pl.col("value_left") > pl.col("value_right"), - pl.col("flag_right").cast(pl.Int32) == 1, + pl.col("flag_right") == 1, ).sort("id").to_dict(as_series=False) == { "id": [1, 2, 3, 3], "value_left": [10, 20, 30, 30], From bbd7aa5d64dfe959e6c03042dc8d6a7d9726c016 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 5 Feb 2025 18:39:07 +1100 Subject: [PATCH 8/9] c --- py-polars/tests/unit/operations/test_inequality_join.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 43ee653ae3d0..3e2bf381ed89 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -662,7 +662,7 @@ def test_join_where_literal_20061() -> None: assert df_left.join_where( df_right, pl.col("value_left") > pl.col("value_right"), - pl.col("flag_right") == 1, + pl.col("flag_right") == pl.lit(1, dtype=pl.Int8), ).sort("id").to_dict(as_series=False) == { "id": [1, 2, 3, 3], "value_left": [10, 20, 30, 30], From 1453baa3d19ffe048966e5480d16219fd009d8d9 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 5 Feb 2025 20:47:22 +1100 Subject: [PATCH 9/9] undo pub mod and comment --- crates/polars-plan/src/plans/conversion/join.rs | 3 +++ crates/polars-plan/src/plans/conversion/mod.rs | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 1ee6137ff8a3..a7aa88058621 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -558,6 +558,9 @@ fn process_join_where_predicate( let dtype_left = resolve_dtype(&left, left_node)?; let dtype_right = resolve_dtype(&right, right_node)?; + // Note: We only upcast the sides if the expr output dtype is Boolean (i.e. `op` is + // a comparison), otherwise the output may change. + if let Some(dtype) = get_numeric_upcast_supertype_lossless(&dtype_left, &dtype_right) .filter(|_| op.is_comparison()) diff --git a/crates/polars-plan/src/plans/conversion/mod.rs b/crates/polars-plan/src/plans/conversion/mod.rs index 3d7173ae9d6f..987dfc89cb37 100644 --- a/crates/polars-plan/src/plans/conversion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/mod.rs @@ -10,7 +10,7 @@ mod ir_to_dsl; feature = "json" ))] mod scans; -pub(crate) mod stack_opt; +mod stack_opt; use std::borrow::Cow; use std::sync::{Arc, Mutex};