From c01c3356926407646c9046a7aa6f8d4d65919722 Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Thu, 27 Nov 2025 21:36:56 +0800 Subject: [PATCH 1/6] Support simplify not for physical expr --- datafusion/execution/src/config.rs | 4 +- .../physical-expr/src/simplifier/mod.rs | 15 +- .../physical-expr/src/simplifier/not.rs | 379 ++++++++++++++++++ 3 files changed, 394 insertions(+), 4 deletions(-) create mode 100644 datafusion/physical-expr/src/simplifier/not.rs diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 443229a3cb77..3fa602f12554 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -114,10 +114,10 @@ impl Default for SessionConfig { } /// A type map for storing extensions. -/// +/// /// Extensions are indexed by their type `T`. If multiple values of the same type are provided, only the last one /// will be kept. -/// +/// /// Extensions are opaque objects that are unknown to DataFusion itself but can be downcast by optimizer rules, /// execution plans, or other components that have access to the session config. /// They provide a flexible way to attach extra data or behavior to the session config. diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 80d6ee0a7b91..4a3ed6a1299e 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -24,8 +24,9 @@ use datafusion_common::{ }; use std::sync::Arc; -use crate::PhysicalExpr; +use crate::{simplifier::not::simplify_not_expr_recursive, PhysicalExpr}; +pub mod not; pub mod unwrap_cast; /// Simplifies physical expressions by applying various optimizations @@ -56,6 +57,11 @@ impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> { type Node = Arc; fn f_up(&mut self, node: Self::Node) -> Result> { + // Apply NOT expression simplification first + let not_simplified = simplify_not_expr_recursive(&node, self.schema)?; + let node = not_simplified.data; + let transformed = not_simplified.transformed; + // Apply unwrap cast optimization #[cfg(test)] let original_type = node.data_type(self.schema).unwrap(); @@ -66,7 +72,12 @@ impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> { original_type, "Simplified expression should have the same data type as the original" ); - Ok(unwrapped) + // Combine transformation results + let final_transformed = transformed || unwrapped.transformed; + Ok(Transformed::new_transformed( + unwrapped.data, + final_transformed, + )) } } diff --git a/datafusion/physical-expr/src/simplifier/not.rs b/datafusion/physical-expr/src/simplifier/not.rs new file mode 100644 index 000000000000..7f448d8cc89b --- /dev/null +++ b/datafusion/physical-expr/src/simplifier/not.rs @@ -0,0 +1,379 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplify NOT expressions in physical expressions +//! +//! This module provides optimizations for NOT expressions such as: +//! - Double negation elimination: NOT(NOT(expr)) -> expr +//! - NOT with binary comparisons: NOT(a = b) -> a != b +//! - NOT with IN expressions: NOT(a IN (list)) -> a NOT IN (list) +//! - De Morgan's laws: NOT(A AND B) -> NOT A OR NOT B +//! - Constant folding: NOT(TRUE) -> FALSE, NOT(FALSE) -> TRUE + +use std::sync::Arc; + +use arrow::datatypes::Schema; +use datafusion_common::{tree_node::Transformed, Result, ScalarValue}; +use datafusion_expr::Operator; + +use crate::expressions::{lit, BinaryExpr, Literal, NotExpr}; +use crate::PhysicalExpr; + +/// Attempts to simplify NOT expressions +pub(crate) fn simplify_not_expr( + expr: Arc, + schema: &Schema, +) -> Result>> { + // Check if this is a NOT expression + let not_expr = match expr.as_any().downcast_ref::() { + Some(not_expr) => not_expr, + None => return Ok(Transformed::no(expr)), + }; + + let inner_expr = not_expr.arg(); + + // Handle NOT(NOT(expr)) -> expr (double negation elimination) + if let Some(inner_not) = inner_expr.as_any().downcast_ref::() { + // Recursively simplify the inner expression + let simplified = simplify_not_expr_recursive(inner_not.arg(), schema)?; + // We eliminated double negation, so always return transformed=true + return Ok(Transformed::yes(simplified.data)); + } + + // Handle NOT(literal) -> !literal + if let Some(literal) = inner_expr.as_any().downcast_ref::() { + if let ScalarValue::Boolean(Some(val)) = literal.value() { + return Ok(Transformed::yes(lit(ScalarValue::Boolean(Some(!val))))); + } + if let ScalarValue::Boolean(None) = literal.value() { + return Ok(Transformed::yes(lit(ScalarValue::Boolean(None)))); + } + } + + // Handle NOT(binary_expr) where we can flip the operator + if let Some(binary_expr) = inner_expr.as_any().downcast_ref::() { + if let Some(negated_op) = negate_operator(binary_expr.op()) { + // Recursively simplify the left and right expressions first + let left_simplified = + simplify_not_expr_recursive(binary_expr.left(), schema)?; + let right_simplified = + simplify_not_expr_recursive(binary_expr.right(), schema)?; + + let new_binary = Arc::new(BinaryExpr::new( + left_simplified.data, + negated_op, + right_simplified.data, + )); + // We flipped the operator, so always return transformed=true + return Ok(Transformed::yes(new_binary)); + } + + // Handle De Morgan's laws for AND/OR + match binary_expr.op() { + Operator::And => { + // NOT(A AND B) -> NOT A OR NOT B + let not_left: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.left()))); + let not_right: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.right()))); + + // Recursively simplify the NOT expressions + let simplified_left = simplify_not_expr_recursive(¬_left, schema)?; + let simplified_right = simplify_not_expr_recursive(¬_right, schema)?; + + let new_binary = Arc::new(BinaryExpr::new( + simplified_left.data, + Operator::Or, + simplified_right.data, + )); + return Ok(Transformed::yes(new_binary)); + } + Operator::Or => { + // NOT(A OR B) -> NOT A AND NOT B + let not_left: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.left()))); + let not_right: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.right()))); + + // Recursively simplify the NOT expressions + let simplified_left = simplify_not_expr_recursive(¬_left, schema)?; + let simplified_right = simplify_not_expr_recursive(¬_right, schema)?; + + let new_binary = Arc::new(BinaryExpr::new( + simplified_left.data, + Operator::And, + simplified_right.data, + )); + return Ok(Transformed::yes(new_binary)); + } + _ => {} + } + } + + // If no simplification possible, return the original expression + Ok(Transformed::no(expr)) +} + +/// Helper function that recursively simplifies expressions, including NOT expressions +pub fn simplify_not_expr_recursive( + expr: &Arc, + schema: &Schema, +) -> Result>> { + // First, try to simplify any NOT expressions in this expression + let not_simplified = simplify_not_expr(Arc::clone(expr), schema)?; + + // If the expression was transformed, we might have created new opportunities for simplification + if not_simplified.transformed { + // Recursively simplify the result + let further_simplified = + simplify_not_expr_recursive(¬_simplified.data, schema)?; + if further_simplified.transformed { + return Ok(Transformed::yes(further_simplified.data)); + } else { + return Ok(not_simplified); + } + } + + // If this expression wasn't a NOT expression, try to simplify its children + // This handles cases where NOT expressions might be nested deeper in the tree + if let Some(binary_expr) = expr.as_any().downcast_ref::() { + let left_simplified = simplify_not_expr_recursive(binary_expr.left(), schema)?; + let right_simplified = simplify_not_expr_recursive(binary_expr.right(), schema)?; + + if left_simplified.transformed || right_simplified.transformed { + let new_binary = Arc::new(BinaryExpr::new( + left_simplified.data, + *binary_expr.op(), + right_simplified.data, + )); + return Ok(Transformed::yes(new_binary)); + } + } + + Ok(not_simplified) +} + +/// Returns the negated version of a comparison operator, if possible +fn negate_operator(op: &Operator) -> Option { + match op { + Operator::Eq => Some(Operator::NotEq), + Operator::NotEq => Some(Operator::Eq), + Operator::Lt => Some(Operator::GtEq), + Operator::LtEq => Some(Operator::Gt), + Operator::Gt => Some(Operator::LtEq), + Operator::GtEq => Some(Operator::Lt), + Operator::IsDistinctFrom => Some(Operator::IsNotDistinctFrom), + Operator::IsNotDistinctFrom => Some(Operator::IsDistinctFrom), + // For other operators, we can't directly negate them + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, lit, BinaryExpr, NotExpr}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use datafusion_expr::Operator; + + fn test_schema() -> Schema { + Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Int32, false), + ]) + } + + #[test] + fn test_double_negation_elimination() -> Result<()> { + let schema = test_schema(); + + // Create NOT(NOT(b > 5)) + let inner_expr: Arc = Arc::new(BinaryExpr::new( + col("b", &schema)?, + Operator::Gt, + lit(ScalarValue::Int32(Some(5))), + )); + let inner_not = Arc::new(NotExpr::new(Arc::clone(&inner_expr))); + let double_not: Arc = Arc::new(NotExpr::new(inner_not)); + + let result = simplify_not_expr_recursive(&double_not, &schema)?; + + assert!(result.transformed); + // Should be simplified back to the original b > 5 + assert_eq!(result.data.to_string(), inner_expr.to_string()); + Ok(()) + } + + #[test] + fn test_not_literal() -> Result<()> { + let schema = test_schema(); + + // NOT(TRUE) -> FALSE + let not_true = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(true))))); + let result = simplify_not_expr(not_true, &schema)?; + assert!(result.transformed); + + if let Some(literal) = result.data.as_any().downcast_ref::() { + assert_eq!(literal.value(), &ScalarValue::Boolean(Some(false))); + } else { + panic!("Expected literal result"); + } + + // NOT(FALSE) -> TRUE + let not_false: Arc = + Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(false))))); + let result = simplify_not_expr_recursive(¬_false, &schema)?; + assert!(result.transformed); + + if let Some(literal) = result.data.as_any().downcast_ref::() { + assert_eq!(literal.value(), &ScalarValue::Boolean(Some(true))); + } else { + panic!("Expected literal result"); + } + + Ok(()) + } + + #[test] + fn test_negate_comparison() -> Result<()> { + let schema = test_schema(); + + // NOT(b = 5) -> b != 5 + let eq_expr = Arc::new(BinaryExpr::new( + col("b", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(5))), + )); + let not_eq: Arc = Arc::new(NotExpr::new(eq_expr)); + + let result = simplify_not_expr_recursive(¬_eq, &schema)?; + assert!(result.transformed); + + if let Some(binary) = result.data.as_any().downcast_ref::() { + assert_eq!(binary.op(), &Operator::NotEq); + } else { + panic!("Expected binary expression result"); + } + + Ok(()) + } + + #[test] + fn test_demorgans_law_and() -> Result<()> { + let schema = test_schema(); + + // NOT(a AND b) -> NOT a OR NOT b + let and_expr = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::And, + col("b", &schema)?, + )); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let result = simplify_not_expr_recursive(¬_and, &schema)?; + assert!(result.transformed); + + if let Some(binary) = result.data.as_any().downcast_ref::() { + assert_eq!(binary.op(), &Operator::Or); + // Left and right should both be NOT expressions + assert!(binary.left().as_any().downcast_ref::().is_some()); + assert!(binary.right().as_any().downcast_ref::().is_some()); + } else { + panic!("Expected binary expression result"); + } + + Ok(()) + } + + #[test] + fn test_demorgans_law_or() -> Result<()> { + let schema = test_schema(); + + // NOT(a OR b) -> NOT a AND NOT b + let or_expr = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::Or, + col("b", &schema)?, + )); + let not_or: Arc = Arc::new(NotExpr::new(or_expr)); + + let result = simplify_not_expr_recursive(¬_or, &schema)?; + assert!(result.transformed); + + if let Some(binary) = result.data.as_any().downcast_ref::() { + assert_eq!(binary.op(), &Operator::And); + // Left and right should both be NOT expressions + assert!(binary.left().as_any().downcast_ref::().is_some()); + assert!(binary.right().as_any().downcast_ref::().is_some()); + } else { + panic!("Expected binary expression result"); + } + + Ok(()) + } + + #[test] + fn test_demorgans_with_comparison_simplification() -> Result<()> { + let schema = test_schema(); + + // NOT(b = 1 AND b = 2) -> b != 1 OR b != 2 + // This tests the combination of De Morgan's law and operator negation + let eq1 = Arc::new(BinaryExpr::new( + col("b", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(1))), + )); + let eq2 = Arc::new(BinaryExpr::new( + col("b", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(2))), + )); + let and_expr = Arc::new(BinaryExpr::new(eq1, Operator::And, eq2)); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let result = simplify_not_expr_recursive(¬_and, &schema)?; + assert!(result.transformed, "Expression should be transformed"); + + // Verify the result is an OR expression + if let Some(or_binary) = result.data.as_any().downcast_ref::() { + assert_eq!(or_binary.op(), &Operator::Or, "Top level should be OR"); + + // Verify left side is b != 1 + if let Some(left_binary) = + or_binary.left().as_any().downcast_ref::() + { + assert_eq!(left_binary.op(), &Operator::NotEq, "Left should be NotEq"); + } else { + panic!("Expected left to be a binary expression with !="); + } + + // Verify right side is b != 2 + if let Some(right_binary) = + or_binary.right().as_any().downcast_ref::() + { + assert_eq!(right_binary.op(), &Operator::NotEq, "Right should be NotEq"); + } else { + panic!("Expected right to be a binary expression with !="); + } + } else { + panic!("Expected binary OR expression result"); + } + + Ok(()) + } +} From a0116cec5edbd576440610fe4d137f784235e4e0 Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Mon, 1 Dec 2025 09:57:18 +0800 Subject: [PATCH 2/6] not(not a and not b) --- .../physical-expr/src/simplifier/not.rs | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/datafusion/physical-expr/src/simplifier/not.rs b/datafusion/physical-expr/src/simplifier/not.rs index 7f448d8cc89b..e41e4035bd14 100644 --- a/datafusion/physical-expr/src/simplifier/not.rs +++ b/datafusion/physical-expr/src/simplifier/not.rs @@ -376,4 +376,48 @@ mod tests { Ok(()) } + + #[test] + fn test_not_of_not_and_not() -> Result<()> { + let schema = test_schema(); + + // NOT(NOT(a) AND NOT(b)) -> a OR b + // This tests the combination of De Morgan's law and double negation elimination + let not_a = Arc::new(NotExpr::new(col("a", &schema)?)); + let not_b = Arc::new(NotExpr::new(col("b", &schema)?)); + let and_expr = Arc::new(BinaryExpr::new(not_a, Operator::And, not_b)); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let result = simplify_not_expr_recursive(¬_and, &schema)?; + assert!(result.transformed, "Expression should be transformed"); + + // Verify the result is an OR expression + if let Some(or_binary) = result.data.as_any().downcast_ref::() { + assert_eq!(or_binary.op(), &Operator::Or, "Top level should be OR"); + + // Verify left side is just 'a' + assert!( + or_binary + .left() + .as_any() + .downcast_ref::() + .is_none(), + "Left should be simplified to just 'a'" + ); + + // Verify right side is just 'b' + assert!( + or_binary + .right() + .as_any() + .downcast_ref::() + .is_none(), + "Right should be simplified to just 'b'" + ); + } else { + panic!("Expected binary OR expression result"); + } + + Ok(()) + } } From 0d075ba7998167f1af160266e38efab7aeb51e6e Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Mon, 1 Dec 2025 10:30:15 +0800 Subject: [PATCH 3/6] support not list --- .../physical-expr/src/simplifier/mod.rs | 6 +- .../physical-expr/src/simplifier/not.rs | 136 +++++++++++++++--- 2 files changed, 123 insertions(+), 19 deletions(-) diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 4a3ed6a1299e..1059db49f86b 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -58,9 +58,9 @@ impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> { fn f_up(&mut self, node: Self::Node) -> Result> { // Apply NOT expression simplification first - let not_simplified = simplify_not_expr_recursive(&node, self.schema)?; - let node = not_simplified.data; - let transformed = not_simplified.transformed; + let not_expr_simplified = simplify_not_expr_recursive(&node, self.schema)?; + let node = not_expr_simplified.data; + let transformed = not_expr_simplified.transformed; // Apply unwrap cast optimization #[cfg(test)] diff --git a/datafusion/physical-expr/src/simplifier/not.rs b/datafusion/physical-expr/src/simplifier/not.rs index e41e4035bd14..2d264231a868 100644 --- a/datafusion/physical-expr/src/simplifier/not.rs +++ b/datafusion/physical-expr/src/simplifier/not.rs @@ -30,7 +30,7 @@ use arrow::datatypes::Schema; use datafusion_common::{tree_node::Transformed, Result, ScalarValue}; use datafusion_expr::Operator; -use crate::expressions::{lit, BinaryExpr, Literal, NotExpr}; +use crate::expressions::{in_list, lit, BinaryExpr, InListExpr, Literal, NotExpr}; use crate::PhysicalExpr; /// Attempts to simplify NOT expressions @@ -64,6 +64,19 @@ pub(crate) fn simplify_not_expr( } } + // Handle NOT(IN list) -> NOT IN list + if let Some(in_list_expr) = inner_expr.as_any().downcast_ref::() { + // Create a new InList expression with negated flag flipped + let negated = !in_list_expr.negated(); + let new_in_list = in_list( + Arc::clone(in_list_expr.expr()), + in_list_expr.list().to_vec(), + &negated, + schema, + )?; + return Ok(Transformed::yes(new_in_list)); + } + // Handle NOT(binary_expr) where we can flip the operator if let Some(binary_expr) = inner_expr.as_any().downcast_ref::() { if let Some(negated_op) = negate_operator(binary_expr.op()) { @@ -186,7 +199,7 @@ fn negate_operator(op: &Operator) -> Option { #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit, BinaryExpr, NotExpr}; + use crate::expressions::{col, in_list, lit, BinaryExpr, NotExpr}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; @@ -396,26 +409,117 @@ mod tests { assert_eq!(or_binary.op(), &Operator::Or, "Top level should be OR"); // Verify left side is just 'a' + assert!(or_binary.left().as_any().downcast_ref::().is_none(), + "Left should not be a NOT expression, it should be simplified to just 'a'"); + + // Verify right side is just 'b' + assert!(or_binary.right().as_any().downcast_ref::().is_none(), + "Right should not be a NOT expression, it should be simplified to just 'b'"); + } else { + panic!("Expected binary OR expression result"); + } + + Ok(()) + } + + #[test] + fn test_not_in_list() -> Result<()> { + let schema = test_schema(); + + // NOT(b IN (1, 2, 3)) -> b NOT IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let in_list_expr = in_list(col("b", &schema)?, list, &false, &schema)?; + let not_in: Arc = Arc::new(NotExpr::new(in_list_expr)); + + let result = simplify_not_expr_recursive(¬_in, &schema)?; + assert!(result.transformed, "Expression should be transformed"); + + // Verify the result is an InList expression with negated=true + if let Some(in_list_result) = result.data.as_any().downcast_ref::() { assert!( - or_binary - .left() - .as_any() - .downcast_ref::() - .is_none(), - "Left should be simplified to just 'a'" + in_list_result.negated(), + "InList should be negated (NOT IN)" + ); + assert_eq!( + in_list_result.list().len(), + 3, + "Should have 3 items in list" ); + } else { + panic!("Expected InListExpr result"); + } - // Verify right side is just 'b' + Ok(()) + } + + #[test] + fn test_not_not_in_list() -> Result<()> { + let schema = test_schema(); + + // NOT(b NOT IN (1, 2, 3)) -> b IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let not_in_list_expr = in_list(col("b", &schema)?, list, &true, &schema)?; + let not_not_in: Arc = Arc::new(NotExpr::new(not_in_list_expr)); + + let result = simplify_not_expr_recursive(¬_not_in, &schema)?; + assert!(result.transformed, "Expression should be transformed"); + + // Verify the result is an InList expression with negated=false + if let Some(in_list_result) = result.data.as_any().downcast_ref::() { assert!( - or_binary - .right() - .as_any() - .downcast_ref::() - .is_none(), - "Right should be simplified to just 'b'" + !in_list_result.negated(), + "InList should not be negated (IN)" + ); + assert_eq!( + in_list_result.list().len(), + 3, + "Should have 3 items in list" ); } else { - panic!("Expected binary OR expression result"); + panic!("Expected InListExpr result"); + } + + Ok(()) + } + + #[test] + fn test_double_not_in_list() -> Result<()> { + let schema = test_schema(); + + // NOT(NOT(b IN (1, 2, 3))) -> b IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let in_list_expr = in_list(col("b", &schema)?, list, &false, &schema)?; + let not_in = Arc::new(NotExpr::new(in_list_expr)); + let double_not: Arc = Arc::new(NotExpr::new(not_in)); + + let result = simplify_not_expr_recursive(&double_not, &schema)?; + assert!(result.transformed, "Expression should be transformed"); + + // After double negation elimination, we should get back the original IN expression + if let Some(in_list_result) = result.data.as_any().downcast_ref::() { + assert!( + !in_list_result.negated(), + "InList should not be negated (IN)" + ); + assert_eq!( + in_list_result.list().len(), + 3, + "Should have 3 items in list" + ); + } else { + panic!("Expected InListExpr result"); } Ok(()) From ceb85ff1327c78647decf04812b88ba8fdc6185a Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Mon, 1 Dec 2025 12:10:04 +0800 Subject: [PATCH 4/6] replace recursive and add test --- .../physical-expr/src/simplifier/mod.rs | 4 +- .../physical-expr/src/simplifier/not.rs | 145 ++++++++++++------ 2 files changed, 96 insertions(+), 53 deletions(-) diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 1059db49f86b..ea4979870664 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -24,7 +24,7 @@ use datafusion_common::{ }; use std::sync::Arc; -use crate::{simplifier::not::simplify_not_expr_recursive, PhysicalExpr}; +use crate::{simplifier::not::simplify_not_expr, PhysicalExpr}; pub mod not; pub mod unwrap_cast; @@ -58,7 +58,7 @@ impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> { fn f_up(&mut self, node: Self::Node) -> Result> { // Apply NOT expression simplification first - let not_expr_simplified = simplify_not_expr_recursive(&node, self.schema)?; + let not_expr_simplified = simplify_not_expr(&node, self.schema)?; let node = not_expr_simplified.data; let transformed = not_expr_simplified.transformed; diff --git a/datafusion/physical-expr/src/simplifier/not.rs b/datafusion/physical-expr/src/simplifier/not.rs index 2d264231a868..286e0794b8d2 100644 --- a/datafusion/physical-expr/src/simplifier/not.rs +++ b/datafusion/physical-expr/src/simplifier/not.rs @@ -34,7 +34,7 @@ use crate::expressions::{in_list, lit, BinaryExpr, InListExpr, Literal, NotExpr} use crate::PhysicalExpr; /// Attempts to simplify NOT expressions -pub(crate) fn simplify_not_expr( +pub(crate) fn simplify_not_expr_impl( expr: Arc, schema: &Schema, ) -> Result>> { @@ -48,10 +48,8 @@ pub(crate) fn simplify_not_expr( // Handle NOT(NOT(expr)) -> expr (double negation elimination) if let Some(inner_not) = inner_expr.as_any().downcast_ref::() { - // Recursively simplify the inner expression - let simplified = simplify_not_expr_recursive(inner_not.arg(), schema)?; // We eliminated double negation, so always return transformed=true - return Ok(Transformed::yes(simplified.data)); + return Ok(Transformed::yes(Arc::clone(inner_not.arg()))); } // Handle NOT(literal) -> !literal @@ -81,10 +79,8 @@ pub(crate) fn simplify_not_expr( if let Some(binary_expr) = inner_expr.as_any().downcast_ref::() { if let Some(negated_op) = negate_operator(binary_expr.op()) { // Recursively simplify the left and right expressions first - let left_simplified = - simplify_not_expr_recursive(binary_expr.left(), schema)?; - let right_simplified = - simplify_not_expr_recursive(binary_expr.right(), schema)?; + let left_simplified = simplify_not_expr(binary_expr.left(), schema)?; + let right_simplified = simplify_not_expr(binary_expr.right(), schema)?; let new_binary = Arc::new(BinaryExpr::new( left_simplified.data, @@ -105,8 +101,8 @@ pub(crate) fn simplify_not_expr( Arc::new(NotExpr::new(Arc::clone(binary_expr.right()))); // Recursively simplify the NOT expressions - let simplified_left = simplify_not_expr_recursive(¬_left, schema)?; - let simplified_right = simplify_not_expr_recursive(¬_right, schema)?; + let simplified_left = simplify_not_expr(¬_left, schema)?; + let simplified_right = simplify_not_expr(¬_right, schema)?; let new_binary = Arc::new(BinaryExpr::new( simplified_left.data, @@ -123,8 +119,8 @@ pub(crate) fn simplify_not_expr( Arc::new(NotExpr::new(Arc::clone(binary_expr.right()))); // Recursively simplify the NOT expressions - let simplified_left = simplify_not_expr_recursive(¬_left, schema)?; - let simplified_right = simplify_not_expr_recursive(¬_right, schema)?; + let simplified_left = simplify_not_expr(¬_left, schema)?; + let simplified_right = simplify_not_expr(¬_right, schema)?; let new_binary = Arc::new(BinaryExpr::new( simplified_left.data, @@ -141,43 +137,43 @@ pub(crate) fn simplify_not_expr( Ok(Transformed::no(expr)) } -/// Helper function that recursively simplifies expressions, including NOT expressions -pub fn simplify_not_expr_recursive( +pub fn simplify_not_expr( expr: &Arc, schema: &Schema, ) -> Result>> { - // First, try to simplify any NOT expressions in this expression - let not_simplified = simplify_not_expr(Arc::clone(expr), schema)?; - - // If the expression was transformed, we might have created new opportunities for simplification - if not_simplified.transformed { - // Recursively simplify the result - let further_simplified = - simplify_not_expr_recursive(¬_simplified.data, schema)?; - if further_simplified.transformed { - return Ok(Transformed::yes(further_simplified.data)); - } else { - return Ok(not_simplified); + let mut current_expr = Arc::clone(expr); + let mut overall_transformed = false; + + loop { + let not_simplified = simplify_not_expr_impl(Arc::clone(¤t_expr), schema)?; + if not_simplified.transformed { + overall_transformed = true; + current_expr = not_simplified.data; + continue; } - } - // If this expression wasn't a NOT expression, try to simplify its children - // This handles cases where NOT expressions might be nested deeper in the tree - if let Some(binary_expr) = expr.as_any().downcast_ref::() { - let left_simplified = simplify_not_expr_recursive(binary_expr.left(), schema)?; - let right_simplified = simplify_not_expr_recursive(binary_expr.right(), schema)?; + if let Some(binary_expr) = current_expr.as_any().downcast_ref::() { + let left_simplified = simplify_not_expr(binary_expr.left(), schema)?; + let right_simplified = simplify_not_expr(binary_expr.right(), schema)?; - if left_simplified.transformed || right_simplified.transformed { - let new_binary = Arc::new(BinaryExpr::new( - left_simplified.data, - *binary_expr.op(), - right_simplified.data, - )); - return Ok(Transformed::yes(new_binary)); + if left_simplified.transformed || right_simplified.transformed { + let new_binary = Arc::new(BinaryExpr::new( + left_simplified.data, + *binary_expr.op(), + right_simplified.data, + )); + return Ok(Transformed::yes(new_binary)); + } } + + break; } - Ok(not_simplified) + if overall_transformed { + Ok(Transformed::yes(current_expr)) + } else { + Ok(Transformed::no(current_expr)) + } } /// Returns the negated version of a comparison operator, if possible @@ -224,7 +220,7 @@ mod tests { let inner_not = Arc::new(NotExpr::new(Arc::clone(&inner_expr))); let double_not: Arc = Arc::new(NotExpr::new(inner_not)); - let result = simplify_not_expr_recursive(&double_not, &schema)?; + let result = simplify_not_expr(&double_not, &schema)?; assert!(result.transformed); // Should be simplified back to the original b > 5 @@ -238,7 +234,7 @@ mod tests { // NOT(TRUE) -> FALSE let not_true = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(true))))); - let result = simplify_not_expr(not_true, &schema)?; + let result = simplify_not_expr_impl(not_true, &schema)?; assert!(result.transformed); if let Some(literal) = result.data.as_any().downcast_ref::() { @@ -250,7 +246,7 @@ mod tests { // NOT(FALSE) -> TRUE let not_false: Arc = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(false))))); - let result = simplify_not_expr_recursive(¬_false, &schema)?; + let result = simplify_not_expr(¬_false, &schema)?; assert!(result.transformed); if let Some(literal) = result.data.as_any().downcast_ref::() { @@ -274,7 +270,7 @@ mod tests { )); let not_eq: Arc = Arc::new(NotExpr::new(eq_expr)); - let result = simplify_not_expr_recursive(¬_eq, &schema)?; + let result = simplify_not_expr(¬_eq, &schema)?; assert!(result.transformed); if let Some(binary) = result.data.as_any().downcast_ref::() { @@ -298,7 +294,7 @@ mod tests { )); let not_and: Arc = Arc::new(NotExpr::new(and_expr)); - let result = simplify_not_expr_recursive(¬_and, &schema)?; + let result = simplify_not_expr(¬_and, &schema)?; assert!(result.transformed); if let Some(binary) = result.data.as_any().downcast_ref::() { @@ -325,7 +321,7 @@ mod tests { )); let not_or: Arc = Arc::new(NotExpr::new(or_expr)); - let result = simplify_not_expr_recursive(¬_or, &schema)?; + let result = simplify_not_expr(¬_or, &schema)?; assert!(result.transformed); if let Some(binary) = result.data.as_any().downcast_ref::() { @@ -359,7 +355,7 @@ mod tests { let and_expr = Arc::new(BinaryExpr::new(eq1, Operator::And, eq2)); let not_and: Arc = Arc::new(NotExpr::new(and_expr)); - let result = simplify_not_expr_recursive(¬_and, &schema)?; + let result = simplify_not_expr(¬_and, &schema)?; assert!(result.transformed, "Expression should be transformed"); // Verify the result is an OR expression @@ -401,7 +397,7 @@ mod tests { let and_expr = Arc::new(BinaryExpr::new(not_a, Operator::And, not_b)); let not_and: Arc = Arc::new(NotExpr::new(and_expr)); - let result = simplify_not_expr_recursive(¬_and, &schema)?; + let result = simplify_not_expr(¬_and, &schema)?; assert!(result.transformed, "Expression should be transformed"); // Verify the result is an OR expression @@ -435,7 +431,7 @@ mod tests { let in_list_expr = in_list(col("b", &schema)?, list, &false, &schema)?; let not_in: Arc = Arc::new(NotExpr::new(in_list_expr)); - let result = simplify_not_expr_recursive(¬_in, &schema)?; + let result = simplify_not_expr(¬_in, &schema)?; assert!(result.transformed, "Expression should be transformed"); // Verify the result is an InList expression with negated=true @@ -469,7 +465,7 @@ mod tests { let not_in_list_expr = in_list(col("b", &schema)?, list, &true, &schema)?; let not_not_in: Arc = Arc::new(NotExpr::new(not_in_list_expr)); - let result = simplify_not_expr_recursive(¬_not_in, &schema)?; + let result = simplify_not_expr(¬_not_in, &schema)?; assert!(result.transformed, "Expression should be transformed"); // Verify the result is an InList expression with negated=false @@ -504,7 +500,7 @@ mod tests { let not_in = Arc::new(NotExpr::new(in_list_expr)); let double_not: Arc = Arc::new(NotExpr::new(not_in)); - let result = simplify_not_expr_recursive(&double_not, &schema)?; + let result = simplify_not_expr(&double_not, &schema)?; assert!(result.transformed, "Expression should be transformed"); // After double negation elimination, we should get back the original IN expression @@ -524,4 +520,51 @@ mod tests { Ok(()) } + + #[test] + fn test_deeply_nested_not() -> Result<()> { + let schema = test_schema(); + + // Create a deeply nested NOT expression: NOT(NOT(NOT(...NOT(b > 5)...))) + // This tests that we don't get stack overflow with many nested NOTs + let inner_expr: Arc = Arc::new(BinaryExpr::new( + col("b", &schema)?, + Operator::Gt, + lit(ScalarValue::Int32(Some(5))), + )); + + let mut expr = Arc::clone(&inner_expr); + // Create 20000 layers of NOT + for _ in 0..20000 { + expr = Arc::new(NotExpr::new(expr)); + } + + let result = simplify_not_expr(&expr, &schema)?; + + // With 20000 NOTs (even number), should simplify back to the original expression + assert_eq!( + result.data.to_string(), + inner_expr.to_string(), + "Should simplify back to original expression" + ); + + // Manually dismantle the deep input expression to avoid Stack Overflow on Drop + // If we just let `expr` go out of scope, Rust's recursive Drop will blow the stack. + // We peel off layers one by one. + while let Some(not_expr) = expr.as_any().downcast_ref::() { + // Clone the child (Arc increment). + // Now child has 2 refs: one in parent, one in `child`. + let child = Arc::clone(not_expr.arg()); + + // Reassign `expr` to `child`. + // This drops the old `expr` (Parent). + // Parent refcount -> 0, Parent is dropped. + // Parent drops its reference to Child. + // Child refcount decrements 2 -> 1. + // Child is NOT dropped recursively because we still hold it in `expr` + expr = child; + } + + Ok(()) + } } From be077dbe30cd5a9831bafe4488d4b0bf9831f417 Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Tue, 2 Dec 2025 15:39:33 +0800 Subject: [PATCH 5/6] Use transform_data to connect simplify_not_expr & unwrap_cast_in_comparison --- .../physical-expr/src/simplifier/mod.rs | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index ea4979870664..8fe6ccdf81bd 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -57,27 +57,23 @@ impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> { type Node = Arc; fn f_up(&mut self, node: Self::Node) -> Result> { - // Apply NOT expression simplification first - let not_expr_simplified = simplify_not_expr(&node, self.schema)?; - let node = not_expr_simplified.data; - let transformed = not_expr_simplified.transformed; - - // Apply unwrap cast optimization #[cfg(test)] let original_type = node.data_type(self.schema).unwrap(); - let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, self.schema)?; + + // Apply NOT expression simplification first, then unwrap cast optimization + let rewritten = + simplify_not_expr(&node, self.schema)?.transform_data(|node| { + unwrap_cast::unwrap_cast_in_comparison(node, self.schema) + })?; + #[cfg(test)] assert_eq!( - unwrapped.data.data_type(self.schema).unwrap(), + rewritten.data.data_type(self.schema).unwrap(), original_type, "Simplified expression should have the same data type as the original" ); - // Combine transformation results - let final_transformed = transformed || unwrapped.transformed; - Ok(Transformed::new_transformed( - unwrapped.data, - final_transformed, - )) + + Ok(rewritten) } } From 27c4ef18b9bd3ac0beddcc816041972903a47144 Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Tue, 2 Dec 2025 16:50:47 +0800 Subject: [PATCH 6/6] resolve all comments --- .../physical-expr/src/simplifier/mod.rs | 357 +++++++++++-- .../physical-expr/src/simplifier/not.rs | 490 +----------------- 2 files changed, 350 insertions(+), 497 deletions(-) diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 8fe6ccdf81bd..d06629d03c5e 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -29,6 +29,8 @@ use crate::{simplifier::not::simplify_not_expr, PhysicalExpr}; pub mod not; pub mod unwrap_cast; +const MAX_LOOP_COUNT: usize = 5; + /// Simplifies physical expressions by applying various optimizations /// /// This can be useful after adapting expressions from a table schema @@ -49,7 +51,17 @@ impl<'a> PhysicalExprSimplifier<'a> { &mut self, expr: Arc, ) -> Result> { - Ok(expr.rewrite(self)?.data) + let mut current_expr = expr; + let mut count = 0; + while count < MAX_LOOP_COUNT { + count += 1; + let result = current_expr.rewrite(self)?; + if !result.transformed { + return Ok(result.data); + } + current_expr = result.data; + } + Ok(current_expr) } } @@ -80,7 +92,9 @@ impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> { #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; + use crate::expressions::{ + col, in_list, lit, BinaryExpr, CastExpr, Literal, NotExpr, TryCastExpr, + }; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; @@ -93,6 +107,42 @@ mod tests { ]) } + fn not_test_schema() -> Schema { + Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]) + } + + /// Helper function to extract a Literal from a PhysicalExpr + fn as_literal(expr: &Arc) -> &Literal { + expr.as_any() + .downcast_ref::() + .unwrap_or_else(|| panic!("Expected Literal, got: {expr}")) + } + + /// Helper function to extract a BinaryExpr from a PhysicalExpr + fn as_binary(expr: &Arc) -> &BinaryExpr { + expr.as_any() + .downcast_ref::() + .unwrap_or_else(|| panic!("Expected BinaryExpr, got: {expr}")) + } + + /// Assert that simplifying `input` produces `expected` + fn assert_not_simplify( + simplifier: &mut PhysicalExprSimplifier, + input: Arc, + expected: Arc, + ) { + let result = simplifier.simplify(Arc::clone(&input)).unwrap(); + assert_eq!( + &result, + &expected, + "Simplification should transform:\n input: {input}\n to: {expected}\n got: {result}" + ); + } + #[test] fn test_simplify() { let schema = test_schema(); @@ -108,7 +158,7 @@ mod tests { // Apply full simplification (uses TreeNodeRewriter) let optimized = simplifier.simplify(binary_expr).unwrap(); - let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + let optimized_binary = as_binary(&optimized); // Should be optimized to: c2 != INT64(99) (c2 is INT64, literal cast to match) let left_expr = optimized_binary.left(); @@ -116,11 +166,7 @@ mod tests { left_expr.as_any().downcast_ref::().is_none() && left_expr.as_any().downcast_ref::().is_none() ); - let right_literal = optimized_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_literal = as_literal(optimized_binary.right()); assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99))); } @@ -145,14 +191,10 @@ mod tests { // Apply simplification let optimized = simplifier.simplify(or_expr).unwrap(); - let or_binary = optimized.as_any().downcast_ref::().unwrap(); + let or_binary = as_binary(&optimized); // Verify left side: c1 > INT32(5) - let left_binary = or_binary - .left() - .as_any() - .downcast_ref::() - .unwrap(); + let left_binary = as_binary(or_binary.left()); let left_left_expr = left_binary.left(); assert!( left_left_expr.as_any().downcast_ref::().is_none() @@ -161,19 +203,11 @@ mod tests { .downcast_ref::() .is_none() ); - let left_literal = left_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let left_literal = as_literal(left_binary.right()); assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5))); // Verify right side: c2 <= INT64(10) - let right_binary = or_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_binary = as_binary(or_binary.right()); let right_left_expr = right_binary.left(); assert!( right_left_expr @@ -185,11 +219,276 @@ mod tests { .downcast_ref::() .is_none() ); - let right_literal = right_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_literal = as_literal(right_binary.right()); assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10))); } + + #[test] + fn test_double_negation_elimination() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(NOT(c > 5)) -> c > 5 + let inner_expr: Arc = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Gt, + lit(ScalarValue::Int32(Some(5))), + )); + let inner_not = Arc::new(NotExpr::new(Arc::clone(&inner_expr))); + let double_not: Arc = Arc::new(NotExpr::new(inner_not)); + + let expected = inner_expr; + assert_not_simplify(&mut simplifier, double_not, expected); + Ok(()) + } + + #[test] + fn test_not_literal() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(TRUE) -> FALSE + let not_true = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(true))))); + let expected = lit(ScalarValue::Boolean(Some(false))); + assert_not_simplify(&mut simplifier, not_true, expected); + + // NOT(FALSE) -> TRUE + let not_false = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(false))))); + let expected = lit(ScalarValue::Boolean(Some(true))); + assert_not_simplify(&mut simplifier, not_false, expected); + + Ok(()) + } + + #[test] + fn test_negate_comparison() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c = 5) -> c != 5 + let not_eq = Arc::new(NotExpr::new(Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(5))), + )))); + let expected = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::NotEq, + lit(ScalarValue::Int32(Some(5))), + )); + assert_not_simplify(&mut simplifier, not_eq, expected); + + Ok(()) + } + + #[test] + fn test_demorgans_law_and() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(a AND b) -> NOT a OR NOT b + let and_expr = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::And, + col("b", &schema)?, + )); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + Arc::new(NotExpr::new(col("a", &schema)?)), + Operator::Or, + Arc::new(NotExpr::new(col("b", &schema)?)), + )); + assert_not_simplify(&mut simplifier, not_and, expected); + + Ok(()) + } + + #[test] + fn test_demorgans_law_or() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(a OR b) -> NOT a AND NOT b + let or_expr = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::Or, + col("b", &schema)?, + )); + let not_or: Arc = Arc::new(NotExpr::new(or_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + Arc::new(NotExpr::new(col("a", &schema)?)), + Operator::And, + Arc::new(NotExpr::new(col("b", &schema)?)), + )); + assert_not_simplify(&mut simplifier, not_or, expected); + + Ok(()) + } + + #[test] + fn test_demorgans_with_comparison_simplification() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c = 1 AND c = 2) -> c != 1 OR c != 2 + let eq1 = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(1))), + )); + let eq2 = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(2))), + )); + let and_expr = Arc::new(BinaryExpr::new(eq1, Operator::And, eq2)); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::NotEq, + lit(ScalarValue::Int32(Some(1))), + )), + Operator::Or, + Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::NotEq, + lit(ScalarValue::Int32(Some(2))), + )), + )); + assert_not_simplify(&mut simplifier, not_and, expected); + + Ok(()) + } + + #[test] + fn test_not_of_not_and_not() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(NOT(a) AND NOT(b)) -> a OR b + let not_a = Arc::new(NotExpr::new(col("a", &schema)?)); + let not_b = Arc::new(NotExpr::new(col("b", &schema)?)); + let and_expr = Arc::new(BinaryExpr::new(not_a, Operator::And, not_b)); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::Or, + col("b", &schema)?, + )); + assert_not_simplify(&mut simplifier, not_and, expected); + + Ok(()) + } + + #[test] + fn test_not_in_list() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c IN (1, 2, 3)) -> c NOT IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?; + let not_in: Arc = Arc::new(NotExpr::new(in_list_expr)); + + let expected = in_list(col("c", &schema)?, list, &true, &schema)?; + assert_not_simplify(&mut simplifier, not_in, expected); + + Ok(()) + } + + #[test] + fn test_not_not_in_list() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c NOT IN (1, 2, 3)) -> c IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let not_in_list_expr = in_list(col("c", &schema)?, list.clone(), &true, &schema)?; + let not_not_in: Arc = Arc::new(NotExpr::new(not_in_list_expr)); + + let expected = in_list(col("c", &schema)?, list, &false, &schema)?; + assert_not_simplify(&mut simplifier, not_not_in, expected); + + Ok(()) + } + + #[test] + fn test_double_not_in_list() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(NOT(c IN (1, 2, 3))) -> c IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?; + let not_in = Arc::new(NotExpr::new(in_list_expr)); + let double_not: Arc = Arc::new(NotExpr::new(not_in)); + + let expected = in_list(col("c", &schema)?, list, &false, &schema)?; + assert_not_simplify(&mut simplifier, double_not, expected); + + Ok(()) + } + + #[test] + fn test_deeply_nested_not() -> Result<()> { + let schema = not_test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // Create a deeply nested NOT expression: NOT(NOT(NOT(...NOT(c > 5)...))) + // This tests that we don't get stack overflow with many nested NOTs. + // With recursive_protection enabled (default), this should work by + // automatically growing the stack as needed. + let inner_expr: Arc = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Gt, + lit(ScalarValue::Int32(Some(5))), + )); + + let mut expr = Arc::clone(&inner_expr); + // Create 200 layers of NOT to test deep recursion handling + for _ in 0..200 { + expr = Arc::new(NotExpr::new(expr)); + } + + // With 200 NOTs (even number), should simplify back to the original expression + let expected = inner_expr; + assert_not_simplify(&mut simplifier, Arc::clone(&expr), expected); + + // Manually dismantle the deep input expression to avoid Stack Overflow on Drop + // If we just let `expr` go out of scope, Rust's recursive Drop will blow the stack + // even with recursive_protection, because Drop doesn't use the #[recursive] attribute. + // We peel off layers one by one to avoid deep recursion in Drop. + while let Some(not_expr) = expr.as_any().downcast_ref::() { + // Clone the child (Arc increment). + // Now child has 2 refs: one in parent, one in `child`. + let child = Arc::clone(not_expr.arg()); + + // Reassign `expr` to `child`. + // This drops the old `expr` (Parent). + // Parent refcount -> 0, Parent is dropped. + // Parent drops its reference to Child. + // Child refcount decrements 2 -> 1. + // Child is NOT dropped recursively because we still hold it in `expr` + expr = child; + } + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/simplifier/not.rs b/datafusion/physical-expr/src/simplifier/not.rs index 286e0794b8d2..1ea969f58ff9 100644 --- a/datafusion/physical-expr/src/simplifier/not.rs +++ b/datafusion/physical-expr/src/simplifier/not.rs @@ -23,6 +23,11 @@ //! - NOT with IN expressions: NOT(a IN (list)) -> a NOT IN (list) //! - De Morgan's laws: NOT(A AND B) -> NOT A OR NOT B //! - Constant folding: NOT(TRUE) -> FALSE, NOT(FALSE) -> TRUE +//! +//! This function is designed to work with TreeNodeRewriter's f_up traversal, +//! which means children are already simplified when this function is called. +//! The TreeNodeRewriter will automatically call this function repeatedly until +//! no more transformations are possible. use std::sync::Arc; @@ -33,22 +38,25 @@ use datafusion_expr::Operator; use crate::expressions::{in_list, lit, BinaryExpr, InListExpr, Literal, NotExpr}; use crate::PhysicalExpr; -/// Attempts to simplify NOT expressions -pub(crate) fn simplify_not_expr_impl( - expr: Arc, +/// Attempts to simplify NOT expressions by applying one level of transformation +/// +/// This function applies a single simplification rule and returns. When used with +/// TreeNodeRewriter, multiple passes will automatically be applied until no more +/// transformations are possible. +pub fn simplify_not_expr( + expr: &Arc, schema: &Schema, ) -> Result>> { // Check if this is a NOT expression let not_expr = match expr.as_any().downcast_ref::() { Some(not_expr) => not_expr, - None => return Ok(Transformed::no(expr)), + None => return Ok(Transformed::no(Arc::clone(expr))), }; let inner_expr = not_expr.arg(); // Handle NOT(NOT(expr)) -> expr (double negation elimination) if let Some(inner_not) = inner_expr.as_any().downcast_ref::() { - // We eliminated double negation, so always return transformed=true return Ok(Transformed::yes(Arc::clone(inner_not.arg()))); } @@ -64,7 +72,6 @@ pub(crate) fn simplify_not_expr_impl( // Handle NOT(IN list) -> NOT IN list if let Some(in_list_expr) = inner_expr.as_any().downcast_ref::() { - // Create a new InList expression with negated flag flipped let negated = !in_list_expr.negated(); let new_in_list = in_list( Arc::clone(in_list_expr.expr()), @@ -75,19 +82,14 @@ pub(crate) fn simplify_not_expr_impl( return Ok(Transformed::yes(new_in_list)); } - // Handle NOT(binary_expr) where we can flip the operator + // Handle NOT(binary_expr) if let Some(binary_expr) = inner_expr.as_any().downcast_ref::() { - if let Some(negated_op) = negate_operator(binary_expr.op()) { - // Recursively simplify the left and right expressions first - let left_simplified = simplify_not_expr(binary_expr.left(), schema)?; - let right_simplified = simplify_not_expr(binary_expr.right(), schema)?; - + if let Some(negated_op) = binary_expr.op().negate() { let new_binary = Arc::new(BinaryExpr::new( - left_simplified.data, + Arc::clone(binary_expr.left()), negated_op, - right_simplified.data, + Arc::clone(binary_expr.right()), )); - // We flipped the operator, so always return transformed=true return Ok(Transformed::yes(new_binary)); } @@ -99,16 +101,8 @@ pub(crate) fn simplify_not_expr_impl( Arc::new(NotExpr::new(Arc::clone(binary_expr.left()))); let not_right: Arc = Arc::new(NotExpr::new(Arc::clone(binary_expr.right()))); - - // Recursively simplify the NOT expressions - let simplified_left = simplify_not_expr(¬_left, schema)?; - let simplified_right = simplify_not_expr(¬_right, schema)?; - - let new_binary = Arc::new(BinaryExpr::new( - simplified_left.data, - Operator::Or, - simplified_right.data, - )); + let new_binary = + Arc::new(BinaryExpr::new(not_left, Operator::Or, not_right)); return Ok(Transformed::yes(new_binary)); } Operator::Or => { @@ -117,16 +111,8 @@ pub(crate) fn simplify_not_expr_impl( Arc::new(NotExpr::new(Arc::clone(binary_expr.left()))); let not_right: Arc = Arc::new(NotExpr::new(Arc::clone(binary_expr.right()))); - - // Recursively simplify the NOT expressions - let simplified_left = simplify_not_expr(¬_left, schema)?; - let simplified_right = simplify_not_expr(¬_right, schema)?; - - let new_binary = Arc::new(BinaryExpr::new( - simplified_left.data, - Operator::And, - simplified_right.data, - )); + let new_binary = + Arc::new(BinaryExpr::new(not_left, Operator::And, not_right)); return Ok(Transformed::yes(new_binary)); } _ => {} @@ -134,437 +120,5 @@ pub(crate) fn simplify_not_expr_impl( } // If no simplification possible, return the original expression - Ok(Transformed::no(expr)) -} - -pub fn simplify_not_expr( - expr: &Arc, - schema: &Schema, -) -> Result>> { - let mut current_expr = Arc::clone(expr); - let mut overall_transformed = false; - - loop { - let not_simplified = simplify_not_expr_impl(Arc::clone(¤t_expr), schema)?; - if not_simplified.transformed { - overall_transformed = true; - current_expr = not_simplified.data; - continue; - } - - if let Some(binary_expr) = current_expr.as_any().downcast_ref::() { - let left_simplified = simplify_not_expr(binary_expr.left(), schema)?; - let right_simplified = simplify_not_expr(binary_expr.right(), schema)?; - - if left_simplified.transformed || right_simplified.transformed { - let new_binary = Arc::new(BinaryExpr::new( - left_simplified.data, - *binary_expr.op(), - right_simplified.data, - )); - return Ok(Transformed::yes(new_binary)); - } - } - - break; - } - - if overall_transformed { - Ok(Transformed::yes(current_expr)) - } else { - Ok(Transformed::no(current_expr)) - } -} - -/// Returns the negated version of a comparison operator, if possible -fn negate_operator(op: &Operator) -> Option { - match op { - Operator::Eq => Some(Operator::NotEq), - Operator::NotEq => Some(Operator::Eq), - Operator::Lt => Some(Operator::GtEq), - Operator::LtEq => Some(Operator::Gt), - Operator::Gt => Some(Operator::LtEq), - Operator::GtEq => Some(Operator::Lt), - Operator::IsDistinctFrom => Some(Operator::IsNotDistinctFrom), - Operator::IsNotDistinctFrom => Some(Operator::IsDistinctFrom), - // For other operators, we can't directly negate them - _ => None, - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::{col, in_list, lit, BinaryExpr, NotExpr}; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; - use datafusion_expr::Operator; - - fn test_schema() -> Schema { - Schema::new(vec![ - Field::new("a", DataType::Boolean, false), - Field::new("b", DataType::Int32, false), - ]) - } - - #[test] - fn test_double_negation_elimination() -> Result<()> { - let schema = test_schema(); - - // Create NOT(NOT(b > 5)) - let inner_expr: Arc = Arc::new(BinaryExpr::new( - col("b", &schema)?, - Operator::Gt, - lit(ScalarValue::Int32(Some(5))), - )); - let inner_not = Arc::new(NotExpr::new(Arc::clone(&inner_expr))); - let double_not: Arc = Arc::new(NotExpr::new(inner_not)); - - let result = simplify_not_expr(&double_not, &schema)?; - - assert!(result.transformed); - // Should be simplified back to the original b > 5 - assert_eq!(result.data.to_string(), inner_expr.to_string()); - Ok(()) - } - - #[test] - fn test_not_literal() -> Result<()> { - let schema = test_schema(); - - // NOT(TRUE) -> FALSE - let not_true = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(true))))); - let result = simplify_not_expr_impl(not_true, &schema)?; - assert!(result.transformed); - - if let Some(literal) = result.data.as_any().downcast_ref::() { - assert_eq!(literal.value(), &ScalarValue::Boolean(Some(false))); - } else { - panic!("Expected literal result"); - } - - // NOT(FALSE) -> TRUE - let not_false: Arc = - Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(false))))); - let result = simplify_not_expr(¬_false, &schema)?; - assert!(result.transformed); - - if let Some(literal) = result.data.as_any().downcast_ref::() { - assert_eq!(literal.value(), &ScalarValue::Boolean(Some(true))); - } else { - panic!("Expected literal result"); - } - - Ok(()) - } - - #[test] - fn test_negate_comparison() -> Result<()> { - let schema = test_schema(); - - // NOT(b = 5) -> b != 5 - let eq_expr = Arc::new(BinaryExpr::new( - col("b", &schema)?, - Operator::Eq, - lit(ScalarValue::Int32(Some(5))), - )); - let not_eq: Arc = Arc::new(NotExpr::new(eq_expr)); - - let result = simplify_not_expr(¬_eq, &schema)?; - assert!(result.transformed); - - if let Some(binary) = result.data.as_any().downcast_ref::() { - assert_eq!(binary.op(), &Operator::NotEq); - } else { - panic!("Expected binary expression result"); - } - - Ok(()) - } - - #[test] - fn test_demorgans_law_and() -> Result<()> { - let schema = test_schema(); - - // NOT(a AND b) -> NOT a OR NOT b - let and_expr = Arc::new(BinaryExpr::new( - col("a", &schema)?, - Operator::And, - col("b", &schema)?, - )); - let not_and: Arc = Arc::new(NotExpr::new(and_expr)); - - let result = simplify_not_expr(¬_and, &schema)?; - assert!(result.transformed); - - if let Some(binary) = result.data.as_any().downcast_ref::() { - assert_eq!(binary.op(), &Operator::Or); - // Left and right should both be NOT expressions - assert!(binary.left().as_any().downcast_ref::().is_some()); - assert!(binary.right().as_any().downcast_ref::().is_some()); - } else { - panic!("Expected binary expression result"); - } - - Ok(()) - } - - #[test] - fn test_demorgans_law_or() -> Result<()> { - let schema = test_schema(); - - // NOT(a OR b) -> NOT a AND NOT b - let or_expr = Arc::new(BinaryExpr::new( - col("a", &schema)?, - Operator::Or, - col("b", &schema)?, - )); - let not_or: Arc = Arc::new(NotExpr::new(or_expr)); - - let result = simplify_not_expr(¬_or, &schema)?; - assert!(result.transformed); - - if let Some(binary) = result.data.as_any().downcast_ref::() { - assert_eq!(binary.op(), &Operator::And); - // Left and right should both be NOT expressions - assert!(binary.left().as_any().downcast_ref::().is_some()); - assert!(binary.right().as_any().downcast_ref::().is_some()); - } else { - panic!("Expected binary expression result"); - } - - Ok(()) - } - - #[test] - fn test_demorgans_with_comparison_simplification() -> Result<()> { - let schema = test_schema(); - - // NOT(b = 1 AND b = 2) -> b != 1 OR b != 2 - // This tests the combination of De Morgan's law and operator negation - let eq1 = Arc::new(BinaryExpr::new( - col("b", &schema)?, - Operator::Eq, - lit(ScalarValue::Int32(Some(1))), - )); - let eq2 = Arc::new(BinaryExpr::new( - col("b", &schema)?, - Operator::Eq, - lit(ScalarValue::Int32(Some(2))), - )); - let and_expr = Arc::new(BinaryExpr::new(eq1, Operator::And, eq2)); - let not_and: Arc = Arc::new(NotExpr::new(and_expr)); - - let result = simplify_not_expr(¬_and, &schema)?; - assert!(result.transformed, "Expression should be transformed"); - - // Verify the result is an OR expression - if let Some(or_binary) = result.data.as_any().downcast_ref::() { - assert_eq!(or_binary.op(), &Operator::Or, "Top level should be OR"); - - // Verify left side is b != 1 - if let Some(left_binary) = - or_binary.left().as_any().downcast_ref::() - { - assert_eq!(left_binary.op(), &Operator::NotEq, "Left should be NotEq"); - } else { - panic!("Expected left to be a binary expression with !="); - } - - // Verify right side is b != 2 - if let Some(right_binary) = - or_binary.right().as_any().downcast_ref::() - { - assert_eq!(right_binary.op(), &Operator::NotEq, "Right should be NotEq"); - } else { - panic!("Expected right to be a binary expression with !="); - } - } else { - panic!("Expected binary OR expression result"); - } - - Ok(()) - } - - #[test] - fn test_not_of_not_and_not() -> Result<()> { - let schema = test_schema(); - - // NOT(NOT(a) AND NOT(b)) -> a OR b - // This tests the combination of De Morgan's law and double negation elimination - let not_a = Arc::new(NotExpr::new(col("a", &schema)?)); - let not_b = Arc::new(NotExpr::new(col("b", &schema)?)); - let and_expr = Arc::new(BinaryExpr::new(not_a, Operator::And, not_b)); - let not_and: Arc = Arc::new(NotExpr::new(and_expr)); - - let result = simplify_not_expr(¬_and, &schema)?; - assert!(result.transformed, "Expression should be transformed"); - - // Verify the result is an OR expression - if let Some(or_binary) = result.data.as_any().downcast_ref::() { - assert_eq!(or_binary.op(), &Operator::Or, "Top level should be OR"); - - // Verify left side is just 'a' - assert!(or_binary.left().as_any().downcast_ref::().is_none(), - "Left should not be a NOT expression, it should be simplified to just 'a'"); - - // Verify right side is just 'b' - assert!(or_binary.right().as_any().downcast_ref::().is_none(), - "Right should not be a NOT expression, it should be simplified to just 'b'"); - } else { - panic!("Expected binary OR expression result"); - } - - Ok(()) - } - - #[test] - fn test_not_in_list() -> Result<()> { - let schema = test_schema(); - - // NOT(b IN (1, 2, 3)) -> b NOT IN (1, 2, 3) - let list = vec![ - lit(ScalarValue::Int32(Some(1))), - lit(ScalarValue::Int32(Some(2))), - lit(ScalarValue::Int32(Some(3))), - ]; - let in_list_expr = in_list(col("b", &schema)?, list, &false, &schema)?; - let not_in: Arc = Arc::new(NotExpr::new(in_list_expr)); - - let result = simplify_not_expr(¬_in, &schema)?; - assert!(result.transformed, "Expression should be transformed"); - - // Verify the result is an InList expression with negated=true - if let Some(in_list_result) = result.data.as_any().downcast_ref::() { - assert!( - in_list_result.negated(), - "InList should be negated (NOT IN)" - ); - assert_eq!( - in_list_result.list().len(), - 3, - "Should have 3 items in list" - ); - } else { - panic!("Expected InListExpr result"); - } - - Ok(()) - } - - #[test] - fn test_not_not_in_list() -> Result<()> { - let schema = test_schema(); - - // NOT(b NOT IN (1, 2, 3)) -> b IN (1, 2, 3) - let list = vec![ - lit(ScalarValue::Int32(Some(1))), - lit(ScalarValue::Int32(Some(2))), - lit(ScalarValue::Int32(Some(3))), - ]; - let not_in_list_expr = in_list(col("b", &schema)?, list, &true, &schema)?; - let not_not_in: Arc = Arc::new(NotExpr::new(not_in_list_expr)); - - let result = simplify_not_expr(¬_not_in, &schema)?; - assert!(result.transformed, "Expression should be transformed"); - - // Verify the result is an InList expression with negated=false - if let Some(in_list_result) = result.data.as_any().downcast_ref::() { - assert!( - !in_list_result.negated(), - "InList should not be negated (IN)" - ); - assert_eq!( - in_list_result.list().len(), - 3, - "Should have 3 items in list" - ); - } else { - panic!("Expected InListExpr result"); - } - - Ok(()) - } - - #[test] - fn test_double_not_in_list() -> Result<()> { - let schema = test_schema(); - - // NOT(NOT(b IN (1, 2, 3))) -> b IN (1, 2, 3) - let list = vec![ - lit(ScalarValue::Int32(Some(1))), - lit(ScalarValue::Int32(Some(2))), - lit(ScalarValue::Int32(Some(3))), - ]; - let in_list_expr = in_list(col("b", &schema)?, list, &false, &schema)?; - let not_in = Arc::new(NotExpr::new(in_list_expr)); - let double_not: Arc = Arc::new(NotExpr::new(not_in)); - - let result = simplify_not_expr(&double_not, &schema)?; - assert!(result.transformed, "Expression should be transformed"); - - // After double negation elimination, we should get back the original IN expression - if let Some(in_list_result) = result.data.as_any().downcast_ref::() { - assert!( - !in_list_result.negated(), - "InList should not be negated (IN)" - ); - assert_eq!( - in_list_result.list().len(), - 3, - "Should have 3 items in list" - ); - } else { - panic!("Expected InListExpr result"); - } - - Ok(()) - } - - #[test] - fn test_deeply_nested_not() -> Result<()> { - let schema = test_schema(); - - // Create a deeply nested NOT expression: NOT(NOT(NOT(...NOT(b > 5)...))) - // This tests that we don't get stack overflow with many nested NOTs - let inner_expr: Arc = Arc::new(BinaryExpr::new( - col("b", &schema)?, - Operator::Gt, - lit(ScalarValue::Int32(Some(5))), - )); - - let mut expr = Arc::clone(&inner_expr); - // Create 20000 layers of NOT - for _ in 0..20000 { - expr = Arc::new(NotExpr::new(expr)); - } - - let result = simplify_not_expr(&expr, &schema)?; - - // With 20000 NOTs (even number), should simplify back to the original expression - assert_eq!( - result.data.to_string(), - inner_expr.to_string(), - "Should simplify back to original expression" - ); - - // Manually dismantle the deep input expression to avoid Stack Overflow on Drop - // If we just let `expr` go out of scope, Rust's recursive Drop will blow the stack. - // We peel off layers one by one. - while let Some(not_expr) = expr.as_any().downcast_ref::() { - // Clone the child (Arc increment). - // Now child has 2 refs: one in parent, one in `child`. - let child = Arc::clone(not_expr.arg()); - - // Reassign `expr` to `child`. - // This drops the old `expr` (Parent). - // Parent refcount -> 0, Parent is dropped. - // Parent drops its reference to Child. - // Child refcount decrements 2 -> 1. - // Child is NOT dropped recursively because we still hold it in `expr` - expr = child; - } - - Ok(()) - } + Ok(Transformed::no(Arc::clone(expr))) }