diff --git a/pumpkin-crates/core/src/api/outputs/mod.rs b/pumpkin-crates/core/src/api/outputs/mod.rs index 413a59158..6db16f8b4 100644 --- a/pumpkin-crates/core/src/api/outputs/mod.rs +++ b/pumpkin-crates/core/src/api/outputs/mod.rs @@ -43,13 +43,15 @@ pub enum SatisfactionResultUnderAssumptions<'solver, 'brancher, B: Brancher> { /// The result of a call to [`Solver::optimise`]. #[derive(Debug)] -pub enum OptimisationResult { +pub enum OptimisationResult { /// Indicates that an optimal solution has been found and proven to be optimal. It provides an /// instance of [`Solution`] which contains the optimal solution. Optimal(Solution), /// Indicates that a solution was found and provides an instance of [`Solution`] which contains /// best known solution by the solver. Satisfiable(Solution), + /// The optimisation was stopped by the solution callback. + Stopped(Solution, Stop), /// Indicates that there is no solution to the problem. Unsatisfiable, /// Indicates that it is not known whether a solution exists. This is likely due to a diff --git a/pumpkin-crates/core/src/api/solver.rs b/pumpkin-crates/core/src/api/solver.rs index 6088df6ec..69599b6ca 100644 --- a/pumpkin-crates/core/src/api/solver.rs +++ b/pumpkin-crates/core/src/api/solver.rs @@ -437,7 +437,7 @@ impl Solver { brancher: &mut B, termination: &mut impl TerminationCondition, mut optimisation_procedure: impl OptimisationProcedure, - ) -> OptimisationResult + ) -> OptimisationResult where B: Brancher, Callback: SolutionCallback, diff --git a/pumpkin-crates/core/src/optimisation/linear_sat_unsat.rs b/pumpkin-crates/core/src/optimisation/linear_sat_unsat.rs index 7f61aea52..4c422cb9e 100644 --- a/pumpkin-crates/core/src/optimisation/linear_sat_unsat.rs +++ b/pumpkin-crates/core/src/optimisation/linear_sat_unsat.rs @@ -1,3 +1,5 @@ +use std::ops::ControlFlow; + use super::OptimisationProcedure; use super::solution_callback::SolutionCallback; use crate::Solver; @@ -46,7 +48,7 @@ where brancher: &mut B, termination: &mut impl TerminationCondition, solver: &mut Solver, - ) -> OptimisationResult { + ) -> OptimisationResult { let objective = match self.direction { OptimisationDirection::Maximise => self.objective.scaled(-1), OptimisationDirection::Minimise => self.objective.scaled(1), @@ -60,12 +62,16 @@ where }; loop { - self.solution_callback.on_solution_callback( + let callback_result = self.solution_callback.on_solution_callback( solver, best_solution.as_reference(), brancher, ); + if let ControlFlow::Break(stop) = callback_result { + return OptimisationResult::Stopped(best_solution, stop); + } + let best_objective_value = best_solution.get_integer_value(objective.clone()); let conclusion = { diff --git a/pumpkin-crates/core/src/optimisation/linear_unsat_sat.rs b/pumpkin-crates/core/src/optimisation/linear_unsat_sat.rs index 7b0d65324..73f028dc7 100644 --- a/pumpkin-crates/core/src/optimisation/linear_unsat_sat.rs +++ b/pumpkin-crates/core/src/optimisation/linear_unsat_sat.rs @@ -1,4 +1,5 @@ use std::num::NonZero; +use std::ops::ControlFlow; use super::OptimisationProcedure; use super::solution_callback::SolutionCallback; @@ -49,7 +50,7 @@ where brancher: &mut B, termination: &mut impl TerminationCondition, solver: &mut Solver, - ) -> OptimisationResult { + ) -> OptimisationResult { let objective = match self.direction { OptimisationDirection::Maximise => self.objective.scaled(-1), OptimisationDirection::Minimise => self.objective.scaled(1), @@ -62,12 +63,16 @@ where SatisfactionResult::Unknown(_, _) => return OptimisationResult::Unknown, }; - self.solution_callback.on_solution_callback( + let callback_result = self.solution_callback.on_solution_callback( solver, primal_solution.as_reference(), brancher, ); + if let ControlFlow::Break(stop) = callback_result { + return OptimisationResult::Stopped(primal_solution, stop); + } + let primal_objective = primal_solution.get_integer_value(objective.clone()); // Then, we iterate from the lower bound of the objective until (excluding) the primal @@ -108,7 +113,8 @@ where match conclusion { Some(OptimisationResult::Optimal(solution)) => { - self.solution_callback.on_solution_callback( + // Optimisation will stop regardless of the result of the callback. + let _ = self.solution_callback.on_solution_callback( solver, primal_solution.as_reference(), brancher, diff --git a/pumpkin-crates/core/src/optimisation/mod.rs b/pumpkin-crates/core/src/optimisation/mod.rs index 2e2ec0730..3a623aecb 100644 --- a/pumpkin-crates/core/src/optimisation/mod.rs +++ b/pumpkin-crates/core/src/optimisation/mod.rs @@ -17,7 +17,7 @@ pub trait OptimisationProcedure> { brancher: &mut B, termination: &mut impl TerminationCondition, solver: &mut Solver, - ) -> OptimisationResult; + ) -> OptimisationResult; } /// The type of search which is performed by the solver. diff --git a/pumpkin-crates/core/src/optimisation/solution_callback.rs b/pumpkin-crates/core/src/optimisation/solution_callback.rs index cb5b4f3a6..073afb966 100644 --- a/pumpkin-crates/core/src/optimisation/solution_callback.rs +++ b/pumpkin-crates/core/src/optimisation/solution_callback.rs @@ -1,21 +1,62 @@ +use std::ops::ControlFlow; + use crate::Solver; use crate::branching::Brancher; use crate::results::SolutionReference; +/// Called during optimisation with every encountered solution. +/// +/// The callback can determine whether to proceed optimising or whether to stop by returning a +/// [`ControlFlow`] value. When [`ControlFlow::Break`] is returned, a value of +/// [`SolutionCallback::Stop`] can be supplied that will be forwarded to the result of the +/// optimisation call. pub trait SolutionCallback { - fn on_solution_callback(&self, solver: &Solver, solution: SolutionReference, brancher: &B); + /// The type of value to return if optimisation should stop. + type Stop; + + /// Called when a solution is encountered. + fn on_solution_callback( + &mut self, + solver: &Solver, + solution: SolutionReference, + brancher: &B, + ) -> ControlFlow; } -impl SolutionCallback for T { - fn on_solution_callback(&self, solver: &Solver, solution: SolutionReference, brancher: &B) { +impl SolutionCallback for T +where + T: FnMut(&Solver, SolutionReference, &B) -> ControlFlow, + B: Brancher, +{ + type Stop = R; + + fn on_solution_callback( + &mut self, + solver: &Solver, + solution: SolutionReference, + brancher: &B, + ) -> ControlFlow { (self)(solver, solution, brancher) } } -impl, B: Brancher> SolutionCallback for Option { - fn on_solution_callback(&self, solver: &Solver, solution: SolutionReference, brancher: &B) { +impl SolutionCallback for Option +where + T: SolutionCallback, + B: Brancher, +{ + type Stop = R; + + fn on_solution_callback( + &mut self, + solver: &Solver, + solution: SolutionReference, + brancher: &B, + ) -> ControlFlow { if let Some(callback) = self { - callback.on_solution_callback(solver, solution, brancher) + return callback.on_solution_callback(solver, solution, brancher); } + + ControlFlow::Continue(()) } } diff --git a/pumpkin-solver-py/src/model.rs b/pumpkin-solver-py/src/model.rs index f8e596372..89a938fb2 100644 --- a/pumpkin-solver-py/src/model.rs +++ b/pumpkin-solver-py/src/model.rs @@ -1,10 +1,14 @@ use std::num::NonZero; +use std::ops::ControlFlow; use std::path::PathBuf; use std::time::Duration; use std::time::Instant; use pumpkin_solver::DefaultBrancher; use pumpkin_solver::Solver; +use pumpkin_solver::branching::Brancher; +use pumpkin_solver::branching::branchers::warm_start::WarmStart; +use pumpkin_solver::containers::HashMap; use pumpkin_solver::containers::StorageKey; use pumpkin_solver::optimisation::OptimisationDirection; use pumpkin_solver::optimisation::linear_sat_unsat::LinearSatUnsat; @@ -19,6 +23,8 @@ use pumpkin_solver::results::SolutionReference; use pumpkin_solver::termination::Indefinite; use pumpkin_solver::termination::TerminationCondition; use pumpkin_solver::termination::TimeBudget; +use pumpkin_solver::variables::AffineView; +use pumpkin_solver::variables::DomainId; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; @@ -36,7 +42,7 @@ use crate::variables::Predicate; #[pyclass(unsendable)] pub struct Model { solver: Solver, - brancher: DefaultBrancher, + brancher: PythonBrancher, } #[pyclass] @@ -82,7 +88,10 @@ impl Model { }; let solver = Solver::with_options(options); - let brancher = solver.default_brancher(); + let brancher = PythonBrancher { + warm_start: WarmStart::new(&[], &[]), + default_brancher: solver.default_brancher(), + }; Ok(Model { solver, brancher }) } @@ -102,7 +111,7 @@ impl Model { self.solver.new_bounded_integer(lower_bound, upper_bound) }; - self.brancher.add_domain(domain_id); + self.brancher.default_brancher.add_domain(domain_id); domain_id.into() } @@ -117,6 +126,7 @@ impl Model { }; self.brancher + .default_brancher .add_domain(literal.get_true_predicate().get_domain()); literal.into() @@ -134,7 +144,7 @@ impl Model { /// A tag should be provided for this link to be identifiable in the proof. fn boolean_as_integer(&mut self, boolean: BoolExpression, tag: Tag) -> IntExpression { let new_domain = self.solver.new_bounded_integer(0, 1); - self.brancher.add_domain(new_domain); + self.brancher.default_brancher.add_domain(new_domain); let boolean_true = boolean.0.get_true_predicate(); @@ -253,7 +263,18 @@ impl Model { } } - #[pyo3(signature = (objective, optimiser=Optimiser::LinearSatUnsat, direction=Direction::Minimise, timeout=None, on_solution=None))] + #[allow( + clippy::too_many_arguments, + reason = "this is common in many Python APIs" + )] + #[pyo3(signature = ( + objective, + optimiser=Optimiser::LinearSatUnsat, + direction=Direction::Minimise, + timeout=None, + on_solution=None, + warm_start=HashMap::default(), + ))] fn optimise( &mut self, py: Python<'_>, @@ -262,7 +283,8 @@ impl Model { direction: Direction, timeout: Option, on_solution: Option>, - ) -> OptimisationResult { + warm_start: HashMap, + ) -> PyResult { let mut termination = get_termination(timeout); let direction = match direction { @@ -272,16 +294,25 @@ impl Model { let objective = objective.0; - let callback = move |_: &Solver, solution: SolutionReference<'_>, _: &DefaultBrancher| { + let callback = |_: &Solver, solution: SolutionReference<'_>, _: &PythonBrancher| { let python_solution = crate::result::Solution::from(solution); - if let Some(on_solution_callback) = on_solution.as_ref() { - let _ = on_solution_callback - .call(py, (python_solution,), None) - .expect("failed to call solution callback"); - } + // If there is a solution callback, unpack it. + let Some(on_solution_callback) = on_solution.as_ref() else { + return ControlFlow::Continue(()); + }; + + // Call the callback, and if there is an error, unpack it. + let Err(err) = on_solution_callback.call(py, (python_solution,), None) else { + return ControlFlow::Continue(()); + }; + + // Stop optimising and return the error. + ControlFlow::Break(err) }; + self.update_warm_start(warm_start); + let result = match optimiser { Optimiser::LinearSatUnsat => self.solver.optimise( &mut self.brancher, @@ -296,20 +327,44 @@ impl Model { }; match result { + pumpkin_solver::results::OptimisationResult::Stopped(_, err) => Err(err), pumpkin_solver::results::OptimisationResult::Satisfiable(solution) => { - OptimisationResult::Satisfiable(solution.into()) + Ok(OptimisationResult::Satisfiable(solution.into())) } pumpkin_solver::results::OptimisationResult::Optimal(solution) => { - OptimisationResult::Optimal(solution.into()) + Ok(OptimisationResult::Optimal(solution.into())) } pumpkin_solver::results::OptimisationResult::Unsatisfiable => { - OptimisationResult::Unsatisfiable() + Ok(OptimisationResult::Unsatisfiable()) + } + pumpkin_solver::results::OptimisationResult::Unknown => { + Ok(OptimisationResult::Unknown()) } - pumpkin_solver::results::OptimisationResult::Unknown => OptimisationResult::Unknown(), } } } +impl Model { + /// Update the warm start in the [`PythonBrancher`]. + fn update_warm_start(&mut self, warm_start: HashMap) { + // First create the slice of variables to give to the WarmStart brancher. + let warm_start_variables: Vec<_> = warm_start.keys().map(|variable| variable.0).collect(); + + // For every variable collect the value into another slice. + let warm_start_values: Vec<_> = warm_start_variables + .iter() + .map(|variable| { + warm_start + .get(&IntExpression(*variable)) + .copied() + .expect("all elements are keys") + }) + .collect(); + + self.brancher.warm_start = WarmStart::new(&warm_start_variables, &warm_start_values); + } +} + fn get_termination(end_time: Option) -> Box { end_time .map(|secs| Instant::now() + Duration::from_secs_f32(secs)) @@ -319,3 +374,25 @@ fn get_termination(end_time: Option) -> Box { }) .unwrap_or(Box::new(Indefinite)) } + +struct PythonBrancher { + warm_start: WarmStart>, + default_brancher: DefaultBrancher, +} + +impl Brancher for PythonBrancher { + fn next_decision( + &mut self, + context: &mut pumpkin_solver::branching::SelectionContext, + ) -> Option { + if let Some(predicate) = self.warm_start.next_decision(context) { + return Some(predicate); + } + + self.default_brancher.next_decision(context) + } + + fn subscribe_to_events(&self) -> Vec { + self.default_brancher.subscribe_to_events() + } +} diff --git a/pumpkin-solver-py/tests/test_optimisation.py b/pumpkin-solver-py/tests/test_optimisation.py index 20af4588e..d56bcb859 100644 --- a/pumpkin-solver-py/tests/test_optimisation.py +++ b/pumpkin-solver-py/tests/test_optimisation.py @@ -26,3 +26,28 @@ def test_linear_sat_unsat_maximisation(): solution = result._0 assert solution.int_value(objective) == 5 + + +def test_warm_start_with_callback(): + model = Model() + + objective = model.new_integer_variable(1, 5, name="objective") + + first_value = None + + def on_solution(solution): + nonlocal first_value + + if first_value is None: + first_value = solution.int_value(objective) + + result = model.optimise( + objective, + direction=Direction.Maximise, + warm_start={objective: 3}, + on_solution=on_solution, + ) + + assert isinstance(result, OptimisationResult.Optimal) + + assert first_value == 3 diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs index 41d92fb92..ced53f939 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs @@ -6,6 +6,7 @@ mod parser; use std::fs::File; use std::io::Read; +use std::ops::ControlFlow; use std::path::Path; use std::time::Duration; use std::time::Instant; @@ -176,19 +177,23 @@ pub(crate) fn solve( } }; - let callback = - |solver: &Solver, solution: SolutionReference<'_>, brancher: &DynamicBrancher| { - solution_callback( - brancher, - Some(objective), - options.all_solutions, - &outputs, - solver, - solution, - options.verbose, - init_time, - ); - }; + let callback = |solver: &Solver, + solution: SolutionReference<'_>, + brancher: &DynamicBrancher| + -> ControlFlow<()> { + solution_callback( + brancher, + Some(objective), + options.all_solutions, + &outputs, + solver, + solution, + options.verbose, + init_time, + ); + + ControlFlow::Continue(()) + }; let result = match options.optimisation_strategy { OptimisationStrategy::LinearSatUnsat => solver.optimise( @@ -204,6 +209,9 @@ pub(crate) fn solve( }; match result { + OptimisationResult::Stopped(_, _) => { + unreachable!("the callback will never return ControlFlow::Break") + } OptimisationResult::Optimal(optimal_solution) => { let objective_value = optimal_solution.get_integer_value(objective) as i64; if !options.all_solutions { diff --git a/pumpkin-solver/src/lib.rs b/pumpkin-solver/src/lib.rs index 3c289cf85..86a2b0537 100644 --- a/pumpkin-solver/src/lib.rs +++ b/pumpkin-solver/src/lib.rs @@ -131,6 +131,8 @@ //! //! Then we can find the optimal solution using [`Solver::optimise`]: //! ```rust +//! # use std::cmp::max; +//! # use std::ops::ControlFlow; //! # use pumpkin_solver::Solver; //! # use pumpkin_solver::results::OptimisationResult; //! # use pumpkin_solver::termination::Indefinite; @@ -139,7 +141,6 @@ //! # use pumpkin_solver::constraints::Constraint; //! # use pumpkin_solver::optimisation::OptimisationDirection; //! # use pumpkin_solver::optimisation::linear_sat_unsat::LinearSatUnsat; -//! # use std::cmp::max; //! # use crate::pumpkin_solver::optimisation::OptimisationProcedure; //! # use pumpkin_solver::results::SolutionReference; //! # use pumpkin_solver::DefaultBrancher; @@ -153,8 +154,12 @@ //! # solver.add_constraint(pumpkin_constraints::maximum(vec![x, y, z], objective, c1)).post(); //! # let mut termination = Indefinite; //! # let mut brancher = solver.default_brancher(); +//! +//! let callback = |_: &Solver, _: SolutionReference, _: &DefaultBrancher| -> ControlFlow<()> { +//! return ControlFlow::Continue(()); +//! }; +//! //! // Then we solve to optimality -//! let callback: fn(&Solver, SolutionReference, &DefaultBrancher) = |_, _, _| {}; //! let result = solver.optimise( //! &mut brancher, //! &mut termination,