Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pumpkin-crates/core/src/api/outputs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ pub enum SatisfactionResultUnderAssumptions<'solver, 'brancher, B: Brancher> {

/// The result of a call to [`Solver::optimise`].
#[derive(Debug)]
pub enum OptimisationResult {
pub enum OptimisationResult<Stop> {
/// Indicates that an optimal solution has been found and proven to be optimal. It provides an
/// instance of [`Solution`] which contains the optimal solution.
Optimal(Solution),
/// Indicates that a solution was found and provides an instance of [`Solution`] which contains
/// best known solution by the solver.
Satisfiable(Solution),
/// The optimisation was stopped by the solution callback.
Stopped(Solution, Stop),
/// Indicates that there is no solution to the problem.
Unsatisfiable,
/// Indicates that it is not known whether a solution exists. This is likely due to a
Expand Down
2 changes: 1 addition & 1 deletion pumpkin-crates/core/src/api/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ impl Solver {
brancher: &mut B,
termination: &mut impl TerminationCondition,
mut optimisation_procedure: impl OptimisationProcedure<B, Callback>,
) -> OptimisationResult
) -> OptimisationResult<Callback::Stop>
where
B: Brancher,
Callback: SolutionCallback<B>,
Expand Down
10 changes: 8 additions & 2 deletions pumpkin-crates/core/src/optimisation/linear_sat_unsat.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::ControlFlow;

use super::OptimisationProcedure;
use super::solution_callback::SolutionCallback;
use crate::Solver;
Expand Down Expand Up @@ -46,7 +48,7 @@ where
brancher: &mut B,
termination: &mut impl TerminationCondition,
solver: &mut Solver,
) -> OptimisationResult {
) -> OptimisationResult<Callback::Stop> {
let objective = match self.direction {
OptimisationDirection::Maximise => self.objective.scaled(-1),
OptimisationDirection::Minimise => self.objective.scaled(1),
Expand All @@ -60,12 +62,16 @@ where
};

loop {
self.solution_callback.on_solution_callback(
let callback_result = self.solution_callback.on_solution_callback(
solver,
best_solution.as_reference(),
brancher,
);

if let ControlFlow::Break(stop) = callback_result {
return OptimisationResult::Stopped(best_solution, stop);
}

let best_objective_value = best_solution.get_integer_value(objective.clone());

let conclusion = {
Expand Down
12 changes: 9 additions & 3 deletions pumpkin-crates/core/src/optimisation/linear_unsat_sat.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::num::NonZero;
use std::ops::ControlFlow;

use super::OptimisationProcedure;
use super::solution_callback::SolutionCallback;
Expand Down Expand Up @@ -49,7 +50,7 @@ where
brancher: &mut B,
termination: &mut impl TerminationCondition,
solver: &mut Solver,
) -> OptimisationResult {
) -> OptimisationResult<Callback::Stop> {
let objective = match self.direction {
OptimisationDirection::Maximise => self.objective.scaled(-1),
OptimisationDirection::Minimise => self.objective.scaled(1),
Expand All @@ -62,12 +63,16 @@ where
SatisfactionResult::Unknown(_, _) => return OptimisationResult::Unknown,
};

self.solution_callback.on_solution_callback(
let callback_result = self.solution_callback.on_solution_callback(
solver,
primal_solution.as_reference(),
brancher,
);

if let ControlFlow::Break(stop) = callback_result {
return OptimisationResult::Stopped(primal_solution, stop);
}

let primal_objective = primal_solution.get_integer_value(objective.clone());

// Then, we iterate from the lower bound of the objective until (excluding) the primal
Expand Down Expand Up @@ -108,7 +113,8 @@ where

match conclusion {
Some(OptimisationResult::Optimal(solution)) => {
self.solution_callback.on_solution_callback(
// Optimisation will stop regardless of the result of the callback.
let _ = self.solution_callback.on_solution_callback(
solver,
primal_solution.as_reference(),
brancher,
Expand Down
2 changes: 1 addition & 1 deletion pumpkin-crates/core/src/optimisation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub trait OptimisationProcedure<B: Brancher, Callback: SolutionCallback<B>> {
brancher: &mut B,
termination: &mut impl TerminationCondition,
solver: &mut Solver,
) -> OptimisationResult;
) -> OptimisationResult<Callback::Stop>;
}

/// The type of search which is performed by the solver.
Expand Down
53 changes: 47 additions & 6 deletions pumpkin-crates/core/src/optimisation/solution_callback.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,62 @@
use std::ops::ControlFlow;

use crate::Solver;
use crate::branching::Brancher;
use crate::results::SolutionReference;

/// Called during optimisation with every encountered solution.
///
/// The callback can determine whether to proceed optimising or whether to stop by returning a
/// [`ControlFlow`] value. When [`ControlFlow::Break`] is returned, a value of
/// [`SolutionCallback::Stop`] can be supplied that will be forwarded to the result of the
/// optimisation call.
pub trait SolutionCallback<B: Brancher> {
fn on_solution_callback(&self, solver: &Solver, solution: SolutionReference, brancher: &B);
/// The type of value to return if optimisation should stop.
type Stop;

/// Called when a solution is encountered.
fn on_solution_callback(
&mut self,
solver: &Solver,
solution: SolutionReference,
brancher: &B,
) -> ControlFlow<Self::Stop>;
}

impl<T: Fn(&Solver, SolutionReference, &B), B: Brancher> SolutionCallback<B> for T {
fn on_solution_callback(&self, solver: &Solver, solution: SolutionReference, brancher: &B) {
impl<T, B, R> SolutionCallback<B> for T
where
T: FnMut(&Solver, SolutionReference, &B) -> ControlFlow<R>,
B: Brancher,
{
type Stop = R;

fn on_solution_callback(
&mut self,
solver: &Solver,
solution: SolutionReference,
brancher: &B,
) -> ControlFlow<Self::Stop> {
(self)(solver, solution, brancher)
}
}

impl<T: SolutionCallback<B>, B: Brancher> SolutionCallback<B> for Option<T> {
fn on_solution_callback(&self, solver: &Solver, solution: SolutionReference, brancher: &B) {
impl<T, R, B> SolutionCallback<B> for Option<T>
where
T: SolutionCallback<B, Stop = R>,
B: Brancher,
{
type Stop = R;

fn on_solution_callback(
&mut self,
solver: &Solver,
solution: SolutionReference,
brancher: &B,
) -> ControlFlow<Self::Stop> {
if let Some(callback) = self {
callback.on_solution_callback(solver, solution, brancher)
return callback.on_solution_callback(solver, solution, brancher);
}

ControlFlow::Continue(())
}
}
109 changes: 93 additions & 16 deletions pumpkin-solver-py/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use std::num::NonZero;
use std::ops::ControlFlow;
use std::path::PathBuf;
use std::time::Duration;
use std::time::Instant;

use pumpkin_solver::DefaultBrancher;
use pumpkin_solver::Solver;
use pumpkin_solver::branching::Brancher;
use pumpkin_solver::branching::branchers::warm_start::WarmStart;
use pumpkin_solver::containers::HashMap;
use pumpkin_solver::containers::StorageKey;
use pumpkin_solver::optimisation::OptimisationDirection;
use pumpkin_solver::optimisation::linear_sat_unsat::LinearSatUnsat;
Expand All @@ -19,6 +23,8 @@ use pumpkin_solver::results::SolutionReference;
use pumpkin_solver::termination::Indefinite;
use pumpkin_solver::termination::TerminationCondition;
use pumpkin_solver::termination::TimeBudget;
use pumpkin_solver::variables::AffineView;
use pumpkin_solver::variables::DomainId;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;

Expand All @@ -36,7 +42,7 @@ use crate::variables::Predicate;
#[pyclass(unsendable)]
pub struct Model {
solver: Solver,
brancher: DefaultBrancher,
brancher: PythonBrancher,
}

#[pyclass]
Expand Down Expand Up @@ -82,7 +88,10 @@ impl Model {
};

let solver = Solver::with_options(options);
let brancher = solver.default_brancher();
let brancher = PythonBrancher {
warm_start: WarmStart::new(&[], &[]),
default_brancher: solver.default_brancher(),
};

Ok(Model { solver, brancher })
}
Expand All @@ -102,7 +111,7 @@ impl Model {
self.solver.new_bounded_integer(lower_bound, upper_bound)
};

self.brancher.add_domain(domain_id);
self.brancher.default_brancher.add_domain(domain_id);

domain_id.into()
}
Expand All @@ -117,6 +126,7 @@ impl Model {
};

self.brancher
.default_brancher
.add_domain(literal.get_true_predicate().get_domain());

literal.into()
Expand All @@ -134,7 +144,7 @@ impl Model {
/// A tag should be provided for this link to be identifiable in the proof.
fn boolean_as_integer(&mut self, boolean: BoolExpression, tag: Tag) -> IntExpression {
let new_domain = self.solver.new_bounded_integer(0, 1);
self.brancher.add_domain(new_domain);
self.brancher.default_brancher.add_domain(new_domain);

let boolean_true = boolean.0.get_true_predicate();

Expand Down Expand Up @@ -253,7 +263,18 @@ impl Model {
}
}

#[pyo3(signature = (objective, optimiser=Optimiser::LinearSatUnsat, direction=Direction::Minimise, timeout=None, on_solution=None))]
#[allow(
clippy::too_many_arguments,
reason = "this is common in many Python APIs"
)]
#[pyo3(signature = (
objective,
optimiser=Optimiser::LinearSatUnsat,
direction=Direction::Minimise,
timeout=None,
on_solution=None,
warm_start=HashMap::default(),
))]
fn optimise(
&mut self,
py: Python<'_>,
Expand All @@ -262,7 +283,8 @@ impl Model {
direction: Direction,
timeout: Option<f32>,
on_solution: Option<Py<PyAny>>,
) -> OptimisationResult {
warm_start: HashMap<IntExpression, i32>,
) -> PyResult<OptimisationResult> {
let mut termination = get_termination(timeout);

let direction = match direction {
Expand All @@ -272,16 +294,25 @@ impl Model {

let objective = objective.0;

let callback = move |_: &Solver, solution: SolutionReference<'_>, _: &DefaultBrancher| {
let callback = |_: &Solver, solution: SolutionReference<'_>, _: &PythonBrancher| {
let python_solution = crate::result::Solution::from(solution);

if let Some(on_solution_callback) = on_solution.as_ref() {
let _ = on_solution_callback
.call(py, (python_solution,), None)
.expect("failed to call solution callback");
}
// If there is a solution callback, unpack it.
let Some(on_solution_callback) = on_solution.as_ref() else {
return ControlFlow::Continue(());
};

// Call the callback, and if there is an error, unpack it.
let Err(err) = on_solution_callback.call(py, (python_solution,), None) else {
return ControlFlow::Continue(());
};

// Stop optimising and return the error.
ControlFlow::Break(err)
};

self.update_warm_start(warm_start);

let result = match optimiser {
Optimiser::LinearSatUnsat => self.solver.optimise(
&mut self.brancher,
Expand All @@ -296,20 +327,44 @@ impl Model {
};

match result {
pumpkin_solver::results::OptimisationResult::Stopped(_, err) => Err(err),
pumpkin_solver::results::OptimisationResult::Satisfiable(solution) => {
OptimisationResult::Satisfiable(solution.into())
Ok(OptimisationResult::Satisfiable(solution.into()))
}
pumpkin_solver::results::OptimisationResult::Optimal(solution) => {
OptimisationResult::Optimal(solution.into())
Ok(OptimisationResult::Optimal(solution.into()))
}
pumpkin_solver::results::OptimisationResult::Unsatisfiable => {
OptimisationResult::Unsatisfiable()
Ok(OptimisationResult::Unsatisfiable())
}
pumpkin_solver::results::OptimisationResult::Unknown => {
Ok(OptimisationResult::Unknown())
}
pumpkin_solver::results::OptimisationResult::Unknown => OptimisationResult::Unknown(),
}
}
}

impl Model {
/// Update the warm start in the [`PythonBrancher`].
fn update_warm_start(&mut self, warm_start: HashMap<IntExpression, i32>) {
// First create the slice of variables to give to the WarmStart brancher.
let warm_start_variables: Vec<_> = warm_start.keys().map(|variable| variable.0).collect();

// For every variable collect the value into another slice.
let warm_start_values: Vec<_> = warm_start_variables
.iter()
.map(|variable| {
warm_start
.get(&IntExpression(*variable))
.copied()
.expect("all elements are keys")
})
.collect();

self.brancher.warm_start = WarmStart::new(&warm_start_variables, &warm_start_values);
}
}

fn get_termination(end_time: Option<f32>) -> Box<dyn TerminationCondition> {
end_time
.map(|secs| Instant::now() + Duration::from_secs_f32(secs))
Expand All @@ -319,3 +374,25 @@ fn get_termination(end_time: Option<f32>) -> Box<dyn TerminationCondition> {
})
.unwrap_or(Box::new(Indefinite))
}

struct PythonBrancher {
warm_start: WarmStart<AffineView<DomainId>>,
default_brancher: DefaultBrancher,
}

impl Brancher for PythonBrancher {
fn next_decision(
&mut self,
context: &mut pumpkin_solver::branching::SelectionContext,
) -> Option<pumpkin_solver::predicates::Predicate> {
if let Some(predicate) = self.warm_start.next_decision(context) {
return Some(predicate);
}

self.default_brancher.next_decision(context)
}

fn subscribe_to_events(&self) -> Vec<pumpkin_solver::branching::BrancherEvent> {
self.default_brancher.subscribe_to_events()
}
}
Loading