From 9189203c7ca48a884d59a3cd895e4887eb52db75 Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 13:13:59 +0800 Subject: [PATCH 1/5] fix(constraint): handle compound CHECK expressions with AND/OR operators The CHECK constraint evaluator was failing for compound expressions like `CHECK (age >= 0 AND age <= 150)` because it only looked for the first comparison operator and tried to parse everything after it as a number. This fix adds proper handling for: - AND expressions: recursively evaluate both sides with && - OR expressions: recursively evaluate both sides with || - Parentheses: strip outer parens and recurse - Numeric operands: extract only the numeric portion before operators The implementation respects parenthesis nesting to correctly split on top-level logical operators. Fixes #24 Co-Authored-By: Claude Opus 4.5 --- src/database/database.rs | 173 ++++++++++++++++++++++++++++----------- 1 file changed, 125 insertions(+), 48 deletions(-) diff --git a/src/database/database.rs b/src/database/database.rs index abf84cb..c3df039 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -4386,64 +4386,141 @@ impl Database { return true; } - let expr_lower = expr_str.to_lowercase(); - let col_lower = col_name.to_lowercase(); + Self::eval_check_expr_recursive(expr_str, col_name, value) + } - if expr_lower.contains(&col_lower) { - if let Some(op_idx) = expr_str.find(">=") { - let right_part = expr_str[op_idx + 2..].trim(); - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Int(v) = value { - return *v >= threshold; - } - } - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Float(v) = value { - return *v >= threshold; - } - } - } else if let Some(op_idx) = expr_str.find("<=") { - let right_part = expr_str[op_idx + 2..].trim(); - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Int(v) = value { - return *v <= threshold; - } - } - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Float(v) = value { - return *v <= threshold; - } - } - } else if let Some(op_idx) = expr_str.find('>') { - let right_part = expr_str[op_idx + 1..].trim(); - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Int(v) = value { - return *v > threshold; - } - } - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Float(v) = value { - return *v > threshold; - } - } - } else if let Some(op_idx) = expr_str.find('<') { - let right_part = expr_str[op_idx + 1..].trim(); - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Int(v) = value { - return *v < threshold; - } + fn eval_check_expr_recursive(expr_str: &str, col_name: &str, value: &OwnedValue) -> bool { + let trimmed = expr_str.trim(); + let lower = trimmed.to_lowercase(); + + if let Some((left, right)) = Self::split_on_logical_op(&lower, trimmed, " and ") { + return Self::eval_check_expr_recursive(left, col_name, value) + && Self::eval_check_expr_recursive(right, col_name, value); + } + + if let Some((left, right)) = Self::split_on_logical_op(&lower, trimmed, " or ") { + return Self::eval_check_expr_recursive(left, col_name, value) + || Self::eval_check_expr_recursive(right, col_name, value); + } + + let stripped = Self::strip_outer_parens(trimmed); + if stripped != trimmed { + return Self::eval_check_expr_recursive(stripped, col_name, value); + } + + Self::eval_simple_comparison(trimmed, col_name, value) + } + + fn split_on_logical_op<'a>( + lower: &str, + original: &'a str, + op: &str, + ) -> Option<(&'a str, &'a str)> { + let mut depth = 0; + let mut i = 0; + let lower_bytes = lower.as_bytes(); + let op_bytes = op.as_bytes(); + + while i + op.len() <= lower.len() { + let c = lower_bytes[i]; + if c == b'(' { + depth += 1; + } else if c == b')' { + depth -= 1; + } else if depth == 0 && lower_bytes[i..].starts_with(op_bytes) { + let left = original[..i].trim(); + let right = original[i + op.len()..].trim(); + if !left.is_empty() && !right.is_empty() { + return Some((left, right)); } - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Float(v) = value { - return *v < threshold; + } + i += 1; + } + None + } + + fn strip_outer_parens(s: &str) -> &str { + let trimmed = s.trim(); + if !trimmed.starts_with('(') || !trimmed.ends_with(')') { + return trimmed; + } + + let inner = &trimmed[1..trimmed.len() - 1]; + let mut depth = 0; + for c in inner.chars() { + match c { + '(' => depth += 1, + ')' => { + depth -= 1; + if depth < 0 { + return trimmed; } } + _ => {} + } + } + + if depth == 0 { + inner + } else { + trimmed + } + } + + fn eval_simple_comparison(expr_str: &str, col_name: &str, value: &OwnedValue) -> bool { + let expr_lower = expr_str.to_lowercase(); + let col_lower = col_name.to_lowercase(); + + if !expr_lower.contains(&col_lower) { + return true; + } + + if let Some(op_idx) = expr_str.find(">=") { + if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 2..]) { + return Self::compare_value_with_threshold(value, threshold, |v, t| v >= t); + } + } else if let Some(op_idx) = expr_str.find("<=") { + if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 2..]) { + return Self::compare_value_with_threshold(value, threshold, |v, t| v <= t); + } + } else if let Some(op_idx) = expr_str.find('>') { + if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 1..]) { + return Self::compare_value_with_threshold(value, threshold, |v, t| v > t); + } + } else if let Some(op_idx) = expr_str.find('<') { + if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 1..]) { + return Self::compare_value_with_threshold(value, threshold, |v, t| v < t); } } true } + fn extract_numeric_operand(s: &str) -> Option { + let trimmed = s.trim(); + let numeric_part: String = trimmed + .chars() + .take_while(|c| c.is_ascii_digit() || *c == '.' || *c == '-' || *c == '+') + .collect(); + + if numeric_part.is_empty() { + return None; + } + + numeric_part.parse::().ok() + } + + fn compare_value_with_threshold(value: &OwnedValue, threshold: f64, cmp: F) -> bool + where + F: Fn(f64, f64) -> bool, + { + match value { + OwnedValue::Int(v) => cmp(*v as f64, threshold), + OwnedValue::Float(v) => cmp(*v, threshold), + _ => true, + } + } + pub(crate) fn get_or_create_hnsw_index( &self, schema: &str, From d93bd77e3fd294be709ca7ac3488b73b0e671507 Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 13:35:22 +0800 Subject: [PATCH 2/5] fix(constraint): address code review - zero-alloc and precedence Address critical issues from code review: 1. Zero-allocation violations fixed: - Replace to_lowercase() with byte-level case-insensitive comparison - Replace chars().collect() with byte slicing for numeric extraction - Add contains_ignore_ascii_case() using eq_ignore_ascii_case on windows 2. Operator precedence fixed: - Check OR first (lower precedence), then AND (higher precedence) - "a OR b AND c" now correctly evaluates as "a OR (b AND c)" 3. Performance impact eliminated: - Previously: 7-9 allocations per CHECK evaluation - Now: 0 allocations (all byte-level operations) Co-Authored-By: Claude Opus 4.5 --- src/database/database.rs | 81 ++++++++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 28 deletions(-) diff --git a/src/database/database.rs b/src/database/database.rs index c3df039..c3e9a32 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -4391,16 +4391,16 @@ impl Database { fn eval_check_expr_recursive(expr_str: &str, col_name: &str, value: &OwnedValue) -> bool { let trimmed = expr_str.trim(); - let lower = trimmed.to_lowercase(); + let bytes = trimmed.as_bytes(); - if let Some((left, right)) = Self::split_on_logical_op(&lower, trimmed, " and ") { + if let Some((left, right)) = Self::split_on_logical_op_bytes(bytes, trimmed, b" or ", b" OR ") { return Self::eval_check_expr_recursive(left, col_name, value) - && Self::eval_check_expr_recursive(right, col_name, value); + || Self::eval_check_expr_recursive(right, col_name, value); } - if let Some((left, right)) = Self::split_on_logical_op(&lower, trimmed, " or ") { + if let Some((left, right)) = Self::split_on_logical_op_bytes(bytes, trimmed, b" and ", b" AND ") { return Self::eval_check_expr_recursive(left, col_name, value) - || Self::eval_check_expr_recursive(right, col_name, value); + && Self::eval_check_expr_recursive(right, col_name, value); } let stripped = Self::strip_outer_parens(trimmed); @@ -4411,27 +4411,29 @@ impl Database { Self::eval_simple_comparison(trimmed, col_name, value) } - fn split_on_logical_op<'a>( - lower: &str, + fn split_on_logical_op_bytes<'a>( + bytes: &[u8], original: &'a str, - op: &str, + op_lower: &[u8], + op_upper: &[u8], ) -> Option<(&'a str, &'a str)> { let mut depth = 0; let mut i = 0; - let lower_bytes = lower.as_bytes(); - let op_bytes = op.as_bytes(); - while i + op.len() <= lower.len() { - let c = lower_bytes[i]; + while i + op_lower.len() <= bytes.len() { + let c = bytes[i]; if c == b'(' { depth += 1; } else if c == b')' { depth -= 1; - } else if depth == 0 && lower_bytes[i..].starts_with(op_bytes) { - let left = original[..i].trim(); - let right = original[i + op.len()..].trim(); - if !left.is_empty() && !right.is_empty() { - return Some((left, right)); + } else if depth == 0 { + let slice = &bytes[i..]; + if slice.starts_with(op_lower) || slice.starts_with(op_upper) { + let left = original[..i].trim(); + let right = original[i + op_lower.len()..].trim(); + if !left.is_empty() && !right.is_empty() { + return Some((left, right)); + } } } i += 1; @@ -4468,10 +4470,7 @@ impl Database { } fn eval_simple_comparison(expr_str: &str, col_name: &str, value: &OwnedValue) -> bool { - let expr_lower = expr_str.to_lowercase(); - let col_lower = col_name.to_lowercase(); - - if !expr_lower.contains(&col_lower) { + if !Self::contains_ignore_ascii_case(expr_str, col_name) { return true; } @@ -4496,18 +4495,44 @@ impl Database { true } + fn contains_ignore_ascii_case(haystack: &str, needle: &str) -> bool { + let haystack_bytes = haystack.as_bytes(); + let needle_bytes = needle.as_bytes(); + + if needle_bytes.is_empty() { + return true; + } + if needle_bytes.len() > haystack_bytes.len() { + return false; + } + + haystack_bytes + .windows(needle_bytes.len()) + .any(|window| window.eq_ignore_ascii_case(needle_bytes)) + } + fn extract_numeric_operand(s: &str) -> Option { - let trimmed = s.trim(); - let numeric_part: String = trimmed - .chars() - .take_while(|c| c.is_ascii_digit() || *c == '.' || *c == '-' || *c == '+') - .collect(); + let bytes = s.trim_start().as_bytes(); + if bytes.is_empty() { + return None; + } + + let mut end = 0; + for &b in bytes { + if b.is_ascii_digit() || b == b'.' || b == b'-' || b == b'+' { + end += 1; + } else { + break; + } + } - if numeric_part.is_empty() { + if end == 0 { return None; } - numeric_part.parse::().ok() + std::str::from_utf8(&bytes[..end]) + .ok() + .and_then(|s| s.parse::().ok()) } fn compare_value_with_threshold(value: &OwnedValue, threshold: f64, cmp: F) -> bool From ab213cde5bcb28d5f8cf19f5a32dcad1a9f02954 Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 13:53:42 +0800 Subject: [PATCH 3/5] fix(constraint): address second code review round 1. Replace closure with CheckCompareOp enum - Closure |v, t| v >= t replaced with enum-based comparison - Enum uses match with #[derive(Clone, Copy)] for zero overhead 2. Fix case sensitivity for AND/OR - Use eq_ignore_ascii_case for keyword matching - Handles all case variants: 'Or', 'AnD', 'AND', etc. 3. Fail-closed semantics - Return false when expression cannot be evaluated - Data integrity: malformed constraints now reject inserts 4. Add recursion depth limit - MAX_CHECK_EXPR_DEPTH = 32 - Prevents stack overflow from deeply nested expressions Co-Authored-By: Claude Opus 4.5 --- src/database/database.rs | 84 ++++++++++++++++++++++++++-------------- 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/src/database/database.rs b/src/database/database.rs index c3e9a32..838299d 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -184,6 +184,25 @@ impl SharedDatabase { } } +#[derive(Clone, Copy)] +enum CheckCompareOp { + Ge, + Le, + Gt, + Lt, +} + +impl CheckCompareOp { + fn compare(self, lhs: f64, rhs: f64) -> bool { + match self { + CheckCompareOp::Ge => lhs >= rhs, + CheckCompareOp::Le => lhs <= rhs, + CheckCompareOp::Gt => lhs > rhs, + CheckCompareOp::Lt => lhs < rhs, + } + } +} + impl Database { pub fn open>(path: P) -> Result { Self::open_with_recovery(path).map(|(db, _)| db) @@ -4389,53 +4408,65 @@ impl Database { Self::eval_check_expr_recursive(expr_str, col_name, value) } + const MAX_CHECK_EXPR_DEPTH: usize = 32; + fn eval_check_expr_recursive(expr_str: &str, col_name: &str, value: &OwnedValue) -> bool { + Self::eval_check_expr_with_depth(expr_str, col_name, value, 0) + } + + fn eval_check_expr_with_depth( + expr_str: &str, + col_name: &str, + value: &OwnedValue, + depth: usize, + ) -> bool { + if depth >= Self::MAX_CHECK_EXPR_DEPTH { + return false; + } + let trimmed = expr_str.trim(); let bytes = trimmed.as_bytes(); - if let Some((left, right)) = Self::split_on_logical_op_bytes(bytes, trimmed, b" or ", b" OR ") { - return Self::eval_check_expr_recursive(left, col_name, value) - || Self::eval_check_expr_recursive(right, col_name, value); + if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(bytes, trimmed, b" or ") { + return Self::eval_check_expr_with_depth(left, col_name, value, depth + 1) + || Self::eval_check_expr_with_depth(right, col_name, value, depth + 1); } - if let Some((left, right)) = Self::split_on_logical_op_bytes(bytes, trimmed, b" and ", b" AND ") { - return Self::eval_check_expr_recursive(left, col_name, value) - && Self::eval_check_expr_recursive(right, col_name, value); + if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(bytes, trimmed, b" and ") { + return Self::eval_check_expr_with_depth(left, col_name, value, depth + 1) + && Self::eval_check_expr_with_depth(right, col_name, value, depth + 1); } let stripped = Self::strip_outer_parens(trimmed); if stripped != trimmed { - return Self::eval_check_expr_recursive(stripped, col_name, value); + return Self::eval_check_expr_with_depth(stripped, col_name, value, depth + 1); } Self::eval_simple_comparison(trimmed, col_name, value) } - fn split_on_logical_op_bytes<'a>( + fn split_on_logical_op_case_insensitive<'a>( bytes: &[u8], original: &'a str, - op_lower: &[u8], - op_upper: &[u8], + op: &[u8], ) -> Option<(&'a str, &'a str)> { let mut depth = 0; let mut i = 0; - while i + op_lower.len() <= bytes.len() { + while i + op.len() <= bytes.len() { let c = bytes[i]; if c == b'(' { depth += 1; } else if c == b')' { depth -= 1; - } else if depth == 0 { - let slice = &bytes[i..]; - if slice.starts_with(op_lower) || slice.starts_with(op_upper) { + } else if depth == 0 && bytes[i..].len() >= op.len() + && bytes[i..i + op.len()].eq_ignore_ascii_case(op) { let left = original[..i].trim(); - let right = original[i + op_lower.len()..].trim(); + let right = original[i + op.len()..].trim(); if !left.is_empty() && !right.is_empty() { return Some((left, right)); } } - } i += 1; } None @@ -4476,23 +4507,23 @@ impl Database { if let Some(op_idx) = expr_str.find(">=") { if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 2..]) { - return Self::compare_value_with_threshold(value, threshold, |v, t| v >= t); + return Self::compare_value_with_threshold(value, threshold, CheckCompareOp::Ge); } } else if let Some(op_idx) = expr_str.find("<=") { if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 2..]) { - return Self::compare_value_with_threshold(value, threshold, |v, t| v <= t); + return Self::compare_value_with_threshold(value, threshold, CheckCompareOp::Le); } } else if let Some(op_idx) = expr_str.find('>') { if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 1..]) { - return Self::compare_value_with_threshold(value, threshold, |v, t| v > t); + return Self::compare_value_with_threshold(value, threshold, CheckCompareOp::Gt); } } else if let Some(op_idx) = expr_str.find('<') { if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 1..]) { - return Self::compare_value_with_threshold(value, threshold, |v, t| v < t); + return Self::compare_value_with_threshold(value, threshold, CheckCompareOp::Lt); } } - true + false } fn contains_ignore_ascii_case(haystack: &str, needle: &str) -> bool { @@ -4535,14 +4566,11 @@ impl Database { .and_then(|s| s.parse::().ok()) } - fn compare_value_with_threshold(value: &OwnedValue, threshold: f64, cmp: F) -> bool - where - F: Fn(f64, f64) -> bool, - { + fn compare_value_with_threshold(value: &OwnedValue, threshold: f64, op: CheckCompareOp) -> bool { match value { - OwnedValue::Int(v) => cmp(*v as f64, threshold), - OwnedValue::Float(v) => cmp(*v, threshold), - _ => true, + OwnedValue::Int(v) => op.compare(*v as f64, threshold), + OwnedValue::Float(v) => op.compare(*v, threshold), + _ => false, } } From 0c11083fc29703f29f12679449aac129b7bb076a Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 14:17:03 +0800 Subject: [PATCH 4/5] fix(constraint): address third code review issues - Fix unbalanced parenthesis handling: return None when depth goes negative - Fix integer precision loss: compare as i64 when threshold is a whole number - Remove redundant bytes parameter from split_on_logical_op_case_insensitive - Remove double trim in strip_outer_parens (caller already trims) - Remove redundant bounds check in split function Co-Authored-By: Claude Opus 4.5 --- src/database/database.rs | 53 ++++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/src/database/database.rs b/src/database/database.rs index 838299d..2e9e813 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -4425,14 +4425,13 @@ impl Database { } let trimmed = expr_str.trim(); - let bytes = trimmed.as_bytes(); - if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(bytes, trimmed, b" or ") { + if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(trimmed, b" or ") { return Self::eval_check_expr_with_depth(left, col_name, value, depth + 1) || Self::eval_check_expr_with_depth(right, col_name, value, depth + 1); } - if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(bytes, trimmed, b" and ") { + if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(trimmed, b" and ") { return Self::eval_check_expr_with_depth(left, col_name, value, depth + 1) && Self::eval_check_expr_with_depth(right, col_name, value, depth + 1); } @@ -4446,11 +4445,11 @@ impl Database { } fn split_on_logical_op_case_insensitive<'a>( - bytes: &[u8], original: &'a str, op: &[u8], ) -> Option<(&'a str, &'a str)> { - let mut depth = 0; + let bytes = original.as_bytes(); + let mut depth: i32 = 0; let mut i = 0; while i + op.len() <= bytes.len() { @@ -4458,27 +4457,28 @@ impl Database { if c == b'(' { depth += 1; } else if c == b')' { + if depth == 0 { + return None; + } depth -= 1; - } else if depth == 0 && bytes[i..].len() >= op.len() - && bytes[i..i + op.len()].eq_ignore_ascii_case(op) { - let left = original[..i].trim(); - let right = original[i + op.len()..].trim(); - if !left.is_empty() && !right.is_empty() { - return Some((left, right)); - } + } else if depth == 0 && bytes[i..i + op.len()].eq_ignore_ascii_case(op) { + let left = original[..i].trim(); + let right = original[i + op.len()..].trim(); + if !left.is_empty() && !right.is_empty() { + return Some((left, right)); } + } i += 1; } None } fn strip_outer_parens(s: &str) -> &str { - let trimmed = s.trim(); - if !trimmed.starts_with('(') || !trimmed.ends_with(')') { - return trimmed; + if !s.starts_with('(') || !s.ends_with(')') { + return s; } - let inner = &trimmed[1..trimmed.len() - 1]; + let inner = &s[1..s.len() - 1]; let mut depth = 0; for c in inner.chars() { match c { @@ -4486,7 +4486,7 @@ impl Database { ')' => { depth -= 1; if depth < 0 { - return trimmed; + return s; } } _ => {} @@ -4496,7 +4496,7 @@ impl Database { if depth == 0 { inner } else { - trimmed + s } } @@ -4568,7 +4568,22 @@ impl Database { fn compare_value_with_threshold(value: &OwnedValue, threshold: f64, op: CheckCompareOp) -> bool { match value { - OwnedValue::Int(v) => op.compare(*v as f64, threshold), + OwnedValue::Int(v) => { + if threshold.fract() == 0.0 + && threshold >= i64::MIN as f64 + && threshold <= i64::MAX as f64 + { + let threshold_i64 = threshold as i64; + match op { + CheckCompareOp::Ge => *v >= threshold_i64, + CheckCompareOp::Le => *v <= threshold_i64, + CheckCompareOp::Gt => *v > threshold_i64, + CheckCompareOp::Lt => *v < threshold_i64, + } + } else { + op.compare(*v as f64, threshold) + } + } OwnedValue::Float(v) => op.compare(*v, threshold), _ => false, } From 28d41e04bcd3993eb6cff9059002bcefb8a27cc6 Mon Sep 17 00:00:00 2001 From: Mohammad Julfikar Date: Tue, 20 Jan 2026 15:13:36 +0800 Subject: [PATCH 5/5] fix(constraint): address fourth code review round - Return Result instead of bool when max depth exceeded - Use usize with saturating arithmetic for depth counter - Implement single-pass operator scanning instead of 4 sequential scans - Strengthen numeric validation with state machine - Add doc comments to CheckCompareOp enum and helpers - Support UnaryOp expressions (negative numbers) in expr_to_string - Add comprehensive edge case tests: - Negative numbers: CHECK (temp >= -273.15) - Float comparisons with range - Deeply nested parentheses - Mixed case operators (AnD) - Max depth limit with nested AND - Multiple OR conditions Co-Authored-By: Claude Opus 4.5 --- src/database/convert.rs | 12 +- src/database/database.rs | 119 +++++++---- src/database/dml/insert.rs | 2 +- src/database/dml/update.rs | 6 +- src/database/mod.rs | 41 ++++ tests/prepared_statement_constraints.rs | 261 ++++++++++++++++++++++++ 6 files changed, 398 insertions(+), 43 deletions(-) diff --git a/src/database/convert.rs b/src/database/convert.rs index 7ae7fbf..6307908 100644 --- a/src/database/convert.rs +++ b/src/database/convert.rs @@ -204,7 +204,7 @@ impl Database { } pub(crate) fn expr_to_string(expr: &crate::sql::ast::Expr<'_>) -> Option { - use crate::sql::ast::{BinaryOperator, Expr}; + use crate::sql::ast::{BinaryOperator, Expr, UnaryOperator}; match expr { Expr::BinaryOp { left, op, right } => { @@ -228,6 +228,16 @@ impl Database { }; Some(format!("{} {} {}", left_str, op_str, right_str)) } + Expr::UnaryOp { op, expr: inner } => { + let inner_str = Self::expr_to_string(inner)?; + let op_str = match op { + UnaryOperator::Minus => "-", + UnaryOperator::Plus => "+", + UnaryOperator::Not => "NOT ", + UnaryOperator::BitwiseNot => "~", + }; + Some(format!("{}{}", op_str, inner_str)) + } Expr::Column(col_ref) => Some(col_ref.column.to_string()), Expr::Literal(lit) => match lit { crate::sql::ast::Literal::Integer(n) => Some(n.to_string()), diff --git a/src/database/database.rs b/src/database/database.rs index 2e9e813..ba4598f 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -184,15 +184,21 @@ impl SharedDatabase { } } +/// Comparison operators for CHECK constraint evaluation. #[derive(Clone, Copy)] enum CheckCompareOp { + /// Greater than or equal (>=) Ge, + /// Less than or equal (<=) Le, + /// Greater than (>) Gt, + /// Less than (<) Lt, } impl CheckCompareOp { + /// Compares two f64 values using this operator. fn compare(self, lhs: f64, rhs: f64) -> bool { match self { CheckCompareOp::Ge => lhs >= rhs, @@ -4396,13 +4402,13 @@ impl Database { expr_str: &str, col_name: &str, col_value: Option<&OwnedValue>, - ) -> bool { + ) -> Result { let Some(value) = col_value else { - return true; + return Ok(true); }; if value.is_null() { - return true; + return Ok(true); } Self::eval_check_expr_recursive(expr_str, col_name, value) @@ -4410,7 +4416,11 @@ impl Database { const MAX_CHECK_EXPR_DEPTH: usize = 32; - fn eval_check_expr_recursive(expr_str: &str, col_name: &str, value: &OwnedValue) -> bool { + fn eval_check_expr_recursive( + expr_str: &str, + col_name: &str, + value: &OwnedValue, + ) -> Result { Self::eval_check_expr_with_depth(expr_str, col_name, value, 0) } @@ -4419,21 +4429,27 @@ impl Database { col_name: &str, value: &OwnedValue, depth: usize, - ) -> bool { + ) -> Result { if depth >= Self::MAX_CHECK_EXPR_DEPTH { - return false; + bail!( + "CHECK constraint expression exceeds maximum nesting depth of {} for column '{}'", + Self::MAX_CHECK_EXPR_DEPTH, + col_name + ); } let trimmed = expr_str.trim(); if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(trimmed, b" or ") { - return Self::eval_check_expr_with_depth(left, col_name, value, depth + 1) - || Self::eval_check_expr_with_depth(right, col_name, value, depth + 1); + let left_result = Self::eval_check_expr_with_depth(left, col_name, value, depth + 1)?; + let right_result = Self::eval_check_expr_with_depth(right, col_name, value, depth + 1)?; + return Ok(left_result || right_result); } if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(trimmed, b" and ") { - return Self::eval_check_expr_with_depth(left, col_name, value, depth + 1) - && Self::eval_check_expr_with_depth(right, col_name, value, depth + 1); + let left_result = Self::eval_check_expr_with_depth(left, col_name, value, depth + 1)?; + let right_result = Self::eval_check_expr_with_depth(right, col_name, value, depth + 1)?; + return Ok(left_result && right_result); } let stripped = Self::strip_outer_parens(trimmed); @@ -4441,7 +4457,7 @@ impl Database { return Self::eval_check_expr_with_depth(stripped, col_name, value, depth + 1); } - Self::eval_simple_comparison(trimmed, col_name, value) + Ok(Self::eval_simple_comparison(trimmed, col_name, value)) } fn split_on_logical_op_case_insensitive<'a>( @@ -4449,18 +4465,18 @@ impl Database { op: &[u8], ) -> Option<(&'a str, &'a str)> { let bytes = original.as_bytes(); - let mut depth: i32 = 0; + let mut depth: usize = 0; let mut i = 0; while i + op.len() <= bytes.len() { let c = bytes[i]; if c == b'(' { - depth += 1; + depth = depth.saturating_add(1); } else if c == b')' { if depth == 0 { return None; } - depth -= 1; + depth = depth.saturating_sub(1); } else if depth == 0 && bytes[i..i + op.len()].eq_ignore_ascii_case(op) { let left = original[..i].trim(); let right = original[i + op.len()..].trim(); @@ -4505,27 +4521,41 @@ impl Database { return true; } - if let Some(op_idx) = expr_str.find(">=") { - if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 2..]) { - return Self::compare_value_with_threshold(value, threshold, CheckCompareOp::Ge); - } - } else if let Some(op_idx) = expr_str.find("<=") { - if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 2..]) { - return Self::compare_value_with_threshold(value, threshold, CheckCompareOp::Le); - } - } else if let Some(op_idx) = expr_str.find('>') { - if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 1..]) { - return Self::compare_value_with_threshold(value, threshold, CheckCompareOp::Gt); - } - } else if let Some(op_idx) = expr_str.find('<') { - if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + 1..]) { - return Self::compare_value_with_threshold(value, threshold, CheckCompareOp::Lt); + if let Some((op, op_len, op_idx)) = Self::find_comparison_operator(expr_str) { + if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + op_len..]) { + return Self::compare_value_with_threshold(value, threshold, op); } } false } + fn find_comparison_operator(s: &str) -> Option<(CheckCompareOp, usize, usize)> { + let bytes = s.as_bytes(); + let len = bytes.len(); + let mut i = 0; + + while i < len { + match bytes[i] { + b'>' => { + if i + 1 < len && bytes[i + 1] == b'=' { + return Some((CheckCompareOp::Ge, 2, i)); + } + return Some((CheckCompareOp::Gt, 1, i)); + } + b'<' => { + if i + 1 < len && bytes[i + 1] == b'=' { + return Some((CheckCompareOp::Le, 2, i)); + } + return Some((CheckCompareOp::Lt, 1, i)); + } + _ => {} + } + i += 1; + } + None + } + fn contains_ignore_ascii_case(haystack: &str, needle: &str) -> bool { let haystack_bytes = haystack.as_bytes(); let needle_bytes = needle.as_bytes(); @@ -4548,22 +4578,35 @@ impl Database { return None; } - let mut end = 0; - for &b in bytes { - if b.is_ascii_digit() || b == b'.' || b == b'-' || b == b'+' { - end += 1; - } else { - break; + let mut i = 0; + let mut has_digit = false; + let mut has_dot = false; + + if i < bytes.len() && (bytes[i] == b'-' || bytes[i] == b'+') { + i += 1; + } + + while i < bytes.len() { + match bytes[i] { + b'0'..=b'9' => { + has_digit = true; + i += 1; + } + b'.' if !has_dot => { + has_dot = true; + i += 1; + } + _ => break, } } - if end == 0 { + if !has_digit || i == 0 { return None; } - std::str::from_utf8(&bytes[..end]) + std::str::from_utf8(&bytes[..i]) .ok() - .and_then(|s| s.parse::().ok()) + .and_then(|num_str| num_str.parse::().ok()) } fn compare_value_with_threshold(value: &OwnedValue, threshold: f64, op: CheckCompareOp) -> bool { diff --git a/src/database/dml/insert.rs b/src/database/dml/insert.rs index 07c3ff7..40d2dbb 100644 --- a/src/database/dml/insert.rs +++ b/src/database/dml/insert.rs @@ -567,7 +567,7 @@ impl Database { for constraint in col.constraints() { if let Constraint::Check(expr_str) = constraint { let col_value = values.get(col_idx); - if !Database::evaluate_check_expression(expr_str, col.name(), col_value) { + if !Database::evaluate_check_expression(expr_str, col.name(), col_value)? { bail!( "CHECK constraint violated on column '{}' in table '{}': {}", col.name(), diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index d656e9c..0c4839f 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -490,7 +490,7 @@ impl Database { expr_str, col.name(), col_value, - ) { + )? { bail!( "CHECK constraint violated on column '{}' in table '{}': {}", col.name(), @@ -649,7 +649,7 @@ impl Database { for constraint in col.constraints() { if let Constraint::Check(expr_str) = constraint { let col_value = row_values.get(col_idx); - if !Self::evaluate_check_expression(expr_str, col.name(), col_value) + if !Self::evaluate_check_expression(expr_str, col.name(), col_value)? { bail!( "CHECK constraint violated on column '{}' in table '{}': {}", @@ -1369,7 +1369,7 @@ impl Database { for constraint in col.constraints() { if let Constraint::Check(expr_str) = constraint { let col_value = row_values.get(col_idx); - if !Self::evaluate_check_expression(expr_str, col.name(), col_value) + if !Self::evaluate_check_expression(expr_str, col.name(), col_value)? { bail!( "CHECK constraint violated on column '{}' in table '{}': {}", diff --git a/src/database/mod.rs b/src/database/mod.rs index c2cac1a..2c2fd5f 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1408,6 +1408,47 @@ mod tests { ); } + #[test] + fn test_evaluate_check_expression_with_negative_threshold() { + use crate::types::OwnedValue; + + let result = Database::evaluate_check_expression( + "temp >= -273.15", + "temp", + Some(&OwnedValue::Float(-100.0)), + ) + .unwrap(); + assert!(result, "-100.0 >= -273.15 should be true"); + + let result = Database::evaluate_check_expression( + "temp >= -273.15", + "temp", + Some(&OwnedValue::Float(-300.0)), + ) + .unwrap(); + assert!(!result, "-300.0 >= -273.15 should be false"); + } + + #[test] + fn test_check_constraint_with_negative_threshold_rejects_value() { + let dir = tempdir().unwrap(); + let db_path = dir.path().join("test_db"); + + let db = Database::create(&db_path).unwrap(); + + db.execute("CREATE TABLE temps (id INT, temp REAL CHECK(temp >= -273.15))") + .unwrap(); + + db.execute("INSERT INTO temps VALUES (1, -100.0)") + .expect("Insert with -100.0 (above -273.15) should succeed"); + + let result = db.execute("INSERT INTO temps VALUES (2, -300.0)"); + assert!( + result.is_err(), + "CHECK constraint should reject -300.0 (below -273.15)" + ); + } + #[test] fn test_foreign_key_rejects_missing_reference_on_insert() { let dir = tempdir().unwrap(); diff --git a/tests/prepared_statement_constraints.rs b/tests/prepared_statement_constraints.rs index 3e8047d..4ff1bdc 100644 --- a/tests/prepared_statement_constraints.rs +++ b/tests/prepared_statement_constraints.rs @@ -375,6 +375,267 @@ mod check_constraint_tests { result.unwrap_err() ); } + + #[test] + fn check_constraint_handles_negative_numbers() { + let (_dir, db) = create_test_db(); + + db.execute( + "CREATE TABLE temperatures ( + id INTEGER PRIMARY KEY, + temp REAL CHECK(temp >= -273.15) + )", + ) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO temperatures VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_valid = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Float(-100.0)) + .execute(&db); + + assert!( + result_valid.is_ok(), + "INSERT with -100.0 (above -273.15) should succeed" + ); + + let stmt2 = db + .prepare("INSERT INTO temperatures VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_invalid = stmt2 + .bind(OwnedValue::Int(2)) + .bind(OwnedValue::Float(-300.0)) + .execute(&db); + + assert!( + result_invalid.is_err(), + "INSERT with -300.0 (below -273.15) should fail" + ); + } + + #[test] + fn check_constraint_handles_float_range() { + let (_dir, db) = create_test_db(); + + db.execute( + "CREATE TABLE ratios ( + id INTEGER PRIMARY KEY, + ratio REAL CHECK(ratio >= 0.0 AND ratio <= 1.0) + )", + ) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO ratios VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_valid = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Float(0.5)) + .execute(&db); + + assert!( + result_valid.is_ok(), + "INSERT with 0.5 should succeed" + ); + + let stmt2 = db + .prepare("INSERT INTO ratios VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_invalid = stmt2 + .bind(OwnedValue::Int(2)) + .bind(OwnedValue::Float(1.5)) + .execute(&db); + + assert!( + result_invalid.is_err(), + "INSERT with 1.5 (above 1.0) should fail" + ); + } + + #[test] + fn check_constraint_handles_deeply_nested_parentheses() { + let (_dir, db) = create_test_db(); + + db.execute( + "CREATE TABLE nested_check ( + id INTEGER PRIMARY KEY, + age INTEGER CHECK((((age >= 0)))) + )", + ) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO nested_check VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_valid = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Int(25)) + .execute(&db); + + assert!( + result_valid.is_ok(), + "INSERT with valid value through nested parens should succeed" + ); + + let stmt2 = db + .prepare("INSERT INTO nested_check VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_invalid = stmt2 + .bind(OwnedValue::Int(2)) + .bind(OwnedValue::Int(-5)) + .execute(&db); + + assert!( + result_invalid.is_err(), + "INSERT with invalid value through nested parens should fail" + ); + } + + #[test] + fn check_constraint_handles_mixed_case_operators() { + let (_dir, db) = create_test_db(); + + db.execute( + "CREATE TABLE mixed_case ( + id INTEGER PRIMARY KEY, + age INTEGER CHECK(age >= 0 AnD age <= 150) + )", + ) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO mixed_case VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_valid = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Int(75)) + .execute(&db); + + assert!( + result_valid.is_ok(), + "INSERT should succeed with mixed case AND operator" + ); + + let stmt2 = db + .prepare("INSERT INTO mixed_case VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_high = stmt2 + .bind(OwnedValue::Int(2)) + .bind(OwnedValue::Int(200)) + .execute(&db); + + assert!( + result_high.is_err(), + "INSERT with value above 150 should fail" + ); + } + + #[test] + fn check_constraint_handles_multiple_or_conditions() { + let (_dir, db) = create_test_db(); + + db.execute( + "CREATE TABLE status_codes ( + id INTEGER PRIMARY KEY, + code INTEGER CHECK(code >= 200 AND code <= 299 OR code >= 400 AND code <= 499) + )", + ) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO status_codes VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_200 = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Int(200)) + .execute(&db); + + assert!( + result_200.is_ok(), + "INSERT with 200 should succeed" + ); + + let stmt2 = db + .prepare("INSERT INTO status_codes VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_404 = stmt2 + .bind(OwnedValue::Int(2)) + .bind(OwnedValue::Int(404)) + .execute(&db); + + assert!( + result_404.is_ok(), + "INSERT with 404 should succeed" + ); + + let stmt3 = db + .prepare("INSERT INTO status_codes VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_300 = stmt3 + .bind(OwnedValue::Int(3)) + .bind(OwnedValue::Int(300)) + .execute(&db); + + assert!( + result_300.is_err(), + "INSERT with 300 should fail (not in valid ranges)" + ); + } + + #[test] + fn check_constraint_max_depth_returns_error() { + let (_dir, db) = create_test_db(); + + let mut nested_expr = "age >= 0".to_string(); + for i in 1..35 { + nested_expr = format!("({} AND age <= {})", nested_expr, 100 + i); + } + + let create_sql = format!( + "CREATE TABLE deep_nesting ( + id INTEGER PRIMARY KEY, + age INTEGER CHECK({}) + )", + nested_expr + ); + + db.execute(&create_sql) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO deep_nesting VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Int(25)) + .execute(&db); + + assert!( + result.is_err(), + "INSERT with deeply nested AND expression (>32 levels) should return error" + ); + + let err_msg = result.unwrap_err().to_string().to_lowercase(); + assert!( + err_msg.contains("depth") || err_msg.contains("nesting") || err_msg.contains("exceed"), + "Error should mention depth/nesting exceeded, got: {}", + err_msg + ); + } } mod not_null_tests {