Skip to content

Commit

Permalink
Refactor into fine-grained state-machines
Browse files Browse the repository at this point in the history
- Remove evaluation caches
  now that evaluation is stored in state.
  For the same reason,
  evaluation can now be accessed as
  `o.state().evaluation()`
  instead of `o.evaluation()`.
- Replace type-state pattern
  with simpler types
  for state-machines.
- Use `derive_getters::Dissolve`
  and rename `into_inner()` to `into_parts()`,
  to better fit Rust conventions.
- optimal-steepest: add fixed-step-size steepest state-machine.
- Use tuple-structs for `State` wrappers.
  • Loading branch information
justinlovinger committed Sep 7, 2023
1 parent ded6ad3 commit 8be75b2
Show file tree
Hide file tree
Showing 9 changed files with 1,118 additions and 458 deletions.
15 changes: 7 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion optimal-pbil/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ derive-bounded = { path = "../derive-bounded" }
derive-getters = "0.3.0"
derive_more = "0.99.17"
num-traits = "0.2.16"
once_cell = "1.18.0"
optimal-core = { path = "../optimal-core" }
partial-min-max = "0.4.0"
rand = "0.8.5"
rand_xoshiro = "0.6.0"
replace_with = "0.1.7"
Expand Down
178 changes: 105 additions & 73 deletions optimal-pbil/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ mod types;
mod until_probabilities_converged;

use default_for::DefaultFor;
use derive_getters::Getters;
use derive_getters::{Dissolve, Getters};
use derive_more::IsVariant;
use once_cell::sync::OnceCell;
pub use optimal_core::prelude::*;
use rand::prelude::*;
use rand_xoshiro::{SplitMix64, Xoshiro256PlusPlus};
Expand All @@ -44,8 +43,9 @@ use serde::{Deserialize, Serialize};
pub struct MismatchedLengthError;

/// A running PBIL optimizer.
#[derive(Clone, Debug, Getters)]
#[derive(Clone, Debug, Getters, Dissolve)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[dissolve(rename = "into_parts")]
pub struct Pbil<B, F> {
/// Optimizer configuration.
config: Config,
Expand All @@ -55,10 +55,6 @@ pub struct Pbil<B, F> {

/// Objective function to minimize.
obj_func: F,

#[getter(skip)]
#[cfg_attr(feature = "serde", serde(skip))]
evaluation_cache: OnceCell<Evaluation<B>>,
}

/// PBIL configuration parameters.
Expand All @@ -82,40 +78,36 @@ pub struct Config {
/// PBIL state.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct State<B> {
inner: DynState<B>,
}
pub struct State<B>(DynState<B>);

/// PBIL state kind.
#[derive(Clone, Debug, PartialEq, IsVariant)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum StateKind {
/// Ready to start sampling.
Ready,
/// For sampling
/// and adjusting probabilities
/// based on samples.
Sampling,
/// For mutating probabilities.
Mutating,
/// Iteration started.
Started,
/// Sample generated.
Sampled,
/// Sample evaluated.
Evaluated,
/// Samples compared.
Compared,
/// Probabilities adjusted.
Adjusted,
/// Probabilities mutated.
Mutated,
/// Iteration finished.
Finished,
}

type Evaluation<B> = Option<B>;

impl<B, F> Pbil<B, F> {
fn new(state: State<B>, config: Config, obj_func: F) -> Self {
Self {
config,
obj_func,
state,
evaluation_cache: OnceCell::new(),
}
}

/// Return configuration, state, and problem parameters.
pub fn into_inner(self) -> (Config, State<B>, F) {
(self.config, self.state, self.obj_func)
}
}

impl<B, F> Pbil<B, F>
Expand All @@ -126,16 +118,6 @@ where
pub fn best_point_value(&self) -> B {
(self.obj_func)(&self.best_point())
}

/// Return evaluation of current state,
/// evaluating and caching if necessary.
pub fn evaluation(&self) -> &Evaluation<B> {
self.evaluation_cache.get_or_init(|| self.evaluate())
}

fn evaluate(&self) -> Evaluation<B> {
self.state.evaluatee().map(|xs| (self.obj_func)(xs))
}
}

impl<B, F> StreamingIterator for Pbil<B, F>
Expand All @@ -146,31 +128,35 @@ where
type Item = Self;

fn advance(&mut self) {
let evaluation = self
.evaluation_cache
.take()
.unwrap_or_else(|| self.evaluate());
replace_with::replace_with_or_abort(&mut self.state.inner, |state| match state {
// `evaluation.unwrap_unchecked()` is safe
// because we always have a sample to evaluate
// when it is called.
DynState::Ready(s) => {
DynState::Sampling(s.to_sampling(unsafe { evaluation.unwrap_unchecked() }))
replace_with::replace_with_or_abort(&mut self.state.0, |state| match state {
DynState::Started(x) => {
DynState::SampledFirst(x.into_initialized_sampling().into_sampled_first())
}
DynState::Sampling(s) => {
let value = unsafe { evaluation.unwrap_unchecked() };
if s.samples_generated() < self.config.num_samples.into_inner() {
DynState::Sampling(s.to_sampling(value))
} else if self.config.mutation_chance.is_zero() {
DynState::Ready(s.to_ready(self.config.adjust_rate, value))
DynState::SampledFirst(x) => {
DynState::EvaluatedFirst(x.into_evaluated_first(&self.obj_func))
}
DynState::EvaluatedFirst(x) => DynState::Sampled(x.into_sampled()),
DynState::Sampled(x) => DynState::Evaluated(x.into_evaluated(&self.obj_func)),
DynState::Evaluated(x) => DynState::Compared(x.into_compared()),
DynState::Compared(x) => {
if x.samples_generated < self.config.num_samples.into_inner() {
DynState::Sampled(x.into_sampled())
} else {
DynState::Mutating(s.to_mutating(self.config.adjust_rate, value))
DynState::Adjusted(x.into_adjusted(self.config.adjust_rate))
}
}
DynState::Mutating(s) => DynState::Ready(s.to_ready(
self.config.mutation_chance,
self.config.mutation_adjust_rate,
)),
DynState::Adjusted(x) => {
if self.config.mutation_chance.into_inner() > 0.0 {
DynState::Mutated(x.into_mutated(
self.config.mutation_chance,
self.config.mutation_adjust_rate,
))
} else {
DynState::Finished(x.into_finished())
}
}
DynState::Mutated(x) => DynState::Finished(x.into_finished()),
DynState::Finished(x) => DynState::Started(x.into_started()),
});
}

Expand Down Expand Up @@ -309,9 +295,7 @@ impl<B> State<B> {

/// Return custom initial state.
pub fn new(probabilities: Vec<Probability>, rng: Xoshiro256PlusPlus) -> Self {
Self {
inner: DynState::new(probabilities, rng),
}
Self(DynState::new(probabilities, rng))
}

/// Return number of bits being optimized.
Expand All @@ -322,10 +306,46 @@ impl<B> State<B> {

/// Return data to be evaluated.
pub fn evaluatee(&self) -> Option<&[bool]> {
match &self.inner {
DynState::Ready(s) => Some(s.sample()),
DynState::Sampling(s) => Some(s.sample()),
DynState::Mutating(_) => None,
match &self.0 {
DynState::Started(_) => None,
DynState::SampledFirst(x) => Some(&x.sample),
DynState::EvaluatedFirst(_) => None,
DynState::Sampled(x) => Some(&x.sample),
DynState::Evaluated(_) => None,
DynState::Compared(_) => None,
DynState::Adjusted(_) => None,
DynState::Mutated(_) => None,
DynState::Finished(_) => None,
}
}

/// Return result of evaluation.
pub fn evaluation(&self) -> Option<&B> {
match &self.0 {
DynState::Started(_) => None,
DynState::SampledFirst(_) => None,
DynState::EvaluatedFirst(x) => Some(x.sample.value()),
DynState::Sampled(_) => None,
DynState::Evaluated(x) => Some(x.sample.value()),
DynState::Compared(_) => None,
DynState::Adjusted(_) => None,
DynState::Mutated(_) => None,
DynState::Finished(_) => None,
}
}

/// Return sample if stored.
pub fn sample(&self) -> Option<&[bool]> {
match &self.0 {
DynState::Started(_) => None,
DynState::SampledFirst(x) => Some(&x.sample),
DynState::EvaluatedFirst(x) => Some(x.sample.x()),
DynState::Sampled(x) => Some(&x.sample),
DynState::Evaluated(x) => Some(x.sample.x()),
DynState::Compared(_) => None,
DynState::Adjusted(_) => None,
DynState::Mutated(_) => None,
DynState::Finished(_) => None,
}
}

Expand All @@ -339,20 +359,32 @@ impl<B> State<B> {

/// Return kind of state of inner state-machine.
pub fn kind(&self) -> StateKind {
match self.inner {
DynState::Ready(_) => StateKind::Ready,
DynState::Sampling(_) => StateKind::Sampling,
DynState::Mutating(_) => StateKind::Mutating,
match self.0 {
DynState::Started(_) => StateKind::Started,
DynState::SampledFirst(_) => StateKind::Sampled,
DynState::EvaluatedFirst(_) => StateKind::Evaluated,
DynState::Sampled(_) => StateKind::Sampled,
DynState::Evaluated(_) => StateKind::Evaluated,
DynState::Compared(_) => StateKind::Compared,
DynState::Adjusted(_) => StateKind::Adjusted,
DynState::Mutated(_) => StateKind::Mutated,
DynState::Finished(_) => StateKind::Finished,
}
}
}

impl<B> Probabilities for State<B> {
fn probabilities(&self) -> &[Probability] {
match &self.inner {
DynState::Ready(s) => s.probabilities(),
DynState::Sampling(s) => s.probabilities(),
DynState::Mutating(s) => s.probabilities(),
match &self.0 {
DynState::Started(x) => &x.probabilities,
DynState::SampledFirst(x) => x.probabilities.probabilities(),
DynState::EvaluatedFirst(x) => x.probabilities.probabilities(),
DynState::Sampled(x) => x.probabilities.probabilities(),
DynState::Evaluated(x) => x.probabilities.probabilities(),
DynState::Compared(x) => x.probabilities.probabilities(),
DynState::Adjusted(x) => &x.probabilities,
DynState::Mutated(x) => &x.probabilities,
DynState::Finished(x) => &x.probabilities,
}
}
}
Expand All @@ -370,7 +402,7 @@ mod tests {
mutation_adjust_rate: MutationAdjustRate::default(),
}
.start(16, |point| point.iter().filter(|x| **x).count())
.inspect(|x| assert!(!x.state().kind().is_mutating()))
.inspect(|x| assert!(!x.state().kind().is_mutated()))
.nth(100);
}
}
Loading

0 comments on commit 8be75b2

Please sign in to comment.