Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions pumpkin-crates/core/src/propagation/constructor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ pub struct PropagatorConstructorContext<'a> {
/// either a reference or an owned value, to support
/// [`PropagatorConstructorContext::reborrow`].
next_local_id: RefOrOwned<'a, LocalId>,

/// Marker to indicate whether the constructor registered for at least one domain event or
/// predicate becoming assigned. If not, the [`Drop`] implementation will cause a panic.
did_register: RefOrOwned<'a, bool>,
}

impl PropagatorConstructorContext<'_> {
Expand All @@ -65,9 +69,19 @@ impl PropagatorConstructorContext<'_> {
next_local_id: RefOrOwned::Owned(LocalId::from(0)),
propagator_id,
state,
did_register: RefOrOwned::Owned(false),
}
}

/// Indicate that the constructor is deliberately not registering the propagator to be enqueued
/// at any time.
///
/// If this is called and later a registration happens, then the registration will still go
/// through. Calling this function only prevents the crash if no registration happens.
pub fn will_not_register_any_events(&mut self) {
*self.did_register = true;
}

/// Get domain information.
pub fn domains(&mut self) -> Domains<'_> {
Domains::new(&self.state.assignments, &mut self.state.trailed_values)
Expand All @@ -87,6 +101,8 @@ impl PropagatorConstructorContext<'_> {
domain_events: DomainEvents,
local_id: LocalId,
) {
self.will_not_register_any_events();

let propagator_var = PropagatorVarId {
propagator: self.propagator_id,
variable: local_id,
Expand All @@ -101,6 +117,8 @@ impl PropagatorConstructorContext<'_> {
/// Register the propagator to be enqueued when the given [`Predicate`] becomes true.
/// Returns the [`PredicateId`] used by the solver to track the predicate.
pub fn register_predicate(&mut self, predicate: Predicate) -> PredicateId {
self.will_not_register_any_events();

self.state.notification_engine.watch_predicate(
predicate,
self.propagator_id,
Expand Down Expand Up @@ -165,6 +183,10 @@ impl PropagatorConstructorContext<'_> {
RefOrOwned::Ref(next_local_id) => RefOrOwned::Ref(next_local_id),
RefOrOwned::Owned(next_local_id) => RefOrOwned::Ref(next_local_id),
},
did_register: match &mut self.did_register {
RefOrOwned::Ref(did_register) => RefOrOwned::Ref(did_register),
RefOrOwned::Owned(did_register) => RefOrOwned::Ref(did_register),
},
state: self.state,
}
}
Expand All @@ -177,6 +199,18 @@ impl PropagatorConstructorContext<'_> {
}
}

impl Drop for PropagatorConstructorContext<'_> {
fn drop(&mut self) {
if let RefOrOwned::Owned(did_register) = self.did_register
&& !did_register
{
panic!(
"Propagator did not register to be enqueued. If this is intentional, call PropagatorConstructorContext::will_not_register_any_events()."
);
}
}
}

/// Either owns a value or has a mutable reference to a value.
///
/// Used to store data in a reborrowed context that needs to be 'shared' with the original context
Expand Down Expand Up @@ -234,15 +268,47 @@ mod tests {
use super::*;
use crate::variables::DomainId;

#[test]
#[should_panic]
fn panic_when_no_registration_happened() {
let mut state = State::default();
state.notification_engine.grow();

let _c1 = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
}

#[test]
fn do_not_panic_if_told_no_registration_will_happen() {
let mut state = State::default();
state.notification_engine.grow();

let mut ctx = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
ctx.will_not_register_any_events();
}

#[test]
fn do_not_panic_if_no_registration_happens_in_reborrowed() {
let mut state = State::default();
state.notification_engine.grow();

let mut ctx = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
let ctx2 = ctx.reborrow();
drop(ctx2);

ctx.will_not_register_any_events();
}

#[test]
fn reborrowing_remembers_next_local_id() {
let mut state = State::default();
state.notification_engine.grow();

let mut c1 = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
c1.will_not_register_any_events();

let mut c2 = c1.reborrow();
c2.register(DomainId::new(0), DomainEvents::ANY_INT, LocalId::from(1));
drop(c2);

assert_eq!(LocalId::from(2), c1.get_next_local_id());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ impl NogoodPropagatorConstructor {
impl PropagatorConstructor for NogoodPropagatorConstructor {
type PropagatorImpl = NogoodPropagator;

fn create(self, context: PropagatorConstructorContext) -> Self::PropagatorImpl {
fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl {
context.will_not_register_any_events();

NogoodPropagator {
handle: PropagatorHandle::new(context.propagator_id),
parameters: self.parameters,
Expand Down