Skip to content

Commit

Permalink
Store solver callbacks in the solver struct
Browse files Browse the repository at this point in the history
  • Loading branch information
Dekker1 committed Jun 5, 2024
1 parent ea9c76f commit 6695c90
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 98 deletions.
42 changes: 26 additions & 16 deletions crates/pindakaas-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ struct IpasirOpts {
#[darling(default)]
learn_callback: bool,
#[darling(default)]
learn_callback_ident: Option<Ident>,
#[darling(default)]
term_callback: bool,
#[darling(default)]
term_callback_ident: Option<Ident>,
#[darling(default)]
ipasir_up: bool,
#[darling(default = "default_true")]
has_default: bool,
Expand Down Expand Up @@ -88,30 +92,33 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
};

let term_callback = if opts.term_callback {
let term_cb = match opts.term_callback_ident {
Some(x) => quote! { self. #x },
None => quote! { self.term_cb },
};
quote! {
impl crate::solver::TermCallback for #ident {
fn set_terminate_callback<F: FnMut() -> crate::solver::SlvTermSignal>(
fn set_terminate_callback<F: FnMut() -> crate::solver::SlvTermSignal + 'static>(
&mut self,
cb: Option<F>,
) {
if let Some(mut cb) = cb {
let mut wrapped_cb = move || -> std::ffi::c_int {
#term_cb = crate::solver::libloading::TermCB::new(move || -> std::ffi::c_int {
match cb() {
crate::solver::SlvTermSignal::Continue => std::ffi::c_int::from(0),
crate::solver::SlvTermSignal::Terminate => std::ffi::c_int::from(1),
}
};
let trampoline = crate::solver::libloading::get_trampoline0(&wrapped_cb);
// WARNING: Any data in the callback now exists forever
let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void;
});

unsafe {
#krate::ipasir_set_terminate(
#ptr,
data,
Some(trampoline),
#term_cb .as_ptr(),
Some(crate::solver::libloading::TermCB::exec_callback),
)
}
} else {
#term_cb = Default::default();
unsafe { #krate::ipasir_set_terminate(#ptr, std::ptr::null_mut(), None) }
}
}
Expand All @@ -122,31 +129,34 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
};

let learn_callback = if opts.learn_callback {
let learn_cb = match opts.learn_callback_ident {
Some(x) => quote! { self. #x },
None => quote! { self.learn_cb },
};
quote! {
impl crate::solver::LearnCallback for #ident {
fn set_learn_callback<F: FnMut(&mut dyn Iterator<Item = crate::Lit>)>(
fn set_learn_callback<F: FnMut(&mut dyn Iterator<Item = crate::Lit>) + 'static>(
&mut self,
cb: Option<F>,
) {
const MAX_LEN: std::ffi::c_int = 512;
if let Some(mut cb) = cb {
let mut wrapped_cb = move |clause: *const i32| {
#learn_cb = crate::solver::libloading::LearnCB::new(move |clause: *const i32| {
let mut iter = crate::solver::libloading::ExplIter(clause)
.map(|i: i32| crate::Lit(std::num::NonZeroI32::new(i).unwrap()));
cb(&mut iter)
};
let trampoline = crate::solver::libloading::get_trampoline1(&wrapped_cb);
// WARNING: Any data in the callback now exists forever
let data = Box::leak(Box::new(wrapped_cb)) as *mut _ as *mut std::ffi::c_void;
});

unsafe {
#krate::ipasir_set_learn(
#ptr,
data,
#learn_cb .as_ptr(),
MAX_LEN,
Some(trampoline),
Some(crate::solver::libloading::LearnCB::exec_callback),
)
}
} else {
#learn_cb = Default::default();
unsafe { #krate::ipasir_set_learn(#ptr, std::ptr::null_mut(), MAX_LEN, None) }
}
}
Expand Down
15 changes: 5 additions & 10 deletions crates/pindakaas/src/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,10 @@ pub trait LearnCallback: Solver {
///
/// Subsequent calls to this method override the previously set
/// callback function.
///
/// For IPASIR connected through C, the callback and any objects contained
/// within it might be leaked to satisfy the FFI requirements. Note that
/// [`Drop`] implementations might not be called on these objects.
fn set_learn_callback<F: FnMut(&mut dyn Iterator<Item = Lit>)>(&mut self, cb: Option<F>);
fn set_learn_callback<F: FnMut(&mut dyn Iterator<Item = Lit>) + 'static>(
&mut self,
cb: Option<F>,
);
}

pub trait TermCallback: Solver {
Expand All @@ -91,11 +90,7 @@ pub trait TermCallback: Solver {
///
/// Subsequent calls to this method override the previously set
/// callback function.
///
/// For IPASIR connected through C, the callback and any objects contained
/// within it might be leaked to satisfy the FFI requirements. Note that
/// [`Drop`] implementations might not be called on these objects.
fn set_terminate_callback<F: FnMut() -> SlvTermSignal>(&mut self, cb: Option<F>);
fn set_terminate_callback<F: FnMut() -> SlvTermSignal + 'static>(&mut self, cb: Option<F>);
}

#[cfg(feature = "ipasir-up")]
Expand Down
64 changes: 17 additions & 47 deletions crates/pindakaas/src/solver/cadical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,25 @@ use pindakaas_cadical::{ccadical_copy, ccadical_phase, ccadical_unphase};
use pindakaas_derive::IpasirSolver;

use super::VarFactory;
use crate::Lit;
use crate::{
solver::libloading::{LearnCB, TermCB},
Lit,
};

#[derive(IpasirSolver)]
#[ipasir(krate = pindakaas_cadical, assumptions, learn_callback, term_callback, ipasir_up)]
pub struct Cadical {
/// The raw pointer to the Cadical solver.
ptr: *mut std::ffi::c_void,
/// The variable factory for this solver.
vars: VarFactory,
/// The callback used when a clause is learned.
learn_cb: LearnCB,
/// The callback used to check whether the solver should terminate.
term_cb: TermCB,

#[cfg(feature = "ipasir-up")]
/// The external propagator called by the solver
prop: Option<Box<CadicalProp>>,
}

Expand All @@ -20,6 +31,8 @@ impl Default for Cadical {
Self {
ptr: unsafe { pindakaas_cadical::ipasir_init() },
vars: VarFactory::default(),
learn_cb: LearnCB::default(),
term_cb: TermCB::default(),
#[cfg(feature = "ipasir-up")]
prop: None,
}
Expand All @@ -32,6 +45,8 @@ impl Clone for Cadical {
Self {
ptr,
vars: self.vars,
learn_cb: LearnCB::default(),
term_cb: TermCB::default(),
#[cfg(feature = "ipasir-up")]
prop: None,
}
Expand Down Expand Up @@ -77,7 +92,7 @@ mod tests {
use super::*;
use crate::{
linear::LimitComp,
solver::{LearnCallback, SlvTermSignal, SolveResult, Solver, TermCallback},
solver::{SolveResult, Solver},
CardinalityOne, ClauseDatabase, Encoder, PairwiseEncoder, Valuation,
};

Expand Down Expand Up @@ -114,51 +129,6 @@ mod tests {
});
}

#[test]
fn test_cadical_cb_no_drop() {
let mut slv = Cadical::default();

let a = slv.new_var().into();
let b = slv.new_var().into();
PairwiseEncoder::default()
.encode(
&mut slv,
&CardinalityOne {
lits: vec![a, b],
cmp: LimitComp::Equal,
},
)
.unwrap();

struct NoDrop(i32);
impl NoDrop {
fn seen(&mut self) {
self.0 += 1;
eprintln!("seen {}", self.0);
}
}
impl Drop for NoDrop {
fn drop(&mut self) {
panic!("I have been dropped {}", self.0);
}
}

{
let mut nodrop = NoDrop(0);
slv.set_terminate_callback(Some(move || {
nodrop.seen();
SlvTermSignal::Continue
}));
}
{
let mut nodrop = NoDrop(0);
slv.set_learn_callback(Some(move |_: &mut dyn Iterator<Item = Lit>| {
nodrop.seen();
}));
}
assert_eq!(slv.solve(|_| {}), SolveResult::Sat);
}

#[cfg(feature = "ipasir-up")]
#[test]
fn test_ipasir_up() {
Expand Down
9 changes: 9 additions & 0 deletions crates/pindakaas/src/solver/intel_sat.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
use pindakaas_derive::IpasirSolver;

use super::VarFactory;
use crate::solver::libloading::{LearnCB, TermCB};

#[derive(Debug, IpasirSolver)]
#[ipasir(krate = pindakaas_intel_sat, assumptions, learn_callback, term_callback)]
pub struct IntelSat {
/// The raw pointer to the Intel SAT solver.
ptr: *mut std::ffi::c_void,
/// The variable factory for this solver.
vars: VarFactory,
/// The callback used when a clause is learned.
learn_cb: LearnCB,
/// The callback used to check whether the solver should terminate.
term_cb: TermCB,
}

impl Default for IntelSat {
fn default() -> Self {
Self {
ptr: unsafe { pindakaas_intel_sat::ipasir_init() },
vars: VarFactory::default(),
term_cb: TermCB::default(),
learn_cb: LearnCB::default(),
}
}
}
Expand Down
Loading

0 comments on commit 6695c90

Please sign in to comment.