Skip to content

Commit

Permalink
fix: fix update/delete using subquery only support boolean data type (d…
Browse files Browse the repository at this point in the history
…atabendlabs#14413)

* fix: subquery data type in delete statement should be boolean

* fix: fix update/delete using subquery only support boolean data type

* fix: fix update/delete using subquery only support boolean data type

* fix: fix update/delete using subquery only support boolean data type

* fix: fix update/delete using subquery only support boolean data type

* fix: fix update/delete using subquery only support boolean data type
  • Loading branch information
lichuang authored Jan 29, 2024
1 parent 35568fd commit a0a63c4
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 16 deletions.
41 changes: 35 additions & 6 deletions src/query/service/src/interpreters/interpreter_delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,26 +389,55 @@ pub async fn subquery_filter(
}))
}

pub fn replace_subquery(
// return false means that doesnot replace a subquery with filter,
// in this case we need to replace subquery's parent with filter.
fn do_replace_subquery(
filters: &mut VecDeque<ScalarExpr>,
selection: &mut ScalarExpr,
) -> Result<()> {
) -> Result<bool> {
let data_type = selection.data_type()?;
let mut replace_selection_with_filter = None;

match selection {
ScalarExpr::FunctionCall(func) => {
for arg in &mut func.arguments {
replace_subquery(filters, arg)?;
if !do_replace_subquery(filters, arg)? {
replace_selection_with_filter = Some(filters.pop_back().unwrap());
break;
}
}
}
ScalarExpr::UDFServerCall(udf) => {
for arg in &mut udf.arguments {
replace_subquery(filters, arg)?;
if !do_replace_subquery(filters, arg)? {
replace_selection_with_filter = Some(filters.pop_back().unwrap());
break;
}
}
}
ScalarExpr::SubqueryExpr { .. } => {
let filter = filters.pop_back().unwrap();
*selection = filter;
if data_type == DataType::Nullable(Box::new(DataType::Boolean)) {
let filter = filters.pop_back().unwrap();
*selection = filter;
} else {
return Ok(false);
}
}
_ => {}
}

if let Some(filter) = replace_selection_with_filter {
*selection = filter;
replace_subquery(filters, selection)?;
}
Ok(true)
}

pub fn replace_subquery(
filters: &mut VecDeque<ScalarExpr>,
selection: &mut ScalarExpr,
) -> Result<()> {
let _ = do_replace_subquery(filters, selection)?;

Ok(())
}
51 changes: 41 additions & 10 deletions src/query/sql/src/planner/binder/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::plans::RelOp;
use crate::plans::RelOperator::Scan;
use crate::plans::SubqueryDesc;
use crate::plans::SubqueryExpr;
use crate::plans::Visitor;
use crate::plans::VisitorWithParent;
use crate::BindContext;
use crate::ScalarExpr;

Expand Down Expand Up @@ -125,22 +125,43 @@ impl Binder {
#[async_backtrace::framed]
async fn process_subquery(
&self,
parent: Option<&ScalarExpr>,
subquery_expr: &SubqueryExpr,
mut table_expr: SExpr,
) -> Result<SubqueryDesc> {
if subquery_expr.data_type() != DataType::Nullable(Box::new(DataType::Boolean)) {
let predicate = if subquery_expr.data_type()
== DataType::Nullable(Box::new(DataType::Boolean))
{
subquery_expr.clone().into()
} else if let Some(scalar) = parent {
if let Ok(data_type) = scalar.data_type() {
if data_type == DataType::Nullable(Box::new(DataType::Boolean)) {
scalar.clone()
} else {
return Err(ErrorCode::from_string(
"subquery data type in delete/update statement should be boolean"
.to_string(),
));
}
} else {
return Err(ErrorCode::from_string(
"subquery data type in delete/update statement should be boolean".to_string(),
));
}
} else {
return Err(ErrorCode::from_string(
"subquery data type in delete statement should be boolean".to_string(),
"subquery data type in delete/update statement should be boolean".to_string(),
));
}
};

let mut outer_columns = Default::default();
if let Some(child_expr) = &subquery_expr.child_expr {
outer_columns = child_expr.used_columns();
};
outer_columns.extend(subquery_expr.outer_columns.iter());

let filter = Filter {
predicates: vec![subquery_expr.clone().into()],
predicates: vec![predicate],
};
debug_assert_eq!(table_expr.plan.rel_op(), RelOp::Scan);
let mut scan = match &*table_expr.plan {
Expand Down Expand Up @@ -195,12 +216,20 @@ impl Binder {
subquery_desc: &mut Vec<SubqueryDesc>,
) -> Result<()> {
struct FindSubqueryVisitor<'a> {
subqueries: Vec<&'a SubqueryExpr>,
subqueries: Vec<(Option<&'a ScalarExpr>, &'a SubqueryExpr)>,
}

impl<'a> Visitor<'a> for FindSubqueryVisitor<'a> {
fn visit_subquery(&mut self, subquery: &'a SubqueryExpr) -> Result<()> {
self.subqueries.push(subquery);
impl<'a> VisitorWithParent<'a> for FindSubqueryVisitor<'a> {
fn visit_subquery(
&mut self,
parent: Option<&'a ScalarExpr>,
current: &'a ScalarExpr,
subquery: &'a SubqueryExpr,
) -> Result<()> {
self.subqueries.push((parent, subquery));
if let Some(child_expr) = subquery.child_expr.as_ref() {
self.visit_with_parent(Some(current), child_expr)?;
}
Ok(())
}
}
Expand All @@ -209,7 +238,9 @@ impl Binder {
find_subquery.visit(scalar)?;

for subquery in find_subquery.subqueries {
let desc = self.process_subquery(subquery, table_expr.clone()).await?;
let desc = self
.process_subquery(subquery.0, subquery.1, table_expr.clone())
.await?;
subquery_desc.push(desc);
}

Expand Down
177 changes: 177 additions & 0 deletions src/query/sql/src/planner/plans/scalar_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,183 @@ pub trait Visitor<'a>: Sized {
}
}

// Any `Visitor` which needs to access parent `ScalarExpr` can implement `VisitorWithParent`
pub trait VisitorWithParent<'a>: Sized {
fn visit(&mut self, expr: &'a ScalarExpr) -> Result<()> {
walk_expr_with_parent(self, None, expr)
}

fn visit_with_parent(
&mut self,
parent: Option<&'a ScalarExpr>,
expr: &'a ScalarExpr,
) -> Result<()> {
walk_expr_with_parent(self, parent, expr)
}

fn visit_bound_column_ref(
&mut self,
_parent: Option<&'a ScalarExpr>,
_col: &'a BoundColumnRef,
) -> Result<()> {
Ok(())
}

fn visit_constant(
&mut self,
_parent: Option<&'a ScalarExpr>,
_constant: &'a ConstantExpr,
) -> Result<()> {
Ok(())
}

fn visit_window_function(
&mut self,
_parent: Option<&'a ScalarExpr>,
current: &'a ScalarExpr,
window: &'a WindowFunc,
) -> Result<()> {
fn walk_window_with_parent<'a, V: VisitorWithParent<'a>>(
visitor: &mut V,
current: &'a ScalarExpr,
window: &'a WindowFunc,
) -> Result<()> {
for expr in &window.partition_by {
visitor.visit_with_parent(Some(current), expr)?;
}
for expr in &window.order_by {
visitor.visit_with_parent(Some(current), &expr.expr)?;
}
match &window.func {
WindowFuncType::Aggregate(func) => {
visitor.visit_aggregate_function(Some(current), current, func)?
}
WindowFuncType::NthValue(func) => {
visitor.visit_with_parent(Some(current), &func.arg)?
}
WindowFuncType::LagLead(func) => {
visitor.visit_with_parent(Some(current), &func.arg)?;
if let Some(default) = func.default.as_ref() {
visitor.visit_with_parent(Some(current), default)?
}
}
WindowFuncType::RowNumber
| WindowFuncType::CumeDist
| WindowFuncType::Rank
| WindowFuncType::DenseRank
| WindowFuncType::PercentRank
| WindowFuncType::Ntile(_) => (),
}
Ok(())
}
walk_window_with_parent(self, current, window)
}

fn visit_aggregate_function(
&mut self,
_parent: Option<&'a ScalarExpr>,
current: &'a ScalarExpr,
aggregate: &'a AggregateFunction,
) -> Result<()> {
for expr in &aggregate.args {
self.visit_with_parent(Some(current), expr)?;
}
Ok(())
}

fn visit_lambda_function(
&mut self,
_parent: Option<&'a ScalarExpr>,
current: &'a ScalarExpr,
lambda: &'a LambdaFunc,
) -> Result<()> {
for expr in &lambda.args {
self.visit_with_parent(Some(current), expr)?;
}
Ok(())
}

fn visit_function_call(
&mut self,
_parent: Option<&'a ScalarExpr>,
current: &'a ScalarExpr,
func: &'a FunctionCall,
) -> Result<()> {
for expr in &func.arguments {
self.visit_with_parent(Some(current), expr)?;
}
Ok(())
}

fn visit_cast(
&mut self,
_parent: Option<&'a ScalarExpr>,
current: &'a ScalarExpr,
cast: &'a CastExpr,
) -> Result<()> {
self.visit_with_parent(Some(current), &cast.argument)?;
Ok(())
}

fn visit_subquery(
&mut self,
_parent: Option<&'a ScalarExpr>,
current: &'a ScalarExpr,
subquery: &'a SubqueryExpr,
) -> Result<()> {
if let Some(child_expr) = subquery.child_expr.as_ref() {
self.visit_with_parent(Some(current), child_expr)?;
}
Ok(())
}

fn visit_udf_server_call(
&mut self,
_parent: Option<&'a ScalarExpr>,
current: &'a ScalarExpr,
udf: &'a UDFServerCall,
) -> Result<()> {
for expr in &udf.arguments {
self.visit_with_parent(Some(current), expr)?;
}
Ok(())
}

fn visit_udf_lambda_call(
&mut self,
_parent: Option<&'a ScalarExpr>,
current: &'a ScalarExpr,
udf: &'a UDFLambdaCall,
) -> Result<()> {
self.visit_with_parent(Some(current), &udf.scalar)
}
}

pub fn walk_expr_with_parent<'a, V: VisitorWithParent<'a>>(
visitor: &mut V,
parent: Option<&'a ScalarExpr>,
current: &'a ScalarExpr,
) -> Result<()> {
match current {
ScalarExpr::BoundColumnRef(expr) => visitor.visit_bound_column_ref(parent, expr),
ScalarExpr::ConstantExpr(expr) => visitor.visit_constant(parent, expr),
ScalarExpr::WindowFunction(win_func) => {
visitor.visit_window_function(parent, current, win_func)
}
ScalarExpr::AggregateFunction(aggregate) => {
visitor.visit_aggregate_function(parent, current, aggregate)
}
ScalarExpr::LambdaFunction(lambda) => {
visitor.visit_lambda_function(parent, current, lambda)
}
ScalarExpr::FunctionCall(func) => visitor.visit_function_call(parent, current, func),
ScalarExpr::CastExpr(cast_expr) => visitor.visit_cast(parent, current, cast_expr),
ScalarExpr::SubqueryExpr(subquery) => visitor.visit_subquery(parent, current, subquery),
ScalarExpr::UDFServerCall(udf) => visitor.visit_udf_server_call(parent, current, udf),
ScalarExpr::UDFLambdaCall(udf) => visitor.visit_udf_lambda_call(parent, current, udf),
}
}

pub fn walk_expr<'a, V: Visitor<'a>>(visitor: &mut V, expr: &'a ScalarExpr) -> Result<()> {
match expr {
ScalarExpr::BoundColumnRef(expr) => visitor.visit_bound_column_ref(expr),
Expand Down
17 changes: 17 additions & 0 deletions tests/sqllogictests/suites/base/03_common/03_0035_update.test
Original file line number Diff line number Diff line change
Expand Up @@ -228,5 +228,22 @@ select * from t;
100
100

statement ok
update t set a = 101 where 200 > (select avg(a) from t);

query I
select * from t;
----
101
101
101

statement ok
delete from t where 200 > (select avg(a) from t);

query I
select * from t;
----

statement ok
DROP DATABASE db1

0 comments on commit a0a63c4

Please sign in to comment.