Skip to content

Commit

Permalink
Fix a problem where callbacks given to IPASIR solver where dropped early
Browse files Browse the repository at this point in the history
  • Loading branch information
Dekker1 committed May 23, 2024
1 parent e7087d6 commit 8fb3c6d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
8 changes: 6 additions & 2 deletions crates/pindakaas-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
cb: Option<F>,
) {
if let Some(mut cb) = cb {
let mut wrapped_cb = || -> std::ffi::c_int {
let mut wrapped_cb = move || -> std::ffi::c_int {
match cb() {
crate::solver::SlvTermSignal::Continue => std::ffi::c_int::from(0),
crate::solver::SlvTermSignal::Terminate => std::ffi::c_int::from(1),
Expand All @@ -109,6 +109,8 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
Some(crate::solver::libloading::get_trampoline0(&wrapped_cb)),
)
}
// WARNING: Any data in the callback now exists forever
std::mem::forget(wrapped_cb);
} else {
unsafe { #krate::ipasir_set_terminate(#ptr, std::ptr::null_mut(), None) }
}
Expand All @@ -128,7 +130,7 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
) {
const MAX_LEN: std::ffi::c_int = 512;
if let Some(mut cb) = cb {
let mut wrapped_cb = |clause: *const i32| {
let mut wrapped_cb = move |clause: *const i32| {
let mut iter = crate::solver::libloading::ExplIter(clause)
.map(|i: i32| crate::Lit(std::num::NonZeroI32::new(i).unwrap()));
cb(&mut iter)
Expand All @@ -142,6 +144,8 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
Some(crate::solver::libloading::get_trampoline1(&wrapped_cb)),
)
}
// WARNING: Any data in the callback now exists forever
std::mem::forget(wrapped_cb);
} else {
unsafe { #krate::ipasir_set_learn(#ptr, std::ptr::null_mut(), MAX_LEN, None) }
}
Expand Down
17 changes: 16 additions & 1 deletion crates/pindakaas/src/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,14 @@ pub trait LearnCallback: Solver {
/// Set a callback function used to extract learned clauses up to a given
/// length from the solver.
///
/// WARNING: Subsequent calls to this method override the previously set
/// # Warning
///
/// Subsequent calls to this method override the previously set
/// callback function.
///
/// For IPASIR connected through C, the callback and any objects contained
/// within it might be leaked to satisfy the FFI requirements. Note that
/// [`Drop`] implementations might not be called on these objects.
fn set_learn_callback<F: FnMut(&mut dyn Iterator<Item = Lit>)>(&mut self, cb: Option<F>);
}

Expand All @@ -80,6 +86,15 @@ pub trait TermCallback: Solver {
/// The solver will periodically call this function and check its return value
/// during the search. Subsequent calls to this method override the previously
/// set callback function.
///
/// # Warning
///
/// Subsequent calls to this method override the previously set
/// callback function.
///
/// For IPASIR connected through C, the callback and any objects contained
/// within it might be leaked to satisfy the FFI requirements. Note that
/// [`Drop`] implementations might not be called on these objects.
fn set_terminate_callback<F: FnMut() -> SlvTermSignal>(&mut self, cb: Option<F>);
}

Expand Down
47 changes: 46 additions & 1 deletion crates/pindakaas/src/solver/cadical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ mod tests {
use super::*;
use crate::{
linear::LimitComp,
solver::{SolveResult, Solver},
solver::{LearnCallback, SlvTermSignal, SolveResult, Solver, TermCallback},
CardinalityOne, ClauseDatabase, Encoder, PairwiseEncoder, Valuation,
};

Expand Down Expand Up @@ -102,6 +102,51 @@ mod tests {
});
}

#[test]
fn test_cadical_cb_no_drop() {
let mut slv = Cadical::default();

let a = slv.new_var().into();
let b = slv.new_var().into();
PairwiseEncoder::default()
.encode(
&mut slv,
&CardinalityOne {
lits: vec![a, b],
cmp: LimitComp::Equal,
},
)
.unwrap();

struct NoDrop(i32);
impl NoDrop {
fn seen(&mut self) {
self.0 += 1;
eprintln!("seen {}", self.0);
}
}
impl Drop for NoDrop {
fn drop(&mut self) {
panic!("I have been dropped {}", self.0);
}
}

{
let mut nodrop = NoDrop(0);
slv.set_terminate_callback(Some(move || {
nodrop.seen();
SlvTermSignal::Continue
}));
}
{
let mut nodrop = NoDrop(0);
slv.set_learn_callback(Some(move |_: &mut dyn Iterator<Item = Lit>| {
nodrop.seen();
}));
}
assert_eq!(slv.solve(|_| {}), SolveResult::Sat);
}

#[cfg(feature = "ipasir-up")]
#[test]
fn test_ipasir_up() {
Expand Down

0 comments on commit 8fb3c6d

Please sign in to comment.