Skip to content

Commit

Permalink
optimal-pbil: simplify MutationChance
Browse files Browse the repository at this point in the history
  • Loading branch information
justinlovinger committed Feb 3, 2024
1 parent ab7f825 commit 1ce2a80
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 137 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions optimal-pbil/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,4 @@ thiserror = "1.0.47"
[dev-dependencies]
criterion = "0.5.1"
proptest = "1.2.0"
serde = "1.0.185"
serde_json = "1.0.111"
test-strategy = "0.3.1"
2 changes: 1 addition & 1 deletion optimal-pbil/src/high_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl<F, R> PbilWith<F, R> {
R: Rng,
{
mutate_probabilities(
&self.problem.agnostic.mutation_chance,
self.problem.agnostic.mutation_chance,
self.problem.agnostic.mutation_adjust_rate,
adjust_probabilities(
self.problem.agnostic.adjust_rate,
Expand Down
9 changes: 5 additions & 4 deletions optimal-pbil/src/low_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
//! .collect::<Vec<_>>();
//! while !converged(threshold, probabilities.iter().copied()) {
//! probabilities = mutate_probabilities(
//! &mutation_chance,
//! mutation_chance,
//! mutation_adjust_rate,
//! adjust_probabilities(
//! adjust_rate,
Expand Down Expand Up @@ -77,16 +77,17 @@ where
/// adjust each probability towards a random probability
/// at `adjust_rate`.
pub fn mutate_probabilities<'a, R>(
chance: &'a MutationChance,
chance: MutationChance,
adjust_rate: MutationAdjustRate,
probabilities: impl IntoIterator<Item = Probability> + 'a,
rng: &'a mut R,
) -> impl Iterator<Item = Probability> + 'a
where
R: Rng,
{
let distr = chance.into_distr();
probabilities.into_iter().map(move |p| {
if rng.sample(chance) {
if rng.sample(distr) {
// `Standard` distribution excludes `1`,
// but it more efficient
// than `Uniform::new_inclusive(0., 1.)`.
Expand Down Expand Up @@ -226,7 +227,7 @@ mod tests {
probabilities: Vec<Probability>,
) {
prop_assert!(are_valid(mutate_probabilities(
&mutation_chance,
mutation_chance,
mutation_adjust_rate,
probabilities,
&mut SmallRng::seed_from_u64(seed),
Expand Down
154 changes: 25 additions & 129 deletions optimal-pbil/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
//! Types for PBIL.

use core::convert::TryFrom;
use std::{f64::EPSILON, fmt};
use std::f64::EPSILON;

use derive_more::{Display, Into};
use derive_num_bounded::{
derive_from_str_from_try_into, derive_into_inner, derive_new_from_bounded_float,
derive_new_from_lower_bounded, derive_try_from_from_new,
};
use num_traits::bounds::{LowerBounded, UpperBounded};
use rand::{distributions::Bernoulli, prelude::Distribution};
use rand::distributions::Bernoulli;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -82,154 +82,64 @@ impl Ord for AdjustRate {

/// Probability for each probability to mutate,
/// independently.
#[derive(Clone)]
pub struct MutationChance {
chance: f64,
distribution: Bernoulli,
}

/// Error returned when [`MutationChance`] is given an invalid value.
#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)]
pub enum InvalidMutationChanceError {
/// Value is NaN.
#[error("{0} is NaN")]
IsNan(f64),
/// Value is below lower bound.
#[error("{0} is below lower bound ({})", MutationChance::min_value())]
TooLow(f64),
/// Value is above upper bound.
#[error("{0} is above upper bound ({})", MutationChance::max_value())]
TooHigh(f64),
}

impl fmt::Debug for MutationChance {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MutationChance").field(&self.chance).finish()
}
}
#[derive(Clone, Copy, Debug, Display, PartialEq, PartialOrd, Into)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(into = "f64"))]
#[cfg_attr(feature = "serde", serde(try_from = "f64"))]
pub struct MutationChance(f64);

impl fmt::Display for MutationChance {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.chance.fmt(f)
}
}
derive_new_from_bounded_float!(MutationChance(f64));
derive_into_inner!(MutationChance(f64));
derive_try_from_from_new!(MutationChance(f64));
derive_from_str_from_try_into!(MutationChance(f64));

impl Eq for MutationChance {}

#[allow(clippy::derive_ord_xor_partial_ord)]
impl Ord for MutationChance {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// `f64` has total ordering for the the range of values allowed by this type.
unsafe { self.partial_cmp(other).unwrap_unchecked() }
}
}

impl PartialEq for MutationChance {
fn eq(&self, other: &Self) -> bool {
self.chance.eq(&other.chance)
}
}

impl PartialOrd for MutationChance {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.chance.partial_cmp(&other.chance)
}
}

impl From<MutationChance> for f64 {
fn from(value: MutationChance) -> Self {
value.chance
}
}

impl LowerBounded for MutationChance {
fn min_value() -> Self {
unsafe { Self::new_unchecked(0.0) }
Self(0.0)
}
}

impl UpperBounded for MutationChance {
fn max_value() -> Self {
unsafe { Self::new_unchecked(1.0) }
Self(1.0)
}
}

impl Distribution<bool> for MutationChance {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> bool {
self.distribution.sample(rng)
}
}

#[cfg(any(feature = "serde", test))]
impl<'de> serde::Deserialize<'de> for MutationChance {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let chance = f64::deserialize(deserializer)?;
match MutationChance::new(chance) {
Ok(x) => Ok(x),
Err(e) => Err(<D::Error as serde::de::Error>::custom(e)),
}
}
}

#[cfg(any(feature = "serde", test))]
impl serde::Serialize for MutationChance {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_f64(self.chance)
impl From<MutationChance> for Bernoulli {
fn from(value: MutationChance) -> Self {
Bernoulli::new(value.0).unwrap()
}
}

derive_try_from_from_new!(MutationChance(f64));
derive_from_str_from_try_into!(MutationChance(f64));

impl MutationChance {
/// Return a new [`MutationChance`] if given a valid value.
pub fn new(value: f64) -> Result<Self, InvalidMutationChanceError> {
match (
value.partial_cmp(&Self::min_value().chance),
value.partial_cmp(&Self::max_value().chance),
) {
(None, _) | (_, None) => Err(InvalidMutationChanceError::IsNan(value)),
(Some(std::cmp::Ordering::Less), _) => Err(InvalidMutationChanceError::TooLow(value)),
(_, Some(std::cmp::Ordering::Greater)) => {
Err(InvalidMutationChanceError::TooHigh(value))
}
_ => Ok(unsafe { Self::new_unchecked(value) }),
}
}

/// Return recommended default mutation chance,
/// average of one mutation per step.
pub fn default_for(len: usize) -> Self {
if len == 0 {
unsafe { Self::new_unchecked(1.0) }
Self(1.0)
} else {
unsafe { Self::new_unchecked(1. / len as f64) }
Self(1. / len as f64)
}
}

/// # Safety
///
/// Value must be within range.
unsafe fn new_unchecked(value: f64) -> Self {
Self {
chance: value,
distribution: Bernoulli::new(value).unwrap(),
}
}

/// Unwrap [`MutationChance`] into inner value.
pub fn into_inner(self) -> f64 {
self.chance
}

/// Return whether no chance to mutate.
pub fn is_zero(&self) -> bool {
self.chance == 0.0
self.0 == 0.0
}

/// Return a distribution for sampling.
pub fn into_distr(self) -> Bernoulli {
self.into()
}
}

Expand Down Expand Up @@ -438,9 +348,6 @@ impl TryFrom<Probability> for ProbabilityThreshold {

#[cfg(test)]
mod tests {
use proptest::prelude::*;
use test_strategy::proptest;

use super::*;

#[test]
Expand All @@ -461,17 +368,6 @@ mod tests {
);
}

#[proptest()]
fn mutation_chance_serializes_correctly(chance: MutationChance) {
prop_assert!(
(serde_json::from_str::<MutationChance>(&serde_json::to_string(&chance).unwrap())
.unwrap()
.into_inner()
- chance.into_inner())
< 1e10
)
}

#[test]
fn mutation_adjust_rate_from_str_returns_correct_result() {
assert_eq!(
Expand Down

0 comments on commit 1ce2a80

Please sign in to comment.