Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/database/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ impl Database {
}

pub(crate) fn expr_to_string(expr: &crate::sql::ast::Expr<'_>) -> Option<String> {
use crate::sql::ast::{BinaryOperator, Expr};
use crate::sql::ast::{BinaryOperator, Expr, UnaryOperator};

match expr {
Expr::BinaryOp { left, op, right } => {
Expand All @@ -228,6 +228,16 @@ impl Database {
};
Some(format!("{} {} {}", left_str, op_str, right_str))
}
Expr::UnaryOp { op, expr: inner } => {
let inner_str = Self::expr_to_string(inner)?;
let op_str = match op {
UnaryOperator::Minus => "-",
UnaryOperator::Plus => "+",
UnaryOperator::Not => "NOT ",
UnaryOperator::BitwiseNot => "~",
};
Some(format!("{}{}", op_str, inner_str))
}
Expr::Column(col_ref) => Some(col_ref.column.to_string()),
Expr::Literal(lit) => match lit {
crate::sql::ast::Literal::Integer(n) => Some(n.to_string()),
Expand Down
276 changes: 232 additions & 44 deletions src/database/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,31 @@ impl SharedDatabase {
}
}

/// Comparison operators for CHECK constraint evaluation.
#[derive(Clone, Copy)]
enum CheckCompareOp {
/// Greater than or equal (>=)
Ge,
/// Less than or equal (<=)
Le,
/// Greater than (>)
Gt,
/// Less than (<)
Lt,
}

impl CheckCompareOp {
/// Compares two f64 values using this operator.
fn compare(self, lhs: f64, rhs: f64) -> bool {
match self {
CheckCompareOp::Ge => lhs >= rhs,
CheckCompareOp::Le => lhs <= rhs,
CheckCompareOp::Gt => lhs > rhs,
CheckCompareOp::Lt => lhs < rhs,
}
}
}

impl Database {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
Self::open_with_recovery(path).map(|(db, _)| db)
Expand Down Expand Up @@ -4377,71 +4402,234 @@ impl Database {
expr_str: &str,
col_name: &str,
col_value: Option<&OwnedValue>,
) -> bool {
) -> Result<bool> {
let Some(value) = col_value else {
return true;
return Ok(true);
};

if value.is_null() {
return true;
return Ok(true);
}

let expr_lower = expr_str.to_lowercase();
let col_lower = col_name.to_lowercase();
Self::eval_check_expr_recursive(expr_str, col_name, value)
}

if expr_lower.contains(&col_lower) {
if let Some(op_idx) = expr_str.find(">=") {
let right_part = expr_str[op_idx + 2..].trim();
if let Ok(threshold) = right_part.parse::<i64>() {
if let OwnedValue::Int(v) = value {
return *v >= threshold;
}
const MAX_CHECK_EXPR_DEPTH: usize = 32;

fn eval_check_expr_recursive(
expr_str: &str,
col_name: &str,
value: &OwnedValue,
) -> Result<bool> {
Self::eval_check_expr_with_depth(expr_str, col_name, value, 0)
}

fn eval_check_expr_with_depth(
expr_str: &str,
col_name: &str,
value: &OwnedValue,
depth: usize,
) -> Result<bool> {
if depth >= Self::MAX_CHECK_EXPR_DEPTH {
bail!(
"CHECK constraint expression exceeds maximum nesting depth of {} for column '{}'",
Self::MAX_CHECK_EXPR_DEPTH,
col_name
);
}

let trimmed = expr_str.trim();

if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(trimmed, b" or ") {
let left_result = Self::eval_check_expr_with_depth(left, col_name, value, depth + 1)?;
let right_result = Self::eval_check_expr_with_depth(right, col_name, value, depth + 1)?;
return Ok(left_result || right_result);
}

if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(trimmed, b" and ") {
let left_result = Self::eval_check_expr_with_depth(left, col_name, value, depth + 1)?;
let right_result = Self::eval_check_expr_with_depth(right, col_name, value, depth + 1)?;
return Ok(left_result && right_result);
}

let stripped = Self::strip_outer_parens(trimmed);
if stripped != trimmed {
return Self::eval_check_expr_with_depth(stripped, col_name, value, depth + 1);
}

Ok(Self::eval_simple_comparison(trimmed, col_name, value))
}

fn split_on_logical_op_case_insensitive<'a>(
original: &'a str,
op: &[u8],
) -> Option<(&'a str, &'a str)> {
let bytes = original.as_bytes();
let mut depth: usize = 0;
let mut i = 0;

while i + op.len() <= bytes.len() {
let c = bytes[i];
if c == b'(' {
depth = depth.saturating_add(1);
} else if c == b')' {
if depth == 0 {
return None;
}
if let Ok(threshold) = right_part.parse::<f64>() {
if let OwnedValue::Float(v) = value {
return *v >= threshold;
}
depth = depth.saturating_sub(1);
} else if depth == 0 && bytes[i..i + op.len()].eq_ignore_ascii_case(op) {
let left = original[..i].trim();
let right = original[i + op.len()..].trim();
if !left.is_empty() && !right.is_empty() {
return Some((left, right));
}
} else if let Some(op_idx) = expr_str.find("<=") {
let right_part = expr_str[op_idx + 2..].trim();
if let Ok(threshold) = right_part.parse::<i64>() {
if let OwnedValue::Int(v) = value {
return *v <= threshold;
}
i += 1;
}
None
}

fn strip_outer_parens(s: &str) -> &str {
if !s.starts_with('(') || !s.ends_with(')') {
return s;
}

let inner = &s[1..s.len() - 1];
let mut depth = 0;
for c in inner.chars() {
match c {
'(' => depth += 1,
')' => {
depth -= 1;
if depth < 0 {
return s;
}
}
if let Ok(threshold) = right_part.parse::<f64>() {
if let OwnedValue::Float(v) = value {
return *v <= threshold;
_ => {}
}
}

if depth == 0 {
inner
} else {
s
}
}

fn eval_simple_comparison(expr_str: &str, col_name: &str, value: &OwnedValue) -> bool {
if !Self::contains_ignore_ascii_case(expr_str, col_name) {
return true;
}

if let Some((op, op_len, op_idx)) = Self::find_comparison_operator(expr_str) {
if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + op_len..]) {
return Self::compare_value_with_threshold(value, threshold, op);
}
}

false
}

fn find_comparison_operator(s: &str) -> Option<(CheckCompareOp, usize, usize)> {
let bytes = s.as_bytes();
let len = bytes.len();
let mut i = 0;

while i < len {
match bytes[i] {
b'>' => {
if i + 1 < len && bytes[i + 1] == b'=' {
return Some((CheckCompareOp::Ge, 2, i));
}
return Some((CheckCompareOp::Gt, 1, i));
}
} else if let Some(op_idx) = expr_str.find('>') {
let right_part = expr_str[op_idx + 1..].trim();
if let Ok(threshold) = right_part.parse::<i64>() {
if let OwnedValue::Int(v) = value {
return *v > threshold;
b'<' => {
if i + 1 < len && bytes[i + 1] == b'=' {
return Some((CheckCompareOp::Le, 2, i));
}
return Some((CheckCompareOp::Lt, 1, i));
}
if let Ok(threshold) = right_part.parse::<f64>() {
if let OwnedValue::Float(v) = value {
return *v > threshold;
}
_ => {}
}
i += 1;
}
None
}

fn contains_ignore_ascii_case(haystack: &str, needle: &str) -> bool {
let haystack_bytes = haystack.as_bytes();
let needle_bytes = needle.as_bytes();

if needle_bytes.is_empty() {
return true;
}
if needle_bytes.len() > haystack_bytes.len() {
return false;
}

haystack_bytes
.windows(needle_bytes.len())
.any(|window| window.eq_ignore_ascii_case(needle_bytes))
}

fn extract_numeric_operand(s: &str) -> Option<f64> {
let bytes = s.trim_start().as_bytes();
if bytes.is_empty() {
return None;
}

let mut i = 0;
let mut has_digit = false;
let mut has_dot = false;

if i < bytes.len() && (bytes[i] == b'-' || bytes[i] == b'+') {
i += 1;
}

while i < bytes.len() {
match bytes[i] {
b'0'..=b'9' => {
has_digit = true;
i += 1;
}
} else if let Some(op_idx) = expr_str.find('<') {
let right_part = expr_str[op_idx + 1..].trim();
if let Ok(threshold) = right_part.parse::<i64>() {
if let OwnedValue::Int(v) = value {
return *v < threshold;
}
b'.' if !has_dot => {
has_dot = true;
i += 1;
}
if let Ok(threshold) = right_part.parse::<f64>() {
if let OwnedValue::Float(v) = value {
return *v < threshold;
_ => break,
}
}

if !has_digit || i == 0 {
return None;
}

std::str::from_utf8(&bytes[..i])
.ok()
.and_then(|num_str| num_str.parse::<f64>().ok())
}

fn compare_value_with_threshold(value: &OwnedValue, threshold: f64, op: CheckCompareOp) -> bool {
match value {
OwnedValue::Int(v) => {
if threshold.fract() == 0.0
&& threshold >= i64::MIN as f64
&& threshold <= i64::MAX as f64
{
let threshold_i64 = threshold as i64;
match op {
CheckCompareOp::Ge => *v >= threshold_i64,
CheckCompareOp::Le => *v <= threshold_i64,
CheckCompareOp::Gt => *v > threshold_i64,
CheckCompareOp::Lt => *v < threshold_i64,
}
} else {
op.compare(*v as f64, threshold)
}
}
OwnedValue::Float(v) => op.compare(*v, threshold),
_ => false,
}

true
}

pub(crate) fn get_or_create_hnsw_index(
Expand Down
2 changes: 1 addition & 1 deletion src/database/dml/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ impl Database {
for constraint in col.constraints() {
if let Constraint::Check(expr_str) = constraint {
let col_value = values.get(col_idx);
if !Database::evaluate_check_expression(expr_str, col.name(), col_value) {
if !Database::evaluate_check_expression(expr_str, col.name(), col_value)? {
bail!(
"CHECK constraint violated on column '{}' in table '{}': {}",
col.name(),
Expand Down
6 changes: 3 additions & 3 deletions src/database/dml/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ impl Database {
expr_str,
col.name(),
col_value,
) {
)? {
bail!(
"CHECK constraint violated on column '{}' in table '{}': {}",
col.name(),
Expand Down Expand Up @@ -649,7 +649,7 @@ impl Database {
for constraint in col.constraints() {
if let Constraint::Check(expr_str) = constraint {
let col_value = row_values.get(col_idx);
if !Self::evaluate_check_expression(expr_str, col.name(), col_value)
if !Self::evaluate_check_expression(expr_str, col.name(), col_value)?
{
bail!(
"CHECK constraint violated on column '{}' in table '{}': {}",
Expand Down Expand Up @@ -1369,7 +1369,7 @@ impl Database {
for constraint in col.constraints() {
if let Constraint::Check(expr_str) = constraint {
let col_value = row_values.get(col_idx);
if !Self::evaluate_check_expression(expr_str, col.name(), col_value)
if !Self::evaluate_check_expression(expr_str, col.name(), col_value)?
{
bail!(
"CHECK constraint violated on column '{}' in table '{}': {}",
Expand Down
Loading