diff --git a/pumpkin-crates/core/src/propagation/constructor.rs b/pumpkin-crates/core/src/propagation/constructor.rs index 290f625fc..f68e5518b 100644 --- a/pumpkin-crates/core/src/propagation/constructor.rs +++ b/pumpkin-crates/core/src/propagation/constructor.rs @@ -54,6 +54,10 @@ pub struct PropagatorConstructorContext<'a> { /// either a reference or an owned value, to support /// [`PropagatorConstructorContext::reborrow`]. next_local_id: RefOrOwned<'a, LocalId>, + + /// Marker to indicate whether the constructor registered for at least one domain event or + /// predicate becoming assigned. If not, the [`Drop`] implementation will cause a panic. + did_register: RefOrOwned<'a, bool>, } impl PropagatorConstructorContext<'_> { @@ -65,9 +69,19 @@ impl PropagatorConstructorContext<'_> { next_local_id: RefOrOwned::Owned(LocalId::from(0)), propagator_id, state, + did_register: RefOrOwned::Owned(false), } } + /// Indicate that the constructor is deliberately not registering the propagator to be enqueued + /// at any time. + /// + /// If this is called and later a registration happens, then the registration will still go + /// through. Calling this function only prevents the crash if no registration happens. + pub fn will_not_register_any_events(&mut self) { + *self.did_register = true; + } + /// Get domain information. pub fn domains(&mut self) -> Domains<'_> { Domains::new(&self.state.assignments, &mut self.state.trailed_values) @@ -87,6 +101,8 @@ impl PropagatorConstructorContext<'_> { domain_events: DomainEvents, local_id: LocalId, ) { + self.will_not_register_any_events(); + let propagator_var = PropagatorVarId { propagator: self.propagator_id, variable: local_id, @@ -101,6 +117,8 @@ impl PropagatorConstructorContext<'_> { /// Register the propagator to be enqueued when the given [`Predicate`] becomes true. /// Returns the [`PredicateId`] used by the solver to track the predicate. pub fn register_predicate(&mut self, predicate: Predicate) -> PredicateId { + self.will_not_register_any_events(); + self.state.notification_engine.watch_predicate( predicate, self.propagator_id, @@ -165,6 +183,10 @@ impl PropagatorConstructorContext<'_> { RefOrOwned::Ref(next_local_id) => RefOrOwned::Ref(next_local_id), RefOrOwned::Owned(next_local_id) => RefOrOwned::Ref(next_local_id), }, + did_register: match &mut self.did_register { + RefOrOwned::Ref(did_register) => RefOrOwned::Ref(did_register), + RefOrOwned::Owned(did_register) => RefOrOwned::Ref(did_register), + }, state: self.state, } } @@ -177,6 +199,18 @@ impl PropagatorConstructorContext<'_> { } } +impl Drop for PropagatorConstructorContext<'_> { + fn drop(&mut self) { + if let RefOrOwned::Owned(did_register) = self.did_register + && !did_register + { + panic!( + "Propagator did not register to be enqueued. If this is intentional, call PropagatorConstructorContext::will_not_register_any_events()." + ); + } + } +} + /// Either owns a value or has a mutable reference to a value. /// /// Used to store data in a reborrowed context that needs to be 'shared' with the original context @@ -234,15 +268,47 @@ mod tests { use super::*; use crate::variables::DomainId; + #[test] + #[should_panic] + fn panic_when_no_registration_happened() { + let mut state = State::default(); + state.notification_engine.grow(); + + let _c1 = PropagatorConstructorContext::new(PropagatorId(0), &mut state); + } + + #[test] + fn do_not_panic_if_told_no_registration_will_happen() { + let mut state = State::default(); + state.notification_engine.grow(); + + let mut ctx = PropagatorConstructorContext::new(PropagatorId(0), &mut state); + ctx.will_not_register_any_events(); + } + + #[test] + fn do_not_panic_if_no_registration_happens_in_reborrowed() { + let mut state = State::default(); + state.notification_engine.grow(); + + let mut ctx = PropagatorConstructorContext::new(PropagatorId(0), &mut state); + let ctx2 = ctx.reborrow(); + drop(ctx2); + + ctx.will_not_register_any_events(); + } + #[test] fn reborrowing_remembers_next_local_id() { let mut state = State::default(); state.notification_engine.grow(); let mut c1 = PropagatorConstructorContext::new(PropagatorId(0), &mut state); + c1.will_not_register_any_events(); let mut c2 = c1.reborrow(); c2.register(DomainId::new(0), DomainEvents::ANY_INT, LocalId::from(1)); + drop(c2); assert_eq!(LocalId::from(2), c1.get_next_local_id()); } diff --git a/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs b/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs index 7cfe1cfbb..e32729071 100644 --- a/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs +++ b/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs @@ -103,7 +103,9 @@ impl NogoodPropagatorConstructor { impl PropagatorConstructor for NogoodPropagatorConstructor { type PropagatorImpl = NogoodPropagator; - fn create(self, context: PropagatorConstructorContext) -> Self::PropagatorImpl { + fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { + context.will_not_register_any_events(); + NogoodPropagator { handle: PropagatorHandle::new(context.propagator_id), parameters: self.parameters,