diff --git a/src/query/service/src/interpreters/interpreter_delete.rs b/src/query/service/src/interpreters/interpreter_delete.rs index ac9168212296..9370bab85400 100644 --- a/src/query/service/src/interpreters/interpreter_delete.rs +++ b/src/query/service/src/interpreters/interpreter_delete.rs @@ -389,26 +389,55 @@ pub async fn subquery_filter( })) } -pub fn replace_subquery( +// return false means that doesnot replace a subquery with filter, +// in this case we need to replace subquery's parent with filter. +fn do_replace_subquery( filters: &mut VecDeque, selection: &mut ScalarExpr, -) -> Result<()> { +) -> Result { + let data_type = selection.data_type()?; + let mut replace_selection_with_filter = None; + match selection { ScalarExpr::FunctionCall(func) => { for arg in &mut func.arguments { - replace_subquery(filters, arg)?; + if !do_replace_subquery(filters, arg)? { + replace_selection_with_filter = Some(filters.pop_back().unwrap()); + break; + } } } ScalarExpr::UDFServerCall(udf) => { for arg in &mut udf.arguments { - replace_subquery(filters, arg)?; + if !do_replace_subquery(filters, arg)? { + replace_selection_with_filter = Some(filters.pop_back().unwrap()); + break; + } } } ScalarExpr::SubqueryExpr { .. } => { - let filter = filters.pop_back().unwrap(); - *selection = filter; + if data_type == DataType::Nullable(Box::new(DataType::Boolean)) { + let filter = filters.pop_back().unwrap(); + *selection = filter; + } else { + return Ok(false); + } } _ => {} } + + if let Some(filter) = replace_selection_with_filter { + *selection = filter; + replace_subquery(filters, selection)?; + } + Ok(true) +} + +pub fn replace_subquery( + filters: &mut VecDeque, + selection: &mut ScalarExpr, +) -> Result<()> { + let _ = do_replace_subquery(filters, selection)?; + Ok(()) } diff --git a/src/query/sql/src/planner/binder/delete.rs b/src/query/sql/src/planner/binder/delete.rs index 4e56965e95c9..e35e2edd90f6 100644 --- a/src/query/sql/src/planner/binder/delete.rs +++ b/src/query/sql/src/planner/binder/delete.rs @@ -35,7 +35,7 @@ use crate::plans::RelOp; use crate::plans::RelOperator::Scan; use crate::plans::SubqueryDesc; use crate::plans::SubqueryExpr; -use crate::plans::Visitor; +use crate::plans::VisitorWithParent; use crate::BindContext; use crate::ScalarExpr; @@ -125,14 +125,35 @@ impl Binder { #[async_backtrace::framed] async fn process_subquery( &self, + parent: Option<&ScalarExpr>, subquery_expr: &SubqueryExpr, mut table_expr: SExpr, ) -> Result { - if subquery_expr.data_type() != DataType::Nullable(Box::new(DataType::Boolean)) { + let predicate = if subquery_expr.data_type() + == DataType::Nullable(Box::new(DataType::Boolean)) + { + subquery_expr.clone().into() + } else if let Some(scalar) = parent { + if let Ok(data_type) = scalar.data_type() { + if data_type == DataType::Nullable(Box::new(DataType::Boolean)) { + scalar.clone() + } else { + return Err(ErrorCode::from_string( + "subquery data type in delete/update statement should be boolean" + .to_string(), + )); + } + } else { + return Err(ErrorCode::from_string( + "subquery data type in delete/update statement should be boolean".to_string(), + )); + } + } else { return Err(ErrorCode::from_string( - "subquery data type in delete statement should be boolean".to_string(), + "subquery data type in delete/update statement should be boolean".to_string(), )); - } + }; + let mut outer_columns = Default::default(); if let Some(child_expr) = &subquery_expr.child_expr { outer_columns = child_expr.used_columns(); @@ -140,7 +161,7 @@ impl Binder { outer_columns.extend(subquery_expr.outer_columns.iter()); let filter = Filter { - predicates: vec![subquery_expr.clone().into()], + predicates: vec![predicate], }; debug_assert_eq!(table_expr.plan.rel_op(), RelOp::Scan); let mut scan = match &*table_expr.plan { @@ -195,12 +216,20 @@ impl Binder { subquery_desc: &mut Vec, ) -> Result<()> { struct FindSubqueryVisitor<'a> { - subqueries: Vec<&'a SubqueryExpr>, + subqueries: Vec<(Option<&'a ScalarExpr>, &'a SubqueryExpr)>, } - impl<'a> Visitor<'a> for FindSubqueryVisitor<'a> { - fn visit_subquery(&mut self, subquery: &'a SubqueryExpr) -> Result<()> { - self.subqueries.push(subquery); + impl<'a> VisitorWithParent<'a> for FindSubqueryVisitor<'a> { + fn visit_subquery( + &mut self, + parent: Option<&'a ScalarExpr>, + current: &'a ScalarExpr, + subquery: &'a SubqueryExpr, + ) -> Result<()> { + self.subqueries.push((parent, subquery)); + if let Some(child_expr) = subquery.child_expr.as_ref() { + self.visit_with_parent(Some(current), child_expr)?; + } Ok(()) } } @@ -209,7 +238,9 @@ impl Binder { find_subquery.visit(scalar)?; for subquery in find_subquery.subqueries { - let desc = self.process_subquery(subquery, table_expr.clone()).await?; + let desc = self + .process_subquery(subquery.0, subquery.1, table_expr.clone()) + .await?; subquery_desc.push(desc); } diff --git a/src/query/sql/src/planner/plans/scalar_expr.rs b/src/query/sql/src/planner/plans/scalar_expr.rs index 39d3e82fb5ba..f03ca0bc9b04 100644 --- a/src/query/sql/src/planner/plans/scalar_expr.rs +++ b/src/query/sql/src/planner/plans/scalar_expr.rs @@ -663,6 +663,183 @@ pub trait Visitor<'a>: Sized { } } +// Any `Visitor` which needs to access parent `ScalarExpr` can implement `VisitorWithParent` +pub trait VisitorWithParent<'a>: Sized { + fn visit(&mut self, expr: &'a ScalarExpr) -> Result<()> { + walk_expr_with_parent(self, None, expr) + } + + fn visit_with_parent( + &mut self, + parent: Option<&'a ScalarExpr>, + expr: &'a ScalarExpr, + ) -> Result<()> { + walk_expr_with_parent(self, parent, expr) + } + + fn visit_bound_column_ref( + &mut self, + _parent: Option<&'a ScalarExpr>, + _col: &'a BoundColumnRef, + ) -> Result<()> { + Ok(()) + } + + fn visit_constant( + &mut self, + _parent: Option<&'a ScalarExpr>, + _constant: &'a ConstantExpr, + ) -> Result<()> { + Ok(()) + } + + fn visit_window_function( + &mut self, + _parent: Option<&'a ScalarExpr>, + current: &'a ScalarExpr, + window: &'a WindowFunc, + ) -> Result<()> { + fn walk_window_with_parent<'a, V: VisitorWithParent<'a>>( + visitor: &mut V, + current: &'a ScalarExpr, + window: &'a WindowFunc, + ) -> Result<()> { + for expr in &window.partition_by { + visitor.visit_with_parent(Some(current), expr)?; + } + for expr in &window.order_by { + visitor.visit_with_parent(Some(current), &expr.expr)?; + } + match &window.func { + WindowFuncType::Aggregate(func) => { + visitor.visit_aggregate_function(Some(current), current, func)? + } + WindowFuncType::NthValue(func) => { + visitor.visit_with_parent(Some(current), &func.arg)? + } + WindowFuncType::LagLead(func) => { + visitor.visit_with_parent(Some(current), &func.arg)?; + if let Some(default) = func.default.as_ref() { + visitor.visit_with_parent(Some(current), default)? + } + } + WindowFuncType::RowNumber + | WindowFuncType::CumeDist + | WindowFuncType::Rank + | WindowFuncType::DenseRank + | WindowFuncType::PercentRank + | WindowFuncType::Ntile(_) => (), + } + Ok(()) + } + walk_window_with_parent(self, current, window) + } + + fn visit_aggregate_function( + &mut self, + _parent: Option<&'a ScalarExpr>, + current: &'a ScalarExpr, + aggregate: &'a AggregateFunction, + ) -> Result<()> { + for expr in &aggregate.args { + self.visit_with_parent(Some(current), expr)?; + } + Ok(()) + } + + fn visit_lambda_function( + &mut self, + _parent: Option<&'a ScalarExpr>, + current: &'a ScalarExpr, + lambda: &'a LambdaFunc, + ) -> Result<()> { + for expr in &lambda.args { + self.visit_with_parent(Some(current), expr)?; + } + Ok(()) + } + + fn visit_function_call( + &mut self, + _parent: Option<&'a ScalarExpr>, + current: &'a ScalarExpr, + func: &'a FunctionCall, + ) -> Result<()> { + for expr in &func.arguments { + self.visit_with_parent(Some(current), expr)?; + } + Ok(()) + } + + fn visit_cast( + &mut self, + _parent: Option<&'a ScalarExpr>, + current: &'a ScalarExpr, + cast: &'a CastExpr, + ) -> Result<()> { + self.visit_with_parent(Some(current), &cast.argument)?; + Ok(()) + } + + fn visit_subquery( + &mut self, + _parent: Option<&'a ScalarExpr>, + current: &'a ScalarExpr, + subquery: &'a SubqueryExpr, + ) -> Result<()> { + if let Some(child_expr) = subquery.child_expr.as_ref() { + self.visit_with_parent(Some(current), child_expr)?; + } + Ok(()) + } + + fn visit_udf_server_call( + &mut self, + _parent: Option<&'a ScalarExpr>, + current: &'a ScalarExpr, + udf: &'a UDFServerCall, + ) -> Result<()> { + for expr in &udf.arguments { + self.visit_with_parent(Some(current), expr)?; + } + Ok(()) + } + + fn visit_udf_lambda_call( + &mut self, + _parent: Option<&'a ScalarExpr>, + current: &'a ScalarExpr, + udf: &'a UDFLambdaCall, + ) -> Result<()> { + self.visit_with_parent(Some(current), &udf.scalar) + } +} + +pub fn walk_expr_with_parent<'a, V: VisitorWithParent<'a>>( + visitor: &mut V, + parent: Option<&'a ScalarExpr>, + current: &'a ScalarExpr, +) -> Result<()> { + match current { + ScalarExpr::BoundColumnRef(expr) => visitor.visit_bound_column_ref(parent, expr), + ScalarExpr::ConstantExpr(expr) => visitor.visit_constant(parent, expr), + ScalarExpr::WindowFunction(win_func) => { + visitor.visit_window_function(parent, current, win_func) + } + ScalarExpr::AggregateFunction(aggregate) => { + visitor.visit_aggregate_function(parent, current, aggregate) + } + ScalarExpr::LambdaFunction(lambda) => { + visitor.visit_lambda_function(parent, current, lambda) + } + ScalarExpr::FunctionCall(func) => visitor.visit_function_call(parent, current, func), + ScalarExpr::CastExpr(cast_expr) => visitor.visit_cast(parent, current, cast_expr), + ScalarExpr::SubqueryExpr(subquery) => visitor.visit_subquery(parent, current, subquery), + ScalarExpr::UDFServerCall(udf) => visitor.visit_udf_server_call(parent, current, udf), + ScalarExpr::UDFLambdaCall(udf) => visitor.visit_udf_lambda_call(parent, current, udf), + } +} + pub fn walk_expr<'a, V: Visitor<'a>>(visitor: &mut V, expr: &'a ScalarExpr) -> Result<()> { match expr { ScalarExpr::BoundColumnRef(expr) => visitor.visit_bound_column_ref(expr), diff --git a/tests/sqllogictests/suites/base/03_common/03_0035_update.test b/tests/sqllogictests/suites/base/03_common/03_0035_update.test index 2ab12a95ed81..51a590ec35ac 100644 --- a/tests/sqllogictests/suites/base/03_common/03_0035_update.test +++ b/tests/sqllogictests/suites/base/03_common/03_0035_update.test @@ -228,5 +228,22 @@ select * from t; 100 100 +statement ok +update t set a = 101 where 200 > (select avg(a) from t); + +query I +select * from t; +---- +101 +101 +101 + +statement ok +delete from t where 200 > (select avg(a) from t); + +query I +select * from t; +---- + statement ok DROP DATABASE db1