diff --git a/src/database/convert.rs b/src/database/convert.rs index 7ae7fbf..6307908 100644 --- a/src/database/convert.rs +++ b/src/database/convert.rs @@ -204,7 +204,7 @@ impl Database { } pub(crate) fn expr_to_string(expr: &crate::sql::ast::Expr<'_>) -> Option { - use crate::sql::ast::{BinaryOperator, Expr}; + use crate::sql::ast::{BinaryOperator, Expr, UnaryOperator}; match expr { Expr::BinaryOp { left, op, right } => { @@ -228,6 +228,16 @@ impl Database { }; Some(format!("{} {} {}", left_str, op_str, right_str)) } + Expr::UnaryOp { op, expr: inner } => { + let inner_str = Self::expr_to_string(inner)?; + let op_str = match op { + UnaryOperator::Minus => "-", + UnaryOperator::Plus => "+", + UnaryOperator::Not => "NOT ", + UnaryOperator::BitwiseNot => "~", + }; + Some(format!("{}{}", op_str, inner_str)) + } Expr::Column(col_ref) => Some(col_ref.column.to_string()), Expr::Literal(lit) => match lit { crate::sql::ast::Literal::Integer(n) => Some(n.to_string()), diff --git a/src/database/database.rs b/src/database/database.rs index abf84cb..ba4598f 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -184,6 +184,31 @@ impl SharedDatabase { } } +/// Comparison operators for CHECK constraint evaluation. +#[derive(Clone, Copy)] +enum CheckCompareOp { + /// Greater than or equal (>=) + Ge, + /// Less than or equal (<=) + Le, + /// Greater than (>) + Gt, + /// Less than (<) + Lt, +} + +impl CheckCompareOp { + /// Compares two f64 values using this operator. + fn compare(self, lhs: f64, rhs: f64) -> bool { + match self { + CheckCompareOp::Ge => lhs >= rhs, + CheckCompareOp::Le => lhs <= rhs, + CheckCompareOp::Gt => lhs > rhs, + CheckCompareOp::Lt => lhs < rhs, + } + } +} + impl Database { pub fn open>(path: P) -> Result { Self::open_with_recovery(path).map(|(db, _)| db) @@ -4377,71 +4402,234 @@ impl Database { expr_str: &str, col_name: &str, col_value: Option<&OwnedValue>, - ) -> bool { + ) -> Result { let Some(value) = col_value else { - return true; + return Ok(true); }; if value.is_null() { - return true; + return Ok(true); } - let expr_lower = expr_str.to_lowercase(); - let col_lower = col_name.to_lowercase(); + Self::eval_check_expr_recursive(expr_str, col_name, value) + } - if expr_lower.contains(&col_lower) { - if let Some(op_idx) = expr_str.find(">=") { - let right_part = expr_str[op_idx + 2..].trim(); - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Int(v) = value { - return *v >= threshold; - } + const MAX_CHECK_EXPR_DEPTH: usize = 32; + + fn eval_check_expr_recursive( + expr_str: &str, + col_name: &str, + value: &OwnedValue, + ) -> Result { + Self::eval_check_expr_with_depth(expr_str, col_name, value, 0) + } + + fn eval_check_expr_with_depth( + expr_str: &str, + col_name: &str, + value: &OwnedValue, + depth: usize, + ) -> Result { + if depth >= Self::MAX_CHECK_EXPR_DEPTH { + bail!( + "CHECK constraint expression exceeds maximum nesting depth of {} for column '{}'", + Self::MAX_CHECK_EXPR_DEPTH, + col_name + ); + } + + let trimmed = expr_str.trim(); + + if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(trimmed, b" or ") { + let left_result = Self::eval_check_expr_with_depth(left, col_name, value, depth + 1)?; + let right_result = Self::eval_check_expr_with_depth(right, col_name, value, depth + 1)?; + return Ok(left_result || right_result); + } + + if let Some((left, right)) = Self::split_on_logical_op_case_insensitive(trimmed, b" and ") { + let left_result = Self::eval_check_expr_with_depth(left, col_name, value, depth + 1)?; + let right_result = Self::eval_check_expr_with_depth(right, col_name, value, depth + 1)?; + return Ok(left_result && right_result); + } + + let stripped = Self::strip_outer_parens(trimmed); + if stripped != trimmed { + return Self::eval_check_expr_with_depth(stripped, col_name, value, depth + 1); + } + + Ok(Self::eval_simple_comparison(trimmed, col_name, value)) + } + + fn split_on_logical_op_case_insensitive<'a>( + original: &'a str, + op: &[u8], + ) -> Option<(&'a str, &'a str)> { + let bytes = original.as_bytes(); + let mut depth: usize = 0; + let mut i = 0; + + while i + op.len() <= bytes.len() { + let c = bytes[i]; + if c == b'(' { + depth = depth.saturating_add(1); + } else if c == b')' { + if depth == 0 { + return None; } - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Float(v) = value { - return *v >= threshold; - } + depth = depth.saturating_sub(1); + } else if depth == 0 && bytes[i..i + op.len()].eq_ignore_ascii_case(op) { + let left = original[..i].trim(); + let right = original[i + op.len()..].trim(); + if !left.is_empty() && !right.is_empty() { + return Some((left, right)); } - } else if let Some(op_idx) = expr_str.find("<=") { - let right_part = expr_str[op_idx + 2..].trim(); - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Int(v) = value { - return *v <= threshold; + } + i += 1; + } + None + } + + fn strip_outer_parens(s: &str) -> &str { + if !s.starts_with('(') || !s.ends_with(')') { + return s; + } + + let inner = &s[1..s.len() - 1]; + let mut depth = 0; + for c in inner.chars() { + match c { + '(' => depth += 1, + ')' => { + depth -= 1; + if depth < 0 { + return s; } } - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Float(v) = value { - return *v <= threshold; + _ => {} + } + } + + if depth == 0 { + inner + } else { + s + } + } + + fn eval_simple_comparison(expr_str: &str, col_name: &str, value: &OwnedValue) -> bool { + if !Self::contains_ignore_ascii_case(expr_str, col_name) { + return true; + } + + if let Some((op, op_len, op_idx)) = Self::find_comparison_operator(expr_str) { + if let Some(threshold) = Self::extract_numeric_operand(&expr_str[op_idx + op_len..]) { + return Self::compare_value_with_threshold(value, threshold, op); + } + } + + false + } + + fn find_comparison_operator(s: &str) -> Option<(CheckCompareOp, usize, usize)> { + let bytes = s.as_bytes(); + let len = bytes.len(); + let mut i = 0; + + while i < len { + match bytes[i] { + b'>' => { + if i + 1 < len && bytes[i + 1] == b'=' { + return Some((CheckCompareOp::Ge, 2, i)); } + return Some((CheckCompareOp::Gt, 1, i)); } - } else if let Some(op_idx) = expr_str.find('>') { - let right_part = expr_str[op_idx + 1..].trim(); - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Int(v) = value { - return *v > threshold; + b'<' => { + if i + 1 < len && bytes[i + 1] == b'=' { + return Some((CheckCompareOp::Le, 2, i)); } + return Some((CheckCompareOp::Lt, 1, i)); } - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Float(v) = value { - return *v > threshold; - } + _ => {} + } + i += 1; + } + None + } + + fn contains_ignore_ascii_case(haystack: &str, needle: &str) -> bool { + let haystack_bytes = haystack.as_bytes(); + let needle_bytes = needle.as_bytes(); + + if needle_bytes.is_empty() { + return true; + } + if needle_bytes.len() > haystack_bytes.len() { + return false; + } + + haystack_bytes + .windows(needle_bytes.len()) + .any(|window| window.eq_ignore_ascii_case(needle_bytes)) + } + + fn extract_numeric_operand(s: &str) -> Option { + let bytes = s.trim_start().as_bytes(); + if bytes.is_empty() { + return None; + } + + let mut i = 0; + let mut has_digit = false; + let mut has_dot = false; + + if i < bytes.len() && (bytes[i] == b'-' || bytes[i] == b'+') { + i += 1; + } + + while i < bytes.len() { + match bytes[i] { + b'0'..=b'9' => { + has_digit = true; + i += 1; } - } else if let Some(op_idx) = expr_str.find('<') { - let right_part = expr_str[op_idx + 1..].trim(); - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Int(v) = value { - return *v < threshold; - } + b'.' if !has_dot => { + has_dot = true; + i += 1; } - if let Ok(threshold) = right_part.parse::() { - if let OwnedValue::Float(v) = value { - return *v < threshold; + _ => break, + } + } + + if !has_digit || i == 0 { + return None; + } + + std::str::from_utf8(&bytes[..i]) + .ok() + .and_then(|num_str| num_str.parse::().ok()) + } + + fn compare_value_with_threshold(value: &OwnedValue, threshold: f64, op: CheckCompareOp) -> bool { + match value { + OwnedValue::Int(v) => { + if threshold.fract() == 0.0 + && threshold >= i64::MIN as f64 + && threshold <= i64::MAX as f64 + { + let threshold_i64 = threshold as i64; + match op { + CheckCompareOp::Ge => *v >= threshold_i64, + CheckCompareOp::Le => *v <= threshold_i64, + CheckCompareOp::Gt => *v > threshold_i64, + CheckCompareOp::Lt => *v < threshold_i64, } + } else { + op.compare(*v as f64, threshold) } } + OwnedValue::Float(v) => op.compare(*v, threshold), + _ => false, } - - true } pub(crate) fn get_or_create_hnsw_index( diff --git a/src/database/dml/insert.rs b/src/database/dml/insert.rs index 07c3ff7..40d2dbb 100644 --- a/src/database/dml/insert.rs +++ b/src/database/dml/insert.rs @@ -567,7 +567,7 @@ impl Database { for constraint in col.constraints() { if let Constraint::Check(expr_str) = constraint { let col_value = values.get(col_idx); - if !Database::evaluate_check_expression(expr_str, col.name(), col_value) { + if !Database::evaluate_check_expression(expr_str, col.name(), col_value)? { bail!( "CHECK constraint violated on column '{}' in table '{}': {}", col.name(), diff --git a/src/database/dml/update.rs b/src/database/dml/update.rs index d656e9c..0c4839f 100644 --- a/src/database/dml/update.rs +++ b/src/database/dml/update.rs @@ -490,7 +490,7 @@ impl Database { expr_str, col.name(), col_value, - ) { + )? { bail!( "CHECK constraint violated on column '{}' in table '{}': {}", col.name(), @@ -649,7 +649,7 @@ impl Database { for constraint in col.constraints() { if let Constraint::Check(expr_str) = constraint { let col_value = row_values.get(col_idx); - if !Self::evaluate_check_expression(expr_str, col.name(), col_value) + if !Self::evaluate_check_expression(expr_str, col.name(), col_value)? { bail!( "CHECK constraint violated on column '{}' in table '{}': {}", @@ -1369,7 +1369,7 @@ impl Database { for constraint in col.constraints() { if let Constraint::Check(expr_str) = constraint { let col_value = row_values.get(col_idx); - if !Self::evaluate_check_expression(expr_str, col.name(), col_value) + if !Self::evaluate_check_expression(expr_str, col.name(), col_value)? { bail!( "CHECK constraint violated on column '{}' in table '{}': {}", diff --git a/src/database/mod.rs b/src/database/mod.rs index c2cac1a..2c2fd5f 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1408,6 +1408,47 @@ mod tests { ); } + #[test] + fn test_evaluate_check_expression_with_negative_threshold() { + use crate::types::OwnedValue; + + let result = Database::evaluate_check_expression( + "temp >= -273.15", + "temp", + Some(&OwnedValue::Float(-100.0)), + ) + .unwrap(); + assert!(result, "-100.0 >= -273.15 should be true"); + + let result = Database::evaluate_check_expression( + "temp >= -273.15", + "temp", + Some(&OwnedValue::Float(-300.0)), + ) + .unwrap(); + assert!(!result, "-300.0 >= -273.15 should be false"); + } + + #[test] + fn test_check_constraint_with_negative_threshold_rejects_value() { + let dir = tempdir().unwrap(); + let db_path = dir.path().join("test_db"); + + let db = Database::create(&db_path).unwrap(); + + db.execute("CREATE TABLE temps (id INT, temp REAL CHECK(temp >= -273.15))") + .unwrap(); + + db.execute("INSERT INTO temps VALUES (1, -100.0)") + .expect("Insert with -100.0 (above -273.15) should succeed"); + + let result = db.execute("INSERT INTO temps VALUES (2, -300.0)"); + assert!( + result.is_err(), + "CHECK constraint should reject -300.0 (below -273.15)" + ); + } + #[test] fn test_foreign_key_rejects_missing_reference_on_insert() { let dir = tempdir().unwrap(); diff --git a/tests/prepared_statement_constraints.rs b/tests/prepared_statement_constraints.rs index 3e8047d..4ff1bdc 100644 --- a/tests/prepared_statement_constraints.rs +++ b/tests/prepared_statement_constraints.rs @@ -375,6 +375,267 @@ mod check_constraint_tests { result.unwrap_err() ); } + + #[test] + fn check_constraint_handles_negative_numbers() { + let (_dir, db) = create_test_db(); + + db.execute( + "CREATE TABLE temperatures ( + id INTEGER PRIMARY KEY, + temp REAL CHECK(temp >= -273.15) + )", + ) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO temperatures VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_valid = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Float(-100.0)) + .execute(&db); + + assert!( + result_valid.is_ok(), + "INSERT with -100.0 (above -273.15) should succeed" + ); + + let stmt2 = db + .prepare("INSERT INTO temperatures VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_invalid = stmt2 + .bind(OwnedValue::Int(2)) + .bind(OwnedValue::Float(-300.0)) + .execute(&db); + + assert!( + result_invalid.is_err(), + "INSERT with -300.0 (below -273.15) should fail" + ); + } + + #[test] + fn check_constraint_handles_float_range() { + let (_dir, db) = create_test_db(); + + db.execute( + "CREATE TABLE ratios ( + id INTEGER PRIMARY KEY, + ratio REAL CHECK(ratio >= 0.0 AND ratio <= 1.0) + )", + ) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO ratios VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_valid = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Float(0.5)) + .execute(&db); + + assert!( + result_valid.is_ok(), + "INSERT with 0.5 should succeed" + ); + + let stmt2 = db + .prepare("INSERT INTO ratios VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_invalid = stmt2 + .bind(OwnedValue::Int(2)) + .bind(OwnedValue::Float(1.5)) + .execute(&db); + + assert!( + result_invalid.is_err(), + "INSERT with 1.5 (above 1.0) should fail" + ); + } + + #[test] + fn check_constraint_handles_deeply_nested_parentheses() { + let (_dir, db) = create_test_db(); + + db.execute( + "CREATE TABLE nested_check ( + id INTEGER PRIMARY KEY, + age INTEGER CHECK((((age >= 0)))) + )", + ) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO nested_check VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_valid = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Int(25)) + .execute(&db); + + assert!( + result_valid.is_ok(), + "INSERT with valid value through nested parens should succeed" + ); + + let stmt2 = db + .prepare("INSERT INTO nested_check VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_invalid = stmt2 + .bind(OwnedValue::Int(2)) + .bind(OwnedValue::Int(-5)) + .execute(&db); + + assert!( + result_invalid.is_err(), + "INSERT with invalid value through nested parens should fail" + ); + } + + #[test] + fn check_constraint_handles_mixed_case_operators() { + let (_dir, db) = create_test_db(); + + db.execute( + "CREATE TABLE mixed_case ( + id INTEGER PRIMARY KEY, + age INTEGER CHECK(age >= 0 AnD age <= 150) + )", + ) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO mixed_case VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_valid = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Int(75)) + .execute(&db); + + assert!( + result_valid.is_ok(), + "INSERT should succeed with mixed case AND operator" + ); + + let stmt2 = db + .prepare("INSERT INTO mixed_case VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_high = stmt2 + .bind(OwnedValue::Int(2)) + .bind(OwnedValue::Int(200)) + .execute(&db); + + assert!( + result_high.is_err(), + "INSERT with value above 150 should fail" + ); + } + + #[test] + fn check_constraint_handles_multiple_or_conditions() { + let (_dir, db) = create_test_db(); + + db.execute( + "CREATE TABLE status_codes ( + id INTEGER PRIMARY KEY, + code INTEGER CHECK(code >= 200 AND code <= 299 OR code >= 400 AND code <= 499) + )", + ) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO status_codes VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_200 = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Int(200)) + .execute(&db); + + assert!( + result_200.is_ok(), + "INSERT with 200 should succeed" + ); + + let stmt2 = db + .prepare("INSERT INTO status_codes VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_404 = stmt2 + .bind(OwnedValue::Int(2)) + .bind(OwnedValue::Int(404)) + .execute(&db); + + assert!( + result_404.is_ok(), + "INSERT with 404 should succeed" + ); + + let stmt3 = db + .prepare("INSERT INTO status_codes VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result_300 = stmt3 + .bind(OwnedValue::Int(3)) + .bind(OwnedValue::Int(300)) + .execute(&db); + + assert!( + result_300.is_err(), + "INSERT with 300 should fail (not in valid ranges)" + ); + } + + #[test] + fn check_constraint_max_depth_returns_error() { + let (_dir, db) = create_test_db(); + + let mut nested_expr = "age >= 0".to_string(); + for i in 1..35 { + nested_expr = format!("({} AND age <= {})", nested_expr, 100 + i); + } + + let create_sql = format!( + "CREATE TABLE deep_nesting ( + id INTEGER PRIMARY KEY, + age INTEGER CHECK({}) + )", + nested_expr + ); + + db.execute(&create_sql) + .expect("Failed to create table"); + + let stmt = db + .prepare("INSERT INTO deep_nesting VALUES (?, ?)") + .expect("Failed to prepare statement"); + + let result = stmt + .bind(OwnedValue::Int(1)) + .bind(OwnedValue::Int(25)) + .execute(&db); + + assert!( + result.is_err(), + "INSERT with deeply nested AND expression (>32 levels) should return error" + ); + + let err_msg = result.unwrap_err().to_string().to_lowercase(); + assert!( + err_msg.contains("depth") || err_msg.contains("nesting") || err_msg.contains("exceed"), + "Error should mention depth/nesting exceeded, got: {}", + err_msg + ); + } } mod not_null_tests {