Skip to content
90 changes: 89 additions & 1 deletion src/database/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ use crate::parsing::{
parse_binary_blob, parse_date, parse_hex_blob, parse_interval, parse_time, parse_timestamp,
parse_uuid, parse_vector,
};
use crate::types::{DataType, OwnedValue, Value};
use crate::types::{ArithmeticOp, DataType, OwnedValue, Value};
use eyre::{bail, Result, WrapErr};
use std::borrow::Cow;

use super::Database;

Expand Down Expand Up @@ -404,6 +405,93 @@ impl Database {
Self::eval_literal_with_type(expr, target_type)
}

pub(crate) fn eval_expr_with_params_and_subqueries<'a>(
expr: &crate::sql::ast::Expr<'_>,
target_type: Option<&crate::records::types::DataType>,
params: Option<&'a [OwnedValue]>,
param_idx: &mut usize,
scalar_subquery_results: &'a crate::sql::context::ScalarSubqueryResults,
) -> Result<Cow<'a, OwnedValue>> {
use crate::sql::ast::{Expr, ParameterRef};

match expr {
Expr::Parameter(param_ref) => {
if let Some(params) = params {
let idx = match param_ref {
ParameterRef::Anonymous => {
let i = *param_idx;
*param_idx += 1;
i
}
ParameterRef::Positional(n) => (*n as usize).saturating_sub(1),
ParameterRef::Named(_) => {
let i = *param_idx;
*param_idx += 1;
i
}
};

if idx >= params.len() {
bail!(
"parameter index {} out of range (only {} parameters bound)",
idx + 1,
params.len()
);
}

Ok(Cow::Borrowed(&params[idx]))
} else {
bail!("parameter placeholder found but no parameters were bound")
}
}
Expr::Subquery(subq) => {
let key = std::ptr::from_ref(*subq) as usize;
scalar_subquery_results
.get(&key)
.map(Cow::Borrowed)
.ok_or_else(|| eyre::eyre!("scalar subquery result not found for key 0x{:x}", key))
}
Expr::BinaryOp { left, op, right } => {
let left_val = Self::eval_expr_with_params_and_subqueries(
left,
target_type,
params,
param_idx,
scalar_subquery_results,
)?;
let right_val = Self::eval_expr_with_params_and_subqueries(
right,
target_type,
params,
param_idx,
scalar_subquery_results,
)?;

use crate::sql::ast::BinaryOperator;
let arith_op = match op {
BinaryOperator::Plus => Some(ArithmeticOp::Plus),
BinaryOperator::Minus => Some(ArithmeticOp::Minus),
BinaryOperator::Multiply => Some(ArithmeticOp::Multiply),
BinaryOperator::Divide => Some(ArithmeticOp::Divide),
_ => None,
};
if let Some(aop) = arith_op {
OwnedValue::eval_arithmetic(left_val.as_ref(), aop, right_val.as_ref())
.map(Cow::Owned)
.ok_or_else(|| {
eyre::eyre!(
"unsupported types or division by zero for {:?} in UPDATE SET",
aop
)
})
} else {
Self::eval_literal_with_type(expr, target_type).map(Cow::Owned)
}
}
_ => Self::eval_literal_with_type(expr, target_type).map(Cow::Owned),
}
}

pub(crate) fn parse_json_string(s: &str) -> Result<OwnedValue> {
let value = Self::parse_json_to_value(s.trim())?;
let bytes = Self::jsonb_value_to_bytes(&value);
Expand Down
4 changes: 2 additions & 2 deletions src/database/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -966,9 +966,9 @@ impl Database {

for subq in subqueries {
let key = std::ptr::from_ref(subq) as usize;
if !scalar_subquery_results.iter().any(|(k, _)| *k == key) {
if !scalar_subquery_results.contains_key(&key) {
let result = execute_scalar_subquery(subq, catalog, file_manager, &arena)?;
scalar_subquery_results.push((key, result));
scalar_subquery_results.insert(key, result);
}
}
}
Expand Down
Loading