Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pumpkin-crates/core/src/engine/cp/propagation/constructor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,12 @@ impl PropagatorConstructorContext<'_> {
/// Returns the [`PredicateId`] used by the solver to track the predicate.
#[allow(unused, reason = "will become public API")]
pub(crate) fn register_predicate(&mut self, predicate: Predicate) -> PredicateId {
self.state
.notification_engine
.watch_predicate(predicate, self.propagator_id)
self.state.notification_engine.watch_predicate(
predicate,
self.propagator_id,
&mut self.state.trailed_values,
&self.state.assignments,
)
}

/// Subscribes the propagator to the given [`DomainEvents`] when they are undone during
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use crate::engine::TrailedValues;
use crate::engine::notifications::NotificationEngine;
use crate::engine::notifications::PredicateIdAssignments;
use crate::engine::predicates::predicate::Predicate;
#[cfg(doc)]
use crate::engine::propagation::Propagator;
use crate::engine::propagation::PropagatorId;
use crate::engine::reason::Reason;
use crate::engine::reason::ReasonStore;
Expand Down Expand Up @@ -97,6 +99,19 @@ impl<'a> PropagationContextMut<'a> {
self.notification_engine.get_id(predicate)
}

/// Register the propagator to be enqueued when the provided [`Predicate`] becomes true.
///
/// Returns the [`PredicateId`] assigned to the provided predicate, which will be provided
/// to [`Propagator::notify_predicate_satisfied`].
pub(crate) fn register_predicate(&mut self, predicate: Predicate) -> PredicateId {
self.notification_engine.watch_predicate(
predicate,
self.propagator_id,
self.trailed_values,
self.assignments,
)
}

/// Apply a reification literal to all the explanations that are passed to the context.
pub(crate) fn with_reification(&mut self, reification_literal: Literal) {
pumpkin_assert_simple!(
Expand Down
6 changes: 4 additions & 2 deletions pumpkin-crates/core/src/engine/cp/propagation/propagator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,10 @@ pub(crate) trait Propagator: Downcast + DynClone {

/// Called when a [`PredicateId`] has been satisfied.
///
/// By default, the propagator does nothing when this method is called.
fn notify_predicate_id_satisfied(&mut self, _predicate_id: PredicateId) {}
/// By default, the propagator will be enqueued.
fn notify_predicate_id_satisfied(&mut self, _predicate_id: PredicateId) -> EnqueueDecision {
EnqueueDecision::Enqueue
}

/// Called each time the [`ConstraintSatisfactionSolver`] backtracks, the propagator can then
/// update its internal data structures given the new variable domains.
Expand Down
5 changes: 0 additions & 5 deletions pumpkin-crates/core/src/engine/cp/propagation/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,6 @@ impl PropagatorStore {
None
}
}

#[cfg(test)]
pub(crate) fn keys(&self) -> impl Iterator<Item = PropagatorId> + '_ {
self.propagators.keys()
}
}

impl Index<PropagatorId> for PropagatorStore {
Expand Down
42 changes: 26 additions & 16 deletions pumpkin-crates/core/src/engine/cp/propagator_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::collections::VecDeque;

use crate::containers::HashSet;
use crate::containers::KeyedVec;
use crate::engine::cp::propagation::PropagatorId;
use crate::pumpkin_assert_moderate;

#[derive(Debug, Clone)]
pub(crate) struct PropagatorQueue {
queues: Vec<VecDeque<PropagatorId>>,
present_propagators: HashSet<PropagatorId>,
is_enqueued: KeyedVec<PropagatorId, bool>,
num_enqueued: usize,
present_priorities: BinaryHeap<Reverse<u32>>,
}

Expand All @@ -23,29 +24,28 @@ impl PropagatorQueue {
pub(crate) fn new(num_priority_levels: u32) -> PropagatorQueue {
PropagatorQueue {
queues: vec![VecDeque::new(); num_priority_levels as usize],
present_propagators: HashSet::default(),
is_enqueued: KeyedVec::default(),
num_enqueued: 0,
present_priorities: BinaryHeap::new(),
}
}

pub(crate) fn is_empty(&self) -> bool {
self.present_propagators.is_empty()
}

#[cfg(test)]
pub(crate) fn is_propagator_present(&self, propagator_id: PropagatorId) -> bool {
self.present_propagators.contains(&propagator_id)
self.num_enqueued == 0
}

pub(crate) fn enqueue_propagator(&mut self, propagator_id: PropagatorId, priority: u32) {
pumpkin_assert_moderate!((priority as usize) < self.queues.len());

if !self.is_propagator_enqueued(propagator_id) {
self.is_enqueued.accomodate(propagator_id, false);
self.is_enqueued[propagator_id] = true;
self.num_enqueued += 1;

if self.queues[priority as usize].is_empty() {
self.present_priorities.push(Reverse(priority));
}
self.queues[priority as usize].push_back(propagator_id);
let _ = self.present_propagators.insert(propagator_id);
}
}

Expand All @@ -59,13 +59,15 @@ impl PropagatorQueue {

let next_propagator_id = self.queues[top_priority].pop_front();

next_propagator_id.iter().for_each(|next_propagator_id| {
let _ = self.present_propagators.remove(next_propagator_id);
if let Some(propagator_id) = next_propagator_id {
self.is_enqueued[propagator_id] = false;

if self.queues[top_priority].is_empty() {
let _ = self.present_priorities.pop();
}
});
}

self.num_enqueued -= 1;

next_propagator_id
}
Expand All @@ -76,11 +78,19 @@ impl PropagatorQueue {
pumpkin_assert_moderate!(!self.queues[priority].is_empty());
self.queues[priority].clear();
}
self.present_propagators.clear();

for is_propagator_enqueued in self.is_enqueued.iter_mut() {
*is_propagator_enqueued = false;
}

self.present_priorities.clear();
self.num_enqueued = 0;
}

fn is_propagator_enqueued(&self, propagator_id: PropagatorId) -> bool {
self.present_propagators.contains(&propagator_id)
pub(crate) fn is_propagator_enqueued(&self, propagator_id: PropagatorId) -> bool {
self.is_enqueued
.get(propagator_id)
.copied()
.unwrap_or_default()
}
}
6 changes: 3 additions & 3 deletions pumpkin-crates/core/src/engine/cp/test_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl TestSolver {
&mut self.state.propagators,
&mut propagator_queue,
);
if propagator_queue.is_propagator_present(propagator) {
if propagator_queue.is_propagator_enqueued(propagator) {
EnqueueDecision::Enqueue
} else {
EnqueueDecision::Skip
Expand Down Expand Up @@ -138,7 +138,7 @@ impl TestSolver {
&mut self.state.propagators,
&mut propagator_queue,
);
if propagator_queue.is_propagator_present(propagator) {
if propagator_queue.is_propagator_enqueued(propagator) {
EnqueueDecision::Enqueue
} else {
EnqueueDecision::Skip
Expand Down Expand Up @@ -166,7 +166,7 @@ impl TestSolver {
&mut self.state.propagators,
&mut propagator_queue,
);
if propagator_queue.is_propagator_present(propagator) {
if propagator_queue.is_propagator_enqueued(propagator) {
EnqueueDecision::Enqueue
} else {
EnqueueDecision::Skip
Expand Down
81 changes: 14 additions & 67 deletions pumpkin-crates/core/src/engine/notifications/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ use crate::engine::propagation::PropagatorId;
use crate::engine::propagation::contexts::PropagationContextWithTrailedValues;
use crate::engine::propagation::store::PropagatorStore;
use crate::predicates::Predicate;
use crate::propagators::nogoods::NogoodPropagator;
use crate::pumpkin_assert_extreme;
use crate::pumpkin_assert_simple;
use crate::state::PropagatorHandle;
use crate::variables::DomainId;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -142,13 +140,18 @@ impl NotificationEngine {
&mut self,
predicate: Predicate,
propagator_id: PropagatorId,
trailed_values: &mut TrailedValues,
assignments: &Assignments,
) -> PredicateId {
let predicate_id = self.get_id(predicate);

self.watch_list_predicate_id
.accomodate(predicate_id, vec![]);
self.watch_list_predicate_id[predicate_id].push(propagator_id);

self.predicate_notifier
.track_predicate(predicate_id, trailed_values, assignments);

predicate_id
}

Expand Down Expand Up @@ -264,7 +267,6 @@ impl NotificationEngine {
assignments: &mut Assignments,
trailed_values: &mut TrailedValues,
propagators: &mut PropagatorStore,
nogood_propagator_handle: PropagatorHandle<NogoodPropagator>,
propagator_queue: &mut PropagatorQueue,
) {
// We first take the events because otherwise we get mutability issues when calling methods
Expand All @@ -275,17 +277,6 @@ impl NotificationEngine {
// First we notify the predicate_notifier that a domain has been updated
self.predicate_notifier
.on_update(trailed_values, assignments, event, domain);
// Special case: the nogood propagator is notified about each event.
Self::notify_nogood_propagator(
nogood_propagator_handle,
&mut self.predicate_notifier.predicate_id_assignments,
event,
domain,
propagators,
propagator_queue,
assignments,
trailed_values,
);
// Now notify other propagators subscribed to this event.
#[allow(clippy::unnecessary_to_owned, reason = "Not unnecessary?")]
for propagator_var in self
Expand All @@ -310,7 +301,7 @@ impl NotificationEngine {
self.events = events;

// Then we notify the propagators that a predicate has been satisfied.
self.notify_predicate_id_satisfied(nogood_propagator_handle, propagators);
self.notify_predicate_id_satisfied(propagators, propagator_queue);

self.last_notified_trail_index = assignments.num_trail_entries();
}
Expand Down Expand Up @@ -348,50 +339,25 @@ impl NotificationEngine {
/// Notifies the propagator that certain [`Predicate`]s have been satisfied.
fn notify_predicate_id_satisfied(
&mut self,
nogood_propagator_handle: PropagatorHandle<NogoodPropagator>,
propagators: &mut PropagatorStore,
propagator_queue: &mut PropagatorQueue,
) {
for predicate_id in self.predicate_notifier.drain_satisfied_predicates() {
if let Some(watch_list) = self.watch_list_predicate_id.get(predicate_id) {
let propagators_to_notify = watch_list.iter().copied();

for propagator_id in propagators_to_notify {
propagators[propagator_id].notify_predicate_id_satisfied(predicate_id);
let propagator = &mut propagators[propagator_id];
let enqueue_decision = propagator.notify_predicate_id_satisfied(predicate_id);

if enqueue_decision == EnqueueDecision::Enqueue {
propagator_queue.enqueue_propagator(propagator_id, propagator.priority());
}
}
}

propagators[nogood_propagator_handle.propagator_id()]
.notify_predicate_id_satisfied(predicate_id);
}
}

#[allow(clippy::too_many_arguments, reason = "to be refactored later")]
fn notify_nogood_propagator(
nogood_propagator_id: PropagatorHandle<NogoodPropagator>,
predicate_id_assignments: &mut PredicateIdAssignments,
event: DomainEvent,
domain: DomainId,
propagators: &mut PropagatorStore,
propagator_queue: &mut PropagatorQueue,
assignments: &mut Assignments,
trailed_values: &mut TrailedValues,
) {
// The nogood propagator is implicitly subscribed to every domain event for every variable.
// For this reason, its local id matches the domain id.
// This is special only for the nogood propagator.
let local_id = LocalId::from(domain.id());
Self::notify_propagator(
predicate_id_assignments,
nogood_propagator_id.propagator_id(),
local_id,
event,
propagators,
propagator_queue,
assignments,
trailed_values,
);
}

#[allow(clippy::too_many_arguments, reason = "Should be refactored")]
fn notify_propagator(
predicate_id_assignments: &mut PredicateIdAssignments,
Expand Down Expand Up @@ -425,16 +391,6 @@ impl NotificationEngine {
let _ = self.events.drain();
}

pub(crate) fn track_predicate(
&mut self,
predicate: PredicateId,
trailed_values: &mut TrailedValues,
assignments: &Assignments,
) {
self.predicate_notifier
.track_predicate(predicate, trailed_values, assignments)
}

#[cfg(test)]
pub(crate) fn drain_backtrack_domain_events(
&mut self,
Expand All @@ -459,12 +415,6 @@ impl NotificationEngine {
propagators: &mut PropagatorStore,
propagator_queue: &mut PropagatorQueue,
) {
// There may be a nogood propagator in the store. In that case we need to always
// notify it.
let nogood_propagator_handle = propagators
.keys()
.find_map(|id| propagators.as_propagator_handle::<NogoodPropagator>(id));

// Collect so that we can pass the assignments to the methods within the loop
for (event, domain) in self.events.drain().collect::<Vec<_>>() {
// First we notify the predicate_notifier that a domain has been updated
Expand Down Expand Up @@ -493,10 +443,7 @@ impl NotificationEngine {
}
}

if let Some(handle) = nogood_propagator_handle {
// Then we notify the propagators that a predicate has been satisfied.
self.notify_predicate_id_satisfied(handle, propagators);
}
self.notify_predicate_id_satisfied(propagators, propagator_queue);

self.last_notified_trail_index = assignments.num_trail_entries();
}
Expand Down
2 changes: 0 additions & 2 deletions pumpkin-crates/core/src/engine/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,6 @@ impl State {
&mut self.assignments,
&mut self.trailed_values,
&mut self.propagators,
PropagatorHandle::new(PropagatorId(0)),
&mut self.propagator_queue,
);
pumpkin_assert_extreme!(
Expand Down Expand Up @@ -628,7 +627,6 @@ impl State {
&mut self.assignments,
&mut self.trailed_values,
&mut self.propagators,
PropagatorHandle::new(PropagatorId(0)),
&mut self.propagator_queue,
);

Expand Down
Loading