diff --git a/pumpkin-crates/core/src/engine/conflict_analysis/conflict_analysis_context.rs b/pumpkin-crates/core/src/engine/conflict_analysis/conflict_analysis_context.rs index 23d89d81c..54ff11c4e 100644 --- a/pumpkin-crates/core/src/engine/conflict_analysis/conflict_analysis_context.rs +++ b/pumpkin-crates/core/src/engine/conflict_analysis/conflict_analysis_context.rs @@ -87,10 +87,11 @@ impl ConflictAnalysisContext<'_> { StoredConflictInfo::Propagator(conflict) => { let _ = self.proof_log.log_inference( &self.state.inference_codes, + &mut self.state.constraint_tags, conflict.inference_code, conflict.conjunction.iter().copied(), None, - self.state.variable_names(), + &self.state.variable_names, ); conflict.conjunction @@ -176,19 +177,21 @@ impl ConflictAnalysisContext<'_> { let _ = proof_log.log_inference( &state.inference_codes, + &mut state.constraint_tags, *inference_code, [], Some(predicate), - state.variable_names(), + &state.variable_names, ); } else { // Otherwise we log the inference which was used to derive the nogood let _ = proof_log.log_inference( &state.inference_codes, + &mut state.constraint_tags, inference_code, reason_buffer.as_ref().iter().copied(), Some(predicate), - state.variable_names(), + &state.variable_names, ); } } @@ -219,10 +222,11 @@ impl ConflictAnalysisContext<'_> { // We also need to log this last propagation to the proof log as an inference. let _ = self.proof_log.log_inference( &self.state.inference_codes, + &mut self.state.constraint_tags, conflict.trigger_inference_code, empty_domain_reason.iter().copied(), Some(conflict.trigger_predicate), - self.state.variable_names(), + &self.state.variable_names, ); let old_lower_bound = self.state.lower_bound(conflict_domain); diff --git a/pumpkin-crates/core/src/engine/conflict_analysis/resolvers/resolution_resolver.rs b/pumpkin-crates/core/src/engine/conflict_analysis/resolvers/resolution_resolver.rs index d29031a71..86b8bda78 100644 --- a/pumpkin-crates/core/src/engine/conflict_analysis/resolvers/resolution_resolver.rs +++ b/pumpkin-crates/core/src/engine/conflict_analysis/resolvers/resolution_resolver.rs @@ -101,7 +101,8 @@ impl ConflictResolver for ResolutionResolver { .proof_log .log_deduction( learned_nogood.predicates.iter().copied(), - context.state.variable_names(), + &context.state.variable_names, + &mut context.state.constraint_tags, ) .expect("Failed to write proof log"); diff --git a/pumpkin-crates/core/src/engine/constraint_satisfaction_solver.rs b/pumpkin-crates/core/src/engine/constraint_satisfaction_solver.rs index 1f2a0fe08..53511c1da 100644 --- a/pumpkin-crates/core/src/engine/constraint_satisfaction_solver.rs +++ b/pumpkin-crates/core/src/engine/constraint_satisfaction_solver.rs @@ -335,7 +335,7 @@ impl ConstraintSatisfactionSolver { /// Create a new [`ConstraintTag`]. pub fn new_constraint_tag(&mut self) -> ConstraintTag { - self.internal_parameters.proof_log.new_constraint_tag() + self.state.new_constraint_tag() } pub fn create_new_literal(&mut self, name: Option>) -> Literal { @@ -834,10 +834,11 @@ impl ConstraintSatisfactionSolver { let inference_premises = reason.iter().copied().chain(std::iter::once(!propagated)); let _ = self.internal_parameters.proof_log.log_inference( &self.state.inference_codes, + &mut self.state.constraint_tags, inference_code, inference_premises, None, - self.state.variable_names(), + &self.state.variable_names, ); // Since inference steps are only related to the nogood they directly precede, @@ -869,10 +870,11 @@ impl ConstraintSatisfactionSolver { } // Log the nogood which adds the root-level knowledge to the proof. - let constraint_tag = self - .internal_parameters - .proof_log - .log_deduction([!propagated], self.state.variable_names()); + let constraint_tag = self.internal_parameters.proof_log.log_deduction( + [!propagated], + &self.state.variable_names, + &mut self.state.constraint_tags, + ); if let Ok(constraint_tag) = constraint_tag { let inference_code = self diff --git a/pumpkin-crates/core/src/engine/state.rs b/pumpkin-crates/core/src/engine/state.rs index ca478c73b..df741eb13 100644 --- a/pumpkin-crates/core/src/engine/state.rs +++ b/pumpkin-crates/core/src/engine/state.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use crate::basic_types::PropagatorConflict; +use crate::containers::KeyGenerator; use crate::containers::KeyedVec; use crate::create_statistics_struct; use crate::engine::Assignments; @@ -56,7 +57,7 @@ pub struct State { /// Keep track of trailed values (i.e. values which automatically backtrack). pub(crate) trailed_values: TrailedValues, /// The names of the variables in the solver. - variable_names: VariableNames, + pub(crate) variable_names: VariableNames, /// Dictates the order in which propagators will be called to propagate. pub(crate) propagator_queue: PropagatorQueue, /// Handles storing information about propagation reasons, which are used later to construct @@ -65,7 +66,10 @@ pub struct State { /// Component responsible for providing notifications for changes to the domains of variables /// and/or the polarity [Predicate]s pub(crate) notification_engine: NotificationEngine, + pub(crate) inference_codes: KeyedVec)>, + /// The [`ConstraintTag`]s generated for this proof. + pub(crate) constraint_tags: KeyGenerator, statistics: StateStatistics, } @@ -152,6 +156,7 @@ impl Default for State { notification_engine: NotificationEngine::default(), inference_codes: KeyedVec::default(), statistics: StateStatistics::default(), + constraint_tags: KeyGenerator::default(), }; // As a convention, the assignments contain a dummy domain_id=0, which represents a 0-1 // variable that is assigned to one. We use it to represent predicates that are @@ -204,6 +209,11 @@ impl State { .push((constraint_tag, inference_label.to_str())) } + /// Create a new [`ConstraintTag`]. + pub fn new_constraint_tag(&mut self) -> ConstraintTag { + self.constraint_tags.next_key() + } + /// Creates a new Boolean (0-1) variable. /// /// The name is used in solver traces to identify individual domains. They are required to be diff --git a/pumpkin-crates/core/src/proof/finalizer.rs b/pumpkin-crates/core/src/proof/finalizer.rs index 7115c90f8..4659543cc 100644 --- a/pumpkin-crates/core/src/proof/finalizer.rs +++ b/pumpkin-crates/core/src/proof/finalizer.rs @@ -41,9 +41,11 @@ pub(crate) fn finalize_proof(context: FinalizingContext<'_>) { }) .collect::>(); - let _ = context - .proof_log - .log_deduction(final_nogood, context.state.variable_names()); + let _ = context.proof_log.log_deduction( + final_nogood, + &context.state.variable_names, + &mut context.state.constraint_tags, + ); } pub(crate) struct RootExplanationContext<'a> { @@ -86,9 +88,11 @@ fn get_required_assumptions( // If the predicate is a root-level assignment, add the appropriate inference to the proof. if context.state.assignments.is_initial_bound(predicate) { - let _ = context - .proof_log - .log_domain_inference(predicate, context.state.variable_names()); + let _ = context.proof_log.log_domain_inference( + predicate, + &context.state.variable_names, + &mut context.state.constraint_tags, + ); return vec![]; } @@ -96,10 +100,11 @@ fn get_required_assumptions( if let Some(inference_code) = context.unit_nogood_inference_codes.get(&predicate) { let _ = context.proof_log.log_inference( &context.state.inference_codes, + &mut context.state.constraint_tags, *inference_code, [], Some(predicate), - context.state.variable_names(), + &context.state.variable_names, ); return vec![]; } diff --git a/pumpkin-crates/core/src/proof/mod.rs b/pumpkin-crates/core/src/proof/mod.rs index 2c7095825..091a660c0 100644 --- a/pumpkin-crates/core/src/proof/mod.rs +++ b/pumpkin-crates/core/src/proof/mod.rs @@ -65,7 +65,6 @@ impl ProofLog { propagation_order_hint: if log_hints { Some(vec![]) } else { None }, logged_domain_inferences: HashMap::default(), proof_atomics: ProofAtomics::default(), - constraint_tags: KeyGenerator::default(), }), }) } @@ -82,6 +81,7 @@ impl ProofLog { pub(crate) fn log_inference( &mut self, inference_codes: &KeyedVec)>, + constraint_tags: &mut KeyGenerator, inference_code: InferenceCode, premises: impl IntoIterator, propagated: Option, @@ -90,7 +90,6 @@ impl ProofLog { let Some(ProofImpl::CpProof { writer, propagation_order_hint: Some(propagation_sequence), - constraint_tags, proof_atomics, .. }) = self.internal_proof.as_mut() @@ -127,12 +126,12 @@ impl ProofLog { &mut self, predicate: Predicate, variable_names: &VariableNames, + constraint_tags: &mut KeyGenerator, ) -> std::io::Result { let Some(ProofImpl::CpProof { writer, propagation_order_hint: Some(propagation_sequence), logged_domain_inferences, - constraint_tags, proof_atomics, .. }) = self.internal_proof.as_mut() @@ -180,12 +179,12 @@ impl ProofLog { &mut self, premises: impl IntoIterator, variable_names: &VariableNames, + constraint_tags: &mut KeyGenerator, ) -> std::io::Result { match &mut self.internal_proof { Some(ProofImpl::CpProof { writer, propagation_order_hint, - constraint_tags, proof_atomics, logged_domain_inferences, .. @@ -291,17 +290,6 @@ impl ProofLog { proof_atomics.reify_predicate(literal, predicate); } - /// Create a new constraint tag. - pub(crate) fn new_constraint_tag(&mut self) -> ConstraintTag { - match self.internal_proof { - Some(ProofImpl::CpProof { - ref mut constraint_tags, - .. - }) => constraint_tags.next_key(), - _ => ConstraintTag::create_from_index(0), - } - } - pub(crate) fn is_logging_proof(&self) -> bool { self.internal_proof.is_some() } @@ -345,8 +333,6 @@ impl Write for Sink { enum ProofImpl { CpProof { writer: ProofWriter, - /// The [`ConstraintTag`]s generated for this proof. - constraint_tags: KeyGenerator, // If propagation hints are enabled, this is a buffer used to record propagations in the // order they can be applied to derive the next nogood. //