diff --git a/explainability/src/explain/explanation.rs b/explainability/src/explain/explanation.rs index 39ab8966..8b2d3c80 100644 --- a/explainability/src/explain/explanation.rs +++ b/explainability/src/explain/explanation.rs @@ -4,16 +4,12 @@ use std::sync::Arc; use aries::core::Lit; use aries::model::{Label, Model}; -// "Essence" vs "Counterfactual" ? "Premise" ? -#[derive(Debug, PartialEq, Eq, Hash)] -pub struct ExplEssence(pub BTreeSet, pub BTreeSet); +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Essence(pub BTreeSet, pub BTreeSet); -// support (best alternative) ? justification ? argument ? cause ? -// "contradiction" vs "modelling" ? (but a counterexample could also be seen as one ?) -// just "example" vs "counterexample", maybe ? -#[derive(Debug, PartialEq, Eq, Hash)] -pub enum ExplSubstance { - Modelling(BTreeSet), +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Substance { + ModelConstraints(BTreeSet), CounterExample(BTreeSet), } @@ -37,8 +33,8 @@ impl ExplanationFilter { pub struct Explanation { pub models: Vec>>, - pub essences: Vec, - pub substances: Vec, + pub essences: Vec, + pub substances: Vec, pub table: BTreeMap<(EssenceIndex, SubstanceIndex), BTreeSet>, pub filter: ExplanationFilter, } diff --git a/explainability/src/explain/presupposition.rs b/explainability/src/explain/presupposition.rs index 329b6a59..c0a7e1bb 100644 --- a/explainability/src/explain/presupposition.rs +++ b/explainability/src/explain/presupposition.rs @@ -54,11 +54,36 @@ pub fn check_presupposition( cached_solver: Option<&mut Solver>, ) -> Result<(), UnmetPresupposition> { let solver = if let Some(s) = cached_solver { + // If we (the caller of the function) have supplied a cached solver to use, then use it. s } else { - &mut create_solver((*presupposition.model).clone()) + // If no cached solver has been supplied, then create one and use it. + &mut { + let model = (*presupposition.model).clone(); + let stn_config = StnConfig { + theory_propagation: TheoryPropagationLevel::Full, + ..Default::default() + }; + let mut solver = Solver::::new(model); + solver.reasoners.diff.config = stn_config; + solver + } }; - if skip_model_situ_sat_check { + + if !skip_model_situ_sat_check { + // We need to make sure `model` /\ `situ` is indeed SAT. + match solver.solve_with_assumptions(presupposition.situ.iter().cloned()) { + Ok(_) => solver.restore(DecLvl::from(presupposition.situ.len())), + Err(_) => { + return Err(UnmetPresupposition { + presupposition, + cause: UnmetPresuppositionCause::ModelSituUnsat, + }) + } + } + } else { + // If we (the caller of the function) want to skip checking `model` /\ `situ` is SAT + // (because we know that it's already the case), we only do the initial propagation and assumptions. debug_assert!(solver.current_decision_level() == DecLvl::ROOT); match solver.propagate_and_backtrack_to_consistent(solver.current_decision_level()) { Ok(_) => (), // expected, @@ -69,27 +94,18 @@ pub fn check_presupposition( Ok(_) => (), // expected Err(_) => debug_assert!(false), } - } - } else { - match solver.solve_with_assumptions(presupposition.situ.clone()) { - Ok(_) => solver.restore(DecLvl::from(presupposition.situ.len())), - Err(_) => { - return Err(UnmetPresupposition { - presupposition, - cause: UnmetPresuppositionCause::ModelSituUnsat, - }) - } - } + } } - // Remember, `situ` is already assumed (we backtracked to the latest assumption). + // !!! Remember, at this point `situ` is already assumed + // !!! so, we will just use `query` (or `query_neg`) + // !!! in `solve_with_assumptions` calls below (incremental solving). debug_assert!(solver.current_decision_level() == DecLvl::from(presupposition.situ.len())); - // And so, we will just use `query` (or `query_neg`) in `solve_with_assumptions` calls below (incremental solving). let res = match presupposition.kind { PresuppositionKind::ModelSituUnsatWithQuery => { match solver - .solve_with_assumptions(presupposition.query.clone()) + .solve_with_assumptions(presupposition.query.iter().cloned()) .expect("Solver interrupted.") { Ok(_) => Err(UnmetPresupposition { @@ -101,7 +117,7 @@ pub fn check_presupposition( } PresuppositionKind::ModelSituSatWithQuery => { match solver - .solve_with_assumptions(presupposition.query.clone()) + .solve_with_assumptions(presupposition.query.iter().cloned()) .expect("Solver interrupted.") { Ok(_) => Ok(()), @@ -114,12 +130,12 @@ pub fn check_presupposition( PresuppositionKind::ModelSituNotEntailQuery => { let dl = DecLvl::from(presupposition.query.len()); match solver - .solve_with_assumptions(presupposition.query.clone()) + .solve_with_assumptions(presupposition.query.iter().cloned()) .expect("Solver interrupted.") { Ok(_) => { solver.restore(dl); - let query_neg = presupposition.query.iter().map(|&l| !l).collect_vec(); + let query_neg = presupposition.query.iter().map(|&l| !l); match solver.solve_with_assumptions(query_neg).expect("Solver interrupted.") { Ok(_) => Err(UnmetPresupposition { presupposition, @@ -135,8 +151,8 @@ pub fn check_presupposition( } } PresuppositionKind::ModelSituEntailQuery => { - let neg_query = presupposition.query.iter().map(|&l| !l).collect_vec(); - match solver.solve_with_assumptions(neg_query).expect("Solver interrupted.") { + let query_neg = presupposition.query.iter().map(|&l| !l); + match solver.solve_with_assumptions(query_neg).expect("Solver interrupted.") { Ok(_) => Ok(()), Err(_) => Err(UnmetPresupposition { presupposition, @@ -145,17 +161,8 @@ pub fn check_presupposition( } } }; - // necessary if the solver was a cached one (given as parameter), to ensure it can be safely reused somewhere else. + // necessary if the solver was a cached one (given as parameter), + // to ensure it can be safely reused somewhere else. solver.reset(); res } - -fn create_solver(model: Model) -> Solver { - let stn_config = StnConfig { - theory_propagation: TheoryPropagationLevel::Full, - ..Default::default() - }; - let mut solver = Solver::::new(model); - solver.reasoners.diff.config = stn_config; - solver -} diff --git a/explainability/src/explain/why/unsat.rs b/explainability/src/explain/why/unsat.rs index 74a78336..3a8eae40 100644 --- a/explainability/src/explain/why/unsat.rs +++ b/explainability/src/explain/why/unsat.rs @@ -5,7 +5,7 @@ use aries::core::Lit; use aries::model::{Label, Model}; use crate::explain::explanation::{ - EssenceIndex, ExplEssence, ExplSubstance, Explanation, ExplanationFilter, ModelIndex, SubstanceIndex, + EssenceIndex, Essence, Substance, Explanation, ExplanationFilter, ModelIndex, SubstanceIndex, }; use crate::explain::presupposition::{check_presupposition, Presupposition, PresuppositionKind, UnmetPresupposition}; use crate::explain::{Query, Question, Situation, Vocab}; @@ -51,8 +51,8 @@ impl Question for QwhyUnsat { ); let muses = simple_marco.run().muses.unwrap(); - let mut essences = Vec::::new(); - let mut substances = Vec::::new(); + let mut essences = Vec::::new(); + let mut substances = Vec::::new(); let mut table = BTreeMap::<(EssenceIndex, SubstanceIndex), BTreeSet>::new(); let filter = ExplanationFilter { map: None, @@ -62,7 +62,7 @@ impl Question for QwhyUnsat { let _situ_set = BTreeSet::from_iter(self.situ.iter().cloned()); for (mus_idx, mus) in muses.into_iter().enumerate() { - essences.push(ExplEssence( + essences.push(Essence( mus.difference(&_situ_set).cloned().collect::>(), mus.intersection(&_situ_set).cloned().collect::>(), )); @@ -78,7 +78,7 @@ impl Question for QwhyUnsat { ); let mcses = simple_marco.run().mcses.unwrap(); for mcs in mcses { - let sub = ExplSubstance::Modelling(mcs); + let sub = Substance::ModelConstraints(mcs); let sub_idx = substances.iter().position(|s| s == &sub); match sub_idx { Some(i) => table.insert((mus_idx, i), BTreeSet::from_iter([0])), @@ -110,7 +110,7 @@ mod tests { use aries::model::lang::expr::{and, implies}; use aries::model::lang::linear::LinearSum; - use crate::explain::explanation::{ExplEssence, ExplSubstance}; + use crate::explain::explanation::{Essence, Substance}; use super::Question; @@ -167,32 +167,32 @@ mod tests { let expl = question.try_answer().unwrap(); - let essences: HashSet = expl.essences.into_iter().collect::>(); + let essences: HashSet = expl.essences.iter().cloned().collect::>(); debug_assert_eq!( essences, HashSet::from_iter([ - ExplEssence(BTreeSet::from_iter([p_a, p_b]), BTreeSet::from_iter([p_d])), - ExplEssence(BTreeSet::from_iter([p_b, p_c]), BTreeSet::from_iter([p_d])), + Essence(BTreeSet::from_iter([p_a, p_b]), BTreeSet::from_iter([p_d])), + Essence(BTreeSet::from_iter([p_b, p_c]), BTreeSet::from_iter([p_d])), ]), ); - let substances = expl.substances.into_iter().collect::>(); + let substances = expl.substances.iter().cloned().collect::>(); debug_assert_eq!( substances, HashSet::from_iter([ - ExplSubstance::Modelling(BTreeSet::from_iter([voc[0]])), - ExplSubstance::Modelling(BTreeSet::from_iter([voc[1]])), - ExplSubstance::Modelling(BTreeSet::from_iter([voc[2]])), - ExplSubstance::Modelling(BTreeSet::from_iter([voc[4]])), + Substance::ModelConstraints(BTreeSet::from_iter([voc[0]])), + Substance::ModelConstraints(BTreeSet::from_iter([voc[1]])), + Substance::ModelConstraints(BTreeSet::from_iter([voc[2]])), + Substance::ModelConstraints(BTreeSet::from_iter([voc[4]])), ]), ); - let idxe0 = essences.iter().position(|e| *e == ExplEssence(BTreeSet::from_iter([p_a, p_b]), BTreeSet::from_iter([p_d]))).unwrap(); - let idxe1 = essences.iter().position(|e| *e == ExplEssence(BTreeSet::from_iter([p_b, p_c]), BTreeSet::from_iter([p_d]))).unwrap(); - let idxs0 = substances.iter().position(|s| *s == ExplSubstance::Modelling(BTreeSet::from_iter([voc[0]]))).unwrap(); - let idxs1 = substances.iter().position(|s| *s == ExplSubstance::Modelling(BTreeSet::from_iter([voc[1]]))).unwrap(); - let idxs2 = substances.iter().position(|s| *s == ExplSubstance::Modelling(BTreeSet::from_iter([voc[2]]))).unwrap(); - let idxs3 = substances.iter().position(|s| *s == ExplSubstance::Modelling(BTreeSet::from_iter([voc[4]]))).unwrap(); + let idxe0 = expl.essences.iter().position(|e| *e == Essence(BTreeSet::from_iter([p_a, p_b]), BTreeSet::from_iter([p_d]))).unwrap(); + let idxe1 = expl.essences.iter().position(|e| *e == Essence(BTreeSet::from_iter([p_b, p_c]), BTreeSet::from_iter([p_d]))).unwrap(); + let idxs0 = expl.substances.iter().position(|s| *s == Substance::ModelConstraints(BTreeSet::from_iter([voc[0]]))).unwrap(); + let idxs1 = expl.substances.iter().position(|s| *s == Substance::ModelConstraints(BTreeSet::from_iter([voc[1]]))).unwrap(); + let idxs2 = expl.substances.iter().position(|s| *s == Substance::ModelConstraints(BTreeSet::from_iter([voc[2]]))).unwrap(); + let idxs3 = expl.substances.iter().position(|s| *s == Substance::ModelConstraints(BTreeSet::from_iter([voc[4]]))).unwrap(); let table = expl.table; debug_assert_eq!(