diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 82dcd8121..684e61132 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - uses: dtolnay/rust-toolchain@stable - - run: cargo test --release --no-fail-fast + - run: cargo test --release --no-fail-fast --features pumpkin-core/check-propagations wasm-test: name: Test Suite for pumpkin-core in WebAssembly diff --git a/Cargo.lock b/Cargo.lock index d667a5e7b..2e86b272a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -210,15 +210,6 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" -[[package]] -name = "convert_case" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec182b0ca2f35d8fc196cf3404988fd8b8c739a4d270ff118a398feb0cbec1ca" -dependencies = [ - "unicode-segmentation", -] - [[package]] name = "convert_case" version = "0.8.0" @@ -414,7 +405,7 @@ dependencies = [ name = "fzn-rs-derive" version = "0.1.0" dependencies = [ - "convert_case 0.8.0", + "convert_case", "fzn-rs", "proc-macro2", "quote", @@ -732,9 +723,19 @@ dependencies = [ "drcp-format", "flate2", "fzn-rs", + "pumpkin-checking", + "pumpkin-core", + "pumpkin-propagators", "thiserror", ] +[[package]] +name = "pumpkin-checking" +version = "0.2.2" +dependencies = [ + "dyn-clone", +] + [[package]] name = "pumpkin-constraints" version = "0.2.2" @@ -750,7 +751,7 @@ dependencies = [ "bitfield", "bitfield-struct", "clap", - "convert_case 0.6.0", + "convert_case", "downcast-rs", "drcp-format", "dyn-clone", @@ -764,6 +765,7 @@ dependencies = [ "log", "num", "once_cell", + "pumpkin-checking", "pumpkin-constraints", "rand", "thiserror", @@ -788,8 +790,9 @@ version = "0.2.2" dependencies = [ "bitfield-struct", "clap", - "convert_case 0.6.0", + "convert_case", "enumset", + "pumpkin-checking", "pumpkin-constraints", "pumpkin-core", ] diff --git a/Cargo.toml b/Cargo.toml index 1c1b84625..3077fad18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,18 @@ [workspace] -members = ["./pumpkin-solver", "./pumpkin-checker", "./drcp-format", "./pumpkin-solver-py", "./pumpkin-macros", "./drcp-debugger", "./pumpkin-crates/*", "./fzn-rs", "./fzn-rs-derive"] resolver = "2" +members = [ + # Libraries used by Pumpkin but in principle are independent + "./drcp-format", + "./drcp-debugger", + "./fzn-rs", + "./fzn-rs-derive", + + "./pumpkin-crates/*", # Core libraries of the solver + "./pumpkin-solver", # The solver binary + "./pumpkin-checker", # The uncertified proof checker + "./pumpkin-solver-py", # The python interface + "./pumpkin-macros", # Proc-macros used by the pumpkin source (unpublished) +] [workspace.package] repository = "https://github.com/consol-lab/pumpkin" diff --git a/clippy.toml b/clippy.toml index 937f96fe2..bff8f0240 100644 --- a/clippy.toml +++ b/clippy.toml @@ -1 +1,4 @@ -allowed-duplicate-crates = ["regex-automata", "regex-syntax"] +allowed-duplicate-crates = [ + "hashbrown", + "windows-sys", +] diff --git a/pumpkin-checker/Cargo.toml b/pumpkin-checker/Cargo.toml index 21f817aa5..17e25c99a 100644 --- a/pumpkin-checker/Cargo.toml +++ b/pumpkin-checker/Cargo.toml @@ -7,6 +7,9 @@ license.workspace = true authors.workspace = true [dependencies] +pumpkin-core = { version = "0.2.2", path = "../pumpkin-crates/core/" } +pumpkin-checking = { version = "0.2.2", path = "../pumpkin-crates/checking/" } +pumpkin-propagators = { version = "0.2.2", path = "../pumpkin-crates/propagators/" } anyhow = "1.0.99" clap = { version = "4.5.47", features = ["derive"] } drcp-format = { version = "0.3.0", path = "../drcp-format" } diff --git a/pumpkin-checker/src/deductions.rs b/pumpkin-checker/src/deductions.rs index d49217e3c..0a2d3a7b1 100644 --- a/pumpkin-checker/src/deductions.rs +++ b/pumpkin-checker/src/deductions.rs @@ -3,10 +3,11 @@ use std::rc::Rc; use drcp_format::ConstraintId; use drcp_format::IntAtomic; +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::VariableState; use crate::inferences::Fact; use crate::model::Nogood; -use crate::state::VariableState; /// An inference that was ignored when checking a deduction. #[derive(Clone, Debug)] @@ -50,9 +51,11 @@ pub fn verify_deduction( // At some point, this should either reach a fact without a consequent or derive an // inconsistent domain. - let mut variable_state = - VariableState::prepare_for_conflict_check(&Fact::nogood(deduction.premises.clone())) - .ok_or(InvalidDeduction::InconsistentPremises)?; + let mut variable_state = VariableState::prepare_for_conflict_check( + deduction.premises.iter().cloned().map(Into::into), + None, + ) + .ok_or(InvalidDeduction::InconsistentPremises)?; let mut unused_inferences = Vec::new(); @@ -75,9 +78,22 @@ pub fn verify_deduction( // `String`. The former does not implement `Send`, but that is // required for our error type to be used with anyhow. Some(IntAtomic { - name: String::from(premise.name.as_ref()), - comparison: premise.comparison, - value: premise.value, + name: String::from(premise.identifier().as_ref()), + comparison: match premise.comparison() { + pumpkin_checking::Comparison::GreaterEqual => { + drcp_format::IntComparison::GreaterEqual + } + pumpkin_checking::Comparison::LessEqual => { + drcp_format::IntComparison::LessEqual + } + pumpkin_checking::Comparison::Equal => { + drcp_format::IntComparison::Equal + } + pumpkin_checking::Comparison::NotEqual => { + drcp_format::IntComparison::NotEqual + } + }, + value: premise.value(), }) } }) diff --git a/pumpkin-checker/src/inferences/all_different.rs b/pumpkin-checker/src/inferences/all_different.rs index d1dd3681c..6887cb250 100644 --- a/pumpkin-checker/src/inferences/all_different.rs +++ b/pumpkin-checker/src/inferences/all_different.rs @@ -1,9 +1,12 @@ use std::collections::HashSet; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::VariableState; + use super::Fact; use crate::inferences::InvalidInference; +use crate::model::Atomic; use crate::model::Constraint; -use crate::state::VariableState; /// Verify an `all_different` inference. /// @@ -12,8 +15,9 @@ use crate::state::VariableState; /// /// The checker will reject inferences with redundant atomic constraints. pub(crate) fn verify_all_different( - fact: &Fact, + _: &Fact, constraint: &Constraint, + state: VariableState, ) -> Result<(), InvalidInference> { // This checker takes the union of the domains of the variables in the constraint. If there // are fewer values in the union of the domain than there are variables, then there is a @@ -23,14 +27,11 @@ pub(crate) fn verify_all_different( return Err(InvalidInference::ConstraintLabelMismatch); }; - let variable_state = VariableState::prepare_for_conflict_check(fact) - .ok_or(InvalidInference::InconsistentPremises)?; - // Collect all values present in at least one of the domains. let union_of_domains = all_different .variables .iter() - .filter_map(|variable| variable_state.iter_domain(variable)) + .filter_map(|variable| variable.iter_induced_domain(&state)) .flatten() .collect::>(); @@ -40,7 +41,7 @@ pub(crate) fn verify_all_different( let num_variables = all_different .variables .iter() - .filter(|variable| variable_state.iter_domain(variable).is_some()) + .filter(|variable| variable.iter_induced_domain(&state).is_some()) .count(); if union_of_domains.len() < num_variables { diff --git a/pumpkin-checker/src/inferences/arithmetic.rs b/pumpkin-checker/src/inferences/arithmetic.rs index fb3ff9a51..a42b7fda2 100644 --- a/pumpkin-checker/src/inferences/arithmetic.rs +++ b/pumpkin-checker/src/inferences/arithmetic.rs @@ -1,16 +1,15 @@ use std::collections::BTreeSet; -use std::rc::Rc; -use drcp_format::IntComparison; -use fzn_rs::VariableExpr; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::VariableState; +use pumpkin_propagators::arithmetic::BinaryEqualsChecker; use super::Fact; use crate::inferences::InvalidInference; use crate::model::AllDifferent; use crate::model::Atomic; use crate::model::Constraint; -use crate::state::I32Ext; -use crate::state::VariableState; /// Verify a `binary_equals` inference. /// @@ -20,6 +19,7 @@ use crate::state::VariableState; pub(crate) fn verify_binary_equals( fact: &Fact, constraint: &Constraint, + state: VariableState, ) -> Result<(), InvalidInference> { // To check this inference we expect the intersection of both domains to be empty. @@ -32,69 +32,16 @@ pub(crate) fn verify_binary_equals( return Err(InvalidInference::Unsound); } - let (weight_a, variable_a) = &linear.terms[0]; - let (weight_b, variable_b) = &linear.terms[1]; + let lhs = linear.terms[0].clone(); + let rhs = linear.terms[1].clone(); - // TODO: Generalize this rule to work with non-unit weights. - // At the moment we expect one term to have weight `-1` and the other term to have weight - // `1`. - if weight_a + weight_b != 0 || weight_a.abs() != 1 || weight_b.abs() != 1 { - return Err(InvalidInference::Unsound); - } - - let mut variable_state = VariableState::prepare_for_conflict_check(fact) - .ok_or(InvalidInference::InconsistentPremises)?; - - // We apply the domain of variable 2 to variable 1. If the state remains consistent, then - // the step is unsound! - let state_is_consistent = match variable_a { - VariableExpr::Identifier(var1) => { - let mut consistent = true; + let checker = BinaryEqualsChecker { lhs, rhs }; - if let I32Ext::I32(value) = variable_state.upper_bound(variable_b) { - consistent &= variable_state.apply(&Atomic { - name: Rc::clone(var1), - comparison: IntComparison::LessEqual, - value: linear.bound + value, - }); - } - - if let I32Ext::I32(value) = variable_state.lower_bound(variable_b) { - consistent &= variable_state.apply(&Atomic { - name: Rc::clone(var1), - comparison: IntComparison::GreaterEqual, - value: linear.bound + value, - }); - } - - for value in variable_state.holes(variable_b).collect::>() { - consistent &= variable_state.apply(&Atomic { - name: Rc::clone(var1), - comparison: IntComparison::NotEqual, - value: linear.bound + value, - }); - } - - consistent - } - - VariableExpr::Constant(value) => match variable_b { - VariableExpr::Identifier(var2) => variable_state.apply(&Atomic { - name: Rc::clone(var2), - comparison: IntComparison::NotEqual, - value: linear.bound + *value, - }), - VariableExpr::Constant(_) => panic!("Binary equals over two constants is unexpected."), - }, - }; - - if state_is_consistent { - // The intersection of the domains should yield an inconsistent state for the - // inference to be sound. - return Err(InvalidInference::Unsound); + if checker.check(state, &fact.premises, fact.consequent.as_ref()) { + Ok(()) + } else { + Err(InvalidInference::Unsound) } - - Ok(()) } /// Verify a `binary_not_equals` inference. @@ -102,19 +49,17 @@ pub(crate) fn verify_binary_equals( /// Tests that the premise of the inference and the negation of the consequent force the linear sum /// to equal the right-hand side of the not equals constraint. pub(crate) fn verify_binary_not_equals( - fact: &Fact, + _: &Fact, constraint: &Constraint, + state: VariableState, ) -> Result<(), InvalidInference> { let Constraint::AllDifferent(AllDifferent { variables }) = constraint else { return Err(InvalidInference::ConstraintLabelMismatch); }; - let variable_state = VariableState::prepare_for_conflict_check(fact) - .ok_or(InvalidInference::InconsistentPremises)?; - let mut values = BTreeSet::new(); for variable in variables { - let Some(value) = variable_state.fixed_value(variable) else { + let Some(value) = variable.induced_fixed_value(&state) else { continue; }; diff --git a/pumpkin-checker/src/inferences/linear.rs b/pumpkin-checker/src/inferences/linear.rs index 44996d80c..4cd522d49 100644 --- a/pumpkin-checker/src/inferences/linear.rs +++ b/pumpkin-checker/src/inferences/linear.rs @@ -1,9 +1,13 @@ +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::VariableState; +use pumpkin_propagators::arithmetic::LinearLessOrEqualInferenceChecker; + use crate::inferences::Fact; use crate::inferences::InvalidInference; +use crate::model::Atomic; use crate::model::Constraint; use crate::model::Linear; -use crate::state::I32Ext; -use crate::state::VariableState; +use crate::model::Term; /// Verify a `linear_bounds` inference. /// @@ -11,22 +15,26 @@ use crate::state::VariableState; pub(super) fn verify_linear_bounds( fact: &Fact, generated_by: &Constraint, + state: VariableState, ) -> Result<(), InvalidInference> { match generated_by { - Constraint::LinearLeq(linear) => verify_linear_inference(linear, fact), + Constraint::LinearLeq(linear) => verify_linear_inference(linear, fact, state), Constraint::LinearEq(linear) => { - let try_upper_bound = verify_linear_inference(linear, fact); + let try_upper_bound = verify_linear_inference(linear, fact, state.clone()); let inverted_linear = Linear { terms: linear .terms .iter() - .map(|(weight, variable)| (-weight, variable.clone())) + .map(|term| Term { + weight: -term.weight, + variable: term.variable.clone(), + }) .collect(), bound: -linear.bound, }; - let try_lower_bound = verify_linear_inference(&inverted_linear, fact); + let try_lower_bound = verify_linear_inference(&inverted_linear, fact, state); match (try_lower_bound, try_upper_bound) { (Ok(_), Ok(_)) => panic!("This should not happen."), @@ -39,35 +47,14 @@ pub(super) fn verify_linear_bounds( } } -fn verify_linear_inference(linear: &Linear, fact: &Fact) -> Result<(), InvalidInference> { - let variable_state = VariableState::prepare_for_conflict_check(fact) - .ok_or(InvalidInference::InconsistentPremises)?; - - // Next, we evaluate the linear inequality. The lower bound of the - // left-hand side must exceed the bound in the constraint. - let left_hand_side = linear.terms.iter().fold(None, |acc, (weight, variable)| { - let lower_bound = if *weight >= 0 { - variable_state.lower_bound(variable) - } else { - variable_state.upper_bound(variable) - }; - - match acc { - None => match lower_bound { - I32Ext::I32(value) => Some(weight * value), - I32Ext::NegativeInf => None, - I32Ext::PositiveInf => None, - }, - - Some(v1) => match lower_bound { - I32Ext::I32(v2) => Some(v1 + weight * v2), - I32Ext::NegativeInf => Some(v1), - I32Ext::PositiveInf => Some(v1), - }, - } - }); +fn verify_linear_inference( + linear: &Linear, + fact: &Fact, + state: VariableState, +) -> Result<(), InvalidInference> { + let checker = LinearLessOrEqualInferenceChecker::new(linear.terms.clone().into(), linear.bound); - if left_hand_side.is_some_and(|value| value > linear.bound) { + if checker.check(state, &fact.premises, fact.consequent.as_ref()) { Ok(()) } else { Err(InvalidInference::Unsound) @@ -76,6 +63,10 @@ fn verify_linear_inference(linear: &Linear, fact: &Fact) -> Result<(), InvalidIn #[cfg(test)] mod tests { + use std::num::NonZero; + use std::rc::Rc; + + use drcp_format::IntAtomic; use drcp_format::IntComparison::*; use fzn_rs::VariableExpr::*; @@ -86,21 +77,34 @@ mod tests { fn linear_1() { // x1 - x2 <= -7 let linear = Linear { - terms: vec![(1, Identifier("x1".into())), (-1, Identifier("x2".into()))], + terms: vec![ + Term { + weight: NonZero::new(1).unwrap(), + variable: Identifier(Rc::from("x1")).into(), + }, + Term { + weight: NonZero::new(-1).unwrap(), + variable: Identifier(Rc::from("x2")).into(), + }, + ], bound: -7, }; - let premises = vec![Atomic { + let premises = vec![Atomic::IntAtomic(IntAtomic { name: "x2".into(), comparison: LessEqual, value: 37, - }]; + })]; - let consequent = Some(Atomic { + let consequent = Some(Atomic::IntAtomic(IntAtomic { name: "x1".into(), comparison: LessEqual, value: 30, - }); + })); + + let variable_state = + VariableState::prepare_for_conflict_check(premises.clone(), consequent.clone()) + .expect("no mutually exclusive atomics"); verify_linear_inference( &linear, @@ -108,6 +112,7 @@ mod tests { premises, consequent, }, + variable_state, ) .expect("valid inference"); } diff --git a/pumpkin-checker/src/inferences/mod.rs b/pumpkin-checker/src/inferences/mod.rs index 61e5416a0..41d5887a2 100644 --- a/pumpkin-checker/src/inferences/mod.rs +++ b/pumpkin-checker/src/inferences/mod.rs @@ -4,6 +4,8 @@ mod linear; mod nogood; mod time_table; +use pumpkin_checking::VariableState; + use crate::model::Atomic; use crate::model::Model; @@ -61,8 +63,8 @@ pub(crate) fn verify_inference( inference: &drcp_format::Inference, i32, std::rc::Rc>, ) -> Result { let fact = Fact { - premises: inference.premises.clone(), - consequent: inference.consequent.clone(), + premises: inference.premises.iter().cloned().map(Into::into).collect(), + consequent: inference.consequent.clone().map(Into::into), }; let label = inference @@ -79,7 +81,9 @@ pub(crate) fn verify_inference( return Err(InvalidInference::Unsound); }; - if !model.is_trivially_true(atomic.clone()) { + let atomic: Atomic = atomic.into(); + + if !model.is_trivially_true(&atomic) { // If the consequent is not trivially true in the model then the inference // is unsound. return Err(InvalidInference::Unsound); @@ -88,6 +92,11 @@ pub(crate) fn verify_inference( return Ok(fact); } + // Setup the state for a conflict check. + let variable_state = + VariableState::prepare_for_conflict_check(fact.premises.clone(), fact.consequent.clone()) + .ok_or(InvalidInference::InconsistentPremises)?; + // Get the constraint that generated the inference from the model. let generated_by_constraint_id = inference .generated_by @@ -98,27 +107,27 @@ pub(crate) fn verify_inference( match label { "linear_bounds" => { - linear::verify_linear_bounds(&fact, generated_by)?; + linear::verify_linear_bounds(&fact, generated_by, variable_state)?; } "nogood" => { - nogood::verify_nogood(&fact, generated_by)?; + nogood::verify_nogood(&fact, generated_by, variable_state)?; } "time_table" => { - time_table::verify_time_table(&fact, generated_by)?; + time_table::verify_time_table(&fact, generated_by, variable_state)?; } "all_different" => { - all_different::verify_all_different(&fact, generated_by)?; + all_different::verify_all_different(&fact, generated_by, variable_state)?; } "binary_equals" => { - arithmetic::verify_binary_equals(&fact, generated_by)?; + arithmetic::verify_binary_equals(&fact, generated_by, variable_state)?; } "binary_not_equals" => { - arithmetic::verify_binary_not_equals(&fact, generated_by)?; + arithmetic::verify_binary_not_equals(&fact, generated_by, variable_state)?; } _ => return Err(InvalidInference::UnsupportedLabel), diff --git a/pumpkin-checker/src/inferences/nogood.rs b/pumpkin-checker/src/inferences/nogood.rs index c715ad162..2451a2fe1 100644 --- a/pumpkin-checker/src/inferences/nogood.rs +++ b/pumpkin-checker/src/inferences/nogood.rs @@ -1,22 +1,31 @@ +use std::ops::Deref; + +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::VariableState; +use pumpkin_core::propagators::nogoods::NogoodChecker; + use crate::inferences::Fact; use crate::inferences::InvalidInference; +use crate::model::Atomic; use crate::model::Constraint; -use crate::state::VariableState; /// Verifies a `nogood` inference. /// /// This inference is used to rewrite a nogood `L /\ p -> false` to `L -> not p`. -pub(crate) fn verify_nogood(fact: &Fact, constraint: &Constraint) -> Result<(), InvalidInference> { +pub(crate) fn verify_nogood( + fact: &Fact, + constraint: &Constraint, + state: VariableState, +) -> Result<(), InvalidInference> { let Constraint::Nogood(nogood) = constraint else { return Err(InvalidInference::ConstraintLabelMismatch); }; - let variable_state = VariableState::prepare_for_conflict_check(fact) - .ok_or(InvalidInference::InconsistentPremises)?; - - let is_implied_by_nogood = nogood.iter().all(|atomic| variable_state.is_true(atomic)); + let checker = NogoodChecker { + nogood: nogood.deref().into(), + }; - if is_implied_by_nogood { + if checker.check(state, &fact.premises, fact.consequent.as_ref()) { Ok(()) } else { Err(InvalidInference::Unsound) diff --git a/pumpkin-checker/src/inferences/time_table.rs b/pumpkin-checker/src/inferences/time_table.rs index 4a5c87842..0300bb895 100644 --- a/pumpkin-checker/src/inferences/time_table.rs +++ b/pumpkin-checker/src/inferences/time_table.rs @@ -1,9 +1,12 @@ -use std::collections::BTreeMap; +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::VariableState; +use pumpkin_propagators::cumulative::time_table::CheckerTask; +use pumpkin_propagators::cumulative::time_table::TimeTableChecker; use super::Fact; use crate::inferences::InvalidInference; +use crate::model::Atomic; use crate::model::Constraint; -use crate::state::VariableState; /// Verifies a `time_table` inference for the cumulative constraint. /// @@ -12,37 +15,28 @@ use crate::state::VariableState; pub(crate) fn verify_time_table( fact: &Fact, constraint: &Constraint, + state: VariableState, ) -> Result<(), InvalidInference> { let Constraint::Cumulative(cumulative) = constraint else { return Err(InvalidInference::ConstraintLabelMismatch); }; - let variable_state = VariableState::prepare_for_conflict_check(fact) - .ok_or(InvalidInference::InconsistentPremises)?; - - // The profile is a key-value store. The keys correspond to time-points, and the values to the - // relative change in resource consumption. A BTreeMap is used to maintain a sorted order of - // the time points. - let mut profile = BTreeMap::new(); - - for task in cumulative.tasks.iter() { - let lst = variable_state.upper_bound(&task.start_time); - let ect = variable_state.lower_bound(&task.start_time) + task.duration; - - if ect <= lst { - *profile.entry(ect).or_insert(0) += task.resource_usage; - *profile.entry(lst).or_insert(0) -= task.resource_usage; - } - } - - let mut usage = 0; - for delta in profile.values() { - usage += delta; + let checker = TimeTableChecker { + tasks: cumulative + .tasks + .iter() + .map(|task| CheckerTask { + start_time: task.start_time.clone(), + resource_usage: task.resource_usage, + processing_time: task.duration, + }) + .collect(), + capacity: cumulative.capacity, + }; - if usage > cumulative.capacity { - return Ok(()); - } + if checker.check(state, &fact.premises, fact.consequent.as_ref()) { + Ok(()) + } else { + Err(InvalidInference::Unsound) } - - Err(InvalidInference::Unsound) } diff --git a/pumpkin-checker/src/lib.rs b/pumpkin-checker/src/lib.rs index cac7b8fc9..47409dc4c 100644 --- a/pumpkin-checker/src/lib.rs +++ b/pumpkin-checker/src/lib.rs @@ -12,10 +12,10 @@ use drcp_format::reader::ProofReader; pub mod deductions; pub mod inferences; -mod state; - pub mod model; +pub(crate) mod math; + use model::*; /// The errors that can be returned by the checker. @@ -150,11 +150,11 @@ fn parse_model(path: impl AsRef) -> anyhow::Result { fzn_rs::Method::Optimize { direction: fzn_rs::ast::OptimizationDirection::Minimize, objective, - } => Some(Objective::Minimize(objective.clone())), + } => Some(Objective::Minimize(objective.clone().into())), fzn_rs::Method::Optimize { direction: fzn_rs::ast::OptimizationDirection::Maximize, objective, - } => Some(Objective::Maximize(objective.clone())), + } => Some(Objective::Maximize(objective.clone().into())), }; for (name, variable) in fzn_model.variables.iter() { @@ -181,7 +181,12 @@ fn parse_model(path: impl AsRef) -> anyhow::Result { let weight = weight?; let variable = variable?; - terms.push((weight, variable)); + terms.push(Term { + weight: weight + .try_into() + .expect("flatzinc does not have 0-weight terms"), + variable: variable.into(), + }); } Constraint::LinearLeq(Linear { @@ -204,7 +209,12 @@ fn parse_model(path: impl AsRef) -> anyhow::Result { let weight = weight?; let variable = variable?; - terms.push((weight, variable)); + terms.push(Term { + weight: weight + .try_into() + .expect("flatzinc does not have 0-weight terms"), + variable: variable.into(), + }); } Constraint::LinearEq(Linear { @@ -233,7 +243,7 @@ fn parse_model(path: impl AsRef) -> anyhow::Result { let resource_usage = maybe_resource_usage?; Ok(Task { - start_time, + start_time: start_time.into(), duration, resource_usage, }) @@ -250,6 +260,7 @@ fn parse_model(path: impl AsRef) -> anyhow::Result { FlatZincConstraints::AllDifferent(variables) => { let variables = fzn_model .resolve_array(variables)? + .map(|maybe_variable| maybe_variable.map(Variable::from)) .collect::, _>>()?; Constraint::AllDifferent(AllDifferent { variables }) @@ -367,7 +378,9 @@ fn verify_conclusion(model: &Model, conclusion: &drcp_format::Conclusion match conclusion { drcp_format::Conclusion::Unsat => nogood.as_ref().is_empty(), - drcp_format::Conclusion::DualBound(atomic) => nogood.as_ref() == [!atomic.clone()], + drcp_format::Conclusion::DualBound(atomic) => { + nogood.as_ref() == [Atomic::from(!atomic.clone())] + } } }) } diff --git a/pumpkin-checker/src/math.rs b/pumpkin-checker/src/math.rs new file mode 100644 index 000000000..44e2aba3b --- /dev/null +++ b/pumpkin-checker/src/math.rs @@ -0,0 +1,23 @@ +pub(crate) fn div_ceil(lhs: i32, other: i32) -> i32 { + // TODO: The source is taken from the standard library nightly implementation of this + // function and div_floor. Once they are stabilized, these definitions can be removed. + // Tracking issue: https://github.com/rust-lang/rust/issues/88581 + let d = lhs / other; + let r = lhs % other; + if (r > 0 && other > 0) || (r < 0 && other < 0) { + d + 1 + } else { + d + } +} + +pub(crate) fn div_floor(lhs: i32, other: i32) -> i32 { + // TODO: See todo in `div_ceil`. + let d = lhs / other; + let r = lhs % other; + if (r > 0 && other < 0) || (r < 0 && other > 0) { + d - 1 + } else { + d + } +} diff --git a/pumpkin-checker/src/model.rs b/pumpkin-checker/src/model.rs index 722e1d9da..f0cae2484 100644 --- a/pumpkin-checker/src/model.rs +++ b/pumpkin-checker/src/model.rs @@ -3,6 +3,7 @@ //! The main component of the model are the constraints that the checker supports. use std::collections::BTreeMap; +use std::num::NonZero; use std::ops::Deref; use std::rc::Rc; @@ -10,6 +11,14 @@ use drcp_format::ConstraintId; use drcp_format::IntAtomic; use fzn_rs::VariableExpr; use fzn_rs::ast::Domain; +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::Comparison; +use pumpkin_checking::IntExt; +use pumpkin_checking::VariableState; + +use crate::math::div_ceil; +use crate::math::div_floor; #[derive(Clone, Debug)] pub enum Constraint { @@ -20,17 +29,79 @@ pub enum Constraint { AllDifferent(AllDifferent), } -pub type Atomic = IntAtomic, i32>; +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Atomic { + True, + False, + IntAtomic(IntAtomic, i32>), +} + +impl From, i32>> for Atomic { + fn from(value: IntAtomic, i32>) -> Self { + Atomic::IntAtomic(value) + } +} + +impl From for Atomic { + fn from(value: bool) -> Self { + if value { Atomic::True } else { Atomic::False } + } +} + +impl AtomicConstraint for Atomic { + type Identifier = Rc; + + fn identifier(&self) -> Self::Identifier { + match self { + Atomic::True => Rc::from("true"), + Atomic::False => Rc::from("false"), + Atomic::IntAtomic(int_atomic) => Rc::clone(&int_atomic.name), + } + } + + fn comparison(&self) -> Comparison { + let Atomic::IntAtomic(int_atomic) = self else { + return Comparison::Equal; + }; + + match int_atomic.comparison { + drcp_format::IntComparison::GreaterEqual => Comparison::GreaterEqual, + drcp_format::IntComparison::LessEqual => Comparison::LessEqual, + drcp_format::IntComparison::Equal => Comparison::Equal, + drcp_format::IntComparison::NotEqual => Comparison::NotEqual, + } + } + + fn value(&self) -> i32 { + match self { + Atomic::True => 1, + Atomic::False => 0, + Atomic::IntAtomic(int_atomic) => int_atomic.value, + } + } + + fn negate(&self) -> Self { + match self { + Atomic::True => Atomic::False, + Atomic::False => Atomic::True, + Atomic::IntAtomic(int_atomic) => { + let owned = int_atomic.clone(); + Atomic::IntAtomic(!owned) + } + } + } +} #[derive(Clone, Debug)] pub struct Nogood(Vec); -impl From for Nogood +impl From for Nogood where - T: IntoIterator, + T: IntoIterator, + A: Into, { fn from(value: T) -> Self { - Nogood(value.into_iter().collect()) + Nogood(value.into_iter().map(Into::into).collect()) } } @@ -42,15 +113,275 @@ impl Deref for Nogood { } } +/// A checker variable that can be used with [`pumpkin_checking::VariableState`]. +#[derive(Clone, Debug)] +pub struct Variable(VariableExpr); + +impl From> for Variable { + fn from(value: VariableExpr) -> Self { + Variable(value) + } +} + +impl CheckerVariable for Variable { + fn does_atomic_constrain_self(&self, atomic: &Atomic) -> bool { + let Variable(VariableExpr::Identifier(ident)) = self else { + return false; + }; + + let Atomic::IntAtomic(atomic) = atomic else { + return false; + }; + + &atomic.name == ident + } + + fn atomic_less_than(&self, value: i32) -> Atomic { + match self.0 { + VariableExpr::Identifier(ref name) => Atomic::from(IntAtomic { + name: Rc::clone(name), + comparison: drcp_format::IntComparison::LessEqual, + value, + }), + VariableExpr::Constant(constant) => (constant <= value).into(), + } + } + + fn atomic_greater_than(&self, value: i32) -> Atomic { + match self.0 { + VariableExpr::Identifier(ref name) => Atomic::from(IntAtomic { + name: Rc::clone(name), + comparison: drcp_format::IntComparison::GreaterEqual, + value, + }), + VariableExpr::Constant(constant) => (constant >= value).into(), + } + } + + fn atomic_equal(&self, value: i32) -> Atomic { + match self.0 { + VariableExpr::Identifier(ref name) => Atomic::from(IntAtomic { + name: Rc::clone(name), + comparison: drcp_format::IntComparison::Equal, + value, + }), + VariableExpr::Constant(constant) => (constant == value).into(), + } + } + + fn atomic_not_equal(&self, value: i32) -> Atomic { + match self.0 { + VariableExpr::Identifier(ref name) => Atomic::from(IntAtomic { + name: Rc::clone(name), + comparison: drcp_format::IntComparison::NotEqual, + value, + }), + VariableExpr::Constant(constant) => (constant != value).into(), + } + } + + fn induced_lower_bound(&self, variable_state: &VariableState) -> IntExt { + match self.0 { + VariableExpr::Identifier(ref ident) => variable_state.lower_bound(ident), + VariableExpr::Constant(value) => value.into(), + } + } + + fn induced_upper_bound(&self, variable_state: &VariableState) -> IntExt { + match self.0 { + VariableExpr::Identifier(ref ident) => variable_state.upper_bound(ident), + VariableExpr::Constant(value) => value.into(), + } + } + + fn induced_fixed_value(&self, variable_state: &VariableState) -> Option { + match self.0 { + VariableExpr::Identifier(ref ident) => variable_state.fixed_value(ident), + VariableExpr::Constant(value) => value.into(), + } + } + + fn induced_holes<'this, 'state>( + &'this self, + variable_state: &'state VariableState, + ) -> impl Iterator + 'state + where + 'this: 'state, + { + match self.0 { + #[allow( + trivial_casts, + reason = "without it the compiler does not coerce to Box" + )] + VariableExpr::Identifier(ref ident) => { + Box::new(variable_state.holes(ident)) as Box> + } + VariableExpr::Constant(_) => Box::new(std::iter::empty()), + } + } + + fn iter_induced_domain<'this, 'state>( + &'this self, + variable_state: &'state VariableState, + ) -> Option + 'state> + where + 'this: 'state, + { + match self.0 { + #[allow( + trivial_casts, + reason = "without it the compiler does not coerce to Box" + )] + VariableExpr::Identifier(ref ident) => variable_state + .iter_domain(ident) + .map(|iter| Box::new(iter) as Box>), + VariableExpr::Constant(value) => Some(Box::new(std::iter::once(value))), + } + } + + fn induced_domain_contains(&self, variable_state: &VariableState, value: i32) -> bool { + match self.0 { + VariableExpr::Identifier(ref ident) => variable_state.contains(ident, value), + VariableExpr::Constant(constant_value) => constant_value == value, + } + } +} + #[derive(Clone, Debug)] pub struct Linear { - pub terms: Vec<(i32, VariableExpr)>, + pub terms: Vec, pub bound: i32, } +#[derive(Clone, Debug)] +pub struct Term { + pub weight: NonZero, + pub variable: Variable, +} + +impl Term { + /// Apply the inverse transformation of this view on a value, to go from the value in the domain + /// of `self` to a value in the domain of `self.inner`. + fn invert(&self, value: i32, rounding: Rounding) -> i32 { + match rounding { + Rounding::Up => div_ceil(value, self.weight.get()), + Rounding::Down => div_floor(value, self.weight.get()), + } + } +} + +enum Rounding { + Up, + Down, +} + +impl CheckerVariable for Term { + fn does_atomic_constrain_self(&self, atomic: &Atomic) -> bool { + self.variable.does_atomic_constrain_self(atomic) + } + + fn atomic_less_than(&self, value: i32) -> Atomic { + if self.weight.is_negative() { + let inverted_value = self.invert(value, Rounding::Up); + self.variable.atomic_greater_than(inverted_value) + } else { + let inverted_value = self.invert(value, Rounding::Down); + self.variable.atomic_less_than(inverted_value) + } + } + + fn atomic_greater_than(&self, value: i32) -> Atomic { + if self.weight.is_negative() { + let inverted_value = self.invert(value, Rounding::Down); + self.variable.atomic_less_than(inverted_value) + } else { + let inverted_value = self.invert(value, Rounding::Up); + self.variable.atomic_greater_than(inverted_value) + } + } + + fn atomic_equal(&self, value: i32) -> Atomic { + if value % self.weight.get() == 0 { + let inverted_value = self.invert(value, Rounding::Up); + self.variable.atomic_equal(inverted_value) + } else { + Atomic::False + } + } + + fn atomic_not_equal(&self, value: i32) -> Atomic { + if value % self.weight.get() == 0 { + let inverted_value = self.invert(value, Rounding::Up); + self.variable.atomic_not_equal(inverted_value) + } else { + Atomic::True + } + } + + fn induced_lower_bound(&self, variable_state: &VariableState) -> IntExt { + if self.weight.is_positive() { + self.variable.induced_lower_bound(variable_state) * self.weight.get() + } else { + self.variable.induced_upper_bound(variable_state) * self.weight.get() + } + } + + fn induced_upper_bound(&self, variable_state: &VariableState) -> IntExt { + if self.weight.is_positive() { + self.variable.induced_upper_bound(variable_state) * self.weight.get() + } else { + self.variable.induced_lower_bound(variable_state) * self.weight.get() + } + } + + fn induced_fixed_value(&self, variable_state: &VariableState) -> Option { + self.variable + .induced_fixed_value(variable_state) + .map(|value| value * self.weight.get()) + } + + fn induced_holes<'this, 'state>( + &'this self, + _variable_state: &'state VariableState, + ) -> impl Iterator + 'state + where + 'this: 'state, + { + todo!("how to compute holes in a scaled domain?"); + + #[allow( + unreachable_code, + reason = "otherwise the function does not return an impl Iterator" + )] + std::iter::empty() + } + + fn iter_induced_domain<'this, 'state>( + &'this self, + variable_state: &'state VariableState, + ) -> Option + 'state> + where + 'this: 'state, + { + self.variable + .iter_induced_domain(variable_state) + .map(|iter| iter.map(|value| value * self.weight.get())) + } + + fn induced_domain_contains(&self, variable_state: &VariableState, value: i32) -> bool { + if value % self.weight.get() == 0 { + let inverted = self.invert(value, Rounding::Up); + self.variable + .induced_domain_contains(variable_state, inverted) + } else { + false + } + } +} + #[derive(Clone, Debug)] pub struct Task { - pub start_time: VariableExpr, + pub start_time: Variable, pub duration: i32, pub resource_usage: i32, } @@ -63,13 +394,13 @@ pub struct Cumulative { #[derive(Clone, Debug)] pub struct AllDifferent { - pub variables: Vec>, + pub variables: Vec, } #[derive(Clone, Debug)] pub enum Objective { - Maximize(VariableExpr), - Minimize(VariableExpr), + Maximize(Variable), + Minimize(Variable), } #[derive(Clone, Debug, Default)] @@ -108,28 +439,26 @@ impl Model { /// Test whether the atomic is true in the initial domains of the variables. /// /// Returns false if the atomic is over a variable that is not in the model. - pub fn is_trivially_true(&self, atomic: Atomic) -> bool { - let Some(domain) = self.variables.get(&atomic.name) else { + pub fn is_trivially_true(&self, atomic: &Atomic) -> bool { + let Some(domain) = self.variables.get(&atomic.identifier()) else { return false; }; match domain { Domain::UnboundedInt => false, - Domain::Int(dom) => match atomic.comparison { - drcp_format::IntComparison::GreaterEqual => { - *dom.lower_bound() >= atomic.value as i64 - } - drcp_format::IntComparison::LessEqual => *dom.upper_bound() <= atomic.value as i64, - drcp_format::IntComparison::Equal => { - *dom.lower_bound() >= atomic.value as i64 - && *dom.upper_bound() <= atomic.value as i64 + Domain::Int(dom) => match atomic.comparison() { + Comparison::GreaterEqual => *dom.lower_bound() >= atomic.value() as i64, + Comparison::LessEqual => *dom.upper_bound() <= atomic.value() as i64, + Comparison::Equal => { + *dom.lower_bound() >= atomic.value() as i64 + && *dom.upper_bound() <= atomic.value() as i64 } - drcp_format::IntComparison::NotEqual => { - if *dom.lower_bound() >= atomic.value as i64 { + Comparison::NotEqual => { + if *dom.lower_bound() >= atomic.value() as i64 { return true; } - if *dom.upper_bound() <= atomic.value as i64 { + if *dom.upper_bound() <= atomic.value() as i64 { return true; } @@ -137,7 +466,7 @@ impl Model { return false; } - dom.into_iter().all(|value| value != atomic.value as i64) + dom.into_iter().all(|value| value != atomic.value() as i64) } }, Domain::Bool => todo!("boolean variables are not yet supported"), diff --git a/pumpkin-checker/src/state.rs b/pumpkin-checker/src/state.rs deleted file mode 100644 index a37c801ed..000000000 --- a/pumpkin-checker/src/state.rs +++ /dev/null @@ -1,514 +0,0 @@ -use std::cmp::Ordering; -use std::collections::BTreeMap; -use std::collections::BTreeSet; -use std::ops::Add; -use std::rc::Rc; - -use crate::inferences::Fact; -use crate::model::Atomic; - -/// The domains of all variables in the problem. -/// -/// Domains can be reduced through [`VariableState::apply`]. By default, the domain of every -/// variable is infinite. -#[derive(Clone, Debug, Default)] -pub(crate) struct VariableState { - domains: BTreeMap, Domain>, -} - -impl VariableState { - /// Create a variable state that applies all the premises and, if present, the negation of the - /// consequent. - /// - /// Used by inference checkers if they want to identify a conflict by negating the consequent. - pub(crate) fn prepare_for_conflict_check(fact: &Fact) -> Option { - let mut variable_state = VariableState::default(); - - let negated_consequent = fact - .consequent - .as_ref() - .map(|consequent| !consequent.clone()); - - // Apply all the premises and the negation of the consequent to the state. - if !fact - .premises - .iter() - .chain(negated_consequent.as_ref()) - .all(|premise| variable_state.apply(premise)) - { - return None; - } - - Some(variable_state) - } - - /// Get the lower bound of a variable. - pub(crate) fn lower_bound(&self, variable: &fzn_rs::VariableExpr) -> I32Ext { - let name = match variable { - fzn_rs::VariableExpr::Identifier(name) => name, - fzn_rs::VariableExpr::Constant(value) => return I32Ext::I32(*value), - }; - - self.domains - .get(name) - .map(|domain| domain.lower_bound) - .unwrap_or(I32Ext::NegativeInf) - } - - /// Get the upper bound of a variable. - pub(crate) fn upper_bound(&self, variable: &fzn_rs::VariableExpr) -> I32Ext { - let name = match variable { - fzn_rs::VariableExpr::Identifier(name) => name, - fzn_rs::VariableExpr::Constant(value) => return I32Ext::I32(*value), - }; - - self.domains - .get(name) - .map(|domain| domain.upper_bound) - .unwrap_or(I32Ext::PositiveInf) - } - - /// Get the holes within the lower and upper bound of the variable expression. - pub(crate) fn holes( - &self, - variable: &fzn_rs::VariableExpr, - ) -> impl Iterator + '_ { - #[allow(trivial_casts, reason = "without it we get a type error")] - let name = match variable { - fzn_rs::VariableExpr::Identifier(name) => name, - fzn_rs::VariableExpr::Constant(_) => { - return Box::new(std::iter::empty()) as Box>; - } - }; - - #[allow(trivial_casts, reason = "without it we get a type error")] - self.domains - .get(name) - .map(|domain| Box::new(domain.holes.iter().copied()) as Box>) - .unwrap_or_else(|| Box::new(std::iter::empty())) - } - - /// Get the fixed value of this variable, if it is fixed. - pub(crate) fn fixed_value(&self, variable: &fzn_rs::VariableExpr) -> Option { - let name = match variable { - fzn_rs::VariableExpr::Identifier(name) => name, - fzn_rs::VariableExpr::Constant(value) => return Some(*value), - }; - - let domain = self.domains.get(name)?; - - if domain.lower_bound == domain.upper_bound { - let I32Ext::I32(value) = domain.lower_bound else { - panic!( - "lower can only equal upper if they are integers, otherwise the sign of infinity makes them different" - ); - }; - - Some(value) - } else { - None - } - } - - /// Obtain an iterator over the domain of the variable. - /// - /// If the domain is unbounded, then `None` is returned. - pub(crate) fn iter_domain( - &self, - variable: &fzn_rs::VariableExpr, - ) -> Option> { - match variable { - fzn_rs::VariableExpr::Identifier(name) => { - let domain = self.domains.get(name)?; - - let I32Ext::I32(lower_bound) = domain.lower_bound else { - // If there is no lower bound, then the domain is unbounded. - return None; - }; - - // Ensure there is also an upper bound. - if !matches!(domain.upper_bound, I32Ext::I32(_)) { - return None; - } - - Some(DomainIterator(DomainIteratorImpl::Domain { - domain, - next_value: lower_bound, - })) - } - - fzn_rs::VariableExpr::Constant(value) => { - Some(DomainIterator(DomainIteratorImpl::Constant { - value: *value, - finished: false, - })) - } - } - } - - /// Apply the given [`Atomic`] to the state. - /// - /// Returns true if the state remains consistent, or false if the atomic cannot be true in - /// conjunction with previously applied atomics. - pub(crate) fn apply(&mut self, atomic: &Atomic) -> bool { - let domain = self - .domains - .entry(Rc::clone(&atomic.name)) - .or_insert(Domain::new()); - - match atomic.comparison { - drcp_format::IntComparison::GreaterEqual => { - domain.tighten_lower_bound(atomic.value); - } - - drcp_format::IntComparison::LessEqual => { - domain.tighten_upper_bound(atomic.value); - } - - drcp_format::IntComparison::Equal => { - domain.tighten_lower_bound(atomic.value); - domain.tighten_upper_bound(atomic.value); - } - - drcp_format::IntComparison::NotEqual => { - if domain.lower_bound == atomic.value { - domain.tighten_lower_bound(atomic.value + 1); - } - - if domain.upper_bound == atomic.value { - domain.tighten_upper_bound(atomic.value - 1); - } - - if domain.lower_bound < atomic.value && domain.upper_bound > atomic.value { - let _ = domain.holes.insert(atomic.value); - } - } - } - - domain.is_consistent() - } - - /// Is the given atomic true in the current state. - pub(crate) fn is_true(&self, atomic: &Atomic) -> bool { - let Some(domain) = self.domains.get(&atomic.name) else { - return false; - }; - - match atomic.comparison { - drcp_format::IntComparison::GreaterEqual => domain.lower_bound >= atomic.value, - - drcp_format::IntComparison::LessEqual => domain.upper_bound <= atomic.value, - - drcp_format::IntComparison::Equal => { - domain.lower_bound >= atomic.value && domain.upper_bound <= atomic.value - } - - drcp_format::IntComparison::NotEqual => { - if domain.lower_bound >= atomic.value { - return true; - } - - if domain.upper_bound <= atomic.value { - return true; - } - - if domain.holes.contains(&atomic.value) { - return true; - } - - false - } - } - } -} - -#[derive(Clone, Debug)] -struct Domain { - lower_bound: I32Ext, - upper_bound: I32Ext, - holes: BTreeSet, -} - -impl Domain { - fn new() -> Domain { - Domain { - lower_bound: I32Ext::NegativeInf, - upper_bound: I32Ext::PositiveInf, - holes: BTreeSet::default(), - } - } - - fn tighten_lower_bound(&mut self, bound: i32) { - if self.lower_bound >= bound { - return; - } - - self.lower_bound = I32Ext::I32(bound); - self.holes = self.holes.split_off(&bound); - - // Take care of the condition where the new bound is already a hole in the domain. - if self.holes.contains(&bound) { - self.tighten_lower_bound(bound + 1); - } - } - - fn tighten_upper_bound(&mut self, bound: i32) { - if self.upper_bound <= bound { - return; - } - - self.upper_bound = I32Ext::I32(bound); - - // Note the '+ 1' to keep the elements <= the upper bound instead of < - // the upper bound. - let _ = self.holes.split_off(&(bound + 1)); - - // Take care of the condition where the new bound is already a hole in the domain. - if self.holes.contains(&bound) { - self.tighten_upper_bound(bound - 1); - } - } - - fn is_consistent(&self) -> bool { - // No need to check holes, as the invariant of `Domain` specifies the bounds are as tight - // as possible, taking holes into account. - - self.lower_bound <= self.upper_bound - } -} - -/// An `i32` or infinity. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub(crate) enum I32Ext { - I32(i32), - NegativeInf, - PositiveInf, -} - -impl PartialEq for I32Ext { - fn eq(&self, other: &i32) -> bool { - match self { - I32Ext::I32(v1) => v1 == other, - I32Ext::NegativeInf | I32Ext::PositiveInf => false, - } - } -} - -impl PartialOrd for I32Ext { - fn partial_cmp(&self, other: &I32Ext) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for I32Ext { - fn cmp(&self, other: &Self) -> Ordering { - match self { - I32Ext::I32(v1) => match other { - I32Ext::I32(v2) => v1.cmp(v2), - I32Ext::NegativeInf => Ordering::Greater, - I32Ext::PositiveInf => Ordering::Less, - }, - I32Ext::NegativeInf => match other { - I32Ext::I32(_) => Ordering::Less, - I32Ext::PositiveInf => Ordering::Less, - I32Ext::NegativeInf => Ordering::Equal, - }, - I32Ext::PositiveInf => match other { - I32Ext::I32(_) => Ordering::Greater, - I32Ext::NegativeInf => Ordering::Greater, - I32Ext::PositiveInf => Ordering::Greater, - }, - } - } -} - -impl PartialOrd for I32Ext { - fn partial_cmp(&self, other: &i32) -> Option { - match self { - I32Ext::I32(v1) => v1.partial_cmp(other), - I32Ext::NegativeInf => Some(Ordering::Less), - I32Ext::PositiveInf => Some(Ordering::Greater), - } - } -} - -impl Add for I32Ext { - type Output = I32Ext; - - fn add(self, rhs: i32) -> Self::Output { - match self { - I32Ext::I32(lhs) => I32Ext::I32(lhs + rhs), - I32Ext::NegativeInf => I32Ext::NegativeInf, - I32Ext::PositiveInf => I32Ext::PositiveInf, - } - } -} - -/// An iterator over the values in the domain of a variable. -pub(crate) struct DomainIterator<'a>(DomainIteratorImpl<'a>); - -enum DomainIteratorImpl<'a> { - Constant { value: i32, finished: bool }, - Domain { domain: &'a Domain, next_value: i32 }, -} - -impl Iterator for DomainIterator<'_> { - type Item = i32; - - fn next(&mut self) -> Option { - match self.0 { - // Iterating over a contant means only yielding the value once, and then - // never again. - DomainIteratorImpl::Constant { - value, - ref mut finished, - } => { - if *finished { - None - } else { - *finished = true; - Some(value) - } - } - - DomainIteratorImpl::Domain { - domain, - ref mut next_value, - } => { - let I32Ext::I32(upper_bound) = domain.upper_bound else { - panic!("Only finite domains can be iterated.") - }; - - loop { - // We have completed iterating the domain. - if *next_value > upper_bound { - return None; - } - - let value = *next_value; - *next_value += 1; - - // The next value is not part of the domain. - if domain.holes.contains(&value) { - continue; - } - - // Here the value is part of the domain, so we yield it. - return Some(value); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use drcp_format::IntAtomic; - use drcp_format::IntComparison; - - use super::*; - - #[test] - fn domain_iterator_unbounded() { - let state = VariableState::default(); - let iterator = state.iter_domain(&fzn_rs::VariableExpr::Identifier(Rc::from("x1"))); - - assert!(iterator.is_none()); - } - - #[test] - fn domain_iterator_unbounded_lower_bound() { - let mut state = VariableState::default(); - - let variable_name = Rc::from("x1"); - let variable = fzn_rs::VariableExpr::Identifier(Rc::clone(&variable_name)); - - let _ = state.apply(&IntAtomic { - name: variable_name, - comparison: IntComparison::LessEqual, - value: 5, - }); - - let iterator = state.iter_domain(&variable); - - assert!(iterator.is_none()); - } - - #[test] - fn domain_iterator_unbounded_upper_bound() { - let mut state = VariableState::default(); - - let variable_name = Rc::from("x1"); - let variable = fzn_rs::VariableExpr::Identifier(Rc::clone(&variable_name)); - - let _ = state.apply(&IntAtomic { - name: variable_name, - comparison: IntComparison::GreaterEqual, - value: 5, - }); - - let iterator = state.iter_domain(&variable); - - assert!(iterator.is_none()); - } - - #[test] - fn domain_iterator_bounded_no_holes() { - let mut state = VariableState::default(); - - let variable_name = Rc::from("x1"); - let variable = fzn_rs::VariableExpr::Identifier(Rc::clone(&variable_name)); - - let _ = state.apply(&IntAtomic { - name: Rc::clone(&variable_name), - comparison: IntComparison::GreaterEqual, - value: 5, - }); - - let _ = state.apply(&IntAtomic { - name: variable_name, - comparison: IntComparison::LessEqual, - value: 10, - }); - - let values = state - .iter_domain(&variable) - .expect("the domain is bounded") - .collect::>(); - - assert_eq!(values, vec![5, 6, 7, 8, 9, 10]); - } - - #[test] - fn domain_iterator_bounded_with_holes() { - let mut state = VariableState::default(); - - let variable_name = Rc::from("x1"); - let variable = fzn_rs::VariableExpr::Identifier(Rc::clone(&variable_name)); - - let _ = state.apply(&IntAtomic { - name: Rc::clone(&variable_name), - comparison: IntComparison::GreaterEqual, - value: 5, - }); - - let _ = state.apply(&IntAtomic { - name: Rc::clone(&variable_name), - comparison: IntComparison::NotEqual, - value: 7, - }); - - let _ = state.apply(&IntAtomic { - name: variable_name, - comparison: IntComparison::LessEqual, - value: 10, - }); - - let values = state - .iter_domain(&variable) - .expect("the domain is bounded") - .collect::>(); - - assert_eq!(values, vec![5, 6, 8, 9, 10]); - } -} diff --git a/pumpkin-crates/checking/Cargo.toml b/pumpkin-crates/checking/Cargo.toml new file mode 100644 index 000000000..ba8c8c976 --- /dev/null +++ b/pumpkin-crates/checking/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "pumpkin-checking" +version = "0.2.2" +repository.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true + +[dependencies] +dyn-clone = "1.0.20" + +[lints] +workspace = true diff --git a/pumpkin-crates/checking/src/atomic_constraint.rs b/pumpkin-crates/checking/src/atomic_constraint.rs new file mode 100644 index 000000000..6be4a9508 --- /dev/null +++ b/pumpkin-crates/checking/src/atomic_constraint.rs @@ -0,0 +1,89 @@ +use std::fmt::Debug; +use std::fmt::Display; +use std::hash::Hash; + +/// Captures the data associated with an atomic constraint. +/// +/// An atomic constraint has the form `[identifier op value]`, where: +/// - `identifier` identifies a variable, +/// - `op` is a [`Comparison`], +/// - and `value` is an integer. +pub trait AtomicConstraint: Sized + Debug { + /// The type of identifier used for variables. + type Identifier: Hash + Eq; + + /// The identifier of this atomic constraint. + fn identifier(&self) -> Self::Identifier; + + /// The [`Comparison`] used for this atomic constraint. + fn comparison(&self) -> Comparison; + + /// The value on the right-hand side of this atomic constraint. + fn value(&self) -> i32; + + /// The strongest atomic constraint that is mutually exclusive with self. + fn negate(&self) -> Self; +} + +/// An arithmetic comparison between two integers. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Comparison { + GreaterEqual, + LessEqual, + Equal, + NotEqual, +} + +impl Display for Comparison { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + Comparison::GreaterEqual => ">=", + Comparison::LessEqual => "<=", + Comparison::Equal => "==", + Comparison::NotEqual => "!=", + }; + + write!(f, "{s}") + } +} + +/// A simple implementation of an [`AtomicConstraint`]. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct TestAtomic { + pub name: &'static str, + pub comparison: Comparison, + pub value: i32, +} + +impl AtomicConstraint for TestAtomic { + type Identifier = &'static str; + + fn identifier(&self) -> Self::Identifier { + self.name + } + + fn comparison(&self) -> Comparison { + self.comparison + } + + fn value(&self) -> i32 { + self.value + } + + fn negate(&self) -> Self { + TestAtomic { + name: self.name, + comparison: match self.comparison { + Comparison::GreaterEqual => Comparison::LessEqual, + Comparison::LessEqual => Comparison::GreaterEqual, + Comparison::Equal => Comparison::NotEqual, + Comparison::NotEqual => Comparison::Equal, + }, + value: match self.comparison { + Comparison::GreaterEqual => self.value - 1, + Comparison::LessEqual => self.value + 1, + Comparison::NotEqual | Comparison::Equal => self.value, + }, + } + } +} diff --git a/pumpkin-crates/checking/src/int_ext.rs b/pumpkin-crates/checking/src/int_ext.rs new file mode 100644 index 000000000..896843d03 --- /dev/null +++ b/pumpkin-crates/checking/src/int_ext.rs @@ -0,0 +1,336 @@ +use std::cmp::Ordering; +use std::fmt::Debug; +use std::iter::Sum; +use std::ops::Add; +use std::ops::AddAssign; +use std::ops::Mul; +use std::ops::Neg; + +/// An [`i32`] or positive/negative infinity. +/// +/// # Notes on arithmetic operations: +/// - The result of the operation `infty + -infty` is undetermined, and if evaluated will cause a +/// panic. +/// - Multiplying [`IntExt::PositiveInf`] or [`IntExt::NegativeInf`] with `IntExt::I32(0)` will +/// yield `IntExt::I32(0)`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum IntExt { + Int(Int), + NegativeInf, + PositiveInf, +} + +impl IntExt { + pub fn floor_div(&self, other: &IntExt) -> Option> { + match (self, other) { + (IntExt::Int(inner), IntExt::Int(inner_other)) => { + let inner = *inner as f64; + let inner_other = *inner_other as f64; + + Some(IntExt::Int((inner / inner_other).floor() as i32)) + } + (IntExt::NegativeInf, IntExt::Int(inner)) => { + if inner.is_positive() { + Some(IntExt::NegativeInf) + } else { + Some(IntExt::PositiveInf) + } + } + (IntExt::PositiveInf, IntExt::Int(inner)) => { + if inner.is_positive() { + Some(IntExt::PositiveInf) + } else { + Some(IntExt::NegativeInf) + } + } + (IntExt::PositiveInf, IntExt::NegativeInf) => None, + (IntExt::PositiveInf, IntExt::PositiveInf) => None, + (IntExt::NegativeInf, IntExt::NegativeInf) => None, + (IntExt::NegativeInf, IntExt::PositiveInf) => None, + (IntExt::Int(_), IntExt::NegativeInf) => Some(IntExt::Int(0)), + (IntExt::Int(_), IntExt::PositiveInf) => Some(IntExt::Int(0)), + } + } + + pub fn ceil_div(&self, other: &IntExt) -> Option> { + match (self, other) { + (IntExt::Int(inner), IntExt::Int(inner_other)) => { + let inner = *inner as f64; + let inner_other = *inner_other as f64; + + Some(IntExt::Int((inner / inner_other).ceil() as i32)) + } + (IntExt::NegativeInf, IntExt::Int(inner)) => { + if inner.is_positive() { + Some(IntExt::NegativeInf) + } else { + Some(IntExt::PositiveInf) + } + } + (IntExt::PositiveInf, IntExt::Int(inner)) => { + if inner.is_positive() { + Some(IntExt::PositiveInf) + } else { + Some(IntExt::NegativeInf) + } + } + (IntExt::PositiveInf, IntExt::NegativeInf) => None, + (IntExt::PositiveInf, IntExt::PositiveInf) => None, + (IntExt::NegativeInf, IntExt::NegativeInf) => None, + (IntExt::NegativeInf, IntExt::PositiveInf) => None, + (IntExt::Int(_), IntExt::NegativeInf) => Some(IntExt::Int(0)), + (IntExt::Int(_), IntExt::PositiveInf) => Some(IntExt::Int(0)), + } + } +} + +impl From for IntExt { + fn from(value: i32) -> Self { + IntExt::Int(value) + } +} + +impl From> for IntExt { + fn from(value: IntExt) -> Self { + match value { + IntExt::Int(int) => IntExt::Int(int.into()), + IntExt::NegativeInf => IntExt::NegativeInf, + IntExt::PositiveInf => IntExt::PositiveInf, + } + } +} + +// TODO: This is not a great pattern, but for now I do not want to touch this. +impl TryInto for IntExt { + type Error = (); + + fn try_into(self) -> Result { + match self { + IntExt::Int(inner) => Ok(inner), + IntExt::NegativeInf | IntExt::PositiveInf => Err(()), + } + } +} + +impl PartialEq for IntExt { + fn eq(&self, other: &Int) -> bool { + match self { + IntExt::Int(v1) => v1 == other, + IntExt::NegativeInf | IntExt::PositiveInf => false, + } + } +} + +impl PartialEq for i32 { + fn eq(&self, other: &IntExt) -> bool { + other.eq(self) + } +} + +impl PartialOrd for i32 { + fn partial_cmp(&self, other: &IntExt) -> Option { + other.neg().partial_cmp(&self.neg()) + } +} + +impl PartialOrd for IntExt { + fn partial_cmp(&self, other: &IntExt) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for IntExt { + fn cmp(&self, other: &Self) -> Ordering { + match self { + IntExt::Int(v1) => match other { + IntExt::Int(v2) => v1.cmp(v2), + IntExt::NegativeInf => Ordering::Greater, + IntExt::PositiveInf => Ordering::Less, + }, + IntExt::NegativeInf => match other { + IntExt::Int(_) => Ordering::Less, + IntExt::PositiveInf => Ordering::Less, + IntExt::NegativeInf => Ordering::Equal, + }, + IntExt::PositiveInf => match other { + IntExt::Int(_) => Ordering::Greater, + IntExt::NegativeInf => Ordering::Greater, + IntExt::PositiveInf => Ordering::Greater, + }, + } + } +} + +impl PartialOrd for IntExt { + fn partial_cmp(&self, other: &i32) -> Option { + match self { + IntExt::Int(v1) => v1.partial_cmp(other), + IntExt::NegativeInf => Some(Ordering::Less), + IntExt::PositiveInf => Some(Ordering::Greater), + } + } +} + +impl PartialOrd for IntExt { + fn partial_cmp(&self, other: &i64) -> Option { + match self { + IntExt::Int(v1) => v1.partial_cmp(other), + IntExt::NegativeInf => Some(Ordering::Less), + IntExt::PositiveInf => Some(Ordering::Greater), + } + } +} + +impl Add for IntExt { + type Output = IntExt; + + fn add(self, rhs: i32) -> Self::Output { + self + IntExt::Int(rhs) + } +} + +impl + Debug> Add for IntExt { + type Output = IntExt; + + fn add(self, rhs: IntExt) -> Self::Output { + match (self, rhs) { + (IntExt::Int(lhs), IntExt::Int(rhs)) => IntExt::Int(lhs + rhs), + + (IntExt::Int(_), Self::NegativeInf) => Self::NegativeInf, + (IntExt::Int(_), Self::PositiveInf) => Self::PositiveInf, + (Self::NegativeInf, IntExt::Int(_)) => Self::NegativeInf, + (Self::PositiveInf, IntExt::Int(_)) => Self::PositiveInf, + + (IntExt::NegativeInf, IntExt::NegativeInf) => IntExt::NegativeInf, + (IntExt::PositiveInf, IntExt::PositiveInf) => IntExt::PositiveInf, + + (lhs @ IntExt::NegativeInf, rhs @ IntExt::PositiveInf) + | (lhs @ IntExt::PositiveInf, rhs @ IntExt::NegativeInf) => { + panic!("the result of {lhs:?} + {rhs:?} is indeterminate") + } + } + } +} + +impl AddAssign for IntExt +where + Int: AddAssign, +{ + fn add_assign(&mut self, rhs: Int) { + match self { + IntExt::Int(value) => { + value.add_assign(rhs); + } + + IntExt::NegativeInf | IntExt::PositiveInf => {} + } + } +} + +impl Mul for IntExt { + type Output = IntExt; + + fn mul(self, rhs: i32) -> Self::Output { + self * IntExt::Int(rhs) + } +} + +impl Mul for IntExt { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (IntExt::Int(lhs), IntExt::Int(rhs)) => IntExt::Int(lhs * rhs), + + // Multiplication with 0 will always yield 0. + (IntExt::Int(0), Self::NegativeInf) + | (IntExt::Int(0), Self::PositiveInf) + | (Self::NegativeInf, IntExt::Int(0)) + | (Self::PositiveInf, IntExt::Int(0)) => IntExt::Int(0), + + (IntExt::Int(value), IntExt::NegativeInf) + | (IntExt::NegativeInf, IntExt::Int(value)) => { + if value >= 0 { + IntExt::NegativeInf + } else { + IntExt::PositiveInf + } + } + + (IntExt::Int(value), IntExt::PositiveInf) + | (IntExt::PositiveInf, IntExt::Int(value)) => { + if value >= 0 { + IntExt::PositiveInf + } else { + IntExt::NegativeInf + } + } + + (IntExt::NegativeInf, IntExt::NegativeInf) + | (IntExt::PositiveInf, IntExt::PositiveInf) => IntExt::PositiveInf, + + (IntExt::NegativeInf, IntExt::PositiveInf) + | (IntExt::PositiveInf, IntExt::NegativeInf) => IntExt::NegativeInf, + } + } +} + +impl Neg for IntExt { + type Output = Self; + + fn neg(self) -> Self::Output { + match self { + IntExt::Int(value) => IntExt::Int(-value), + IntExt::NegativeInf => IntExt::PositiveInf, + IntExt::PositiveInf => Self::NegativeInf, + } + } +} + +impl Sum for IntExt { + fn sum>(iter: I) -> Self { + iter.fold(IntExt::Int(0), |acc, value| acc + value) + } +} + +impl Sum for IntExt { + fn sum>(iter: I) -> Self { + iter.fold(IntExt::Int(0), |acc, value| acc + value) + } +} + +#[cfg(test)] +mod tests { + use IntExt::*; + + use super::*; + + #[test] + fn ordering_of_i32_with_i32_ext() { + assert!(Int(2) < 3); + assert!(Int(-1) < 3); + assert!(Int(-10) < -1); + } + + #[test] + fn ordering_of_i32_ext_with_i32() { + assert!(1 < Int(2)); + assert!(-10 < Int(-1)); + assert!(-11 < Int(-10)); + } + + #[test] + fn test_adding_i32s() { + assert_eq!(Int(3) + Int(4), Int(7)); + } + + #[test] + fn test_adding_negative_inf() { + assert_eq!(Int(3) + NegativeInf, NegativeInf); + } + + #[test] + fn test_adding_positive_inf() { + assert_eq!(Int(3) + PositiveInf, PositiveInf); + } +} diff --git a/pumpkin-crates/checking/src/lib.rs b/pumpkin-crates/checking/src/lib.rs new file mode 100644 index 000000000..44b2963e5 --- /dev/null +++ b/pumpkin-crates/checking/src/lib.rs @@ -0,0 +1,60 @@ +//! Exposes a common interface used to check inferences. +//! +//! The main exposed type is the [`InferenceChecker`], which can be implemented to verify whether +//! inferences are sound w.r.t. an inference rule. + +mod atomic_constraint; +mod int_ext; +mod variable; +mod variable_state; + +use std::fmt::Debug; + +pub use atomic_constraint::*; +use dyn_clone::DynClone; +pub use int_ext::*; +pub use variable::*; +pub use variable_state::*; + +/// An inference checker tests whether the given state is a conflict under the sematics of an +/// inference rule. +pub trait InferenceChecker: Debug + DynClone { + /// Returns `true` if `state` is a conflict, and `false` if not. + /// + /// For the conflict check, all the premises are true in the state and the consequent, if + /// present, if false. + fn check( + &self, + state: VariableState, + premises: &[Atomic], + consequent: Option<&Atomic>, + ) -> bool; +} + +/// Wrapper around `Box>` that implements [`Clone`]. +#[derive(Debug)] +pub struct BoxedChecker(Box>); + +impl Clone for BoxedChecker { + fn clone(&self) -> Self { + BoxedChecker(dyn_clone::clone_box(&*self.0)) + } +} + +impl From>> for BoxedChecker { + fn from(value: Box>) -> Self { + BoxedChecker(value) + } +} + +impl BoxedChecker { + /// See [`InferenceChecker::check`]. + pub fn check( + &self, + variable_state: VariableState, + premises: &[Atomic], + consequent: Option<&Atomic>, + ) -> bool { + self.0.check(variable_state, premises, consequent) + } +} diff --git a/pumpkin-crates/checking/src/variable.rs b/pumpkin-crates/checking/src/variable.rs new file mode 100644 index 000000000..8bcca0f59 --- /dev/null +++ b/pumpkin-crates/checking/src/variable.rs @@ -0,0 +1,133 @@ +use std::fmt::Debug; + +use crate::AtomicConstraint; +use crate::Comparison; +use crate::IntExt; +use crate::TestAtomic; +use crate::VariableState; + +/// A variable in a constraint satisfaction problem. +pub trait CheckerVariable: Debug + Clone { + /// Tests whether the given atomic is a statement over the variable `self`. + fn does_atomic_constrain_self(&self, atomic: &Atomic) -> bool; + + /// Get the atomic constraint `[self <= value]`. + fn atomic_less_than(&self, value: i32) -> Atomic; + + /// Get the atomic constraint `[self <= value]`. + fn atomic_greater_than(&self, value: i32) -> Atomic; + + /// Get the atomic constraint `[self == value]`. + fn atomic_equal(&self, value: i32) -> Atomic; + + /// Get the atomic constraint `[self != value]`. + fn atomic_not_equal(&self, value: i32) -> Atomic; + + /// Get the lower bound of the domain. + fn induced_lower_bound(&self, variable_state: &VariableState) -> IntExt; + + /// Get the upper bound of the domain. + fn induced_upper_bound(&self, variable_state: &VariableState) -> IntExt; + + /// Get the value the variable is fixed to, if the variable is fixed. + fn induced_fixed_value(&self, variable_state: &VariableState) -> Option; + + /// Returns whether the value is in the domain. + fn induced_domain_contains(&self, variable_state: &VariableState, value: i32) -> bool; + + /// Get the holes in the domain. + fn induced_holes<'this, 'state>( + &'this self, + variable_state: &'state VariableState, + ) -> impl Iterator + 'state + where + 'this: 'state; + + /// Iterate the domain of the variable. + /// + /// The order of the values is unspecified. + fn iter_induced_domain<'this, 'state>( + &'this self, + variable_state: &'state VariableState, + ) -> Option + 'state> + where + 'this: 'state; +} + +impl CheckerVariable for &'static str { + fn does_atomic_constrain_self(&self, atomic: &TestAtomic) -> bool { + &atomic.name == self + } + + fn atomic_less_than(&self, value: i32) -> TestAtomic { + TestAtomic { + name: self, + comparison: Comparison::LessEqual, + value, + } + } + + fn atomic_greater_than(&self, value: i32) -> TestAtomic { + TestAtomic { + name: self, + comparison: Comparison::GreaterEqual, + value, + } + } + + fn atomic_equal(&self, value: i32) -> TestAtomic { + TestAtomic { + name: self, + comparison: Comparison::Equal, + value, + } + } + + fn atomic_not_equal(&self, value: i32) -> TestAtomic { + TestAtomic { + name: self, + comparison: Comparison::NotEqual, + value, + } + } + + fn induced_lower_bound(&self, variable_state: &VariableState) -> IntExt { + variable_state.lower_bound(self) + } + + fn induced_upper_bound(&self, variable_state: &VariableState) -> IntExt { + variable_state.upper_bound(self) + } + + fn induced_fixed_value(&self, variable_state: &VariableState) -> Option { + variable_state.fixed_value(self) + } + + fn induced_domain_contains( + &self, + variable_state: &VariableState, + value: i32, + ) -> bool { + variable_state.contains(self, value) + } + + fn induced_holes<'this, 'state>( + &'this self, + variable_state: &'state VariableState, + ) -> impl Iterator + 'state + where + 'this: 'state, + { + variable_state.holes(self) + } + + fn iter_induced_domain<'this, 'state>( + &'this self, + variable_state: &'state VariableState, + ) -> Option + 'state> + where + 'this: 'state, + { + variable_state.iter_domain(self) + } +} diff --git a/pumpkin-crates/checking/src/variable_state.rs b/pumpkin-crates/checking/src/variable_state.rs new file mode 100644 index 000000000..f225da98b --- /dev/null +++ b/pumpkin-crates/checking/src/variable_state.rs @@ -0,0 +1,411 @@ +use std::collections::BTreeSet; +use std::collections::HashMap; +use std::hash::Hash; + +use crate::AtomicConstraint; +use crate::Comparison; +#[cfg(doc)] +use crate::InferenceChecker; +use crate::IntExt; + +/// The domains of all variables in the problem. +/// +/// Domains are initially unbounded. This is why bounds are represented as [`IntExt`]. +/// +/// Domains can be reduced through [`VariableState::apply`]. By default, the domain of every +/// variable is infinite. +#[derive(Clone, Debug)] +pub struct VariableState { + domains: HashMap, +} + +impl Default for VariableState { + fn default() -> Self { + Self { + domains: Default::default(), + } + } +} + +impl VariableState +where + Ident: Hash + Eq, + Atomic: AtomicConstraint, +{ + /// Create a variable state that applies all the premises and, if present, the negation of the + /// consequent. + /// + /// If `premises /\ !consequent` contain mutually exclusive atomic constraints (e.g., `[x >= + /// 5]` and `[x <= 2]`) then `None` is returned. + /// + /// An [`InferenceChecker`] will receive a [`VariableState`] that conforms to this description. + pub fn prepare_for_conflict_check( + premises: impl IntoIterator, + consequent: Option, + ) -> Option { + let mut variable_state = VariableState::default(); + + let negated_consequent = consequent.as_ref().map(AtomicConstraint::negate); + + // Apply all the premises and the negation of the consequent to the state. + if !premises + .into_iter() + .chain(negated_consequent) + .all(|premise| variable_state.apply(&premise)) + { + return None; + } + + Some(variable_state) + } + + /// Get the lower bound of a variable. + pub fn lower_bound(&self, identifier: &Ident) -> IntExt { + self.domains + .get(identifier) + .map(|domain| domain.lower_bound) + .unwrap_or(IntExt::NegativeInf) + } + + /// Get the upper bound of a variable. + pub fn upper_bound(&self, identifier: &Ident) -> IntExt { + self.domains + .get(identifier) + .map(|domain| domain.upper_bound) + .unwrap_or(IntExt::PositiveInf) + } + + pub fn contains(&self, identifier: &Ident, value: i32) -> bool { + self.domains + .get(identifier) + .map(|domain| { + value >= domain.lower_bound + && value <= domain.upper_bound + && !domain.holes.contains(&value) + }) + .unwrap_or(false) + } + + /// Get the holes within the lower and upper bound of the variable expression. + pub fn holes<'a>(&'a self, identifier: &Ident) -> impl Iterator + 'a + where + Ident: 'a, + { + self.domains + .get(identifier) + .map(|domain| domain.holes.iter().copied()) + .into_iter() + .flatten() + } + + /// Get the fixed value of this variable, if it is fixed. + pub fn fixed_value(&self, identifier: &Ident) -> Option { + let domain = self.domains.get(identifier)?; + + if domain.lower_bound == domain.upper_bound { + let IntExt::Int(value) = domain.lower_bound else { + panic!( + "lower can only equal upper if they are integers, otherwise the sign of infinity makes them different" + ); + }; + + Some(value) + } else { + None + } + } + + /// Obtain an iterator over the domain of the variable. + /// + /// If the domain is unbounded, then `None` is returned. + pub fn iter_domain<'a>(&'a self, identifier: &Ident) -> Option> + where + Ident: 'a, + { + let domain = self.domains.get(identifier)?; + + let IntExt::Int(lower_bound) = domain.lower_bound else { + // If there is no lower bound, then the domain is unbounded. + return None; + }; + + // Ensure there is also an upper bound. + if !matches!(domain.upper_bound, IntExt::Int(_)) { + return None; + } + + Some(DomainIterator { + domain, + next_value: lower_bound, + }) + } + + /// Apply the given `Atomic` to the state. + /// + /// Returns true if the state remains consistent, or false if the atomic cannot be true in + /// conjunction with previously applied atomics. + pub fn apply(&mut self, atomic: &Atomic) -> bool { + let identifier = atomic.identifier(); + let domain = self.domains.entry(identifier).or_insert(Domain::new()); + + match atomic.comparison() { + Comparison::GreaterEqual => { + domain.tighten_lower_bound(atomic.value()); + } + + Comparison::LessEqual => { + domain.tighten_upper_bound(atomic.value()); + } + + Comparison::Equal => { + domain.tighten_lower_bound(atomic.value()); + domain.tighten_upper_bound(atomic.value()); + } + + Comparison::NotEqual => { + if domain.lower_bound == atomic.value() { + domain.tighten_lower_bound(atomic.value() + 1); + } + + if domain.upper_bound == atomic.value() { + domain.tighten_upper_bound(atomic.value() - 1); + } + + if domain.lower_bound < atomic.value() && domain.upper_bound > atomic.value() { + let _ = domain.holes.insert(atomic.value()); + } + } + } + + domain.is_consistent() + } + + /// Is the given atomic true in the current state. + pub fn is_true(&self, atomic: &Atomic) -> bool { + let Some(domain) = self.domains.get(&atomic.identifier()) else { + return false; + }; + + match atomic.comparison() { + Comparison::GreaterEqual => domain.lower_bound >= atomic.value(), + + Comparison::LessEqual => domain.upper_bound <= atomic.value(), + + Comparison::Equal => { + domain.lower_bound >= atomic.value() && domain.upper_bound <= atomic.value() + } + + Comparison::NotEqual => { + if domain.lower_bound >= atomic.value() { + return true; + } + + if domain.upper_bound <= atomic.value() { + return true; + } + + if domain.holes.contains(&atomic.value()) { + return true; + } + + false + } + } + } +} + +/// A domain inside the variable state. +#[derive(Clone, Debug)] +struct Domain { + lower_bound: IntExt, + upper_bound: IntExt, + holes: BTreeSet, +} + +impl Domain { + fn new() -> Domain { + Domain { + lower_bound: IntExt::NegativeInf, + upper_bound: IntExt::PositiveInf, + holes: BTreeSet::default(), + } + } + + /// Tighten the lower bound and remove any holes that are no longer strictly larger than the + /// lower bound. + fn tighten_lower_bound(&mut self, bound: i32) { + if self.lower_bound >= bound { + return; + } + + self.lower_bound = IntExt::Int(bound); + self.holes = self.holes.split_off(&bound); + + // Take care of the condition where the new bound is already a hole in the domain. + if self.holes.contains(&bound) { + self.tighten_lower_bound(bound + 1); + } + } + + /// Tighten the upper bound and remove any holes that are no longer strictly smaller than the + /// upper bound. + fn tighten_upper_bound(&mut self, bound: i32) { + if self.upper_bound <= bound { + return; + } + + self.upper_bound = IntExt::Int(bound); + + // Note the '+ 1' to keep the elements <= the upper bound instead of < + // the upper bound. + let _ = self.holes.split_off(&(bound + 1)); + + // Take care of the condition where the new bound is already a hole in the domain. + if self.holes.contains(&bound) { + self.tighten_upper_bound(bound - 1); + } + } + + /// Returns true if the domain contains at least one value. + fn is_consistent(&self) -> bool { + // No need to check holes, as the invariant of `Domain` specifies the bounds are as tight + // as possible, taking holes into account. + + self.lower_bound <= self.upper_bound + } +} + +/// An iterator over the values in the domain of a variable. +#[derive(Debug)] +pub struct DomainIterator<'a> { + domain: &'a Domain, + next_value: i32, +} + +impl Iterator for DomainIterator<'_> { + type Item = i32; + + fn next(&mut self) -> Option { + let DomainIterator { domain, next_value } = self; + + let IntExt::Int(upper_bound) = domain.upper_bound else { + panic!("Only finite domains can be iterated.") + }; + + loop { + // We have completed iterating the domain. + if *next_value > upper_bound { + return None; + } + + let value = *next_value; + *next_value += 1; + + // The next value is not part of the domain. + if domain.holes.contains(&value) { + continue; + } + + // Here the value is part of the domain, so we yield it. + return Some(value); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestAtomic; + + #[test] + fn domain_iterator_unbounded() { + let state = VariableState::::default(); + let iterator = state.iter_domain(&"x1"); + + assert!(iterator.is_none()); + } + + #[test] + fn domain_iterator_unbounded_lower_bound() { + let mut state = VariableState::default(); + + let _ = state.apply(&TestAtomic { + name: "x1", + comparison: Comparison::LessEqual, + value: 5, + }); + + let iterator = state.iter_domain(&"x1"); + + assert!(iterator.is_none()); + } + + #[test] + fn domain_iterator_unbounded_upper_bound() { + let mut state = VariableState::default(); + + let _ = state.apply(&TestAtomic { + name: "x1", + comparison: Comparison::GreaterEqual, + value: 5, + }); + + let iterator = state.iter_domain(&"x1"); + + assert!(iterator.is_none()); + } + + #[test] + fn domain_iterator_bounded_no_holes() { + let mut state = VariableState::default(); + + let _ = state.apply(&TestAtomic { + name: "x1", + comparison: Comparison::GreaterEqual, + value: 5, + }); + + let _ = state.apply(&TestAtomic { + name: "x1", + comparison: Comparison::LessEqual, + value: 10, + }); + + let values = state + .iter_domain(&"x1") + .expect("the domain is bounded") + .collect::>(); + + assert_eq!(values, vec![5, 6, 7, 8, 9, 10]); + } + + #[test] + fn domain_iterator_bounded_with_holes() { + let mut state = VariableState::default(); + + let _ = state.apply(&TestAtomic { + name: "x1", + comparison: Comparison::GreaterEqual, + value: 5, + }); + + let _ = state.apply(&TestAtomic { + name: "x1", + comparison: Comparison::NotEqual, + value: 7, + }); + + let _ = state.apply(&TestAtomic { + name: "x1", + comparison: Comparison::LessEqual, + value: 10, + }); + + let values = state + .iter_domain(&"x1") + .expect("the domain is bounded") + .collect::>(); + + assert_eq!(values, vec![5, 6, 8, 9, 10]); + } +} diff --git a/pumpkin-crates/core/Cargo.toml b/pumpkin-crates/core/Cargo.toml index 94d304d59..1dc43e6b0 100644 --- a/pumpkin-crates/core/Cargo.toml +++ b/pumpkin-crates/core/Cargo.toml @@ -11,6 +11,7 @@ description = "The core of the Pumpkin constraint programming solver." workspace = true [dependencies] +pumpkin-checking = { version = "0.2.2", path = "../checking" } thiserror = "2.0.12" log = "0.4.17" bitfield = "0.14.0" @@ -20,7 +21,7 @@ rand = { version = "0.8.5", features = [ "small_rng", "alloc" ] } once_cell = "1.19.0" downcast-rs = "1.2.1" drcp-format = { version = "0.3.0", path = "../../drcp-format" } -convert_case = "0.6.0" +convert_case = "0.8.0" itertools = "0.13.0" bitfield-struct = "0.9.2" num = "0.4.3" @@ -30,6 +31,7 @@ indexmap = "2.10.0" dyn-clone = "1.0.20" flate2 = { version = "1.1.2" } + [dev-dependencies] pumpkin-constraints = { version = "0.2.2", path = "../constraints", features=["clap"] } @@ -41,5 +43,6 @@ getrandom = { version = "0.2", features = ["js"] } wasm-bindgen-test = "0.3" [features] +check-propagations = [] debug-checks = [] clap = ["dep:clap"] diff --git a/pumpkin-crates/core/src/engine/conflict_analysis/resolvers/resolution_resolver.rs b/pumpkin-crates/core/src/engine/conflict_analysis/resolvers/resolution_resolver.rs index 3c9a77e57..f6ce2dd66 100644 --- a/pumpkin-crates/core/src/engine/conflict_analysis/resolvers/resolution_resolver.rs +++ b/pumpkin-crates/core/src/engine/conflict_analysis/resolvers/resolution_resolver.rs @@ -19,6 +19,7 @@ use crate::proof::InferenceCode; use crate::proof::RootExplanationContext; use crate::proof::explain_root_assignment; use crate::propagation::CurrentNogood; +use crate::propagators::nogoods::NogoodChecker; use crate::propagators::nogoods::NogoodPropagator; use crate::pumpkin_assert_advanced; use crate::pumpkin_assert_moderate; @@ -125,6 +126,13 @@ impl ConflictResolver for ResolutionResolver { .average_learned_nogood_length .add_term(learned_nogood.predicates.len() as u64); + context.state.add_inference_checker( + inference_code.clone(), + Box::new(NogoodChecker { + nogood: learned_nogood.predicates.clone().into(), + }), + ); + self.add_learned_nogood(context, learned_nogood, inference_code); } } @@ -239,6 +247,9 @@ impl ResolutionResolver { learned_nogood: LearnedNogood, inference_code: InferenceCode, ) { + #[cfg(feature = "check-propagations")] + let trail_len_before_nogood = context.state.trail_len(); + let (nogood_propagator, mut propagation_context) = context .state .get_propagator_mut_with_context(self.nogood_propagator_handle); @@ -251,6 +262,9 @@ impl ResolutionResolver { &mut propagation_context, context.counters, ); + + #[cfg(feature = "check-propagations")] + context.state.check_propagations(trail_len_before_nogood); } /// Clears all data structures to prepare for the new conflict analysis. diff --git a/pumpkin-crates/core/src/engine/constraint_satisfaction_solver.rs b/pumpkin-crates/core/src/engine/constraint_satisfaction_solver.rs index 828ed1de7..27f4bc5ce 100644 --- a/pumpkin-crates/core/src/engine/constraint_satisfaction_solver.rs +++ b/pumpkin-crates/core/src/engine/constraint_satisfaction_solver.rs @@ -50,6 +50,7 @@ use crate::proof::explain_root_assignment; use crate::proof::finalize_proof; use crate::propagation::PropagatorConstructor; use crate::propagation::store::PropagatorHandle; +use crate::propagators::nogoods::NogoodChecker; use crate::propagators::nogoods::NogoodPropagator; use crate::propagators::nogoods::NogoodPropagatorConstructor; use crate::pumpkin_assert_eq_simple; @@ -948,6 +949,13 @@ impl ConstraintSatisfactionSolver { pumpkin_assert_eq_simple!(self.get_checkpoint(), 0); let num_trail_entries = self.state.trail_len(); + self.state.add_inference_checker( + inference_code.clone(), + Box::new(NogoodChecker { + nogood: nogood.clone().into(), + }), + ); + let (nogood_propagator, mut context) = self .state .get_propagator_mut_with_context(self.nogood_propagator_handle); @@ -1241,13 +1249,15 @@ mod tests { fn create_instance1() -> (ConstraintSatisfactionSolver, Vec) { let mut solver = ConstraintSatisfactionSolver::default(); - let constraint_tag = solver.new_constraint_tag(); + let c1 = solver.new_constraint_tag(); + let c2 = solver.new_constraint_tag(); + let c3 = solver.new_constraint_tag(); let lit1 = solver.create_new_literal(None).get_true_predicate(); let lit2 = solver.create_new_literal(None).get_true_predicate(); - let _ = solver.add_clause([lit1, lit2], constraint_tag); - let _ = solver.add_clause([lit1, !lit2], constraint_tag); - let _ = solver.add_clause([!lit1, lit2], constraint_tag); + let _ = solver.add_clause([lit1, lit2], c1); + let _ = solver.add_clause([lit1, !lit2], c2); + let _ = solver.add_clause([!lit1, lit2], c3); (solver, vec![lit1, lit2]) } @@ -1313,13 +1323,14 @@ mod tests { } fn create_instance2() -> (ConstraintSatisfactionSolver, Vec) { let mut solver = ConstraintSatisfactionSolver::default(); - let constraint_tag = solver.new_constraint_tag(); + let c1 = solver.new_constraint_tag(); + let c2 = solver.new_constraint_tag(); let lit1 = solver.create_new_literal(None).get_true_predicate(); let lit2 = solver.create_new_literal(None).get_true_predicate(); let lit3 = solver.create_new_literal(None).get_true_predicate(); - let _ = solver.add_clause([lit1, lit2, lit3], constraint_tag); - let _ = solver.add_clause([lit1, !lit2, lit3], constraint_tag); + let _ = solver.add_clause([lit1, lit2, lit3], c1); + let _ = solver.add_clause([lit1, !lit2, lit3], c2); (solver, vec![lit1, lit2, lit3]) } diff --git a/pumpkin-crates/core/src/engine/cp/test_solver.rs b/pumpkin-crates/core/src/engine/cp/test_solver.rs index f7bfaeee1..50853a8eb 100644 --- a/pumpkin-crates/core/src/engine/cp/test_solver.rs +++ b/pumpkin-crates/core/src/engine/cp/test_solver.rs @@ -2,6 +2,8 @@ //! setting up specific scenarios under which to test the various operations of a propagator. use std::fmt::Debug; +use pumpkin_checking::InferenceChecker; + use super::PropagatorQueue; use crate::containers::KeyGenerator; use crate::engine::EmptyDomain; @@ -14,6 +16,7 @@ use crate::options::LearningOptions; use crate::predicate; use crate::predicates::PropositionalConjunction; use crate::proof::ConstraintTag; +use crate::proof::InferenceCode; use crate::propagation::Domains; use crate::propagation::EnqueueDecision; use crate::propagation::ExplanationContext; @@ -53,6 +56,25 @@ impl Default for TestSolver { #[deprecated = "Will be replaced by the state API"] impl TestSolver { + pub fn accept_inferences_by(&mut self, inference_code: InferenceCode) { + #[derive(Debug, Clone, Copy)] + struct Checker; + + impl InferenceChecker for Checker { + fn check( + &self, + _: pumpkin_checking::VariableState, + _: &[Predicate], + _: Option<&Predicate>, + ) -> bool { + true + } + } + + self.state + .add_inference_checker(inference_code, Box::new(Checker)); + } + pub fn new_variable(&mut self, lb: i32, ub: i32) -> DomainId { self.state.new_interval_variable(lb, ub, None) } diff --git a/pumpkin-crates/core/src/engine/predicates/predicate.rs b/pumpkin-crates/core/src/engine/predicates/predicate.rs index 5f3acec88..2e4826d86 100644 --- a/pumpkin-crates/core/src/engine/predicates/predicate.rs +++ b/pumpkin-crates/core/src/engine/predicates/predicate.rs @@ -1,3 +1,5 @@ +use pumpkin_checking::AtomicConstraint; + use crate::engine::Assignments; use crate::engine::variables::DomainId; use crate::predicate; @@ -231,6 +233,31 @@ impl std::fmt::Debug for Predicate { } } +impl AtomicConstraint for Predicate { + type Identifier = DomainId; + + fn identifier(&self) -> Self::Identifier { + self.get_domain() + } + + fn comparison(&self) -> pumpkin_checking::Comparison { + match self.get_predicate_type() { + PredicateType::LowerBound => pumpkin_checking::Comparison::GreaterEqual, + PredicateType::UpperBound => pumpkin_checking::Comparison::LessEqual, + PredicateType::NotEqual => pumpkin_checking::Comparison::NotEqual, + PredicateType::Equal => pumpkin_checking::Comparison::Equal, + } + } + + fn value(&self) -> i32 { + self.get_right_hand_side() + } + + fn negate(&self) -> Self { + !*self + } +} + #[cfg(test)] mod test { use super::Predicate; diff --git a/pumpkin-crates/core/src/engine/state.rs b/pumpkin-crates/core/src/engine/state.rs index 2e5b6b936..386bedd57 100644 --- a/pumpkin-crates/core/src/engine/state.rs +++ b/pumpkin-crates/core/src/engine/state.rs @@ -1,6 +1,12 @@ use std::sync::Arc; +use pumpkin_checking::BoxedChecker; +use pumpkin_checking::InferenceChecker; +#[cfg(feature = "check-propagations")] +use pumpkin_checking::VariableState; + use crate::basic_types::PropagatorConflict; +use crate::containers::HashMap; use crate::containers::KeyGenerator; use crate::create_statistics_struct; use crate::engine::Assignments; @@ -23,6 +29,8 @@ use crate::proof::ProofLog; use crate::propagation::CurrentNogood; use crate::propagation::Domains; use crate::propagation::ExplanationContext; +#[cfg(feature = "check-propagations")] +use crate::propagation::InferenceCheckers; use crate::propagation::PropagationContext; use crate::propagation::Propagator; use crate::propagation::PropagatorConstructor; @@ -69,6 +77,9 @@ pub struct State { pub(crate) constraint_tags: KeyGenerator, statistics: StateStatistics, + + /// Inference checkers to run in the propagation loop. + checkers: HashMap>>, } create_statistics_struct!(StateStatistics { @@ -153,6 +164,7 @@ impl Default for State { notification_engine: NotificationEngine::default(), statistics: StateStatistics::default(), constraint_tags: KeyGenerator::default(), + checkers: HashMap::default(), }; // As a convention, the assignments contain a dummy domain_id=0, which represents a 0-1 // variable that is assigned to one. We use it to represent predicates that are @@ -347,8 +359,12 @@ impl State { Constructor: PropagatorConstructor, Constructor::PropagatorImpl: 'static, { + #[cfg(feature = "check-propagations")] + constructor.add_inference_checkers(InferenceCheckers::new(self)); + let original_handle: PropagatorHandle = self.propagators.new_propagator().key(); + let constructor_context = PropagatorConstructorContext::new(original_handle.propagator_id(), self); let propagator = constructor.create(constructor_context); @@ -370,6 +386,22 @@ impl State { handle } + + /// Add an inference checker to the state. + /// + /// The inference checker will be used to check propagations performed during + /// [`Self::propagate_to_fixed_point`], if the `check-propagations` feature is enabled. + /// + /// Multiple inference checkers may be added for the same inference code. In that case, if + /// any checker accepts the inference, the inference is accepted. + pub fn add_inference_checker( + &mut self, + inference_code: InferenceCode, + checker: Box>, + ) { + let checkers = self.checkers.entry(inference_code).or_default(); + checkers.push(BoxedChecker::from(checker)); + } } /// Operations for retrieving propagators. @@ -551,6 +583,9 @@ impl State { propagator.propagate(context) }; + #[cfg(feature = "check-propagations")] + self.check_propagations(num_trail_entries_before); + match propagation_status { Ok(_) => { // Notify other propagators of the propagations and continue. @@ -575,6 +610,9 @@ impl State { ); } Err(conflict) => { + #[cfg(feature = "check-propagations")] + self.check_conflict(&conflict); + self.statistics.num_conflicts += 1; if let Conflict::Propagator(inner) = &conflict { pumpkin_assert_advanced!(DebugHelper::debug_reported_failure( @@ -593,6 +631,62 @@ impl State { Ok(()) } + /// Check the inference that triggered the given conflict. + /// + /// Does nothing when the conflict is an empty domain. + /// + /// Panics when the inference checker rejects the conflict. + #[cfg(feature = "check-propagations")] + fn check_conflict(&mut self, conflict: &Conflict) { + if let Conflict::Propagator(propagator_conflict) = conflict { + self.run_checker( + propagator_conflict.conjunction.clone(), + None, + &propagator_conflict.inference_code, + ); + } + } + + /// For every item on the trail starting at index `first_propagation_index`, run the + /// inference checker for it. + /// + /// This method should be called after every propagator invocation, so all elements on the + /// trail starting at `first_propagation_index` should be propagations. Otherwise this function + /// will panic. + /// + /// If the checker rejects the inference, this method panics. + #[cfg(feature = "check-propagations")] + pub(crate) fn check_propagations(&mut self, first_propagation_index: usize) { + let mut reason_buffer = vec![]; + + for trail_index in first_propagation_index..self.assignments.num_trail_entries() { + let entry = self.assignments.get_trail_entry(trail_index); + + let (reason_ref, inference_code) = entry + .reason + .expect("propagations should only be checked after propagations"); + + reason_buffer.clear(); + let reason_exists = self.reason_store.get_or_compute( + reason_ref, + ExplanationContext::without_working_nogood( + &self.assignments, + trail_index, + &mut self.notification_engine, + ), + &mut self.propagators, + &mut reason_buffer, + ); + assert!(reason_exists, "all propagations have reasons"); + + self.run_checker( + std::mem::take(&mut reason_buffer), + Some(entry.predicate), + &inference_code, + ); + } + } + /// Performs fixed-point propagation using the propagators defined in the [`State`]. /// /// The posted [`Predicate`]s (using [`State::post`]) and added propagators (using @@ -634,6 +728,49 @@ impl State { } } +#[cfg(feature = "check-propagations")] +impl State { + /// Run the checker for the given inference code on the given inference. + fn run_checker( + &self, + premises: impl IntoIterator, + consequent: Option, + inference_code: &InferenceCode, + ) { + let premises: Vec<_> = premises.into_iter().collect(); + + let checkers = self + .checkers + .get(inference_code) + .map(|vec| vec.as_slice()) + .unwrap_or(&[]); + + assert!( + !checkers.is_empty(), + "missing checker for inference code {inference_code:?}" + ); + + let any_checker_accepts_inference = checkers.iter().any(|checker| { + // Construct the variable state for the conflict check. + let variable_state = + VariableState::prepare_for_conflict_check(premises.clone(), consequent) + .unwrap_or_else(|| { + panic!("inconsistent atomics in inference by {inference_code:?}") + }); + + checker.check(variable_state, &premises, consequent.as_ref()) + }); + + assert!( + any_checker_accepts_inference, + "checker for inference code {:?} fails on inference {:?} -> {:?}", + inference_code, + premises.into_iter().collect::>(), + consequent, + ); + } +} + impl State { /// This is a temporary accessor to help refactoring. pub(crate) fn get_solution_reference(&self) -> SolutionReference<'_> { diff --git a/pumpkin-crates/core/src/engine/variables/affine_view.rs b/pumpkin-crates/core/src/engine/variables/affine_view.rs index aa60c4836..8565880d2 100644 --- a/pumpkin-crates/core/src/engine/variables/affine_view.rs +++ b/pumpkin-crates/core/src/engine/variables/affine_view.rs @@ -1,6 +1,8 @@ use std::cmp::Ordering; use enumset::EnumSet; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::IntExt; use super::TransformableVariable; use crate::engine::Assignments; @@ -48,6 +50,123 @@ impl AffineView { } } +impl CheckerVariable for AffineView { + fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool { + self.inner.does_atomic_constrain_self(atomic) + } + + fn atomic_less_than(&self, value: i32) -> Predicate { + use crate::predicate; + + predicate![self <= value] + } + + fn atomic_greater_than(&self, value: i32) -> Predicate { + use crate::predicate; + + predicate![self >= value] + } + + fn atomic_equal(&self, value: i32) -> Predicate { + use crate::predicate; + + predicate![self == value] + } + + fn atomic_not_equal(&self, value: i32) -> Predicate { + use crate::predicate; + + predicate![self != value] + } + + fn induced_lower_bound( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> IntExt { + if self.scale.is_positive() { + match self.inner.induced_lower_bound(variable_state) { + IntExt::Int(value) => IntExt::Int(self.map(value)), + bound => bound, + } + } else { + match self.inner.induced_upper_bound(variable_state) { + IntExt::Int(value) => IntExt::Int(self.map(value)), + IntExt::NegativeInf => IntExt::PositiveInf, + IntExt::PositiveInf => IntExt::NegativeInf, + } + } + } + + fn induced_upper_bound( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> IntExt { + if self.scale.is_positive() { + match self.inner.induced_upper_bound(variable_state) { + IntExt::Int(value) => IntExt::Int(self.map(value)), + bound => bound, + } + } else { + match self.inner.induced_lower_bound(variable_state) { + IntExt::Int(value) => IntExt::Int(self.map(value)), + IntExt::NegativeInf => IntExt::PositiveInf, + IntExt::PositiveInf => IntExt::NegativeInf, + } + } + } + + fn induced_fixed_value( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> Option { + self.inner + .induced_fixed_value(variable_state) + .map(|value| self.map(value)) + } + + fn induced_holes<'this, 'state>( + &'this self, + variable_state: &'state pumpkin_checking::VariableState, + ) -> impl Iterator + 'state + where + 'this: 'state, + { + if self.scale == 1 || self.scale == -1 { + return self + .inner + .induced_holes(variable_state) + .map(|value| self.map(value)); + } + + todo!("how to iterate holes of a scaled domain"); + } + + fn iter_induced_domain<'this, 'state>( + &'this self, + variable_state: &'state pumpkin_checking::VariableState, + ) -> Option + 'state> + where + 'this: 'state, + { + self.inner + .iter_induced_domain(variable_state) + .map(|iter| iter.map(|value| self.map(value))) + } + + fn induced_domain_contains( + &self, + variable_state: &pumpkin_checking::VariableState, + value: i32, + ) -> bool { + if (value - self.offset) % self.scale == 0 { + let inverted = self.invert(value, Rounding::Up); + self.inner.induced_domain_contains(variable_state, inverted) + } else { + false + } + } +} + impl IntegerVariable for AffineView where View: IntegerVariable, diff --git a/pumpkin-crates/core/src/engine/variables/domain_id.rs b/pumpkin-crates/core/src/engine/variables/domain_id.rs index 53014e6d2..b64eb55f1 100644 --- a/pumpkin-crates/core/src/engine/variables/domain_id.rs +++ b/pumpkin-crates/core/src/engine/variables/domain_id.rs @@ -1,4 +1,5 @@ use enumset::EnumSet; +use pumpkin_checking::CheckerVariable; use super::TransformableVariable; use crate::containers::StorageKey; @@ -8,6 +9,7 @@ use crate::engine::notifications::OpaqueDomainEvent; use crate::engine::notifications::Watchers; use crate::engine::variables::AffineView; use crate::engine::variables::IntegerVariable; +use crate::predicates::Predicate; use crate::pumpkin_assert_simple; /// A structure which represents the most basic [`IntegerVariable`]; it is simply the id which links @@ -28,6 +30,85 @@ impl DomainId { } } +impl CheckerVariable for DomainId { + fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool { + atomic.get_domain() == *self + } + + fn atomic_less_than(&self, value: i32) -> Predicate { + use crate::predicate; + + predicate![self <= value] + } + + fn atomic_greater_than(&self, value: i32) -> Predicate { + use crate::predicate; + + predicate![self >= value] + } + + fn atomic_equal(&self, value: i32) -> Predicate { + use crate::predicate; + + predicate![self == value] + } + + fn atomic_not_equal(&self, value: i32) -> Predicate { + use crate::predicate; + + predicate![self != value] + } + + fn induced_lower_bound( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> pumpkin_checking::IntExt { + variable_state.lower_bound(self) + } + + fn induced_upper_bound( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> pumpkin_checking::IntExt { + variable_state.upper_bound(self) + } + + fn induced_fixed_value( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> Option { + variable_state.fixed_value(self) + } + + fn induced_holes<'this, 'state>( + &'this self, + variable_state: &'state pumpkin_checking::VariableState, + ) -> impl Iterator + 'state + where + 'this: 'state, + { + variable_state.holes(self) + } + + fn iter_induced_domain<'this, 'state>( + &'this self, + variable_state: &'state pumpkin_checking::VariableState, + ) -> Option + 'state> + where + 'this: 'state, + { + variable_state.iter_domain(self) + } + + fn induced_domain_contains( + &self, + variable_state: &pumpkin_checking::VariableState, + value: i32, + ) -> bool { + variable_state.contains(self, value) + } +} + impl IntegerVariable for DomainId { type AffineView = AffineView; diff --git a/pumpkin-crates/core/src/engine/variables/integer_variable.rs b/pumpkin-crates/core/src/engine/variables/integer_variable.rs index 61f4791d7..280fca407 100644 --- a/pumpkin-crates/core/src/engine/variables/integer_variable.rs +++ b/pumpkin-crates/core/src/engine/variables/integer_variable.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; use enumset::EnumSet; +use pumpkin_checking::CheckerVariable; use super::TransformableVariable; use crate::engine::Assignments; @@ -8,11 +9,16 @@ use crate::engine::notifications::DomainEvent; use crate::engine::notifications::OpaqueDomainEvent; use crate::engine::notifications::Watchers; use crate::engine::predicates::predicate_constructor::PredicateConstructor; +use crate::predicates::Predicate; /// A trait specifying the required behaviour of an integer variable such as retrieving a /// lower-bound ([`IntegerVariable::lower_bound`]). pub trait IntegerVariable: - Clone + PredicateConstructor + TransformableVariable + Debug + Clone + + PredicateConstructor + + TransformableVariable + + Debug + + CheckerVariable { type AffineView: IntegerVariable; diff --git a/pumpkin-crates/core/src/engine/variables/literal.rs b/pumpkin-crates/core/src/engine/variables/literal.rs index 6ecee24b9..3dd8263f9 100644 --- a/pumpkin-crates/core/src/engine/variables/literal.rs +++ b/pumpkin-crates/core/src/engine/variables/literal.rs @@ -1,6 +1,9 @@ use std::ops::Not; use enumset::EnumSet; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::IntExt; +use pumpkin_checking::VariableState; use super::DomainId; use super::IntegerVariable; @@ -51,6 +54,56 @@ impl Not for Literal { } } +/// Forwards a function implementation to the field on self. +macro_rules! forward { + ( + $field:ident, + fn $(<$($lt:lifetime),+>)? $name:ident( + & $($lt_self:lifetime)? self, + $($param_name:ident : $param_type:ty),* + ) -> $return_type:ty + $(where $($where_clause:tt)*)? + ) => { + fn $name$(<$($lt),+>)?( + & $($lt_self)? self, + $($param_name: $param_type),* + ) -> $return_type $(where $($where_clause)*)? { + self.$field.$name($($param_name),*) + } + } +} + +impl CheckerVariable for Literal { + forward!(integer_variable, fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool); + forward!(integer_variable, fn atomic_less_than(&self, value: i32) -> Predicate); + forward!(integer_variable, fn atomic_greater_than(&self, value: i32) -> Predicate); + forward!(integer_variable, fn atomic_not_equal(&self, value: i32) -> Predicate); + forward!(integer_variable, fn atomic_equal(&self, value: i32) -> Predicate); + + forward!(integer_variable, fn induced_lower_bound(&self, variable_state: &VariableState) -> IntExt); + forward!(integer_variable, fn induced_upper_bound(&self, variable_state: &VariableState) -> IntExt); + forward!(integer_variable, fn induced_fixed_value(&self, variable_state: &VariableState) -> Option); + forward!(integer_variable, fn induced_domain_contains(&self, variable_state: &VariableState, value: i32) -> bool); + forward!( + integer_variable, + fn <'this, 'state> induced_holes( + &'this self, + variable_state: &'state VariableState + ) -> impl Iterator + 'state + where + 'this: 'state, + ); + forward!( + integer_variable, + fn <'this, 'state> iter_induced_domain( + &'this self, + variable_state: &'state VariableState + ) -> Option + 'state> + where + 'this: 'state, + ); +} + impl IntegerVariable for Literal { type AffineView = AffineView; diff --git a/pumpkin-crates/core/src/lib.rs b/pumpkin-crates/core/src/lib.rs index 244d6755a..36138d519 100644 --- a/pumpkin-crates/core/src/lib.rs +++ b/pumpkin-crates/core/src/lib.rs @@ -16,7 +16,7 @@ pub mod constraints; pub mod optimisation; pub mod proof; pub mod propagation; -pub(crate) mod propagators; +pub mod propagators; pub mod statistics; pub use convert_case; diff --git a/pumpkin-crates/core/src/proof/mod.rs b/pumpkin-crates/core/src/proof/mod.rs index 284ac9617..6fe243cf8 100644 --- a/pumpkin-crates/core/src/proof/mod.rs +++ b/pumpkin-crates/core/src/proof/mod.rs @@ -25,7 +25,6 @@ use proof_atomics::ProofAtomics; use crate::Solver; use crate::containers::HashMap; use crate::containers::KeyGenerator; -use crate::containers::StorageKey; use crate::engine::variable_names::VariableNames; use crate::predicates::Predicate; use crate::variables::Literal; @@ -84,6 +83,8 @@ impl ProofLog { propagated: Option, variable_names: &VariableNames, ) -> std::io::Result { + let inference_tag = constraint_tags.next_key(); + let Some(ProofImpl::CpProof { writer, propagation_order_hint: Some(propagation_sequence), @@ -91,11 +92,9 @@ impl ProofLog { .. }) = self.internal_proof.as_mut() else { - return Ok(ConstraintTag::create_from_index(0)); + return Ok(inference_tag); }; - let inference_tag = constraint_tags.next_key(); - let inference = Inference { constraint_id: inference_tag.into(), premises: premises @@ -123,6 +122,8 @@ impl ProofLog { variable_names: &VariableNames, constraint_tags: &mut KeyGenerator, ) -> std::io::Result { + let inference_tag = constraint_tags.next_key(); + let Some(ProofImpl::CpProof { writer, propagation_order_hint: Some(propagation_sequence), @@ -131,7 +132,7 @@ impl ProofLog { .. }) = self.internal_proof.as_mut() else { - return Ok(ConstraintTag::create_from_index(0)); + return Ok(inference_tag); }; if let Some(hint_idx) = logged_domain_inferences.get(&predicate).copied() { @@ -145,8 +146,6 @@ impl ProofLog { return Ok(tag); } - let inference_tag = constraint_tags.next_key(); - let inference = Inference { constraint_id: inference_tag.into(), premises: vec![], @@ -176,6 +175,8 @@ impl ProofLog { variable_names: &VariableNames, constraint_tags: &mut KeyGenerator, ) -> std::io::Result { + let constraint_tag = constraint_tags.next_key(); + match &mut self.internal_proof { Some(ProofImpl::CpProof { writer, @@ -187,8 +188,6 @@ impl ProofLog { // Reset the logged domain inferences. logged_domain_inferences.clear(); - let constraint_tag = constraint_tags.next_key(); - let deduction = Deduction { constraint_id: constraint_tag.into(), premises: premises @@ -219,10 +218,10 @@ impl ProofLog { Some(ProofImpl::DimacsProof(writer)) => { let clause = premises.into_iter().map(|predicate| !predicate); writer.learned_clause(clause, variable_names)?; - Ok(ConstraintTag::create_from_index(0)) + Ok(constraint_tag) } - None => Ok(ConstraintTag::create_from_index(0)), + None => Ok(constraint_tag), } } diff --git a/pumpkin-crates/core/src/propagation/constructor.rs b/pumpkin-crates/core/src/propagation/constructor.rs index b6e6652c5..5a13e6a84 100644 --- a/pumpkin-crates/core/src/propagation/constructor.rs +++ b/pumpkin-crates/core/src/propagation/constructor.rs @@ -1,6 +1,8 @@ use std::ops::Deref; use std::ops::DerefMut; +use pumpkin_checking::InferenceChecker; + use super::Domains; use super::LocalId; use super::Propagator; @@ -18,10 +20,13 @@ use crate::engine::variables::AffineView; #[cfg(doc)] use crate::engine::variables::DomainId; use crate::predicates::Predicate; +use crate::proof::InferenceCode; #[cfg(doc)] use crate::propagation::DomainEvent; use crate::propagation::DomainEvents; +use crate::propagators::reified_propagator::ReifiedChecker; use crate::variables::IntegerVariable; +use crate::variables::Literal; /// A propagator constructor creates a fully initialized instance of a [`Propagator`]. /// @@ -33,10 +38,59 @@ pub trait PropagatorConstructor { /// The propagator that is produced by this constructor. type PropagatorImpl: Propagator + Clone; + /// Add inference checkers to the solver if applicable. + /// + /// If the `check-propagations` feature is turned on, then the inference checker will be used + /// to verify the propagations done by this propagator are correct. + /// + /// See [`InferenceChecker`] for more information. + fn add_inference_checkers(&self, _checkers: InferenceCheckers<'_>) {} + /// Create the propagator instance from `Self`. fn create(self, context: PropagatorConstructorContext) -> Self::PropagatorImpl; } +/// Interface used to add [`InferenceChecker`]s to the [`State`]. +#[derive(Debug)] +pub struct InferenceCheckers<'state> { + state: &'state mut State, + reification_literal: Option, +} + +impl<'state> InferenceCheckers<'state> { + #[cfg(feature = "check-propagations")] + pub(crate) fn new(state: &'state mut State) -> Self { + InferenceCheckers { + state, + reification_literal: None, + } + } +} + +impl InferenceCheckers<'_> { + /// Forwards to [`State::add_inference_checker`]. + pub fn add_inference_checker( + &mut self, + inference_code: InferenceCode, + checker: Box>, + ) { + if let Some(reification_literal) = self.reification_literal { + let reification_checker = ReifiedChecker { + inner: checker.into(), + reification_literal, + }; + self.state + .add_inference_checker(inference_code, Box::new(reification_checker)); + } else { + self.state.add_inference_checker(inference_code, checker); + } + } + + pub fn with_reification_literal(&mut self, literal: Literal) { + self.reification_literal = Some(literal) + } +} + /// [`PropagatorConstructorContext`] is used when [`Propagator`]s are initialised after creation. /// /// It represents a communication point between the [`Solver`] and the [`Propagator`]. @@ -177,6 +231,18 @@ impl PropagatorConstructorContext<'_> { } } + /// Add an inference checker for inferences produced by the propagator. + /// + /// If the `check-propagations` feature is not enabled, adding an [`InferenceChecker`] will not + /// do anything. + pub fn add_inference_checker( + &mut self, + inference_code: InferenceCode, + checker: Box>, + ) { + self.state.add_inference_checker(inference_code, checker); + } + /// Set the next local id to be at least one more than the largest encountered local id. fn update_next_local_id(&mut self, local_id: LocalId) { let next_local_id = (*self.next_local_id.deref()).max(LocalId::from(local_id.unpack() + 1)); diff --git a/pumpkin-crates/core/src/propagators/mod.rs b/pumpkin-crates/core/src/propagators/mod.rs index 6f7842ca0..2d88c10e4 100644 --- a/pumpkin-crates/core/src/propagators/mod.rs +++ b/pumpkin-crates/core/src/propagators/mod.rs @@ -1,2 +1,4 @@ -pub(crate) mod nogoods; +pub mod nogoods; pub(crate) mod reified_propagator; + +pub use reified_propagator::*; diff --git a/pumpkin-crates/core/src/propagators/nogoods/checker.rs b/pumpkin-crates/core/src/propagators/nogoods/checker.rs new file mode 100644 index 000000000..700ee6a2c --- /dev/null +++ b/pumpkin-crates/core/src/propagators/nogoods/checker.rs @@ -0,0 +1,23 @@ +use std::fmt::Debug; + +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::InferenceChecker; + +#[derive(Debug, Clone)] +pub struct NogoodChecker { + pub nogood: Box<[Atomic]>, +} + +impl InferenceChecker for NogoodChecker +where + Atomic: AtomicConstraint + Clone + Debug, +{ + fn check( + &self, + state: pumpkin_checking::VariableState, + _: &[Atomic], + _: Option<&Atomic>, + ) -> bool { + self.nogood.iter().all(|atomic| state.is_true(atomic)) + } +} diff --git a/pumpkin-crates/core/src/propagators/nogoods/mod.rs b/pumpkin-crates/core/src/propagators/nogoods/mod.rs index 417df106c..e04fd791a 100644 --- a/pumpkin-crates/core/src/propagators/nogoods/mod.rs +++ b/pumpkin-crates/core/src/propagators/nogoods/mod.rs @@ -1,9 +1,11 @@ mod arena_allocator; +mod checker; mod learning_options; mod nogood_id; mod nogood_info; mod nogood_propagator; +pub use checker::*; pub use learning_options::*; pub(crate) use nogood_id::*; pub(crate) use nogood_info::*; diff --git a/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs b/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs index 9bec2334e..6fe77f0a1 100644 --- a/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs +++ b/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs @@ -621,6 +621,13 @@ impl NogoodPropagator { // // The preprocessing ensures that all predicates are unassigned. else { + #[cfg(feature = "check-propagations")] + let nogood = input_nogood + .iter() + .map(|predicate| context.get_id(*predicate)) + .collect::>(); + + #[cfg(not(feature = "check-propagations"))] let nogood = nogood .iter() .map(|predicate| context.get_id(*predicate)) diff --git a/pumpkin-crates/core/src/propagators/reified_propagator.rs b/pumpkin-crates/core/src/propagators/reified_propagator.rs index 8047addb9..09e163c25 100644 --- a/pumpkin-crates/core/src/propagators/reified_propagator.rs +++ b/pumpkin-crates/core/src/propagators/reified_propagator.rs @@ -1,3 +1,8 @@ +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::BoxedChecker; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; + use crate::basic_types::PropagationStatusCP; use crate::engine::notifications::OpaqueDomainEvent; use crate::predicates::Predicate; @@ -5,6 +10,7 @@ use crate::propagation::DomainEvents; use crate::propagation::Domains; use crate::propagation::EnqueueDecision; use crate::propagation::ExplanationContext; +use crate::propagation::InferenceCheckers; use crate::propagation::LocalId; use crate::propagation::NotificationContext; use crate::propagation::Priority; @@ -56,6 +62,12 @@ where reason_buffer: vec![], } } + + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.with_reification_literal(self.reification_literal); + + self.propagator.add_inference_checkers(checkers); + } } /// Propagator for the constraint `r -> p`, where `r` is a Boolean literal and `p` is an arbitrary @@ -221,6 +233,37 @@ impl ReifiedPropagator { } } +#[derive(Debug, Clone)] +pub struct ReifiedChecker { + pub inner: BoxedChecker, + pub reification_literal: Var, +} + +impl> InferenceChecker + for ReifiedChecker +{ + fn check( + &self, + state: pumpkin_checking::VariableState, + premises: &[Atomic], + consequent: Option<&Atomic>, + ) -> bool { + if self.reification_literal.induced_domain_contains(&state, 0) { + return false; + } + + if let Some(consequent) = consequent + && self + .reification_literal + .does_atomic_constrain_self(consequent) + { + self.inner.check(state, premises, None) + } else { + self.inner.check(state, premises, consequent) + } + } +} + #[allow(deprecated, reason = "Will be refactored")] #[cfg(test)] mod tests { @@ -248,6 +291,7 @@ mod tests { let t2 = triggered_conflict.clone(); let inference_code = InferenceCode::unknown_label(ConstraintTag::create_from_index(0)); + solver.accept_inferences_by(inference_code.clone()); let i1 = inference_code.clone(); let i2 = inference_code.clone(); @@ -324,6 +368,7 @@ mod tests { let var = solver.new_variable(1, 1); let inference_code = InferenceCode::unknown_label(ConstraintTag::create_from_index(0)); + solver.accept_inferences_by(inference_code.clone()); let inconsistency = solver .new_propagator(ReifiedPropagatorArgs { @@ -364,6 +409,7 @@ mod tests { let var = solver.new_variable(1, 5); let inference_code = InferenceCode::unknown_label(ConstraintTag::create_from_index(0)); + solver.accept_inferences_by(inference_code.clone()); let propagator = solver .new_propagator(ReifiedPropagatorArgs { diff --git a/pumpkin-crates/core/src/statistics/statistic_logging.rs b/pumpkin-crates/core/src/statistics/statistic_logging.rs index 55eeb98e5..8bcae72aa 100644 --- a/pumpkin-crates/core/src/statistics/statistic_logging.rs +++ b/pumpkin-crates/core/src/statistics/statistic_logging.rs @@ -21,7 +21,7 @@ pub struct StatisticOptions<'a> { // A closing line which is printed after all of the statistics have been printed after_statistics: Option<&'a str>, // The casing of the name of the statistic - statistics_casing: Option, + statistics_casing: Option>, // The writer to which the statistics are written statistics_writer: Box, } @@ -48,7 +48,7 @@ static STATISTIC_OPTIONS: OnceLock> = OnceLock::new(); pub fn configure_statistic_logging( prefix: &'static str, after: Option<&'static str>, - casing: Option, + casing: Option>, writer: Option>, ) { let _ = STATISTIC_OPTIONS.get_or_init(|| { diff --git a/pumpkin-crates/propagators/Cargo.toml b/pumpkin-crates/propagators/Cargo.toml index 4c7e40f4a..11679e519 100644 --- a/pumpkin-crates/propagators/Cargo.toml +++ b/pumpkin-crates/propagators/Cargo.toml @@ -12,9 +12,10 @@ workspace = true [dependencies] pumpkin-core = { version = "0.2.2", path = "../core" } +pumpkin-checking = { version = "0.2.2", path = "../checking" } enumset = "1.1.2" bitfield-struct = "0.9.2" -convert_case = "0.6.0" +convert_case = "0.8.0" clap = { version = "4.5.40", optional = true, features=["derive"]} [dev-dependencies] diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/absolute_value.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/absolute_value.rs index eac1df128..95b6e334a 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/absolute_value.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/absolute_value.rs @@ -1,9 +1,14 @@ +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::IntExt; use pumpkin_core::conjunction; use pumpkin_core::declare_inference_label; use pumpkin_core::predicate; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::Priority; use pumpkin_core::propagation::PropagationContext; @@ -30,6 +35,16 @@ where { type PropagatorImpl = AbsoluteValuePropagator; + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, AbsoluteValue), + Box::new(AbsoluteValueChecker { + signed: self.signed.clone(), + absolute: self.absolute.clone(), + }), + ); + } + fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let AbsoluteValueArgs { signed, @@ -146,6 +161,50 @@ where } } +#[derive(Clone, Debug)] +pub struct AbsoluteValueChecker { + signed: VA, + absolute: VB, +} + +impl InferenceChecker for AbsoluteValueChecker +where + VA: CheckerVariable, + VB: CheckerVariable, + Atomic: AtomicConstraint, +{ + fn check( + &self, + state: pumpkin_checking::VariableState, + _: &[Atomic], + _: Option<&Atomic>, + ) -> bool { + let signed_lower = self.signed.induced_lower_bound(&state); + let signed_upper = self.signed.induced_upper_bound(&state); + let absolute_lower = self.absolute.induced_lower_bound(&state); + let absolute_upper = self.absolute.induced_upper_bound(&state); + + if absolute_lower < 0 { + // The absolute value cannot have negative values. + return true; + } + + // Now we compute the interval for |signed| based on the domain of signed. + let (computed_signed_lower, computed_signed_upper) = if signed_lower >= 0 { + (signed_lower, signed_upper) + } else if signed_upper <= 0 { + (-signed_upper, -signed_lower) + } else if signed_lower < 0 && 0_i32 < signed_upper { + (IntExt::Int(0), std::cmp::max(-signed_lower, signed_upper)) + } else { + unreachable!() + }; + + // The intervals should not match, otherwise there is no conflict. + computed_signed_lower != absolute_lower || computed_signed_upper != absolute_upper + } +} + #[allow(deprecated, reason = "Will be refactored")] #[cfg(test)] mod tests { diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_equals.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_equals.rs index 56ee42373..964388ae9 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_equals.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_equals.rs @@ -2,6 +2,10 @@ use std::slice; use bitfield_struct::bitfield; +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::IntExt; use pumpkin_core::asserts::pumpkin_assert_advanced; use pumpkin_core::conjunction; use pumpkin_core::containers::HashSet; @@ -17,6 +21,7 @@ use pumpkin_core::propagation::DomainEvents; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; use pumpkin_core::propagation::ExplanationContext; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -48,6 +53,16 @@ where { type PropagatorImpl = BinaryEqualsPropagator; + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, BinaryEquals), + Box::new(BinaryEqualsChecker { + lhs: self.a.clone(), + rhs: self.b.clone(), + }), + ); + } + fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let BinaryEqualsPropagatorArgs { a, @@ -383,6 +398,47 @@ struct BinaryEqualsPropagation { __: u16, } +#[derive(Clone, Debug)] +pub struct BinaryEqualsChecker { + pub lhs: Lhs, + pub rhs: Rhs, +} + +impl InferenceChecker for BinaryEqualsChecker +where + Atomic: AtomicConstraint, + Lhs: CheckerVariable, + Rhs: CheckerVariable, +{ + fn check( + &self, + mut state: pumpkin_checking::VariableState, + _: &[Atomic], + _: Option<&Atomic>, + ) -> bool { + // We apply the domain of variable 2 to variable 1. If the state remains consistent, then + // the step is unsound! + let mut consistent = true; + + if let IntExt::Int(value) = self.rhs.induced_upper_bound(&state) { + let atomic = self.lhs.atomic_less_than(value); + consistent &= state.apply(&atomic); + } + + if let IntExt::Int(value) = self.rhs.induced_lower_bound(&state) { + let atomic = self.lhs.atomic_greater_than(value); + consistent &= state.apply(&atomic); + } + + for value in self.rhs.induced_holes(&state).collect::>() { + let atomic = self.lhs.atomic_not_equal(value); + consistent &= state.apply(&atomic); + } + + !consistent + } +} + #[allow(deprecated, reason = "Will be refactored")] #[cfg(test)] mod tests { diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_not_equals.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_not_equals.rs index e99bccac9..efe6d8768 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_not_equals.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_not_equals.rs @@ -1,3 +1,6 @@ +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; use pumpkin_core::conjunction; use pumpkin_core::declare_inference_label; use pumpkin_core::predicate; @@ -5,6 +8,7 @@ use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; use pumpkin_core::propagation::Domains; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::Priority; use pumpkin_core::propagation::PropagationContext; @@ -33,6 +37,16 @@ where { type PropagatorImpl = BinaryNotEqualsPropagator; + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, BinaryNotEquals), + Box::new(BinaryNotEqualsChecker { + lhs: self.a.clone(), + rhs: self.b.clone(), + }), + ); + } + fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let BinaryNotEqualsPropagatorArgs { a, @@ -168,6 +182,30 @@ where } } +#[derive(Clone, Debug)] +pub struct BinaryNotEqualsChecker { + pub lhs: Lhs, + pub rhs: Rhs, +} + +impl InferenceChecker for BinaryNotEqualsChecker +where + Atomic: AtomicConstraint, + Lhs: CheckerVariable, + Rhs: CheckerVariable, +{ + fn check( + &self, + state: pumpkin_checking::VariableState, + _: &[Atomic], + _: Option<&Atomic>, + ) -> bool { + // There is a conflict if both variables are fixed to the same values. + + self.lhs.induced_fixed_value(&state) == self.rhs.induced_fixed_value(&state) + } +} + #[allow(deprecated, reason = "Will be refactored")] #[cfg(test)] mod tests { diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/integer_division.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/integer_division.rs index 4155a532a..27c4f121b 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/integer_division.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/integer_division.rs @@ -1,3 +1,6 @@ +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; use pumpkin_core::asserts::pumpkin_assert_simple; use pumpkin_core::conjunction; use pumpkin_core::declare_inference_label; @@ -5,6 +8,7 @@ use pumpkin_core::predicate; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::Priority; use pumpkin_core::propagation::PropagationContext; @@ -64,6 +68,17 @@ where inference_code, } } + + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, Division), + Box::new(IntegerDivisionChecker { + numerator: self.numerator.clone(), + denominator: self.denominator.clone(), + rhs: self.rhs.clone(), + }), + ); + } } /// A propagator for maintaining the constraint `numerator / denominator = rhs`; note that this @@ -377,6 +392,70 @@ fn propagate_signs { + pub numerator: VA, + pub denominator: VB, + pub rhs: VC, +} + +impl InferenceChecker for IntegerDivisionChecker +where + Atomic: AtomicConstraint, + VA: CheckerVariable, + VB: CheckerVariable, + VC: CheckerVariable, +{ + fn check( + &self, + state: pumpkin_checking::VariableState, + _premises: &[Atomic], + _consequent: Option<&Atomic>, + ) -> bool { + // We apply interval arithmetic to determine that the computed interval `a div b` + // does not intersect with the domain of `c`. + // + // See https://en.wikipedia.org/wiki/Interval_arithmetic#Interval_operators. + + let x1 = self.numerator.induced_lower_bound(&state); + let x2 = self.numerator.induced_upper_bound(&state); + let y1 = self.denominator.induced_lower_bound(&state); + let y2 = self.denominator.induced_upper_bound(&state); + + assert!( + y2 < 0 || y1 > 0, + "Currentl, the checker does not contain inferences where the denominator spans 0" + ); + + let floor_x1y1 = x1.floor_div(&y1); + let floor_x1y2 = x1.floor_div(&y2); + let floor_x2y1 = x2.floor_div(&y1); + let floor_x2y2 = x2.floor_div(&y2); + + let ceil_x1y1 = x1.ceil_div(&y1); + let ceil_x1y2 = x1.ceil_div(&y2); + let ceil_x2y1 = x2.ceil_div(&y1); + let ceil_x2y2 = x2.ceil_div(&y2); + + // TODO: Can we just ignore these options? + let computed_c_lower = [ceil_x1y1, ceil_x1y2, ceil_x2y1, ceil_x2y2] + .into_iter() + .flatten() + .min() + .expect("Expected at least one element to be defined"); + let computed_c_upper = [floor_x1y1, floor_x1y2, floor_x2y1, floor_x2y2] + .into_iter() + .flatten() + .max() + .expect("Expected at least one element to be defined"); + + let c_lower = self.rhs.induced_lower_bound(&state); + let c_upper = self.rhs.induced_upper_bound(&state); + + computed_c_upper < c_lower || computed_c_lower > c_upper + } +} + #[allow(deprecated, reason = "Will be refactored")] #[cfg(test)] mod tests { diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/integer_multiplication.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/integer_multiplication.rs index d6cb8a456..45410a37e 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/integer_multiplication.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/integer_multiplication.rs @@ -1,3 +1,6 @@ +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; use pumpkin_core::asserts::pumpkin_assert_simple; use pumpkin_core::conjunction; use pumpkin_core::declare_inference_label; @@ -5,6 +8,7 @@ use pumpkin_core::predicate; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::Priority; use pumpkin_core::propagation::PropagationContext; @@ -35,6 +39,17 @@ where { type PropagatorImpl = IntegerMultiplicationPropagator; + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, IntegerMultiplication), + Box::new(IntegerMultiplicationChecker { + a: self.a.clone(), + b: self.b.clone(), + c: self.c.clone(), + }), + ); + } + fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let IntegerMultiplicationArgs { a, @@ -357,6 +372,51 @@ fn div_ceil_pos(numerator: i32, denominator: i32) -> i32 { numerator / denominator + (numerator % denominator).signum() } +#[derive(Clone, Debug)] +pub struct IntegerMultiplicationChecker { + pub a: VA, + pub b: VB, + pub c: VC, +} + +impl InferenceChecker for IntegerMultiplicationChecker +where + Atomic: AtomicConstraint, + VA: CheckerVariable, + VB: CheckerVariable, + VC: CheckerVariable, +{ + fn check( + &self, + state: pumpkin_checking::VariableState, + _: &[Atomic], + _: Option<&Atomic>, + ) -> bool { + // We apply interval arithmetic to determine that the computed interval `a times b` + // does not intersect with the domain of `c`. + // + // See https://en.wikipedia.org/wiki/Interval_arithmetic#Interval_operators. + + let x1 = self.a.induced_lower_bound(&state); + let x2 = self.a.induced_upper_bound(&state); + let y1 = self.b.induced_lower_bound(&state); + let y2 = self.b.induced_upper_bound(&state); + + let c_lower = self.c.induced_lower_bound(&state); + let c_upper = self.c.induced_upper_bound(&state); + + let x1y1 = x1 * y1; + let x1y2 = x1 * y2; + let x2y1 = x2 * y1; + let x2y2 = x2 * y2; + + let computed_c_lower = x1y1.min(x1y2).min(x2y1).min(x2y2); + let computed_c_upper = x1y1.max(x1y2).max(x2y1).max(x2y2); + + computed_c_upper < c_lower || computed_c_lower > c_upper + } +} + #[allow(deprecated, reason = "Will be refactored")] #[cfg(test)] mod tests { diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs index b2537cc02..e1e616574 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs @@ -1,3 +1,8 @@ +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::IntExt; +use pumpkin_checking::VariableState; use pumpkin_core::asserts::pumpkin_assert_simple; use pumpkin_core::declare_inference_label; use pumpkin_core::predicate; @@ -9,6 +14,7 @@ use pumpkin_core::propagation::DomainEvents; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; use pumpkin_core::propagation::ExplanationContext; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -39,6 +45,16 @@ where { type PropagatorImpl = LinearLessOrEqualPropagator; + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, LinearBounds), + Box::new(LinearLessOrEqualInferenceChecker::new( + self.x.clone(), + self.c, + )), + ); + } + fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let LinearLessOrEqualPropagatorArgs { x, @@ -271,6 +287,43 @@ where } } +#[derive(Debug, Clone)] +pub struct LinearLessOrEqualInferenceChecker { + terms: Box<[Var]>, + bound: i32, +} + +impl LinearLessOrEqualInferenceChecker { + pub fn new(terms: Box<[Var]>, bound: i32) -> Self { + LinearLessOrEqualInferenceChecker { terms, bound } + } +} + +impl InferenceChecker for LinearLessOrEqualInferenceChecker +where + Var: CheckerVariable, + Atomic: AtomicConstraint, +{ + fn check( + &self, + variable_state: VariableState, + _: &[Atomic], + _: Option<&Atomic>, + ) -> bool { + // Next, we evaluate the linear inequality. The lower bound of the + // left-hand side must exceed the bound in the constraint. Note that the accumulator is an + // IntExt, and if the lower bound of one of the terms is -infty, then the left-hand side + // will be -infty regardless of the other terms. + let left_hand_side: IntExt = self + .terms + .iter() + .map(|variable| variable.induced_lower_bound(&variable_state).into()) + .sum(); + + left_hand_side > i64::from(self.bound) + } +} + #[allow(deprecated, reason = "Will be refactored")] #[cfg(test)] mod tests { diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_not_equal.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_not_equal.rs index 3b01e76d2..31f8a96e3 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_not_equal.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_not_equal.rs @@ -1,6 +1,11 @@ use std::rc::Rc; use enumset::enum_set; +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::IntExt; +use pumpkin_checking::VariableState; use pumpkin_core::asserts::pumpkin_assert_extreme; use pumpkin_core::asserts::pumpkin_assert_moderate; use pumpkin_core::asserts::pumpkin_assert_simple; @@ -13,6 +18,7 @@ use pumpkin_core::propagation::DomainEvent; use pumpkin_core::propagation::DomainEvents; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -44,6 +50,16 @@ where { type PropagatorImpl = LinearNotEqualPropagator; + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, LinearNotEquals), + Box::new(LinearNotEqualChecker { + terms: self.terms.as_ref().into(), + bound: self.rhs, + }), + ); + } + fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let LinearNotEqualPropagatorArgs { terms, @@ -358,6 +374,34 @@ impl LinearNotEqualPropagator { } } +#[derive(Debug, Clone)] +pub struct LinearNotEqualChecker { + pub terms: Box<[Var]>, + pub bound: i32, +} + +impl InferenceChecker for LinearNotEqualChecker +where + Var: CheckerVariable, + Atomic: AtomicConstraint, +{ + fn check(&self, state: VariableState, _: &[Atomic], _: Option<&Atomic>) -> bool { + // We evaluate the linear sum. It should be fixed to the bound for a conflict to + // exist. + let mut left_hand_side = IntExt::Int(0); + + for term in self.terms.iter() { + let Some(value) = term.induced_fixed_value(&state) else { + return false; + }; + + left_hand_side += i64::from(value); + } + + left_hand_side == i64::from(self.bound) + } +} + #[allow(deprecated, reason = "Will be refactored")] #[cfg(test)] mod tests { diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/maximum.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/maximum.rs index 22c8b86a1..d1f927697 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/maximum.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/maximum.rs @@ -1,3 +1,7 @@ +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::IntExt; use pumpkin_core::conjunction; use pumpkin_core::declare_inference_label; use pumpkin_core::predicate; @@ -5,6 +9,7 @@ use pumpkin_core::predicates::PropositionalConjunction; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::Priority; use pumpkin_core::propagation::PropagationContext; @@ -31,6 +36,16 @@ where { type PropagatorImpl = MaximumPropagator; + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, Maximum), + Box::new(MaximumChecker { + array: self.array.clone(), + rhs: self.rhs.clone(), + }), + ); + } + fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let MaximumArgs { array, @@ -170,6 +185,45 @@ impl Prop } } +#[derive(Clone, Debug)] +pub struct MaximumChecker { + pub array: Box<[ElementVar]>, + pub rhs: Rhs, +} + +impl InferenceChecker for MaximumChecker +where + Atomic: AtomicConstraint, + ElementVar: CheckerVariable, + Rhs: CheckerVariable, +{ + fn check( + &self, + state: pumpkin_checking::VariableState, + _: &[Atomic], + _: Option<&Atomic>, + ) -> bool { + let lowest_maximum = self + .array + .iter() + .map(|element| element.induced_lower_bound(&state)) + .max() + .unwrap_or(IntExt::NegativeInf); + let highest_maximum = self + .array + .iter() + .map(|element| element.induced_upper_bound(&state)) + .max() + .unwrap_or(IntExt::PositiveInf); + + // If the intersection between the domain of `rhs` and `[lowest_maximum, + // highest_maximum]` is empty, there is a conflict. + + lowest_maximum > self.rhs.induced_upper_bound(&state) + || highest_maximum < self.rhs.induced_lower_bound(&state) + } +} + #[allow(deprecated, reason = "Will be refactored")] #[cfg(test)] mod tests { diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/checker.rs b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/checker.rs new file mode 100644 index 000000000..83790b842 --- /dev/null +++ b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/checker.rs @@ -0,0 +1,576 @@ +use std::collections::BTreeMap; + +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::IntExt; +use pumpkin_checking::VariableState; + +use crate::cumulative::time_table::time_table_util::has_overlap_with_interval; + +#[derive(Clone, Debug)] +pub struct TimeTableChecker { + pub tasks: Box<[CheckerTask]>, + pub capacity: i32, +} + +#[derive(Clone, Debug)] +pub struct CheckerTask { + pub start_time: Var, + pub resource_usage: i32, + pub processing_time: i32, +} + +fn lower_bound_can_be_propagated_by_profile< + Var: CheckerVariable, + Atomic: AtomicConstraint, +>( + context: &VariableState, + lower_bound: i32, + task: &CheckerTask, + start: i32, + end: i32, + height: i32, + capacity: i32, +) -> bool { + let upper_bound = task + .start_time + .induced_upper_bound(context) + .try_into() + .unwrap(); + + height + task.resource_usage > capacity + && !(upper_bound < (lower_bound + task.processing_time) + && has_overlap_with_interval( + upper_bound, + lower_bound + task.processing_time, + start, + end, + )) + && has_overlap_with_interval(lower_bound, upper_bound + task.processing_time, start, end) + && (lower_bound + task.processing_time) > start + && lower_bound <= end +} + +fn upper_bound_can_be_propagated_by_profile< + Var: CheckerVariable, + Atomic: AtomicConstraint, +>( + context: &VariableState, + upper_bound: i32, + task: &CheckerTask, + start: i32, + end: i32, + height: i32, + capacity: i32, +) -> bool { + let lower_bound = task + .start_time + .induced_lower_bound(context) + .try_into() + .unwrap(); + + height + task.resource_usage > capacity + && !(upper_bound < (lower_bound + task.processing_time) + && has_overlap_with_interval( + upper_bound, + lower_bound + task.processing_time, + start, + end, + )) + && has_overlap_with_interval(lower_bound, upper_bound + task.processing_time, start, end) + && (upper_bound + task.processing_time) > end + && upper_bound <= end +} + +impl InferenceChecker for TimeTableChecker +where + Var: CheckerVariable, + Atomic: AtomicConstraint, +{ + fn check( + &self, + state: VariableState, + _: &[Atomic], + consequent: Option<&Atomic>, + ) -> bool { + // The profile is a key-value store. The keys correspond to time-points, and the values to + // the relative change in resource consumption. A BTreeMap is used to maintain a + // sorted order of the time points. + let mut profile = BTreeMap::new(); + + for task in self.tasks.iter() { + if task.start_time.induced_lower_bound(&state) == IntExt::NegativeInf + || task.start_time.induced_upper_bound(&state) == IntExt::PositiveInf + { + continue; + } + + let lst: i32 = task + .start_time + .induced_upper_bound(&state) + .try_into() + .unwrap(); + let est: i32 = task + .start_time + .induced_lower_bound(&state) + .try_into() + .unwrap(); + + if lst < est + task.processing_time { + *profile.entry(lst).or_insert(0) += task.resource_usage; + *profile.entry(est + task.processing_time).or_insert(0) -= task.resource_usage; + } + } + + let mut profiles = Vec::new(); + let mut current_usage = 0; + let mut previous_time_point = *profile + .first_key_value() + .expect("Expected at least one mandatory part") + .0; + for (time_point, usage) in profile.iter() { + if current_usage > 0 && *time_point != previous_time_point { + profiles.push((previous_time_point, *time_point - 1, current_usage)) + } + + current_usage += *usage; + + if current_usage > self.capacity { + return true; + } + + previous_time_point = *time_point; + } + + if let Some(propagating_task) = consequent.map(|consequent| { + self.tasks + .iter() + .find(|task| task.start_time.does_atomic_constrain_self(consequent)) + .expect("If there is a consequent, then there should be a propagating task") + }) { + let mut lower_bound: i32 = propagating_task + .start_time + .induced_lower_bound(&state) + .try_into() + .unwrap(); + for (start, end_inclusive, height) in profiles.iter() { + if lower_bound_can_be_propagated_by_profile( + &state, + lower_bound, + propagating_task, + *start, + *end_inclusive, + *height, + self.capacity, + ) { + lower_bound = end_inclusive + 1; + } + } + if lower_bound > propagating_task.start_time.induced_upper_bound(&state) { + return true; + } + + let mut upper_bound: i32 = propagating_task + .start_time + .induced_upper_bound(&state) + .try_into() + .unwrap(); + for (start, end_inclusive, height) in profiles.iter().rev() { + if upper_bound_can_be_propagated_by_profile( + &state, + upper_bound, + propagating_task, + *start, + *end_inclusive, + *height, + self.capacity, + ) { + upper_bound = start - propagating_task.processing_time; + } + } + if upper_bound < propagating_task.start_time.induced_lower_bound(&state) { + return true; + } + } + false + } +} + +#[cfg(test)] +mod tests { + use pumpkin_checking::TestAtomic; + use pumpkin_checking::VariableState; + + use super::*; + + #[test] + fn conflict() { + let premises = [ + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::Equal, + value: 1, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::Equal, + value: 1, + }, + ]; + + let state = VariableState::prepare_for_conflict_check(premises, None) + .expect("no conflicting atomics"); + + let checker = TimeTableChecker { + tasks: vec![ + CheckerTask { + start_time: "x1", + resource_usage: 1, + processing_time: 1, + }, + CheckerTask { + start_time: "x2", + resource_usage: 1, + processing_time: 1, + }, + ] + .into(), + capacity: 1, + }; + + assert!(checker.check(state, &premises, None)); + } + + #[test] + fn hole_in_domain() { + let premises = [TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::Equal, + value: 6, + }]; + + let consequent = Some(TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::NotEqual, + value: 2, + }); + let state = VariableState::prepare_for_conflict_check(premises, consequent) + .expect("no conflicting atomics"); + + let checker = TimeTableChecker { + tasks: vec![ + CheckerTask { + start_time: "x1", + resource_usage: 3, + processing_time: 2, + }, + CheckerTask { + start_time: "x2", + resource_usage: 2, + processing_time: 5, + }, + ] + .into(), + capacity: 4, + }; + + assert!(checker.check(state, &premises, consequent.as_ref())); + } + + #[test] + fn lower_bound_chain() { + let premises = [ + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::Equal, + value: 1, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::Equal, + value: 6, + }, + TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 0, + }, + ]; + + let consequent = Some(TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 16, + }); + let state = VariableState::prepare_for_conflict_check(premises, consequent) + .expect("no conflicting atomics"); + + let checker = TimeTableChecker { + tasks: vec![ + CheckerTask { + start_time: "x1", + resource_usage: 3, + processing_time: 2, + }, + CheckerTask { + start_time: "x2", + resource_usage: 3, + processing_time: 10, + }, + CheckerTask { + start_time: "x3", + resource_usage: 2, + processing_time: 5, + }, + ] + .into(), + capacity: 4, + }; + + assert!(checker.check(state, &premises, consequent.as_ref())); + } + + #[test] + fn upper_bound_chain() { + let premises = [ + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::Equal, + value: 1, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::Equal, + value: 6, + }, + TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 15, + }, + ]; + + let consequent = Some(TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::LessEqual, + value: -4, + }); + let state = VariableState::prepare_for_conflict_check(premises, consequent) + .expect("no conflicting atomics"); + + let checker = TimeTableChecker { + tasks: vec![ + CheckerTask { + start_time: "x1", + resource_usage: 3, + processing_time: 2, + }, + CheckerTask { + start_time: "x2", + resource_usage: 3, + processing_time: 10, + }, + CheckerTask { + start_time: "x3", + resource_usage: 2, + processing_time: 5, + }, + ] + .into(), + capacity: 4, + }; + + assert!(checker.check(state, &premises, consequent.as_ref())); + } + + #[test] + fn hole_in_domain_not_accepted() { + let premises = [TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::Equal, + value: 6, + }]; + + let consequent = Some(TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::NotEqual, + value: 1, + }); + let state = VariableState::prepare_for_conflict_check(premises, consequent) + .expect("no conflicting atomics"); + + let checker = TimeTableChecker { + tasks: vec![ + CheckerTask { + start_time: "x1", + resource_usage: 3, + processing_time: 2, + }, + CheckerTask { + start_time: "x2", + resource_usage: 2, + processing_time: 5, + }, + ] + .into(), + capacity: 4, + }; + + assert!(!checker.check(state, &premises, consequent.as_ref())); + } + + #[test] + fn lower_bound_chain_not_accepted() { + let premises = [ + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::Equal, + value: 1, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::Equal, + value: 8, + }, + TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 0, + }, + ]; + + let consequent = Some(TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 16, + }); + let state = VariableState::prepare_for_conflict_check(premises, consequent) + .expect("no conflicting atomics"); + + let checker = TimeTableChecker { + tasks: vec![ + CheckerTask { + start_time: "x1", + resource_usage: 3, + processing_time: 2, + }, + CheckerTask { + start_time: "x2", + resource_usage: 3, + processing_time: 10, + }, + CheckerTask { + start_time: "x3", + resource_usage: 2, + processing_time: 5, + }, + ] + .into(), + capacity: 4, + }; + + assert!(!checker.check(state, &premises, consequent.as_ref())); + } + + #[test] + fn upper_bound_chain_not_accepted() { + let premises = [ + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::Equal, + value: 1, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::Equal, + value: 8, + }, + TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 15, + }, + ]; + + let consequent = Some(TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::LessEqual, + value: -4, + }); + let state = VariableState::prepare_for_conflict_check(premises, consequent) + .expect("no conflicting atomics"); + + let checker = TimeTableChecker { + tasks: vec![ + CheckerTask { + start_time: "x1", + resource_usage: 3, + processing_time: 2, + }, + CheckerTask { + start_time: "x2", + resource_usage: 3, + processing_time: 10, + }, + CheckerTask { + start_time: "x3", + resource_usage: 2, + processing_time: 5, + }, + ] + .into(), + capacity: 4, + }; + + assert!(!checker.check(state, &premises, consequent.as_ref())); + } + + #[test] + fn simple_test() { + let premises = [ + TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 5, + }, + TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 6, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 7, + }, + ]; + + let consequent = Some(TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 4, + }); + let state = VariableState::prepare_for_conflict_check(premises, consequent) + .expect("no conflicting atomics"); + + let checker = TimeTableChecker { + tasks: vec![ + CheckerTask { + start_time: "x3", + resource_usage: 1, + processing_time: 3, + }, + CheckerTask { + start_time: "x2", + resource_usage: 2, + processing_time: 2, + }, + ] + .into(), + capacity: 2, + }; + + assert!(checker.check(state, &premises, consequent.as_ref())); + } +} diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/mod.rs b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/mod.rs index 30a4d2c77..b542aa54e 100644 --- a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/mod.rs +++ b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/mod.rs @@ -56,6 +56,7 @@ //! Conference, CP 2015, Cork, Ireland, August 31--September 4, 2015, Proceedings 21, 2015, pp. //! 149–157. +mod checker; mod explanations; mod over_interval_incremental_propagator; mod per_point_incremental_propagator; @@ -63,6 +64,8 @@ mod propagation_handler; mod time_table_over_interval; mod time_table_per_point; mod time_table_util; + +pub use checker::*; pub use explanations::CumulativeExplanationType; pub use over_interval_incremental_propagator::*; pub use per_point_incremental_propagator::*; diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs index fc6ff7f9c..3661767a7 100644 --- a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs +++ b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs @@ -11,6 +11,7 @@ use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvent; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -29,6 +30,8 @@ use super::removal; use crate::cumulative::options::CumulativePropagatorOptions; use crate::cumulative::time_table::create_time_table_over_interval_from_scratch; use crate::cumulative::time_table::propagate_from_scratch_time_table_interval; +use crate::cumulative::time_table::CheckerTask; +use crate::cumulative::time_table::TimeTableChecker; use crate::cumulative::util::check_bounds_equal_at_propagation; use crate::cumulative::util::create_tasks; use crate::cumulative::util::register_tasks; @@ -108,6 +111,25 @@ impl PropagatorConstruc { type PropagatorImpl = Self; + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, TimeTable), + Box::new(TimeTableChecker { + tasks: self + .parameters + .tasks + .iter() + .map(|task| CheckerTask { + start_time: task.start_variable.clone(), + processing_time: task.processing_time, + resource_usage: task.resource_usage, + }) + .collect(), + capacity: self.parameters.capacity, + }), + ); + } + fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { // We only register for notifications of backtrack events if incremental backtracking is // enabled diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs index db3a00614..cd3db0d92 100644 --- a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs +++ b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs @@ -11,6 +11,7 @@ use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvent; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -31,7 +32,9 @@ use crate::cumulative::ResourceProfile; use crate::cumulative::Task; use crate::cumulative::UpdatableStructures; use crate::cumulative::options::CumulativePropagatorOptions; +use crate::cumulative::time_table::CheckerTask; use crate::cumulative::time_table::PerPointTimeTableType; +use crate::cumulative::time_table::TimeTableChecker; #[cfg(doc)] use crate::cumulative::time_table::TimeTablePerPointPropagator; use crate::cumulative::time_table::create_time_table_per_point_from_scratch; @@ -105,6 +108,25 @@ impl Propagator { type PropagatorImpl = Self; + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, TimeTable), + Box::new(TimeTableChecker { + tasks: self + .parameters + .tasks + .iter() + .map(|task| CheckerTask { + start_time: task.start_variable.clone(), + processing_time: task.processing_time, + resource_usage: task.resource_usage, + }) + .collect(), + capacity: self.parameters.capacity, + }), + ); + } + fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { register_tasks(&self.parameters.tasks, context.reborrow(), true); self.updatable_structures diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_over_interval.rs b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_over_interval.rs index 6bdf00a08..663b27216 100644 --- a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_over_interval.rs +++ b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_over_interval.rs @@ -9,6 +9,7 @@ use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvent; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -32,6 +33,8 @@ use crate::cumulative::ResourceProfile; use crate::cumulative::Task; use crate::cumulative::UpdatableStructures; use crate::cumulative::options::CumulativePropagatorOptions; +use crate::cumulative::time_table::CheckerTask; +use crate::cumulative::time_table::TimeTableChecker; #[cfg(doc)] use crate::cumulative::time_table::TimeTablePerPointPropagator; use crate::cumulative::util::create_tasks; @@ -107,6 +110,25 @@ impl PropagatorConstructor { type PropagatorImpl = Self; + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, TimeTable), + Box::new(TimeTableChecker { + tasks: self + .parameters + .tasks + .iter() + .map(|task| CheckerTask { + start_time: task.start_variable.clone(), + processing_time: task.processing_time, + resource_usage: task.resource_usage, + }) + .collect(), + capacity: self.parameters.capacity, + }), + ); + } + fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { self.updatable_structures .initialise_bounds_and_remove_fixed(context.domains(), &self.parameters); diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_per_point.rs b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_per_point.rs index 1edfda6be..24b896191 100644 --- a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_per_point.rs +++ b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_per_point.rs @@ -12,6 +12,7 @@ use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvent; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -34,6 +35,8 @@ use crate::cumulative::CumulativeParameters; use crate::cumulative::ResourceProfile; use crate::cumulative::UpdatableStructures; use crate::cumulative::options::CumulativePropagatorOptions; +use crate::cumulative::time_table::CheckerTask; +use crate::cumulative::time_table::TimeTableChecker; use crate::cumulative::util::create_tasks; use crate::cumulative::util::register_tasks; use crate::cumulative::util::update_bounds_task; @@ -98,6 +101,25 @@ impl TimeTablePerPointPropagator { impl PropagatorConstructor for TimeTablePerPointPropagator { type PropagatorImpl = Self; + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, TimeTable), + Box::new(TimeTableChecker { + tasks: self + .parameters + .tasks + .iter() + .map(|task| CheckerTask { + start_time: task.start_variable.clone(), + processing_time: task.processing_time, + resource_usage: task.resource_usage, + }) + .collect(), + capacity: self.parameters.capacity, + }), + ); + } + fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { self.updatable_structures .initialise_bounds_and_remove_fixed(context.domains(), &self.parameters); diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_util.rs b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_util.rs index 7da67427b..dbf6abbd4 100644 --- a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_util.rs +++ b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_util.rs @@ -421,6 +421,7 @@ fn propagate_sequence_of_profiles<'a, Var: IntegerVariable + 'static>( profile.start < context.upper_bound(&task.start_variable) + task.processing_time }); for profile in &time_table[lower_bound_index..upper_bound_index] { + propagation_handler.next_profile(); // Check whether this profile can cause an update if can_be_updated_by_profile(context.domains(), task, profile, parameters.capacity) { diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/utils/structs/resource_profile.rs b/pumpkin-crates/propagators/src/propagators/cumulative/utils/structs/resource_profile.rs index 9997c8377..975316590 100644 --- a/pumpkin-crates/propagators/src/propagators/cumulative/utils/structs/resource_profile.rs +++ b/pumpkin-crates/propagators/src/propagators/cumulative/utils/structs/resource_profile.rs @@ -27,6 +27,7 @@ impl Debug for ResourceProfile { .field("start", &self.start) .field("end", &self.end) .field("height", &self.height) + .field("profile_tasks", &self.profile_tasks) .finish() } } diff --git a/pumpkin-crates/propagators/src/propagators/disjunctive/checker.rs b/pumpkin-crates/propagators/src/propagators/disjunctive/checker.rs new file mode 100644 index 000000000..f9d18f51f --- /dev/null +++ b/pumpkin-crates/propagators/src/propagators/disjunctive/checker.rs @@ -0,0 +1,496 @@ +use std::cmp::max; +use std::cmp::min; +use std::marker::PhantomData; + +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::IntExt; +use pumpkin_checking::VariableState; +use pumpkin_core::containers::KeyedVec; +use pumpkin_core::containers::StorageKey; +use pumpkin_core::propagation::LocalId; + +use crate::disjunctive::ArgDisjunctiveTask; +use crate::disjunctive::disjunctive_task::DisjunctiveTask; +use crate::disjunctive::theta_lambda_tree::Node; + +#[derive(Clone, Debug)] +pub struct DisjunctiveEdgeFindingChecker { + pub tasks: Box<[ArgDisjunctiveTask]>, +} + +impl InferenceChecker for DisjunctiveEdgeFindingChecker +where + Var: CheckerVariable, + Atomic: AtomicConstraint, +{ + fn check( + &self, + state: VariableState, + _premises: &[Atomic], + consequent: Option<&Atomic>, + ) -> bool { + // Recall the following: + // - For conflict detection, the explanation represents a set omega with the following + // property: `p_omega > lct_omega - est_omega`. + // + // We simply need to check whether the interval [est_omega, lct_omega] is overloaded + // - For propagation, the explanation represents a set omega (and omega') such that the + // following holds: `min(est_i, est_omega) + p_omega + p_i > lct_omega -> [s_i >= + // ect_omega]`. + let mut lb_interval = i32::MAX; + let mut ub_interval = i32::MIN; + let mut p = 0; + let mut propagating_task = None; + let mut theta = Vec::new(); + + // We go over all of the tasks + for task in self.tasks.iter() { + // Only if they are present in the explanation, do we actually process them + // - For tasks in omega, both bounds should be present to define the interval + // - For the propagating task, the lower-bound should be present, and the negation of + // the consequent ensures that an upper-bound is present + if task.start_time.induced_lower_bound(&state) != IntExt::NegativeInf + && task.start_time.induced_upper_bound(&state) != IntExt::PositiveInf + { + // Now we calculate the durations of tasks + let est_task: i32 = task + .start_time + .induced_lower_bound(&state) + .try_into() + .unwrap(); + let lst_task = + >::try_into(task.start_time.induced_upper_bound(&state)) + .unwrap(); + + let is_propagating_task = if let Some(consequent) = consequent { + task.start_time.does_atomic_constrain_self(consequent) + } else { + false + }; + if !is_propagating_task { + theta.push(task.clone()); + p += task.processing_time; + lb_interval = lb_interval.min(est_task); + ub_interval = ub_interval.max(lst_task + task.processing_time); + } else { + propagating_task = Some(task.clone()); + } + } + } + + if consequent.is_some() { + let propagating_task = propagating_task + .expect("If there is a consequent then there should be a propagating task"); + + let est_task = propagating_task + .start_time + .induced_lower_bound(&state) + .try_into() + .unwrap(); + + let mut theta_lambda_tree = CheckerThetaLambdaTree::new( + &theta + .iter() + .enumerate() + .map(|(index, task)| DisjunctiveTask { + start_time: task.start_time.clone(), + processing_time: task.processing_time, + id: LocalId::from(index as u32), + }) + .collect::>(), + ); + theta_lambda_tree.update(&state); + for (index, task) in theta.iter().enumerate() { + theta_lambda_tree.add_to_theta( + &DisjunctiveTask { + start_time: task.start_time.clone(), + processing_time: task.processing_time, + id: LocalId::from(index as u32), + }, + &state, + ); + } + + min(est_task, lb_interval) + p + propagating_task.processing_time > ub_interval + && theta_lambda_tree.ect() > propagating_task.start_time.induced_upper_bound(&state) + } else { + // We simply check whether the interval is overloaded + p > (ub_interval - lb_interval) + } + } +} + +#[cfg(test)] +mod tests { + use pumpkin_checking::TestAtomic; + use pumpkin_checking::VariableState; + + use super::*; + + #[test] + fn test_simple_propagation() { + let premises = [ + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 0, + }, + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 7, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 5, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 6, + }, + TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 0, + }, + ]; + + let consequent = Some(TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 8, + }); + let state = VariableState::prepare_for_conflict_check(premises, consequent) + .expect("no conflicting atomics"); + + let checker = DisjunctiveEdgeFindingChecker { + tasks: vec![ + ArgDisjunctiveTask { + start_time: "x1", + processing_time: 2, + }, + ArgDisjunctiveTask { + start_time: "x2", + processing_time: 3, + }, + ArgDisjunctiveTask { + start_time: "x3", + processing_time: 5, + }, + ] + .into(), + }; + + assert!(checker.check(state, &premises, consequent.as_ref())); + } + + #[test] + fn test_conflict() { + let premises = [ + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 0, + }, + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 1, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 0, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 1, + }, + ]; + + let state = VariableState::prepare_for_conflict_check(premises, None) + .expect("no conflicting atomics"); + + let checker = DisjunctiveEdgeFindingChecker { + tasks: vec![ + ArgDisjunctiveTask { + start_time: "x1", + processing_time: 2, + }, + ArgDisjunctiveTask { + start_time: "x2", + processing_time: 3, + }, + ] + .into(), + }; + + assert!(checker.check(state, &premises, None)); + } + + #[test] + fn test_simple_propagation_not_accepted() { + let premises = [ + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 0, + }, + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 7, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 5, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 6, + }, + TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 0, + }, + ]; + + let consequent = Some(TestAtomic { + name: "x3", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 9, + }); + let state = VariableState::prepare_for_conflict_check(premises, consequent) + .expect("no conflicting atomics"); + + let checker = DisjunctiveEdgeFindingChecker { + tasks: vec![ + ArgDisjunctiveTask { + start_time: "x1", + processing_time: 2, + }, + ArgDisjunctiveTask { + start_time: "x2", + processing_time: 3, + }, + ArgDisjunctiveTask { + start_time: "x3", + processing_time: 5, + }, + ] + .into(), + }; + + assert!(!checker.check(state, &premises, consequent.as_ref())); + } + + #[test] + fn test_conflict_not_accepted() { + let premises = [ + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 0, + }, + TestAtomic { + name: "x1", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 1, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::GreaterEqual, + value: 0, + }, + TestAtomic { + name: "x2", + comparison: pumpkin_checking::Comparison::LessEqual, + value: 2, + }, + ]; + + let state = VariableState::prepare_for_conflict_check(premises, None) + .expect("no conflicting atomics"); + + let checker = DisjunctiveEdgeFindingChecker { + tasks: vec![ + ArgDisjunctiveTask { + start_time: "x1", + processing_time: 2, + }, + ArgDisjunctiveTask { + start_time: "x2", + processing_time: 3, + }, + ] + .into(), + }; + + assert!(!checker.check(state, &premises, None)); + } +} + +#[derive(Debug, Clone)] +pub(super) struct CheckerThetaLambdaTree { + pub(super) nodes: Vec, + /// Then we keep track of a mapping from the [`LocalId`] to its position in the tree since the + /// methods take as input tasks with [`LocalId`]s. + mapping: KeyedVec, + /// The number of internal nodes in the tree; used to calculate the leaf node index based on + /// the index in the tree + number_of_internal_nodes: usize, + /// The tasks which are stored in the leaves of the tree. + /// + /// These tasks are sorted based on non-decreasing start time. + sorted_tasks: Vec>, + phantom_data: PhantomData, +} + +impl> CheckerThetaLambdaTree { + /// Initialises the theta-lambda tree. + /// + /// Note that [`Self::update`] should be called to actually create the tree itself. + pub(super) fn new(tasks: &[DisjunctiveTask]) -> Self { + // Calculate the number of internal nodes which are required to create the binary tree + let mut number_of_internal_nodes = 1; + while number_of_internal_nodes < tasks.len() { + number_of_internal_nodes <<= 1; + } + + CheckerThetaLambdaTree { + nodes: Default::default(), + mapping: KeyedVec::default(), + number_of_internal_nodes: number_of_internal_nodes - 1, + sorted_tasks: tasks.to_vec(), + phantom_data: PhantomData, + } + } + + /// Update the theta-lambda tree based on the provided `context`. + /// + /// It resets theta and lambda to be the empty set. + pub(super) fn update(&mut self, context: &VariableState) { + // First we sort the tasks by lower-bound/earliest start time. + self.sorted_tasks + .sort_by_key(|task| task.start_time.induced_lower_bound(context)); + + // Then we keep track of a mapping from the [`LocalId`] to its position in the tree and a + // reverse mapping + self.mapping.clear(); + for (index, task) in self.sorted_tasks.iter().enumerate() { + while self.mapping.len() <= task.id.index() { + let _ = self.mapping.push(usize::MAX); + } + self.mapping[task.id] = index; + } + + // Finally, we reset the entire tree to be empty + self.nodes.clear(); + for _ in 0..=2 * self.number_of_internal_nodes { + self.nodes.push(Node::empty()) + } + } + + /// Returns the earliest completion time of Theta + pub(super) fn ect(&self) -> i32 { + assert!(!self.nodes.is_empty()); + self.nodes[0].ect + } + + /// Add the provided task to Theta + pub(super) fn add_to_theta( + &mut self, + task: &DisjunctiveTask, + context: &VariableState, + ) { + // We need to find the leaf node index; note that there are |nodes| / 2 leaves + let position = self.nodes.len() / 2 + self.mapping[task.id]; + let ect = task.start_time.induced_lower_bound(context) + task.processing_time; + + self.nodes[position] = Node::new_white_node( + ect.try_into().expect("Should have bounds"), + task.processing_time, + ); + self.upheap(position) + } + + /// Returns the index of the left child of the provided index + fn get_left_child_index(index: usize) -> usize { + 2 * index + 1 + } + + /// Returns the index of the right child of the provided index + fn get_right_child_index(index: usize) -> usize { + 2 * index + 2 + } + + /// Returns the index of the parent of the provided index + fn get_parent(index: usize) -> usize { + assert!(index > 0); + (index - 1) / 2 + } + + /// Calculate the new values for the ancestors of the provided index + pub(super) fn upheap(&mut self, mut index: usize) { + while index != 0 { + let parent = Self::get_parent(index); + let left_child_of_parent = Self::get_left_child_index(parent); + let right_child_of_parent = Self::get_right_child_index(parent); + assert!(left_child_of_parent == index || right_child_of_parent == index); + + // The sum of processing times is the sum of processing times in the left child + the + // sum of processing times in right child + self.nodes[parent].sum_of_processing_times = self.nodes[left_child_of_parent] + .sum_of_processing_times + + self.nodes[right_child_of_parent].sum_of_processing_times; + + // The ECT is either the ECT of the left child node + the processing times of the right + // child or it is the ECT of the right child (we do not know whether the processing + // times of the left child influence the processing times of the right child) + let ect_left = self.nodes[left_child_of_parent].ect + + self.nodes[right_child_of_parent].sum_of_processing_times; + self.nodes[parent].ect = max(self.nodes[right_child_of_parent].ect, ect_left); + + // The sum of processing times (including one element of lambda) is either: + // 1) The sum of processing times of the right child + the sum of processing times of + // the left child including one element of lambda + // 2) The sum of processing times of the left child + the sum of processing times of the + // right child include one element of lambda + let sum_of_processing_times_left_child_lambda = self.nodes[left_child_of_parent] + .sum_of_processing_times_bar + + self.nodes[right_child_of_parent].sum_of_processing_times; + let sum_of_processing_times_right_child_lambda = self.nodes[left_child_of_parent] + .sum_of_processing_times + + self.nodes[right_child_of_parent].sum_of_processing_times_bar; + self.nodes[parent].sum_of_processing_times_bar = max( + sum_of_processing_times_left_child_lambda, + sum_of_processing_times_right_child_lambda, + ); + + // The earliest completion time (including one element of lambda) is either: + // 1) The earliest completion time including one element of lambda from the right child + // 2) The earliest completion time of the right child + the sum of processing times + // including one element of lambda of the right child + // 2) The earliest completion time of the left child + the sum of processing times + // including one element of lambda of the left child + let ect_right_child_lambda = self.nodes[left_child_of_parent].ect + + self.nodes[right_child_of_parent].sum_of_processing_times_bar; + let ect_left_child_lambda = self.nodes[left_child_of_parent].ect_bar + + self.nodes[right_child_of_parent].sum_of_processing_times; + self.nodes[parent].ect_bar = max( + self.nodes[right_child_of_parent].ect_bar, + max(ect_right_child_lambda, ect_left_child_lambda), + ); + + index = parent; + } + } +} diff --git a/pumpkin-crates/propagators/src/propagators/disjunctive/disjunctive_propagator.rs b/pumpkin-crates/propagators/src/propagators/disjunctive/disjunctive_propagator.rs index b34cf84fa..8799b938c 100644 --- a/pumpkin-crates/propagators/src/propagators/disjunctive/disjunctive_propagator.rs +++ b/pumpkin-crates/propagators/src/propagators/disjunctive/disjunctive_propagator.rs @@ -8,6 +8,7 @@ use pumpkin_core::predicates::PropositionalConjunction; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; +use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::PropagationContext; use pumpkin_core::propagation::Propagator; @@ -22,6 +23,7 @@ use pumpkin_core::variables::IntegerVariable; use super::disjunctive_task::ArgDisjunctiveTask; use super::disjunctive_task::DisjunctiveTask; use super::theta_lambda_tree::ThetaLambdaTree; +use crate::disjunctive::checker::DisjunctiveEdgeFindingChecker; use crate::propagators::disjunctive::DisjunctiveEdgeFinding; /// [`Propagator`] responsible for using disjunctive reasoning to propagate the [Disjunctive](https://sofdem.github.io/gccat/gccat/Cdisjunctive.html) constraint. @@ -105,6 +107,22 @@ impl PropagatorConstructor for DisjunctiveConstr inference_code, } } + + fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { + checkers.add_inference_checker( + InferenceCode::new(self.constraint_tag, DisjunctiveEdgeFinding), + Box::new(DisjunctiveEdgeFindingChecker { + tasks: self + .tasks + .iter() + .map(|task| ArgDisjunctiveTask { + start_time: task.start_time.clone(), + processing_time: task.processing_time, + }) + .collect(), + }), + ); + } } impl Propagator for DisjunctivePropagator { diff --git a/pumpkin-crates/propagators/src/propagators/disjunctive/mod.rs b/pumpkin-crates/propagators/src/propagators/disjunctive/mod.rs index 2b2de5801..df24ffae4 100644 --- a/pumpkin-crates/propagators/src/propagators/disjunctive/mod.rs +++ b/pumpkin-crates/propagators/src/propagators/disjunctive/mod.rs @@ -10,5 +10,7 @@ mod theta_tree; pub use disjunctive_propagator::DisjunctiveConstructor; pub use disjunctive_propagator::DisjunctivePropagator; pub use disjunctive_task::ArgDisjunctiveTask; +pub(crate) mod checker; +pub use checker::*; declare_inference_label!(DisjunctiveEdgeFinding); diff --git a/pumpkin-crates/propagators/src/propagators/disjunctive/theta_lambda_tree.rs b/pumpkin-crates/propagators/src/propagators/disjunctive/theta_lambda_tree.rs index b631e7688..e8c8e2834 100644 --- a/pumpkin-crates/propagators/src/propagators/disjunctive/theta_lambda_tree.rs +++ b/pumpkin-crates/propagators/src/propagators/disjunctive/theta_lambda_tree.rs @@ -19,20 +19,20 @@ use super::disjunctive_task::DisjunctiveTask; #[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct Node { /// The earliest completion time of the set of tasks represented by this node. - ect: i32, + pub(super) ect: i32, /// The sum of the processing times of the set of tasks represented by this node. - sum_of_processing_times: i32, + pub(super) sum_of_processing_times: i32, /// The earliest completion time of the set of tasks represented by this node if a single grey /// task can be added to the set of tasks. - ect_bar: i32, + pub(super) ect_bar: i32, /// The sum of processing times of the set of tasks represented by this node if a single grey /// task can be added to the set of tasks. - sum_of_processing_times_bar: i32, + pub(super) sum_of_processing_times_bar: i32, } impl Node { // Constructs an empty node - fn empty() -> Self { + pub(super) fn empty() -> Self { Self { ect: i32::MIN, sum_of_processing_times: 0, @@ -42,7 +42,7 @@ impl Node { } // Construct a new white node with the provided value - fn new_white_node(ect: i32, sum_of_processing_times: i32) -> Self { + pub(super) fn new_white_node(ect: i32, sum_of_processing_times: i32) -> Self { Self { ect, sum_of_processing_times, @@ -52,7 +52,7 @@ impl Node { } // Construct a new gray node with the provided value - fn new_gray_node(ect: i32, sum_of_processing_times: i32) -> Self { + pub(super) fn new_gray_node(ect: i32, sum_of_processing_times: i32) -> Self { Self { ect: i32::MIN, sum_of_processing_times: 0, diff --git a/pumpkin-solver/Cargo.toml b/pumpkin-solver/Cargo.toml index e31497273..498bd1ad2 100644 --- a/pumpkin-solver/Cargo.toml +++ b/pumpkin-solver/Cargo.toml @@ -33,6 +33,7 @@ workspace = true [features] debug-checks = ["pumpkin-core/debug-checks"] +check-propagations = ["pumpkin-core/check-propagations"] [build-dependencies] cc = "1.1.30"