diff --git a/Cargo.lock b/Cargo.lock index d37948010..9b4c97cb1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -75,9 +75,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.98" +version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" [[package]] name = "ar_archive_writer" @@ -166,9 +166,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.40" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" +checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" dependencies = [ "clap_builder", "clap_derive", @@ -176,9 +176,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.40" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" +checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" dependencies = [ "anstream", "anstyle", @@ -188,9 +188,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.40" +version = "4.5.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" +checksum = "2a0b5487afeab2deb2ff4e03a807ad1a03ac532ff5a2cee5d86884440c7f7671" dependencies = [ "heck", "proc-macro2", @@ -717,6 +717,18 @@ dependencies = [ "cc", ] +[[package]] +name = "pumpkin-checker" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "drcp-format", + "flate2", + "fzn-rs", + "thiserror", +] + [[package]] name = "pumpkin-core" version = "0.2.2" @@ -1070,18 +1082,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.12" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.12" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 6e279b2c0..1c1b84625 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["./pumpkin-solver", "./drcp-format", "./pumpkin-solver-py", "./pumpkin-macros", "./drcp-debugger", "./pumpkin-crates/*", "./fzn-rs", "./fzn-rs-derive"] +members = ["./pumpkin-solver", "./pumpkin-checker", "./drcp-format", "./pumpkin-solver-py", "./pumpkin-macros", "./drcp-debugger", "./pumpkin-crates/*", "./fzn-rs", "./fzn-rs-derive"] resolver = "2" [workspace.package] diff --git a/pumpkin-checker/Cargo.toml b/pumpkin-checker/Cargo.toml new file mode 100644 index 000000000..21f817aa5 --- /dev/null +++ b/pumpkin-checker/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "pumpkin-checker" +version = "0.1.0" +repository.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true + +[dependencies] +anyhow = "1.0.99" +clap = { version = "4.5.47", features = ["derive"] } +drcp-format = { version = "0.3.0", path = "../drcp-format" } +flate2 = "1.1.2" +fzn-rs = { version = "0.1.0", path = "../fzn-rs" } +thiserror = "2.0.16" + +[lints] +workspace = true diff --git a/pumpkin-checker/src/deductions.rs b/pumpkin-checker/src/deductions.rs new file mode 100644 index 000000000..d49217e3c --- /dev/null +++ b/pumpkin-checker/src/deductions.rs @@ -0,0 +1,115 @@ +use std::collections::BTreeMap; +use std::rc::Rc; + +use drcp_format::ConstraintId; +use drcp_format::IntAtomic; + +use crate::inferences::Fact; +use crate::model::Nogood; +use crate::state::VariableState; + +/// An inference that was ignored when checking a deduction. +#[derive(Clone, Debug)] +pub struct IgnoredInference { + /// The ID of the ignored inference. + pub constraint_id: ConstraintId, + + /// The premises that were not satisfied when the inference was evaluated. + pub unsatisfied_premises: Vec>, +} + +/// A deduction is rejected by the checker. +#[derive(thiserror::Error, Debug)] +#[error("invalid deduction")] +pub enum InvalidDeduction { + /// The constraint ID of the deduction is already used by an existing constraint. + #[error("constraint id {0} already in use")] + DuplicateConstraintId(ConstraintId), + + /// An inference in the deduction sequence does not exist in the proof stage. + #[error("inference {0} does not exist")] + UnknownInference(ConstraintId), + + /// The inferences in the proof stage do not derive an empty domain or an explicit + /// conflict. + #[error("no conflict was derived after applying all inferences")] + NoConflict(Vec), + + /// The premise contains mutually exclusive atomic constraints. + #[error("the deduction contains inconsistent premises")] + InconsistentPremises, +} + +/// Verify that a deduction is valid given the inferences in the proof stage. +pub fn verify_deduction( + deduction: &drcp_format::Deduction, i32>, + facts_in_proof_stage: &BTreeMap, +) -> Result { + // To verify a deduction, we assume that the premises are true. Then we go over all the + // facts in the sequence, and if all the premises are satisfied, we apply the consequent. + // At some point, this should either reach a fact without a consequent or derive an + // inconsistent domain. + + let mut variable_state = + VariableState::prepare_for_conflict_check(&Fact::nogood(deduction.premises.clone())) + .ok_or(InvalidDeduction::InconsistentPremises)?; + + let mut unused_inferences = Vec::new(); + + for constraint_id in deduction.sequence.iter() { + // Get the fact associated with the constraint ID from the sequence. + let fact = facts_in_proof_stage + .get(constraint_id) + .ok_or(InvalidDeduction::UnknownInference(*constraint_id))?; + + // Collect all premises that do not evaluate to `true` under the current variable + // state. + let unsatisfied_premises: Vec> = fact + .premises + .iter() + .filter_map::, _>(|premise| { + if variable_state.is_true(premise) { + None + } else { + // We need to convert the premise name from a `Rc` to a + // `String`. The former does not implement `Send`, but that is + // required for our error type to be used with anyhow. + Some(IntAtomic { + name: String::from(premise.name.as_ref()), + comparison: premise.comparison, + value: premise.value, + }) + } + }) + .collect::>(); + + // If at least one premise is unassigned, this fact is ignored for the conflict + // check and recorded as unused. + if !unsatisfied_premises.is_empty() { + unused_inferences.push(IgnoredInference { + constraint_id: *constraint_id, + unsatisfied_premises, + }); + + continue; + } + + // At this point the premises are satisfied so we handle the consequent of the + // inference. + match &fact.consequent { + Some(consequent) => { + if !variable_state.apply(consequent) { + // If applying the consequent yields an empty domain for a + // variable, then the deduction is valid. + return Ok(Nogood::from(deduction.premises.clone())); + } + } + // If the consequent is explicitly false, then the deduction is valid. + None => return Ok(Nogood::from(deduction.premises.clone())), + } + } + + // Reaching this point means that the conjunction of inferences did not yield to a + // conflict. Therefore the deduction is invalid. + Err(InvalidDeduction::NoConflict(unused_inferences)) +} diff --git a/pumpkin-checker/src/inferences/all_different.rs b/pumpkin-checker/src/inferences/all_different.rs new file mode 100644 index 000000000..d1dd3681c --- /dev/null +++ b/pumpkin-checker/src/inferences/all_different.rs @@ -0,0 +1,51 @@ +use std::collections::HashSet; + +use super::Fact; +use crate::inferences::InvalidInference; +use crate::model::Constraint; +use crate::state::VariableState; + +/// Verify an `all_different` inference. +/// +/// The checker tests that the premises and the negation of the consequent form a hall-set. If that +/// is the case, the inference is accepted. Otherwise, the inference is rejected. +/// +/// The checker will reject inferences with redundant atomic constraints. +pub(crate) fn verify_all_different( + fact: &Fact, + constraint: &Constraint, +) -> Result<(), InvalidInference> { + // This checker takes the union of the domains of the variables in the constraint. If there + // are fewer values in the union of the domain than there are variables, then there is a + // conflict and the inference is valid. + + let Constraint::AllDifferent(all_different) = constraint else { + return Err(InvalidInference::ConstraintLabelMismatch); + }; + + let variable_state = VariableState::prepare_for_conflict_check(fact) + .ok_or(InvalidInference::InconsistentPremises)?; + + // Collect all values present in at least one of the domains. + let union_of_domains = all_different + .variables + .iter() + .filter_map(|variable| variable_state.iter_domain(variable)) + .flatten() + .collect::>(); + + // Collect the variables mentioned in the fact. Here we ignore variables with a domain + // equal to all integers, as they are not mentioned in the fact. Therefore they do not + // contribute in the hall-set reasoning. + let num_variables = all_different + .variables + .iter() + .filter(|variable| variable_state.iter_domain(variable).is_some()) + .count(); + + if union_of_domains.len() < num_variables { + Ok(()) + } else { + Err(InvalidInference::Unsound) + } +} diff --git a/pumpkin-checker/src/inferences/arithmetic.rs b/pumpkin-checker/src/inferences/arithmetic.rs new file mode 100644 index 000000000..fb3ff9a51 --- /dev/null +++ b/pumpkin-checker/src/inferences/arithmetic.rs @@ -0,0 +1,127 @@ +use std::collections::BTreeSet; +use std::rc::Rc; + +use drcp_format::IntComparison; +use fzn_rs::VariableExpr; + +use super::Fact; +use crate::inferences::InvalidInference; +use crate::model::AllDifferent; +use crate::model::Atomic; +use crate::model::Constraint; +use crate::state::I32Ext; +use crate::state::VariableState; + +/// Verify a `binary_equals` inference. +/// +/// The checker accepts inferences for binary equality constraints. The difference with the general +/// `linear_bounds` inference is that in the binary case, we can certify holes in the domain as +/// well. +pub(crate) fn verify_binary_equals( + fact: &Fact, + constraint: &Constraint, +) -> Result<(), InvalidInference> { + // To check this inference we expect the intersection of both domains to be empty. + + let Constraint::LinearEq(linear) = constraint else { + return Err(InvalidInference::ConstraintLabelMismatch); + }; + + // For now, this inference only works for constraints over two variables. + if linear.terms.len() != 2 { + return Err(InvalidInference::Unsound); + } + + let (weight_a, variable_a) = &linear.terms[0]; + let (weight_b, variable_b) = &linear.terms[1]; + + // TODO: Generalize this rule to work with non-unit weights. + // At the moment we expect one term to have weight `-1` and the other term to have weight + // `1`. + if weight_a + weight_b != 0 || weight_a.abs() != 1 || weight_b.abs() != 1 { + return Err(InvalidInference::Unsound); + } + + let mut variable_state = VariableState::prepare_for_conflict_check(fact) + .ok_or(InvalidInference::InconsistentPremises)?; + + // We apply the domain of variable 2 to variable 1. If the state remains consistent, then + // the step is unsound! + let state_is_consistent = match variable_a { + VariableExpr::Identifier(var1) => { + let mut consistent = true; + + if let I32Ext::I32(value) = variable_state.upper_bound(variable_b) { + consistent &= variable_state.apply(&Atomic { + name: Rc::clone(var1), + comparison: IntComparison::LessEqual, + value: linear.bound + value, + }); + } + + if let I32Ext::I32(value) = variable_state.lower_bound(variable_b) { + consistent &= variable_state.apply(&Atomic { + name: Rc::clone(var1), + comparison: IntComparison::GreaterEqual, + value: linear.bound + value, + }); + } + + for value in variable_state.holes(variable_b).collect::>() { + consistent &= variable_state.apply(&Atomic { + name: Rc::clone(var1), + comparison: IntComparison::NotEqual, + value: linear.bound + value, + }); + } + + consistent + } + + VariableExpr::Constant(value) => match variable_b { + VariableExpr::Identifier(var2) => variable_state.apply(&Atomic { + name: Rc::clone(var2), + comparison: IntComparison::NotEqual, + value: linear.bound + *value, + }), + VariableExpr::Constant(_) => panic!("Binary equals over two constants is unexpected."), + }, + }; + + if state_is_consistent { + // The intersection of the domains should yield an inconsistent state for the + // inference to be sound. + return Err(InvalidInference::Unsound); + } + + Ok(()) +} + +/// Verify a `binary_not_equals` inference. +/// +/// Tests that the premise of the inference and the negation of the consequent force the linear sum +/// to equal the right-hand side of the not equals constraint. +pub(crate) fn verify_binary_not_equals( + fact: &Fact, + constraint: &Constraint, +) -> Result<(), InvalidInference> { + let Constraint::AllDifferent(AllDifferent { variables }) = constraint else { + return Err(InvalidInference::ConstraintLabelMismatch); + }; + + let variable_state = VariableState::prepare_for_conflict_check(fact) + .ok_or(InvalidInference::InconsistentPremises)?; + + let mut values = BTreeSet::new(); + for variable in variables { + let Some(value) = variable_state.fixed_value(variable) else { + continue; + }; + + if !values.insert(value) { + return Ok(()); + } + } + + Err(InvalidInference::Unsound) +} diff --git a/pumpkin-checker/src/inferences/linear.rs b/pumpkin-checker/src/inferences/linear.rs new file mode 100644 index 000000000..44996d80c --- /dev/null +++ b/pumpkin-checker/src/inferences/linear.rs @@ -0,0 +1,114 @@ +use crate::inferences::Fact; +use crate::inferences::InvalidInference; +use crate::model::Constraint; +use crate::model::Linear; +use crate::state::I32Ext; +use crate::state::VariableState; + +/// Verify a `linear_bounds` inference. +/// +/// The inference is sound for linear inequalites and linear equalities. +pub(super) fn verify_linear_bounds( + fact: &Fact, + generated_by: &Constraint, +) -> Result<(), InvalidInference> { + match generated_by { + Constraint::LinearLeq(linear) => verify_linear_inference(linear, fact), + + Constraint::LinearEq(linear) => { + let try_upper_bound = verify_linear_inference(linear, fact); + + let inverted_linear = Linear { + terms: linear + .terms + .iter() + .map(|(weight, variable)| (-weight, variable.clone())) + .collect(), + bound: -linear.bound, + }; + let try_lower_bound = verify_linear_inference(&inverted_linear, fact); + + match (try_lower_bound, try_upper_bound) { + (Ok(_), Ok(_)) => panic!("This should not happen."), + (Ok(fact), Err(_)) | (Err(_), Ok(fact)) => Ok(fact), + (Err(_), Err(_)) => Err(InvalidInference::Unsound), + } + } + + _ => Err(InvalidInference::ConstraintLabelMismatch), + } +} + +fn verify_linear_inference(linear: &Linear, fact: &Fact) -> Result<(), InvalidInference> { + let variable_state = VariableState::prepare_for_conflict_check(fact) + .ok_or(InvalidInference::InconsistentPremises)?; + + // Next, we evaluate the linear inequality. The lower bound of the + // left-hand side must exceed the bound in the constraint. + let left_hand_side = linear.terms.iter().fold(None, |acc, (weight, variable)| { + let lower_bound = if *weight >= 0 { + variable_state.lower_bound(variable) + } else { + variable_state.upper_bound(variable) + }; + + match acc { + None => match lower_bound { + I32Ext::I32(value) => Some(weight * value), + I32Ext::NegativeInf => None, + I32Ext::PositiveInf => None, + }, + + Some(v1) => match lower_bound { + I32Ext::I32(v2) => Some(v1 + weight * v2), + I32Ext::NegativeInf => Some(v1), + I32Ext::PositiveInf => Some(v1), + }, + } + }); + + if left_hand_side.is_some_and(|value| value > linear.bound) { + Ok(()) + } else { + Err(InvalidInference::Unsound) + } +} + +#[cfg(test)] +mod tests { + use drcp_format::IntComparison::*; + use fzn_rs::VariableExpr::*; + + use super::*; + use crate::model::Atomic; + + #[test] + fn linear_1() { + // x1 - x2 <= -7 + let linear = Linear { + terms: vec![(1, Identifier("x1".into())), (-1, Identifier("x2".into()))], + bound: -7, + }; + + let premises = vec![Atomic { + name: "x2".into(), + comparison: LessEqual, + value: 37, + }]; + + let consequent = Some(Atomic { + name: "x1".into(), + comparison: LessEqual, + value: 30, + }); + + verify_linear_inference( + &linear, + &Fact { + premises, + consequent, + }, + ) + .expect("valid inference"); + } +} diff --git a/pumpkin-checker/src/inferences/mod.rs b/pumpkin-checker/src/inferences/mod.rs new file mode 100644 index 000000000..61e5416a0 --- /dev/null +++ b/pumpkin-checker/src/inferences/mod.rs @@ -0,0 +1,128 @@ +mod all_different; +mod arithmetic; +mod linear; +mod nogood; +mod time_table; + +use crate::model::Atomic; +use crate::model::Model; + +#[derive(Clone, Debug)] +pub struct Fact { + pub premises: Vec, + pub consequent: Option, +} + +impl Fact { + /// Create a fact `premises -> false`. + pub fn nogood(premises: Vec) -> Self { + Fact { + premises, + consequent: None, + } + } +} + +/// The reasons an inference can be rejected. +#[derive(Clone, Copy, thiserror::Error, Debug)] +#[error("invalid inference")] +pub enum InvalidInference { + /// The inference is not annotated with a label when we expect it to be. + #[error("inference does not have a label")] + MissingLabel, + + /// The label of the inference is not recognized by the checker. + #[error("inference is not supported")] + UnsupportedLabel, + + /// The inference label is not sound for the constraint that generated the inference. + #[error("indicated constraint cannot generate inferences with the indicated label")] + ConstraintLabelMismatch, + + /// The constraint that generated the inference does not exist in the model. + #[error("generated by undefined constraint")] + UndefinedConstraint, + + /// The inference does not state which constraint generated it. + #[error("missing constraint hint")] + MissingConstraint, + + /// The premises of the inference are inconsistent. + #[error("inconsistent premises")] + InconsistentPremises, + + /// The inference is unsound for the constraint. + #[error("inference is unsound")] + Unsound, +} + +pub(crate) fn verify_inference( + model: &Model, + inference: &drcp_format::Inference, i32, std::rc::Rc>, +) -> Result { + let fact = Fact { + premises: inference.premises.clone(), + consequent: inference.consequent.clone(), + }; + + let label = inference + .label + .as_ref() + .map(|label| label.as_ref()) + .ok_or(InvalidInference::MissingLabel)?; + + // The initial domain inference is handled separately since it is the only inference that + // does not expect a constraint hint. + if label == "initial_domain" { + let Some(atomic) = inference.consequent.clone() else { + // The initial domain inference requires a consequent. + return Err(InvalidInference::Unsound); + }; + + if !model.is_trivially_true(atomic.clone()) { + // If the consequent is not trivially true in the model then the inference + // is unsound. + return Err(InvalidInference::Unsound); + } + + return Ok(fact); + } + + // Get the constraint that generated the inference from the model. + let generated_by_constraint_id = inference + .generated_by + .ok_or(InvalidInference::MissingConstraint)?; + let generated_by = model + .get_constraint(generated_by_constraint_id) + .ok_or(InvalidInference::MissingConstraint)?; + + match label { + "linear_bounds" => { + linear::verify_linear_bounds(&fact, generated_by)?; + } + + "nogood" => { + nogood::verify_nogood(&fact, generated_by)?; + } + + "time_table" => { + time_table::verify_time_table(&fact, generated_by)?; + } + + "all_different" => { + all_different::verify_all_different(&fact, generated_by)?; + } + + "binary_equals" => { + arithmetic::verify_binary_equals(&fact, generated_by)?; + } + + "binary_not_equals" => { + arithmetic::verify_binary_not_equals(&fact, generated_by)?; + } + + _ => return Err(InvalidInference::UnsupportedLabel), + } + + Ok(fact) +} diff --git a/pumpkin-checker/src/inferences/nogood.rs b/pumpkin-checker/src/inferences/nogood.rs new file mode 100644 index 000000000..c715ad162 --- /dev/null +++ b/pumpkin-checker/src/inferences/nogood.rs @@ -0,0 +1,24 @@ +use crate::inferences::Fact; +use crate::inferences::InvalidInference; +use crate::model::Constraint; +use crate::state::VariableState; + +/// Verifies a `nogood` inference. +/// +/// This inference is used to rewrite a nogood `L /\ p -> false` to `L -> not p`. +pub(crate) fn verify_nogood(fact: &Fact, constraint: &Constraint) -> Result<(), InvalidInference> { + let Constraint::Nogood(nogood) = constraint else { + return Err(InvalidInference::ConstraintLabelMismatch); + }; + + let variable_state = VariableState::prepare_for_conflict_check(fact) + .ok_or(InvalidInference::InconsistentPremises)?; + + let is_implied_by_nogood = nogood.iter().all(|atomic| variable_state.is_true(atomic)); + + if is_implied_by_nogood { + Ok(()) + } else { + Err(InvalidInference::Unsound) + } +} diff --git a/pumpkin-checker/src/inferences/time_table.rs b/pumpkin-checker/src/inferences/time_table.rs new file mode 100644 index 000000000..4a5c87842 --- /dev/null +++ b/pumpkin-checker/src/inferences/time_table.rs @@ -0,0 +1,48 @@ +use std::collections::BTreeMap; + +use super::Fact; +use crate::inferences::InvalidInference; +use crate::model::Constraint; +use crate::state::VariableState; + +/// Verifies a `time_table` inference for the cumulative constraint. +/// +/// The premises and negation of the consequent should lead to an overflow of the resource +/// capacity. +pub(crate) fn verify_time_table( + fact: &Fact, + constraint: &Constraint, +) -> Result<(), InvalidInference> { + let Constraint::Cumulative(cumulative) = constraint else { + return Err(InvalidInference::ConstraintLabelMismatch); + }; + + let variable_state = VariableState::prepare_for_conflict_check(fact) + .ok_or(InvalidInference::InconsistentPremises)?; + + // The profile is a key-value store. The keys correspond to time-points, and the values to the + // relative change in resource consumption. A BTreeMap is used to maintain a sorted order of + // the time points. + let mut profile = BTreeMap::new(); + + for task in cumulative.tasks.iter() { + let lst = variable_state.upper_bound(&task.start_time); + let ect = variable_state.lower_bound(&task.start_time) + task.duration; + + if ect <= lst { + *profile.entry(ect).or_insert(0) += task.resource_usage; + *profile.entry(lst).or_insert(0) -= task.resource_usage; + } + } + + let mut usage = 0; + for delta in profile.values() { + usage += delta; + + if usage > cumulative.capacity { + return Ok(()); + } + } + + Err(InvalidInference::Unsound) +} diff --git a/pumpkin-checker/src/lib.rs b/pumpkin-checker/src/lib.rs new file mode 100644 index 000000000..4cfee913e --- /dev/null +++ b/pumpkin-checker/src/lib.rs @@ -0,0 +1,130 @@ +use std::collections::BTreeMap; +use std::io::BufRead; +use std::rc::Rc; + +use drcp_format::ConstraintId; +use drcp_format::reader::ProofReader; + +pub mod deductions; +pub mod inferences; +mod state; + +pub mod model; + +use model::*; + +/// The errors that can be returned by the checker. +#[derive(Debug, thiserror::Error)] +pub enum CheckError { + /// The inference with the given [`ConstraintId`] is invalid due to + /// [`inferences::InvalidInference`]. + #[error("inference {0} is invalid: {1}")] + InvalidInference(ConstraintId, inferences::InvalidInference), + + /// The inference with the given [`ConstraintId`] is invalid due to + /// [`deductions::InvalidDeduction`]. + #[error("deduction {0} is invalid: {1}")] + InvalidDeduction(ConstraintId, deductions::InvalidDeduction), + + /// The proof did not contain a conclusion line. + #[error("the proof was not terminated with a conclusion")] + MissingConclusion, + + /// The conclusion does not follow from any deduction in the proof. + #[error("the conclusion is not present as a proof step")] + InvalidConclusion, + + /// An I/O error prevented us from reading all the input. + #[error("failed to read next proof line: {0}")] + ProofReadError(#[from] drcp_format::reader::Error), +} + +/// Verify whether the given proof is valid w.r.t. the model. +pub fn verify_proof( + mut model: Model, + mut proof: ProofReader, +) -> Result<(), CheckError> { + // To check a proof we iterate over every step. + // - If the step is an inference, it is checked. If it is valid, then the inference is stored in + // the fact database to be used in the next deduction. Otherwise, an error is returned + // indicating that the step is invalid. + // - If the step is a deduction, it is checked with respect to all the inferences in the fact + // database. If the deduction is valid, the fact database is cleared and the deduction is + // added to the model to be used in future inferences. + + let mut fact_database = BTreeMap::new(); + + loop { + let next_step = proof.next_step()?; + + let Some(step) = next_step else { + // The loop stops when a conclusion is found, so at this point we know the proof does + // not contain a conclusion. + return Err(CheckError::MissingConclusion); + }; + + match step { + drcp_format::Step::Inference(inference) => { + let fact = inferences::verify_inference(&model, &inference) + .map_err(|err| CheckError::InvalidInference(inference.constraint_id, err))?; + + let _ = fact_database.insert(inference.constraint_id, fact); + } + + drcp_format::Step::Deduction(deduction) => { + let derived_constraint = deductions::verify_deduction(&deduction, &fact_database) + .map_err(|err| { + CheckError::InvalidDeduction(deduction.constraint_id, err) + })?; + + let new_constraint_added = model.add_constraint( + deduction.constraint_id, + Constraint::Nogood(derived_constraint), + ); + + if !new_constraint_added { + return Err(CheckError::InvalidDeduction( + deduction.constraint_id, + deductions::InvalidDeduction::DuplicateConstraintId( + deduction.constraint_id, + ), + )); + } + + // Forget the stored inferences. + fact_database.clear(); + } + + drcp_format::Step::Conclusion(conclusion) => { + if verify_conclusion(&model, &conclusion) { + return Ok(()); + } else { + return Err(CheckError::InvalidConclusion); + } + } + } + } +} + +fn verify_conclusion(model: &Model, conclusion: &drcp_format::Conclusion, i32>) -> bool { + // First we ensure the conclusion type matches the solve item in the model. + match (&model.objective, conclusion) { + (Some(_), drcp_format::Conclusion::Unsat) + | (None, drcp_format::Conclusion::DualBound(_)) => return false, + + _ => {} + } + + // We iterate in reverse order, since it is likely that the conclusion is based on a constraint + // towards the end of the proof. + model.iter_constraints().rev().any(|(_, constraint)| { + let Constraint::Nogood(nogood) = constraint else { + return false; + }; + + match conclusion { + drcp_format::Conclusion::Unsat => nogood.as_ref().is_empty(), + drcp_format::Conclusion::DualBound(atomic) => nogood.as_ref() == [!atomic.clone()], + } + }) +} diff --git a/pumpkin-checker/src/main.rs b/pumpkin-checker/src/main.rs new file mode 100644 index 000000000..90f2f7fc7 --- /dev/null +++ b/pumpkin-checker/src/main.rs @@ -0,0 +1,268 @@ +use std::fs::File; +use std::io::BufRead; +use std::io::BufReader; +use std::num::NonZero; +use std::path::Path; +use std::path::PathBuf; +use std::rc::Rc; +use std::time::Instant; + +use clap::Parser; +use drcp_format::reader::ProofReader; +use pumpkin_checker::CheckError; +use pumpkin_checker::deductions::IgnoredInference; +use pumpkin_checker::deductions::InvalidDeduction; +use pumpkin_checker::model::Model; +use pumpkin_checker::model::Objective; +use pumpkin_checker::model::Task; + +#[derive(Parser)] +struct Cli { + /// Path to the model file (.fzn). + model_path: PathBuf, + + /// Path to the proof file. + /// + /// If the path ends in `.gz`, we assume it is GZipped and the checker will unzip the file + /// on-the-fly. + proof_path: PathBuf, +} + +fn main() -> anyhow::Result<()> { + let cli = Cli::parse(); + + let parse_start = Instant::now(); + let model = parse_model(&cli.model_path)?; + println!("parse-flatzinc: {}s", parse_start.elapsed().as_secs_f32()); + + let proof_reader = create_proof_reader(&cli.proof_path)?; + println!("parse-proof: 0s"); + + let verify_start = Instant::now(); + + pumpkin_checker::verify_proof(model, proof_reader).inspect_err(|err| { + print_check_error_info(err); + println!("validate: {}s", verify_start.elapsed().as_secs_f32()); + })?; + + println!("validate: {}s", verify_start.elapsed().as_secs_f32()); + + println!("Proof is valid!"); + + Ok(()) +} + +/// If the error is an invalid deduction, here we print additional info why the deduction is +/// invalid. In particular, it prints any inferences which were ignored because the premise was not +/// satisfied. +fn print_check_error_info(error: &CheckError) { + let CheckError::InvalidDeduction( + constraint_id, + InvalidDeduction::NoConflict(unused_inferences), + ) = error + else { + return; + }; + + eprintln!("Deduction {constraint_id} is invalid."); + + if unused_inferences.is_empty() { + eprintln!(" Failed to derive conflict after applying all inferences."); + } else { + eprintln!(" Could not apply the following inferences:"); + + for unused_inference in unused_inferences { + let IgnoredInference { + constraint_id, + unsatisfied_premises, + } = unused_inference; + + eprint!(" - {constraint_id}:"); + + for premise in unsatisfied_premises { + eprint!(" {premise}"); + } + + eprintln!(); + } + } +} + +/// The constraints supported by the checker. +#[derive(Debug, fzn_rs::FlatZincConstraint)] +enum FlatZincConstraints { + #[name("int_lin_le")] + LinearLeq { + weights: fzn_rs::ArrayExpr, + variables: fzn_rs::ArrayExpr>, + bound: i32, + }, + #[name("int_lin_eq")] + LinearEq { + weights: fzn_rs::ArrayExpr, + variables: fzn_rs::ArrayExpr>, + bound: i32, + }, + #[name("pumpkin_cumulative")] + Cumulative { + start_times: fzn_rs::ArrayExpr>, + durations: fzn_rs::ArrayExpr, + resource_usages: fzn_rs::ArrayExpr, + capacity: i32, + }, + #[name("pumpkin_all_different")] + AllDifferent(fzn_rs::ArrayExpr>), +} + +type FlatZincModel = fzn_rs::TypedInstance; + +/// Parse a FlatZinc file to a checker [`Model`]. +fn parse_model(path: impl AsRef) -> anyhow::Result { + let model_source = std::fs::read_to_string(path)?; + + // TODO: For now the error handling shortcuts here. Ideally the `FznError` type returns + // something that can be converted to an owned type, but for now we have to work around the + // error holding a reference to the source. + let fzn_ast = fzn_rs::fzn::parse(&model_source).map_err(|err| anyhow::anyhow!("{err}"))?; + + let fzn_model = FlatZincModel::from_ast(fzn_ast)?; + + let mut model = Model::default(); + model.objective = match &fzn_model.solve.method.node { + fzn_rs::Method::Satisfy => None, + fzn_rs::Method::Optimize { + direction: fzn_rs::ast::OptimizationDirection::Minimize, + objective, + } => Some(Objective::Minimize(objective.clone())), + fzn_rs::Method::Optimize { + direction: fzn_rs::ast::OptimizationDirection::Maximize, + objective, + } => Some(Objective::Maximize(objective.clone())), + }; + + for (name, variable) in fzn_model.variables.iter() { + model.add_variable(Rc::clone(name), variable.domain.node.clone()); + } + + for (idx, annotated_constraint) in fzn_model.constraints.iter().enumerate() { + let constraint_id = NonZero::new(idx as u32 + 1).expect( + "we always add one, and idx is at least zero, constraint_id is always non-zero", + ); + + let constraint = match &annotated_constraint.constraint.node { + FlatZincConstraints::LinearLeq { + weights, + variables, + bound, + } => { + let weights = fzn_model.resolve_array(weights)?; + let variables = fzn_model.resolve_array(variables)?; + + let mut terms = vec![]; + + for (weight, variable) in weights.zip(variables) { + let weight = weight?; + let variable = variable?; + + terms.push((weight, variable)); + } + + pumpkin_checker::model::Constraint::LinearLeq(pumpkin_checker::model::Linear { + terms, + bound: *bound, + }) + } + + FlatZincConstraints::LinearEq { + weights, + variables, + bound, + } => { + let weights = fzn_model.resolve_array(weights)?; + let variables = fzn_model.resolve_array(variables)?; + + let mut terms = vec![]; + + for (weight, variable) in weights.zip(variables) { + let weight = weight?; + let variable = variable?; + + terms.push((weight, variable)); + } + + pumpkin_checker::model::Constraint::LinearEq(pumpkin_checker::model::Linear { + terms, + bound: *bound, + }) + } + + FlatZincConstraints::Cumulative { + start_times, + durations, + resource_usages, + capacity, + } => { + let start_times = fzn_model.resolve_array(start_times)?; + let durations = fzn_model.resolve_array(durations)?; + let resource_usages = fzn_model.resolve_array(resource_usages)?; + + let tasks = start_times + .zip(durations) + .zip(resource_usages) + .map( + |((maybe_start_time, maybe_duration), maybe_resource_usage)| { + let start_time = maybe_start_time?; + let duration = maybe_duration?; + let resource_usage = maybe_resource_usage?; + + Ok(Task { + start_time, + duration, + resource_usage, + }) + }, + ) + .collect::, fzn_rs::InstanceError>>()?; + + pumpkin_checker::model::Constraint::Cumulative(pumpkin_checker::model::Cumulative { + tasks, + capacity: *capacity, + }) + } + + FlatZincConstraints::AllDifferent(variables) => { + let variables = fzn_model + .resolve_array(variables)? + .collect::, _>>()?; + + pumpkin_checker::model::Constraint::AllDifferent( + pumpkin_checker::model::AllDifferent { variables }, + ) + } + }; + + let _ = model.add_constraint(constraint_id, constraint); + } + + Ok(model) +} + +/// Create a reader for the proof file. +/// +/// GZipped proofs are decompressed on-demand. +fn create_proof_reader( + path: impl AsRef, +) -> anyhow::Result, i32>> { + let file = File::open(path.as_ref())?; + + if path.as_ref().extension().is_some_and(|ext| ext == "gz") { + let decoder = flate2::read::GzDecoder::new(file); + let buf_reader = BufReader::new(decoder); + + Ok(ProofReader::new(Box::new(buf_reader))) + } else { + let buf_reader = BufReader::new(file); + + Ok(ProofReader::new(Box::new(buf_reader))) + } +} diff --git a/pumpkin-checker/src/model.rs b/pumpkin-checker/src/model.rs new file mode 100644 index 000000000..722e1d9da --- /dev/null +++ b/pumpkin-checker/src/model.rs @@ -0,0 +1,146 @@ +//! Defines what models the checker can check proofs for. +//! +//! The main component of the model are the constraints that the checker supports. + +use std::collections::BTreeMap; +use std::ops::Deref; +use std::rc::Rc; + +use drcp_format::ConstraintId; +use drcp_format::IntAtomic; +use fzn_rs::VariableExpr; +use fzn_rs::ast::Domain; + +#[derive(Clone, Debug)] +pub enum Constraint { + Nogood(Nogood), + LinearLeq(Linear), + LinearEq(Linear), + Cumulative(Cumulative), + AllDifferent(AllDifferent), +} + +pub type Atomic = IntAtomic, i32>; + +#[derive(Clone, Debug)] +pub struct Nogood(Vec); + +impl From for Nogood +where + T: IntoIterator, +{ + fn from(value: T) -> Self { + Nogood(value.into_iter().collect()) + } +} + +impl Deref for Nogood { + type Target = [Atomic]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Clone, Debug)] +pub struct Linear { + pub terms: Vec<(i32, VariableExpr)>, + pub bound: i32, +} + +#[derive(Clone, Debug)] +pub struct Task { + pub start_time: VariableExpr, + pub duration: i32, + pub resource_usage: i32, +} + +#[derive(Clone, Debug)] +pub struct Cumulative { + pub tasks: Vec, + pub capacity: i32, +} + +#[derive(Clone, Debug)] +pub struct AllDifferent { + pub variables: Vec>, +} + +#[derive(Clone, Debug)] +pub enum Objective { + Maximize(VariableExpr), + Minimize(VariableExpr), +} + +#[derive(Clone, Debug, Default)] +pub struct Model { + variables: BTreeMap, Domain>, + constraints: BTreeMap, + pub objective: Option, +} + +impl Model { + /// Add a new variable to the model. + pub fn add_variable(&mut self, name: Rc, domain: Domain) { + let _ = self.variables.insert(name, domain); + } + + /// Add a new constraint to the model. + /// + /// If a constraint with the given ID already exists, this returns false. Otherwise, the + /// function returns true. + pub fn add_constraint(&mut self, constraint_id: ConstraintId, constraint: Constraint) -> bool { + self.constraints.insert(constraint_id, constraint).is_none() + } + + /// Iterate over the constraints in the map, ordered by [`ConstraintId`]. + pub fn iter_constraints( + &self, + ) -> std::collections::btree_map::Iter<'_, ConstraintId, Constraint> { + self.constraints.iter() + } + + /// Get the constraint with the given ID if it exists. + pub fn get_constraint(&self, constraint_id: ConstraintId) -> Option<&Constraint> { + self.constraints.get(&constraint_id) + } + + /// Test whether the atomic is true in the initial domains of the variables. + /// + /// Returns false if the atomic is over a variable that is not in the model. + pub fn is_trivially_true(&self, atomic: Atomic) -> bool { + let Some(domain) = self.variables.get(&atomic.name) else { + return false; + }; + + match domain { + Domain::UnboundedInt => false, + Domain::Int(dom) => match atomic.comparison { + drcp_format::IntComparison::GreaterEqual => { + *dom.lower_bound() >= atomic.value as i64 + } + drcp_format::IntComparison::LessEqual => *dom.upper_bound() <= atomic.value as i64, + drcp_format::IntComparison::Equal => { + *dom.lower_bound() >= atomic.value as i64 + && *dom.upper_bound() <= atomic.value as i64 + } + drcp_format::IntComparison::NotEqual => { + if *dom.lower_bound() >= atomic.value as i64 { + return true; + } + + if *dom.upper_bound() <= atomic.value as i64 { + return true; + } + + if dom.is_continuous() { + return false; + } + + dom.into_iter().all(|value| value != atomic.value as i64) + } + }, + Domain::Bool => todo!("boolean variables are not yet supported"), + } + } +} diff --git a/pumpkin-checker/src/state.rs b/pumpkin-checker/src/state.rs new file mode 100644 index 000000000..a37c801ed --- /dev/null +++ b/pumpkin-checker/src/state.rs @@ -0,0 +1,514 @@ +use std::cmp::Ordering; +use std::collections::BTreeMap; +use std::collections::BTreeSet; +use std::ops::Add; +use std::rc::Rc; + +use crate::inferences::Fact; +use crate::model::Atomic; + +/// The domains of all variables in the problem. +/// +/// Domains can be reduced through [`VariableState::apply`]. By default, the domain of every +/// variable is infinite. +#[derive(Clone, Debug, Default)] +pub(crate) struct VariableState { + domains: BTreeMap, Domain>, +} + +impl VariableState { + /// Create a variable state that applies all the premises and, if present, the negation of the + /// consequent. + /// + /// Used by inference checkers if they want to identify a conflict by negating the consequent. + pub(crate) fn prepare_for_conflict_check(fact: &Fact) -> Option { + let mut variable_state = VariableState::default(); + + let negated_consequent = fact + .consequent + .as_ref() + .map(|consequent| !consequent.clone()); + + // Apply all the premises and the negation of the consequent to the state. + if !fact + .premises + .iter() + .chain(negated_consequent.as_ref()) + .all(|premise| variable_state.apply(premise)) + { + return None; + } + + Some(variable_state) + } + + /// Get the lower bound of a variable. + pub(crate) fn lower_bound(&self, variable: &fzn_rs::VariableExpr) -> I32Ext { + let name = match variable { + fzn_rs::VariableExpr::Identifier(name) => name, + fzn_rs::VariableExpr::Constant(value) => return I32Ext::I32(*value), + }; + + self.domains + .get(name) + .map(|domain| domain.lower_bound) + .unwrap_or(I32Ext::NegativeInf) + } + + /// Get the upper bound of a variable. + pub(crate) fn upper_bound(&self, variable: &fzn_rs::VariableExpr) -> I32Ext { + let name = match variable { + fzn_rs::VariableExpr::Identifier(name) => name, + fzn_rs::VariableExpr::Constant(value) => return I32Ext::I32(*value), + }; + + self.domains + .get(name) + .map(|domain| domain.upper_bound) + .unwrap_or(I32Ext::PositiveInf) + } + + /// Get the holes within the lower and upper bound of the variable expression. + pub(crate) fn holes( + &self, + variable: &fzn_rs::VariableExpr, + ) -> impl Iterator + '_ { + #[allow(trivial_casts, reason = "without it we get a type error")] + let name = match variable { + fzn_rs::VariableExpr::Identifier(name) => name, + fzn_rs::VariableExpr::Constant(_) => { + return Box::new(std::iter::empty()) as Box>; + } + }; + + #[allow(trivial_casts, reason = "without it we get a type error")] + self.domains + .get(name) + .map(|domain| Box::new(domain.holes.iter().copied()) as Box>) + .unwrap_or_else(|| Box::new(std::iter::empty())) + } + + /// Get the fixed value of this variable, if it is fixed. + pub(crate) fn fixed_value(&self, variable: &fzn_rs::VariableExpr) -> Option { + let name = match variable { + fzn_rs::VariableExpr::Identifier(name) => name, + fzn_rs::VariableExpr::Constant(value) => return Some(*value), + }; + + let domain = self.domains.get(name)?; + + if domain.lower_bound == domain.upper_bound { + let I32Ext::I32(value) = domain.lower_bound else { + panic!( + "lower can only equal upper if they are integers, otherwise the sign of infinity makes them different" + ); + }; + + Some(value) + } else { + None + } + } + + /// Obtain an iterator over the domain of the variable. + /// + /// If the domain is unbounded, then `None` is returned. + pub(crate) fn iter_domain( + &self, + variable: &fzn_rs::VariableExpr, + ) -> Option> { + match variable { + fzn_rs::VariableExpr::Identifier(name) => { + let domain = self.domains.get(name)?; + + let I32Ext::I32(lower_bound) = domain.lower_bound else { + // If there is no lower bound, then the domain is unbounded. + return None; + }; + + // Ensure there is also an upper bound. + if !matches!(domain.upper_bound, I32Ext::I32(_)) { + return None; + } + + Some(DomainIterator(DomainIteratorImpl::Domain { + domain, + next_value: lower_bound, + })) + } + + fzn_rs::VariableExpr::Constant(value) => { + Some(DomainIterator(DomainIteratorImpl::Constant { + value: *value, + finished: false, + })) + } + } + } + + /// Apply the given [`Atomic`] to the state. + /// + /// Returns true if the state remains consistent, or false if the atomic cannot be true in + /// conjunction with previously applied atomics. + pub(crate) fn apply(&mut self, atomic: &Atomic) -> bool { + let domain = self + .domains + .entry(Rc::clone(&atomic.name)) + .or_insert(Domain::new()); + + match atomic.comparison { + drcp_format::IntComparison::GreaterEqual => { + domain.tighten_lower_bound(atomic.value); + } + + drcp_format::IntComparison::LessEqual => { + domain.tighten_upper_bound(atomic.value); + } + + drcp_format::IntComparison::Equal => { + domain.tighten_lower_bound(atomic.value); + domain.tighten_upper_bound(atomic.value); + } + + drcp_format::IntComparison::NotEqual => { + if domain.lower_bound == atomic.value { + domain.tighten_lower_bound(atomic.value + 1); + } + + if domain.upper_bound == atomic.value { + domain.tighten_upper_bound(atomic.value - 1); + } + + if domain.lower_bound < atomic.value && domain.upper_bound > atomic.value { + let _ = domain.holes.insert(atomic.value); + } + } + } + + domain.is_consistent() + } + + /// Is the given atomic true in the current state. + pub(crate) fn is_true(&self, atomic: &Atomic) -> bool { + let Some(domain) = self.domains.get(&atomic.name) else { + return false; + }; + + match atomic.comparison { + drcp_format::IntComparison::GreaterEqual => domain.lower_bound >= atomic.value, + + drcp_format::IntComparison::LessEqual => domain.upper_bound <= atomic.value, + + drcp_format::IntComparison::Equal => { + domain.lower_bound >= atomic.value && domain.upper_bound <= atomic.value + } + + drcp_format::IntComparison::NotEqual => { + if domain.lower_bound >= atomic.value { + return true; + } + + if domain.upper_bound <= atomic.value { + return true; + } + + if domain.holes.contains(&atomic.value) { + return true; + } + + false + } + } + } +} + +#[derive(Clone, Debug)] +struct Domain { + lower_bound: I32Ext, + upper_bound: I32Ext, + holes: BTreeSet, +} + +impl Domain { + fn new() -> Domain { + Domain { + lower_bound: I32Ext::NegativeInf, + upper_bound: I32Ext::PositiveInf, + holes: BTreeSet::default(), + } + } + + fn tighten_lower_bound(&mut self, bound: i32) { + if self.lower_bound >= bound { + return; + } + + self.lower_bound = I32Ext::I32(bound); + self.holes = self.holes.split_off(&bound); + + // Take care of the condition where the new bound is already a hole in the domain. + if self.holes.contains(&bound) { + self.tighten_lower_bound(bound + 1); + } + } + + fn tighten_upper_bound(&mut self, bound: i32) { + if self.upper_bound <= bound { + return; + } + + self.upper_bound = I32Ext::I32(bound); + + // Note the '+ 1' to keep the elements <= the upper bound instead of < + // the upper bound. + let _ = self.holes.split_off(&(bound + 1)); + + // Take care of the condition where the new bound is already a hole in the domain. + if self.holes.contains(&bound) { + self.tighten_upper_bound(bound - 1); + } + } + + fn is_consistent(&self) -> bool { + // No need to check holes, as the invariant of `Domain` specifies the bounds are as tight + // as possible, taking holes into account. + + self.lower_bound <= self.upper_bound + } +} + +/// An `i32` or infinity. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum I32Ext { + I32(i32), + NegativeInf, + PositiveInf, +} + +impl PartialEq for I32Ext { + fn eq(&self, other: &i32) -> bool { + match self { + I32Ext::I32(v1) => v1 == other, + I32Ext::NegativeInf | I32Ext::PositiveInf => false, + } + } +} + +impl PartialOrd for I32Ext { + fn partial_cmp(&self, other: &I32Ext) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for I32Ext { + fn cmp(&self, other: &Self) -> Ordering { + match self { + I32Ext::I32(v1) => match other { + I32Ext::I32(v2) => v1.cmp(v2), + I32Ext::NegativeInf => Ordering::Greater, + I32Ext::PositiveInf => Ordering::Less, + }, + I32Ext::NegativeInf => match other { + I32Ext::I32(_) => Ordering::Less, + I32Ext::PositiveInf => Ordering::Less, + I32Ext::NegativeInf => Ordering::Equal, + }, + I32Ext::PositiveInf => match other { + I32Ext::I32(_) => Ordering::Greater, + I32Ext::NegativeInf => Ordering::Greater, + I32Ext::PositiveInf => Ordering::Greater, + }, + } + } +} + +impl PartialOrd for I32Ext { + fn partial_cmp(&self, other: &i32) -> Option { + match self { + I32Ext::I32(v1) => v1.partial_cmp(other), + I32Ext::NegativeInf => Some(Ordering::Less), + I32Ext::PositiveInf => Some(Ordering::Greater), + } + } +} + +impl Add for I32Ext { + type Output = I32Ext; + + fn add(self, rhs: i32) -> Self::Output { + match self { + I32Ext::I32(lhs) => I32Ext::I32(lhs + rhs), + I32Ext::NegativeInf => I32Ext::NegativeInf, + I32Ext::PositiveInf => I32Ext::PositiveInf, + } + } +} + +/// An iterator over the values in the domain of a variable. +pub(crate) struct DomainIterator<'a>(DomainIteratorImpl<'a>); + +enum DomainIteratorImpl<'a> { + Constant { value: i32, finished: bool }, + Domain { domain: &'a Domain, next_value: i32 }, +} + +impl Iterator for DomainIterator<'_> { + type Item = i32; + + fn next(&mut self) -> Option { + match self.0 { + // Iterating over a contant means only yielding the value once, and then + // never again. + DomainIteratorImpl::Constant { + value, + ref mut finished, + } => { + if *finished { + None + } else { + *finished = true; + Some(value) + } + } + + DomainIteratorImpl::Domain { + domain, + ref mut next_value, + } => { + let I32Ext::I32(upper_bound) = domain.upper_bound else { + panic!("Only finite domains can be iterated.") + }; + + loop { + // We have completed iterating the domain. + if *next_value > upper_bound { + return None; + } + + let value = *next_value; + *next_value += 1; + + // The next value is not part of the domain. + if domain.holes.contains(&value) { + continue; + } + + // Here the value is part of the domain, so we yield it. + return Some(value); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use drcp_format::IntAtomic; + use drcp_format::IntComparison; + + use super::*; + + #[test] + fn domain_iterator_unbounded() { + let state = VariableState::default(); + let iterator = state.iter_domain(&fzn_rs::VariableExpr::Identifier(Rc::from("x1"))); + + assert!(iterator.is_none()); + } + + #[test] + fn domain_iterator_unbounded_lower_bound() { + let mut state = VariableState::default(); + + let variable_name = Rc::from("x1"); + let variable = fzn_rs::VariableExpr::Identifier(Rc::clone(&variable_name)); + + let _ = state.apply(&IntAtomic { + name: variable_name, + comparison: IntComparison::LessEqual, + value: 5, + }); + + let iterator = state.iter_domain(&variable); + + assert!(iterator.is_none()); + } + + #[test] + fn domain_iterator_unbounded_upper_bound() { + let mut state = VariableState::default(); + + let variable_name = Rc::from("x1"); + let variable = fzn_rs::VariableExpr::Identifier(Rc::clone(&variable_name)); + + let _ = state.apply(&IntAtomic { + name: variable_name, + comparison: IntComparison::GreaterEqual, + value: 5, + }); + + let iterator = state.iter_domain(&variable); + + assert!(iterator.is_none()); + } + + #[test] + fn domain_iterator_bounded_no_holes() { + let mut state = VariableState::default(); + + let variable_name = Rc::from("x1"); + let variable = fzn_rs::VariableExpr::Identifier(Rc::clone(&variable_name)); + + let _ = state.apply(&IntAtomic { + name: Rc::clone(&variable_name), + comparison: IntComparison::GreaterEqual, + value: 5, + }); + + let _ = state.apply(&IntAtomic { + name: variable_name, + comparison: IntComparison::LessEqual, + value: 10, + }); + + let values = state + .iter_domain(&variable) + .expect("the domain is bounded") + .collect::>(); + + assert_eq!(values, vec![5, 6, 7, 8, 9, 10]); + } + + #[test] + fn domain_iterator_bounded_with_holes() { + let mut state = VariableState::default(); + + let variable_name = Rc::from("x1"); + let variable = fzn_rs::VariableExpr::Identifier(Rc::clone(&variable_name)); + + let _ = state.apply(&IntAtomic { + name: Rc::clone(&variable_name), + comparison: IntComparison::GreaterEqual, + value: 5, + }); + + let _ = state.apply(&IntAtomic { + name: Rc::clone(&variable_name), + comparison: IntComparison::NotEqual, + value: 7, + }); + + let _ = state.apply(&IntAtomic { + name: variable_name, + comparison: IntComparison::LessEqual, + value: 10, + }); + + let values = state + .iter_domain(&variable) + .expect("the domain is bounded") + .collect::>(); + + assert_eq!(values, vec![5, 6, 8, 9, 10]); + } +}