From cbf633b598290ca206970f51677f2c3f61c3f5db Mon Sep 17 00:00:00 2001 From: Qi Zhu <821684824@qq.com> Date: Thu, 29 Jan 2026 10:33:35 +0800 Subject: [PATCH] Optimize rewrite performance and SPJ new --- src/rewrite/normal_form.rs | 321 ++++++++++++++++++++++++++----------- 1 file changed, 230 insertions(+), 91 deletions(-) diff --git a/src/rewrite/normal_form.rs b/src/rewrite/normal_form.rs index b492b4f..9cce9d8 100644 --- a/src/rewrite/normal_form.rs +++ b/src/rewrite/normal_form.rs @@ -233,22 +233,11 @@ impl SpjNormalForm { .map(|expr| predicate.normalize_expr(expr)) .collect(); - let mut referenced_tables = vec![]; - original_plan - .apply(|plan| { - if let LogicalPlan::TableScan(scan) = plan { - referenced_tables.push(scan.table_name.clone()); - } - - Ok(TreeNodeRecursion::Continue) - }) - // No chance of error since we never return Err -- this unwrap is safe - .unwrap(); - Ok(Self { output_schema: Arc::clone(original_plan.schema()), output_exprs, - referenced_tables, + // Reuse referenced_tables collected during Predicate::new to avoid extra traversal + referenced_tables: predicate.referenced_tables.clone(), predicate, }) } @@ -258,32 +247,50 @@ impl SpjNormalForm { /// This is useful for rewriting queries to use materialized views. pub fn rewrite_from( &self, - mut other: &Self, + other: &Self, qualifier: TableReference, source: Arc, ) -> Result> { log::trace!("rewriting from {qualifier}"); + + // Cache columns() result to avoid repeated Vec allocation in the loop. + // DFSchema::columns() creates a new Vec on each call. + let output_columns = self.output_schema.columns(); + let mut new_output_exprs = Vec::with_capacity(self.output_exprs.len()); // check that our output exprs are sub-expressions of the other one's output exprs for (i, output_expr) in self.output_exprs.iter().enumerate() { - let new_output_expr = other - .predicate - .normalize_expr(output_expr.clone()) - .rewrite(&mut other)? - .data; - - // Check that all references to the original tables have been replaced. - // All remaining column expressions should be unqualified, which indicates - // that they refer to the output of the sub-plan (in this case the view) - if new_output_expr - .column_refs() - .iter() - .any(|c| c.relation.is_some()) - { - return Ok(None); - } + // Fast path for simple Column expressions (most common case). + // This avoids the expensive normalize_expr transform for columns. + let new_output_expr = if let Expr::Column(col) = output_expr { + let normalized_col = other.predicate.normalize_column(col); + match other.find_output_column(&normalized_col) { + Some(rewritten) => rewritten, + None => return Ok(None), // Column not found, can't rewrite + } + } else { + // Slow path: complex expressions need full transform + let new_output_expr = other + .predicate + .normalize_expr(output_expr.clone()) + .rewrite(&mut &*other)? + .data; + + // Check that all references to the original tables have been replaced. + // All remaining column expressions should be unqualified, which indicates + // that they refer to the output of the sub-plan (in this case the view) + if new_output_expr + .column_refs() + .iter() + .any(|c| c.relation.is_some()) + { + return Ok(None); + } + new_output_expr + }; - let column = &self.output_schema.columns()[i]; + // Use cached columns instead of calling .columns() on each iteration + let column = &output_columns[i]; new_output_exprs.push( new_output_expr.alias_qualified(column.relation.clone(), column.name.clone()), ); @@ -310,7 +317,7 @@ impl SpjNormalForm { .into_iter() .chain(range_filters) .chain(residual_filters) - .map(|expr| expr.rewrite(&mut other).unwrap().data) + .map(|expr| expr.rewrite(&mut &*other).unwrap().data) .reduce(|a, b| a.and(b)); if all_filters @@ -329,6 +336,20 @@ impl SpjNormalForm { builder.project(new_output_exprs)?.build().map(Some) } + + /// Fast path: find a column in output_exprs and return rewritten expression. + /// This avoids full tree traversal for simple column lookups. + #[inline] + fn find_output_column(&self, col: &Column) -> Option { + self.output_exprs + .iter() + .position(|e| matches!(e, Expr::Column(c) if c == col)) + .map(|idx| { + Expr::Column(Column::new_unqualified( + self.output_schema.field(idx).name().clone(), + )) + }) + } } /// Stores information on filters from a Select-Project-Join plan. @@ -344,84 +365,95 @@ struct Predicate { ranges_by_equivalence_class: Vec>, /// Filter expressions that aren't column equality predicates or range filters. residuals: HashSet, + /// Tables referenced in this plan (collected during single-pass traversal) + referenced_tables: Vec, } impl Predicate { + /// Create a new Predicate by analyzing the given logical plan. + /// Uses single-pass traversal to collect schema, columns, filters, and referenced tables. fn new(plan: &LogicalPlan) -> Result { let mut schema = DFSchema::empty(); - plan.apply(|plan| { - if let LogicalPlan::TableScan(scan) = plan { - let new_schema = DFSchema::try_from_qualified_schema( - scan.table_name.clone(), - scan.source.schema().as_ref(), - )?; - schema = if schema.fields().is_empty() { - new_schema - } else { - schema.join(&new_schema)? - } - } + let mut columns_info: Vec<(Column, arrow::datatypes::DataType)> = Vec::new(); + let mut filters: Vec = Vec::new(); + let mut referenced_tables: Vec = Vec::new(); + + // Single traversal to collect everything + plan.apply(|node| { + match node { + LogicalPlan::TableScan(scan) => { + // Collect referenced table + referenced_tables.push(scan.table_name.clone()); - Ok(TreeNodeRecursion::Continue) - })?; + // Build schema + let new_schema = DFSchema::try_from_qualified_schema( + scan.table_name.clone(), + scan.source.schema().as_ref(), + )?; + + // Collect columns with their data types + for (table_ref, field) in new_schema.iter() { + columns_info.push(( + Column::new(table_ref.cloned(), field.name()), + field.data_type().clone(), + )); + } - let mut new = Self { - schema, - eq_classes: vec![], - eq_class_idx_by_column: HashMap::default(), - ranges_by_equivalence_class: vec![], - residuals: HashSet::new(), - }; + // Merge schema + schema = if schema.fields().is_empty() { + new_schema + } else { + schema.join(&new_schema)? + }; - // Collect all referenced columns - plan.apply(|plan| { - if let LogicalPlan::TableScan(scan) = plan { - for (i, (table_ref, field)) in DFSchema::try_from_qualified_schema( - scan.table_name.clone(), - scan.source.schema().as_ref(), - )? - .iter() - .enumerate() - { - let column = Column::new(table_ref.cloned(), field.name()); - let data_type = field.data_type(); - new.eq_classes - .push(ColumnEquivalenceClass::new_singleton(column.clone())); - new.eq_class_idx_by_column.insert(column, i); - new.ranges_by_equivalence_class - .push(Some(Interval::make_unbounded(data_type)?)); + // Collect filters from TableScan + filters.extend(scan.filters.iter().cloned()); + } + LogicalPlan::Filter(filter) => { + filters.push(filter.predicate.clone()); } - } - - Ok(TreeNodeRecursion::Continue) - })?; - - // Collect any filters - plan.apply(|plan| { - let filters = match plan { - LogicalPlan::TableScan(scan) => scan.filters.as_slice(), - LogicalPlan::Filter(filter) => core::slice::from_ref(&filter.predicate), LogicalPlan::Join(_join) => { return Err(DataFusionError::Internal( "joins are not supported yet".to_string(), - )) + )); } - LogicalPlan::Projection(_) => &[], + LogicalPlan::Projection(_) => {} _ => { return Err(DataFusionError::Plan(format!( "unsupported logical plan: {}", - plan.display() - ))) + node.display() + ))); } - }; - - for expr in filters.iter().flat_map(split_conjunction) { - new.insert_conjuct(expr)?; } - Ok(TreeNodeRecursion::Continue) })?; + // Initialize data structures with known capacity + let num_columns = columns_info.len(); + let mut eq_classes = Vec::with_capacity(num_columns); + let mut eq_class_idx_by_column = HashMap::with_capacity(num_columns); + let mut ranges_by_equivalence_class = Vec::with_capacity(num_columns); + + for (i, (column, data_type)) in columns_info.into_iter().enumerate() { + eq_classes.push(ColumnEquivalenceClass::new_singleton(column.clone())); + eq_class_idx_by_column.insert(column, i); + ranges_by_equivalence_class.push(Some(Interval::make_unbounded(&data_type)?)); + } + + let mut new = Self { + schema, + eq_classes, + eq_class_idx_by_column, + ranges_by_equivalence_class, + residuals: HashSet::new(), + referenced_tables, + }; + + // Process all collected filters + for expr in filters.iter().flat_map(split_conjunction) { + new.insert_conjuct(expr)?; + } + Ok(new) } @@ -431,6 +463,17 @@ impl Predicate { .and_then(|&idx| self.eq_classes.get(idx)) } + /// Fast path: normalize a single Column without full tree traversal. + /// This is O(1) lookup instead of O(n) transform. + #[inline] + fn normalize_column(&self, col: &Column) -> Column { + if let Some(eq_class) = self.class_for_column(col) { + eq_class.columns.first().unwrap().clone() + } else { + col.clone() + } + } + /// Add a new column equivalence fn add_equivalence(&mut self, c1: &Column, c2: &Column) -> Result<()> { match ( @@ -792,6 +835,11 @@ impl Predicate { /// Rewrite all expressions in terms of their normal representatives /// with respect to this predicate's equivalence classes. fn normalize_expr(&self, e: Expr) -> Expr { + // Fast path: if it's a simple Column, avoid full transform traversal + if let Expr::Column(ref c) = e { + return Expr::Column(self.normalize_column(c)); + } + e.transform(&|e| { let c = match e { Expr::Column(c) => c, @@ -1163,11 +1211,11 @@ mod test { TestCase { name: "range filter + equality predicate", base: - "SELECT column1, column2 FROM t1 WHERE column1 = column3 AND column1 >= '2022'", + "SELECT column1, column2 FROM t1 WHERE column1 = column3 AND column1 >= '2022'", query: // Since column1 = column3 in the original view, // we are allowed to substitute column1 for column3 and vice versa. - "SELECT column2, column3 FROM t1 WHERE column1 = column3 AND column3 >= '2023'", + "SELECT column2, column3 FROM t1 WHERE column1 = column3 AND column3 >= '2023'", }, TestCase { name: "range filter with inequality on non-discrete type", @@ -1229,4 +1277,95 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_predicate_new_collects_expected_data() -> Result<()> { + let ctx = SessionContext::new(); + + // Create a table with known schema + ctx.sql( + "CREATE TABLE test_table ( + col1 INT, + col2 VARCHAR, + col3 DOUBLE + )", + ) + .await? + .collect() + .await?; + + // Create a plan with filters + let plan = ctx + .sql("SELECT col1, col2 FROM test_table WHERE col1 >= 10 AND col2 = col3") + .await? + .into_optimized_plan()?; + + let normal_form = SpjNormalForm::new(&plan)?; + + // Verify referenced_tables is collected + assert_eq!(normal_form.referenced_tables().len(), 1); + assert_eq!(normal_form.referenced_tables()[0].to_string(), "test_table"); + + // Verify output_exprs matches the projection (2 columns) + assert_eq!(normal_form.output_exprs().len(), 2); + + // Verify schema is preserved + assert_eq!(normal_form.output_schema().fields().len(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_predicate_new_with_join_returns_error() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.sql("CREATE TABLE t1 (a INT, b INT)") + .await? + .collect() + .await?; + ctx.sql("CREATE TABLE t2 (c INT, d INT)") + .await? + .collect() + .await?; + + // Test that join returns an error as it's not supported yet + let plan = ctx + .sql("SELECT t1.a, t2.d FROM t1 JOIN t2 ON t1.b = t2.c WHERE t1.a >= 0 AND t2.d <= 100") + .await? + .into_optimized_plan()?; + + let result = SpjNormalForm::new(&plan); + + // Verify that join returns an error + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("joins are not supported yet")); + + Ok(()) + } + + #[tokio::test] + async fn test_predicate_new_with_range_filters() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.sql("CREATE TABLE range_test (x INT, y INT, z VARCHAR)") + .await? + .collect() + .await?; + + let plan = ctx + .sql("SELECT * FROM range_test WHERE x >= 10 AND x <= 100 AND y = 50") + .await? + .into_optimized_plan()?; + + let normal_form = SpjNormalForm::new(&plan)?; + + // Verify all columns are in output + assert_eq!(normal_form.output_exprs().len(), 3); + assert_eq!(normal_form.referenced_tables().len(), 1); + + Ok(()) + } }