Skip to content

Commit

Permalink
Remove ndarray dependence
Browse files Browse the repository at this point in the history
`ndarray` is mostly useful for higher-dimensional tensors,
but with PBIL sampling one-point-at-a-time,
no optimizer required higher than 1-D tensors.
Standard Rust `Vec` works fine for 1-D tensors,
and now users can work with standard Rust types
instead of pulling in another dependency.
  • Loading branch information
justinlovinger committed Aug 29, 2023
1 parent 249eac4 commit 39e8af4
Show file tree
Hide file tree
Showing 12 changed files with 154 additions and 241 deletions.
76 changes: 0 additions & 76 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ license = "MIT"
github = { repository = "justinlovinger/optimal-rs", workflow = "build" }

[dependencies]
ndarray = "0.15.6"
optimal-core = { path = "optimal-core" }
optimal-pbil = { path = "optimal-pbil" }
optimal-steepest = { path = "optimal-steepest" }
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ println!(
RealDerivativeConfig::start_default_for(
2,
std::iter::repeat(-10.0..=10.0).take(2),
|point| { point.map(|x| x.powi(2)).sum() },
|point| { point.map(|x| 2.0 * x) }
|point| point.iter().map(|x| x.powi(2)).sum(),
|point| point.iter().map(|x| 2.0 * x).collect(),
)
.nth(100)
.unwrap()
Expand Down
1 change: 0 additions & 1 deletion optimal-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ blanket = "0.3.0"
streaming-iterator = "0.1.9"

[dev-dependencies]
ndarray = "0.15.6"
paste = "1.0.14"
replace_with = "0.1.7"
serde = { version = "1.0.185", features = ["derive"] }
Expand Down
4 changes: 1 addition & 3 deletions optimal-pbil/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ categories = ["science", "mathematics"]
license = "MIT"

[features]
serde = ["dep:serde", "ndarray/serde", "rand/serde1", "rand_xoshiro/serde1"]
serde = ["dep:serde", "rand/serde1", "rand_xoshiro/serde1"]

[dependencies]
default-for = { path = "../default-for" }
derive-bounded = { path = "../derive-bounded" }
derive-getters = "0.3.0"
derive_more = "0.99.17"
ndarray = "0.15.6"
ndarray-rand = "0.14.0"
num-traits = "0.2.16"
once_cell = "1.18.0"
optimal-core = { path = "../optimal-core" }
Expand Down
1 change: 0 additions & 1 deletion optimal-pbil/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Population-based incremental learning (PBIL).
## Examples

```rust
use ndarray::prelude::*;
use optimal_pbil::*;

println!(
Expand Down
39 changes: 20 additions & 19 deletions optimal-pbil/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
//! # Examples
//!
//! ```
//! use ndarray::prelude::*;
//! use optimal_pbil::*;
//!
//! println!(
Expand All @@ -28,7 +27,6 @@ use std::fmt::Debug;
use default_for::DefaultFor;
use derive_getters::Getters;
use derive_more::IsVariant;
use ndarray::prelude::*;
use once_cell::sync::OnceCell;
pub use optimal_core::prelude::*;
use rand_xoshiro::{SplitMix64, Xoshiro256PlusPlus};
Expand All @@ -47,11 +45,11 @@ pub struct MismatchedLengthError;
/// A type containing an array of probabilities.
pub trait Probabilities {
/// Return probabilities.
fn probabilities(&self) -> &Array1<Probability>;
fn probabilities(&self) -> &[Probability];
}

impl<B, F> Probabilities for Pbil<B, F> {
fn probabilities(&self) -> &Array1<Probability> {
fn probabilities(&self) -> &[Probability] {
self.state().probabilities()
}
}
Expand Down Expand Up @@ -147,11 +145,11 @@ impl<B, F> Pbil<B, F> {

impl<B, F> Pbil<B, F>
where
F: Fn(ArrayView1<bool>) -> B,
F: Fn(&[bool]) -> B,
{
/// Return value of the best point discovered.
pub fn best_point_value(&self) -> B {
(self.obj_func)(self.best_point().view())
(self.obj_func)(&self.best_point())
}

/// Return evaluation of current state,
Expand All @@ -168,7 +166,7 @@ where
impl<B, F> StreamingIterator for Pbil<B, F>
where
B: PartialOrd,
F: Fn(ArrayView1<bool>) -> B,
F: Fn(&[bool]) -> B,
{
type Item = Self;

Expand Down Expand Up @@ -207,7 +205,7 @@ where
}

impl<B, F> Optimizer for Pbil<B, F> {
type Point = Array1<bool>;
type Point = Vec<bool>;

fn best_point(&self) -> Self::Point {
self.state.best_point()
Expand Down Expand Up @@ -286,7 +284,7 @@ impl Config {
/// - `obj_func`: objective function to minimize
pub fn start_default_for<B, F>(num_bits: usize, obj_func: F) -> Pbil<B, F>
where
F: Fn(ArrayView1<bool>) -> B,
F: Fn(&[bool]) -> B,
{
Self::default_for(num_bits).start(num_bits, obj_func)
}
Expand All @@ -302,7 +300,7 @@ impl Config {
/// - `obj_func`: objective function to minimize
pub fn start<B, F>(self, num_bits: usize, obj_func: F) -> Pbil<B, F>
where
F: Fn(ArrayView1<bool>) -> B,
F: Fn(&[bool]) -> B,
{
Pbil::new(State::Ready(Ready::initial(num_bits)), self, obj_func)
}
Expand All @@ -318,7 +316,7 @@ impl Config {
/// - `rng`: source of randomness
pub fn start_using<B, F>(self, num_bits: usize, obj_func: F, rng: &mut SplitMix64) -> Pbil<B, F>
where
F: Fn(ArrayView1<bool>) -> B,
F: Fn(&[bool]) -> B,
{
Pbil::new(
State::Ready(Ready::initial_using(num_bits, rng)),
Expand All @@ -337,15 +335,15 @@ impl Config {
/// - `state`: PBIL state to start from
pub fn start_from<B, F>(self, obj_func: F, state: State<B>) -> Pbil<B, F>
where
F: Fn(ArrayView1<bool>) -> B,
F: Fn(&[bool]) -> B,
{
Pbil::new(state, self, obj_func)
}
}

impl<B> State<B> {
/// Return custom initial state.
pub fn new(probabilities: Array1<Probability>, rng: Xoshiro256PlusPlus) -> Self {
pub fn new(probabilities: Vec<Probability>, rng: Xoshiro256PlusPlus) -> Self {
Self::Ready(Ready::new(probabilities, rng))
}

Expand All @@ -355,22 +353,25 @@ impl<B> State<B> {
}

/// Return data to be evaluated.
pub fn evaluatee(&self) -> Option<ArrayView1<bool>> {
pub fn evaluatee(&self) -> Option<&[bool]> {
match self {
State::Ready(s) => Some(s.sample().into()),
State::Sampling(s) => Some(s.sample().into()),
State::Ready(s) => Some(s.sample()),
State::Sampling(s) => Some(s.sample()),
State::Mutating(_) => None,
}
}

/// Return the best point discovered.
pub fn best_point(&self) -> Array1<bool> {
self.probabilities().map(|p| f64::from(*p) >= 0.5)
pub fn best_point(&self) -> Vec<bool> {
self.probabilities()
.iter()
.map(|p| f64::from(*p) >= 0.5)
.collect()
}
}

impl<B> Probabilities for State<B> {
fn probabilities(&self) -> &Array1<Probability> {
fn probabilities(&self) -> &[Probability] {
match &self {
State::Ready(s) => s.probabilities(),
State::Sampling(s) => s.probabilities(),
Expand Down
Loading

0 comments on commit 39e8af4

Please sign in to comment.