diff --git a/rosomaxa/src/algorithms/gsom/mod.rs b/rosomaxa/src/algorithms/gsom/mod.rs index 0cc196c7..22f08f18 100644 --- a/rosomaxa/src/algorithms/gsom/mod.rs +++ b/rosomaxa/src/algorithms/gsom/mod.rs @@ -41,6 +41,9 @@ pub trait Storage: Display + Send + Sync { /// Returns a distance between two input weights. fn distance(&self, a: &[Float], b: &[Float]) -> Float; + /// Shrinks the storage to the specified size. + fn resize(&mut self, size: usize); + /// Returns size of the storage. fn size(&self) -> usize; } diff --git a/rosomaxa/src/algorithms/gsom/network.rs b/rosomaxa/src/algorithms/gsom/network.rs index b82f901d..1875e5ac 100644 --- a/rosomaxa/src/algorithms/gsom/network.rs +++ b/rosomaxa/src/algorithms/gsom/network.rs @@ -41,7 +41,10 @@ where } /// GSOM network configuration. +#[derive(Clone, Debug)] pub struct NetworkConfig { + /// A size of a node in the storage. + pub node_size: usize, /// A spread factor. pub spread_factor: Float, /// The factor of distribution (FD), used in error distribution stage, 0 < FD < 1 @@ -50,7 +53,7 @@ pub struct NetworkConfig { pub learning_rate: Float, /// A rebalance memory. pub rebalance_memory: usize, - /// If set to true, initial nodes have error set to the value equal to growing threshold. + /// If set to true, initial nodes have error set to the value equal to a growing threshold. pub has_initial_error: bool, } @@ -65,39 +68,63 @@ where F: StorageFactory, { /// Creates a new instance of `Network`. - pub fn new(context: &C, roots: [I; 4], config: NetworkConfig, random: Arc, storage_factory: F) -> Self { - let dimension = roots[0].weights().len(); - - assert!(roots.iter().all(|r| r.weights().len() == dimension)); + pub fn new( + context: &C, + initial_data: Vec, + config: NetworkConfig, + random: Arc, + storage_factory: SF, + ) -> GenericResult + where + SF: Fn(usize) -> F, + { + assert!(!initial_data.is_empty()); + let dimension = initial_data[0].weights().len(); + let data_size = initial_data.len(); + assert!(initial_data.iter().all(|r| r.weights().len() == dimension)); assert!(config.distribution_factor > 0. && config.distribution_factor < 1.); assert!(config.spread_factor > 0. && config.spread_factor < 1.); - let growing_threshold = -1. * dimension as Float * config.spread_factor.log2(); - let initial_error = if config.has_initial_error { growing_threshold } else { 0. }; - let noise = Noise::new_with_ratio(1., (0.75, 1.25), random.clone()); - + // create initial nodes + // note that storage factory creates storage with size up to data_size + // it should help to prevent data lost until the network is rebalanced let (nodes, min_max_weights) = Self::create_initial_nodes( context, - roots, - initial_error, + initial_data, config.rebalance_memory, - &noise, - &storage_factory, - ); + &storage_factory(data_size), + // apply small noise to initial weights + Noise::new_with_ratio(1., (0.95, 1.05), random.clone()), + )?; - Self { + // create a network with more aggressive initial parameters + let mut network = Self { dimension, - growing_threshold, + growing_threshold: -1. * dimension as Float * config.spread_factor.log2(), distribution_factor: config.distribution_factor, learning_rate: config.learning_rate, time: 0, rebalance_memory: config.rebalance_memory, min_max_weights, nodes, - storage_factory, - random, + storage_factory: storage_factory(data_size), + random: random.clone(), phantom_data: Default::default(), - } + }; + + // run training loop to balance the network + let allow_growth = true; + let rebalance_count = (data_size / 4).clamp(8, 12); + network.train_loop(context, rebalance_count, allow_growth, |_| ()); + + // reset to original parameters and make sure that node storages have the desired size + network.storage_factory = storage_factory(config.node_size); + network.time = 0; + network.nodes.iter_mut().for_each(|(_, node)| { + node.storage.resize(config.node_size); + }); + + Ok(network) } /// Sets a new learning rate. @@ -137,19 +164,8 @@ where where FM: Fn(&mut I), { - (0..rebalance_count).for_each(|_| { - let mut data = self.nodes.iter_mut().flat_map(|(_, node)| node.storage.drain(0..)).collect::>(); - data.sort_unstable_by(compare_input); - data.dedup_by(|a, b| compare_input(a, b) == Ordering::Equal); - data.shuffle(&mut self.random.get_rng()); - data.iter_mut().for_each(&node_fn); - - self.train_on_data(context, data, false); - - self.nodes.iter_mut().for_each(|(_, node)| { - node.error = 0.; - }) - }); + let allow_growth = false; + self.train_loop(context, rebalance_count, allow_growth, node_fn); } /// Compacts network. `node_filter` should return false for nodes to be removed. @@ -204,6 +220,26 @@ where self.get_nodes().map(|node| node.unified_distance(self, 1)).max_by(|a, b| a.total_cmp(b)).unwrap_or_default() } + /// Performs training loop multiple times. + fn train_loop(&mut self, context: &C, rebalance_count: usize, allow_growth: bool, node_fn: FM) + where + FM: Fn(&mut I), + { + (0..rebalance_count).for_each(|_| { + let mut data = self.nodes.iter_mut().flat_map(|(_, node)| node.storage.drain(0..)).collect::>(); + data.sort_unstable_by(compare_input); + data.dedup_by(|a, b| compare_input(a, b) == Ordering::Equal); + data.shuffle(&mut self.random.get_rng()); + data.iter_mut().for_each(&node_fn); + + self.train_on_data(context, data, allow_growth); + + self.nodes.iter_mut().for_each(|(_, node)| { + node.error = 0.; + }) + }); + } + /// Trains network on an input. fn train(&mut self, context: &C, input: I, is_new_input: bool) { debug_assert!(input.weights().len() == self.dimension); @@ -419,46 +455,103 @@ where /// Creates nodes for initial topology. fn create_initial_nodes( context: &C, - roots: [I; 4], - initial_error: Float, + data: Vec, rebalance_memory: usize, - noise: &Noise, storage_factory: &F, - ) -> (NodeHashMap, MinMaxWeights) { - let create_node = |coord: Coordinate, input: I| { - let weights = input.weights().iter().map(|&value| noise.generate(value)).collect::>(); - let mut node = Node::::new( - coord, - weights.as_slice(), - initial_error, - rebalance_memory, - storage_factory.eval(context), - ); - node.storage.add(input); - - node - }; + noise: Noise, + ) -> GenericResult<(NodeHashMap, MinMaxWeights)> { + // sample size is 10% of data, bounded between 4-16 nodes + let sample_size = (data.len() as f64 * 0.1).ceil() as usize; + let sample_size = sample_size.clamp(4, 16); + + let storage = storage_factory.eval(context); + let initial_node_indices = Self::select_initial_samples(&data, sample_size, &storage, noise.random()) + .ok_or_else(|| GenericError::from("cannot select initial samples"))?; + + // create initial node coordinates and data assignments (by index) + let grid_size = (initial_node_indices.len() as f64).sqrt().ceil() as i32; + let mut node_assignments: HashMap> = initial_node_indices + .iter() + .enumerate() + .map(|(grid_idx, &data_idx)| { + (data_idx, Coordinate((grid_idx as i32) % grid_size, (grid_idx as i32) / grid_size)) + }) + .collect_group_by_key(|(_, coord)| *coord) + .into_iter() + .map(|(coord, items)| (coord, items.into_iter().map(|(idx, _)| idx).collect())) + .collect(); + + // assign remaining data points to initial respective coordinates based on relative distance + for (idx, item) in data.iter().enumerate() { + if !initial_node_indices.contains(&idx) { + let get_distance_fn = |coord| { + let init_idx = node_assignments[coord][0]; + storage.distance(data[init_idx].weights(), item.weights()) + }; + + node_assignments + .keys() + .min_by(|&left, &right| get_distance_fn(left).total_cmp(&get_distance_fn(right))) + .cloned() + .and_then(|closest_coord| node_assignments.get_mut(&closest_coord)) + .ok_or_else(|| GenericError::from("cannot find closest node"))? + .push(idx); + } + } - let dimension = roots[0].weights().len(); - let [n00, n01, n11, n10] = roots; + let dimension = data[0].weights().len(); + let mut min_max_weights = (vec![Float::MAX; dimension], vec![Float::MIN; dimension]); + let mut nodes = NodeHashMap::default(); - let n00 = create_node(Coordinate(0, 0), n00); - let n01 = create_node(Coordinate(0, 1), n01); - let n11 = create_node(Coordinate(1, 1), n11); - let n10 = create_node(Coordinate(1, 0), n10); + // first pass: create nodes using assignments without data yet (as it is not cloneable and we need to keep indices valid) + for (&coord, indices) in node_assignments.iter() { + let init_idx = indices[0]; + let weights: Vec = data[init_idx].weights().iter().map(|&v| noise.generate(v)).collect(); + let node = Node::new(coord, &weights, 0., rebalance_memory, storage_factory.eval(context)); + update_min_max(&mut min_max_weights, &weights); - let min_max_weights = [&n00, &n01, &n11, &n10].into_iter().fold( - (vec![Float::MAX; dimension], vec![Float::MIN; dimension]), - |mut min_max_weights, node| { - update_min_max(&mut min_max_weights, node.weights.as_slice()); + nodes.insert(coord, node); + } - min_max_weights - }, - ); + // second pass: populate nodes with data drained + for (idx, item) in data.into_iter().enumerate() { + let node = node_assignments + .iter() + .find(|(_, indices)| indices.contains(&idx)) + .map(|(coord, _)| coord) + .and_then(|coord| nodes.get_mut(coord)) + .ok_or_else(|| GenericError::from("cannot find node for data"))?; - let nodes = [n00, n01, n11, n10].into_iter().map(|node| (node.coordinate, node)).collect::>(); + node.storage.add(item); + } + + Ok((nodes, min_max_weights)) + } + + /// Selects initial samples (represented as index in data). + fn select_initial_samples(data: &[I], sample_size: usize, storage: &S, random: &dyn Random) -> Option> { + let mut selected_indices = Vec::with_capacity(sample_size); + + // select first sample randomly + selected_indices.push(random.uniform_int(0, data.len() as i32 - 1) as usize); + + // Select remaining samples maximizing distance + let dist_fn = |selected_indices: &Vec, idx: usize| { + selected_indices + .iter() + .map(|&sel_idx| storage.distance(data[sel_idx].weights(), data[idx].weights())) + .min_by(|a, b| a.total_cmp(b)) + .unwrap_or_default() + }; + while selected_indices.len() < sample_size { + let next_idx = (0..data.len()) + .filter(|i| !selected_indices.contains(i)) + .max_by(|&i, &j| dist_fn(&selected_indices, i).total_cmp(&dist_fn(&selected_indices, j)))?; + + selected_indices.push(next_idx); + } - (nodes, min_max_weights) + Some(selected_indices) } } diff --git a/rosomaxa/src/example.rs b/rosomaxa/src/example.rs index 9667b53b..84ea4061 100644 --- a/rosomaxa/src/example.rs +++ b/rosomaxa/src/example.rs @@ -45,8 +45,8 @@ pub struct VectorObjective { pub struct VectorSolution { /// Solution payload. pub data: Vec, - weights: Vec, - fitness: Float, + pub(crate) weights: Vec, + pub(crate) fitness: Float, } impl VectorContext { diff --git a/rosomaxa/src/population/elitism.rs b/rosomaxa/src/population/elitism.rs index 357ee2fc..a8a2780d 100644 --- a/rosomaxa/src/population/elitism.rs +++ b/rosomaxa/src/population/elitism.rs @@ -162,6 +162,12 @@ where self.individuals.drain(range).collect() } + /// Shrinks the population to the specified size. + pub fn set_max_population_size(&mut self, max_population_size: usize) { + self.max_population_size = max_population_size; + self.ensure_max_population_size(); + } + fn sort(&mut self) { self.individuals.sort_by(|a, b| self.objective.total_order(a, b)); self.individuals.dedup_by(|a, b| (self.dedup_fn)(&self.objective, a, b)); diff --git a/rosomaxa/src/population/rosomaxa.rs b/rosomaxa/src/population/rosomaxa.rs index 5e0680c8..6121d895 100644 --- a/rosomaxa/src/population/rosomaxa.rs +++ b/rosomaxa/src/population/rosomaxa.rs @@ -8,8 +8,6 @@ use crate::algorithms::math::relative_distance; use crate::population::elitism::{Alternative, DedupFn}; use crate::utils::{parallel_into_collect, Environment, Random}; use rand::prelude::SliceRandom; -use rayon::iter::Either; -use std::convert::TryInto; use std::f64::consts::{E, PI}; use std::fmt::Formatter; use std::ops::RangeBounds; @@ -17,6 +15,8 @@ use std::sync::Arc; /// Specifies rosomaxa configuration settings. pub struct RosomaxaConfig { + /// Initial population size. + pub initial_size: usize, /// Selection size. pub selection_size: usize, /// Elite population size. @@ -38,6 +38,7 @@ impl RosomaxaConfig { /// account data parallelism settings. pub fn new_with_defaults(selection_size: usize) -> Self { Self { + initial_size: 32, selection_size, elite_size: 2, node_size: 2, @@ -160,12 +161,12 @@ where self.elite .select() .take(elite_explore_size) - .chain(coordinates.iter().flat_map(move |coordinate| { - network - .find(coordinate) - .map(|node| Either::Left(node.storage.population.select().take(node_explore_size))) - .unwrap_or_else(|| Either::Right(std::iter::empty())) - })) + .chain( + coordinates + .iter() + .filter_map(move |coordinate| network.find(coordinate)) + .flat_map(move |node| node.storage.population.select().take(node_explore_size)), + ) .take(*selection_size), ) } @@ -242,24 +243,32 @@ where } }; + let exploration_ratio = match statistics.speed { + HeuristicSpeed::Unknown | HeuristicSpeed::Moderate { .. } => self.config.exploration_ratio, + HeuristicSpeed::Slow { ratio, .. } => self.config.exploration_ratio * ratio, + }; + match &mut self.phase { RosomaxaPhases::Initial { solutions: individuals } => { - if individuals.len() >= self.config.selection_size { - let mut network = Self::create_network( + if statistics.termination_estimate > exploration_ratio { + (self.environment.logger)("skip exploration phase"); + self.phase = RosomaxaPhases::Exploitation { selection_size } + } else if individuals.len() >= self.config.initial_size { + let network = Self::create_network( &self.external_ctx, self.objective.clone(), self.environment.clone(), &self.config, - individuals.drain(0..4).collect(), - ); + std::mem::take(individuals), + ) + // TODO: avoid panic here + .expect("cannot create network"); - std::mem::take(individuals).into_iter().for_each(|individual| { - network.store(&self.external_ctx, init_individual(&self.external_ctx, individual), 0) - }); + let coordinates = network.get_coordinates().collect(); self.phase = RosomaxaPhases::Exploration { network, - coordinates: vec![], + coordinates, statistics: statistics.clone(), selection_size, }; @@ -271,11 +280,6 @@ where statistics: old_statistics, selection_size: old_selection_size, } => { - let exploration_ratio = match old_statistics.speed { - HeuristicSpeed::Unknown | HeuristicSpeed::Moderate { .. } => self.config.exploration_ratio, - HeuristicSpeed::Slow { ratio, .. } => self.config.exploration_ratio * ratio, - }; - if statistics.termination_estimate < exploration_ratio { *old_statistics = statistics.clone(); *old_selection_size = selection_size; @@ -343,22 +347,14 @@ where environment: Arc, config: &RosomaxaConfig, individuals: Vec, - ) -> IndividualNetwork { + ) -> GenericResult> { let inputs_vec = parallel_into_collect(individuals, |i| init_individual(context, i)); - let inputs_slice = inputs_vec.into_boxed_slice(); - let inputs_array: Box<[S; 4]> = match inputs_slice.try_into() { - Ok(ba) => ba, - Err(o) => panic!("expected individuals of length {} but it was {}", 4, o.len()), - }; - - let storage_factory = - IndividualStorageFactory { node_size: config.node_size, random: environment.random.clone(), objective }; - Network::new( context, - *inputs_array, + inputs_vec, NetworkConfig { + node_size: config.node_size, spread_factor: config.spread_factor, distribution_factor: config.distribution_factor, learning_rate: 0.1, @@ -366,7 +362,15 @@ where has_initial_error: true, }, environment.random.clone(), - storage_factory, + { + let objective = objective.clone(); + let random = environment.random.clone(); + move |node_size| IndividualStorageFactory { + node_size, + random: random.clone(), + objective: objective.clone(), + } + }, ) } } @@ -487,6 +491,10 @@ where relative_distance(a.iter().cloned(), b.iter().cloned()) } + fn resize(&mut self, size: usize) { + self.population.set_max_population_size(size); + } + fn size(&self) -> usize { self.population.size() } diff --git a/rosomaxa/tests/helpers/algorithms/gsom/mod.rs b/rosomaxa/tests/helpers/algorithms/gsom/mod.rs index 393592f3..730664c8 100644 --- a/rosomaxa/tests/helpers/algorithms/gsom/mod.rs +++ b/rosomaxa/tests/helpers/algorithms/gsom/mod.rs @@ -1,11 +1,10 @@ use crate::algorithms::gsom::{Input, Network, NetworkConfig, Storage, StorageFactory}; -use crate::algorithms::math::relative_distance; use crate::utils::{DefaultRandom, Float}; use std::fmt::{Display, Formatter}; use std::ops::RangeBounds; use std::sync::Arc; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Data { pub values: Vec, } @@ -27,11 +26,16 @@ pub struct DataStorage { pub data: Vec, } +impl DataStorage { + pub fn cartesian(a: &[Float], b: &[Float]) -> Float { + a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::().sqrt() + } +} + impl Storage for DataStorage { type Item = Data; fn add(&mut self, input: Self::Item) { - self.data.clear(); self.data.push(input); } @@ -47,7 +51,11 @@ impl Storage for DataStorage { } fn distance(&self, a: &[Float], b: &[Float]) -> Float { - relative_distance(a.iter().cloned(), b.iter().cloned()) + Self::cartesian(a, b) + } + + fn resize(&mut self, size: usize) { + self.data.truncate(size); } fn size(&self) -> usize { @@ -72,13 +80,14 @@ impl StorageFactory<(), Data, DataStorage> for DataStorageFactory { pub fn create_test_network(has_initial_error: bool) -> Network<(), Data, DataStorage, DataStorageFactory> { Network::new( &(), - [ + vec![ Data::new(0.230529, 0.956665, 0.482008), Data::new(0.400775, 0.142917, 0.555519), Data::new(0.260272, 0.175342, 0.193711), Data::new(0.186712, 0.166380, 0.773621), ], NetworkConfig { + node_size: 2, spread_factor: 0.25, distribution_factor: 0.25, learning_rate: 0.1, @@ -86,6 +95,7 @@ pub fn create_test_network(has_initial_error: bool) -> Network<(), Data, DataSto has_initial_error, }, Arc::new(DefaultRandom::default()), - DataStorageFactory, + |_| DataStorageFactory, ) + .expect("cannot create network") } diff --git a/rosomaxa/tests/unit/algorithms/gsom/network_test.rs b/rosomaxa/tests/unit/algorithms/gsom/network_test.rs index b007557d..bc2eb7ff 100644 --- a/rosomaxa/tests/unit/algorithms/gsom/network_test.rs +++ b/rosomaxa/tests/unit/algorithms/gsom/network_test.rs @@ -1,290 +1,419 @@ -use crate::algorithms::gsom::{Coordinate, Network}; +use super::*; use crate::helpers::algorithms::gsom::{Data, DataStorage, DataStorageFactory}; -use crate::utils::{Float, Random}; +use crate::helpers::utils::create_test_random; +use crate::utils::Float; +use std::collections::HashSet; type NetworkType = Network<(), Data, DataStorage, DataStorageFactory>; -mod common { - use super::*; - use crate::helpers::algorithms::gsom::create_test_network; - use crate::utils::DefaultRandom; +fn create_config(node_size: usize) -> NetworkConfig { + // NOTE these numbers are used in rosomaxa population + NetworkConfig { + node_size, + spread_factor: 0.75, + distribution_factor: 0.75, + rebalance_memory: 100, + learning_rate: 0.1, + has_initial_error: true, + } +} - #[test] - fn can_train_network() { - let mut network = create_test_network(false); - let samples = [Data::new(1.0, 0.0, 0.0), Data::new(0.0, 1.0, 0.0), Data::new(0.0, 0.0, 1.0)]; +fn count_data_stored(nodes: &NodeHashMap) -> usize { + nodes.values().map(|node| node.storage.size()).sum::() +} - // train - let random = DefaultRandom::default(); - for j in 1..4 { - network.smooth(&(), 4, |_| ()); +fn count_non_empty_nodes(nodes: &NodeHashMap) -> usize { + nodes.values().filter(|node| node.storage.iter().next().is_some()).count() +} - for i in 1..500 { - let idx = random.uniform_int(0, samples.len() as i32 - 1) as usize; - network.store(&(), samples[idx].clone(), j * i + i); +fn distance(i: usize, j: usize, data: &[Data]) -> Float { + DataStorage::cartesian(data[i].weights(), data[j].weights()) +} + +fn create_3d_data_grid(size: usize, step: Float) -> Vec { + let mut data = Vec::new(); + for x in 0..size { + for y in 0..size { + for z in 0..size { + data.push(Data::new(x as Float * step, y as Float * step, z as Float * step)); } } - - assert!(!network.nodes.len() >= 4); - samples.iter().for_each(|sample| { - let node = network.find_bmu(sample); - - assert_eq!(node.storage.data.first().unwrap().values, sample.values); - assert_eq!(node.weights.iter().map(|v| v.round()).collect::>(), sample.values); - }); } + data +} - parameterized_test! {can_use_initial_error_parameter, (has_initial_error, size), { - can_use_initial_error_parameter_impl(has_initial_error, size); - }} - - can_use_initial_error_parameter! { - case01: (false, 4), - case02: (true, 6), - } - - fn can_use_initial_error_parameter_impl(has_initial_error: bool, size: usize) { - let mut network = create_test_network(has_initial_error); - - network.train(&(), Data::new(1.0, 0.0, 0.0), true); - - assert_eq!(network.size(), size); - } +fn create_random_3d_data(size: usize, range: Float) -> Vec { + let random = create_test_random(); + (0..size) + .map(|_| { + Data::new( + random.uniform_real(-range, range), + random.uniform_real(-range, range), + random.uniform_real(-range, range), + ) + }) + .collect() +} - fn get_coord_data(coord: (i32, i32), offset: (i32, i32), network: &NetworkType) -> (Coordinate, Vec) { - let node = network.nodes.get(&Coordinate(coord.0 + offset.0, coord.1 + offset.1)).unwrap(); - let coordinate = node.coordinate; - let weights = node.weights.clone(); +fn create_spiral_data(points: usize, revolutions: Float) -> Vec { + (0..points) + .map(|i| { + let t = i as Float / points as Float * revolutions * 2.0 * std::f64::consts::PI; + let r = t / (2.0 * std::f64::consts::PI); + Data::new(r * t.cos(), r * t.sin(), t / (2.0 * std::f64::consts::PI)) + }) + .collect() +} - (coordinate, weights) +#[test] +fn can_create_network() { + // Setup test data + let initial_data = vec![ + Data::new(1., 2., 0.), + Data::new(2., 3., 0.), + Data::new(3., 4., 0.), + Data::new(4., 5., 0.), + Data::new(5., 6., 0.), + ]; + let config = create_config(10); + let random = create_test_random(); + + let network = + NetworkType::new(&(), initial_data.clone(), config.clone(), random.clone(), |_| DataStorageFactory).unwrap(); + + // Verify network properties + assert_eq!(network.dimension, 3); + assert!((network.growing_threshold - -3. * 0.75_f64.log2()).abs() < 1e-6); + assert_eq!(network.distribution_factor, config.distribution_factor); + assert_eq!(network.learning_rate, config.learning_rate); + assert_eq!(network.rebalance_memory, config.rebalance_memory); + + // Verify initial nodes setup + assert!(network.size() >= 4); // Should have at least 4 initial nodes + assert!(network.size() <= 16); // Should not exceed 16 initial nodes + assert_eq!(count_data_stored(&network.nodes), initial_data.len()); + + // Check node properties + for node in network.get_nodes() { + assert_eq!(node.weights.len(), 3); + assert!(node.error >= 0.); } +} - fn add_node(x: i32, y: i32, network: &mut NetworkType) { - network.insert(&(), Coordinate(x, y), &[x as Float, y as Float]); +#[test] +fn can_create_initial_nodes() { + let context = (); + let data = vec![ + Data::new(1., 1., 0.), // + Data::new(2., 2., 0.), + Data::new(3., 3., 0.), // + Data::new(4., 4., 0.), + ]; + let rebalance_memory = 5; + let storage_factory = DataStorageFactory; + let random = create_test_random(); + let noise = Noise::new_with_ratio(1.0, (1., 1.), random); + + let (nodes, min_max_weights) = + NetworkType::create_initial_nodes(&context, data.clone(), rebalance_memory, &storage_factory, noise).unwrap(); + + // Verify nodes + assert!(nodes.len() >= 4); + assert!(nodes.len() <= 16); + + // Check min-max weights + assert_eq!(min_max_weights.0.len(), 3); + assert_eq!(min_max_weights.1.len(), 3); + assert!(min_max_weights.0[0] <= 1.0); // Min values + assert!(min_max_weights.0[1] <= 1.0); + assert!(min_max_weights.1[0] >= 4.0); // Max values + assert!(min_max_weights.1[1] >= 4.0); + + // Verify node properties + for node in nodes.values() { + assert_eq!(node.weights.len(), 3, "weight dimension"); + assert!(node.storage.size() <= rebalance_memory, "storage size"); + + // Check coordinate bounds based on grid size + let grid_size = (nodes.len() as f64).sqrt().ceil() as i32; + assert!(node.coordinate.0 >= 0 && node.coordinate.0 < grid_size); + assert!(node.coordinate.1 >= 0 && node.coordinate.1 < grid_size); } - fn update_zero_neighborhood(network: &mut NetworkType) { - add_node(-1, 1, network); - add_node(-1, 0, network); - add_node(-1, -1, network); - add_node(0, -1, network); - add_node(1, -1, network); - } + // Verify all data points are assigned to nodes + assert_eq!(count_data_stored(&nodes), data.len()); +} - #[test] - fn can_insert_initial_node_neighborhood() { - let network = create_test_network(false); - assert_eq!(network.nodes.len(), 4); +#[test] +fn can_select_initial_samples() { + let sample_size = 4; + let data = vec![ + Data::new(0.0, 0.0, 0.), + Data::new(1.0, 0.0, 0.), + Data::new(0.0, 1.0, 0.), + Data::new(1.0, 1.0, 0.), + Data::new(0.5, 0.5, 0.), + Data::new(0.2, 0.8, 0.), + Data::new(0.8, 0.2, 0.), + Data::new(0.3, 0.7, 0.), + Data::new(0.7, 0.3, 0.), + Data::new(0.4, 0.6, 0.), + ]; + let random = create_test_random(); + + let selected = + NetworkType::select_initial_samples(&data, sample_size, &DataStorage::default(), random.as_ref()).unwrap(); + + // Verify sample size constraints + assert_eq!(selected.len(), ((data.len() as f64 * 0.1).ceil() as usize).clamp(4, 16)); + // Verify uniqueness + let unique_indices: HashSet<_> = selected.iter().collect(); + assert_eq!(unique_indices.len(), selected.len()); + // Verify all indices are valid + assert!(selected.iter().all(|&idx| idx < data.len())); + + // Verify distance maximization + let min_distances = (0..selected.len()) + .flat_map(|i| { + (i + 1..selected.len()).map({ + let data = &data; + let selected = &selected; + move |j| distance(selected[i], selected[j], data) + }) + }) + .collect::>(); + // Check that selected points maintain some minimum distance from each other + assert!(min_distances.iter().all(|&dist| dist >= 0.5)); +} - assert_eq!(get_coord_data((0, 0), (1, 0), &network).0, Coordinate(1, 0)); - assert_eq!(get_coord_data((0, 0), (0, 1), &network).0, Coordinate(0, 1)); +#[test] +fn can_create_network_large_grid() { + let random = create_test_random(); + for _ in 1..10 { + let initial_data = create_3d_data_grid(3, 0.5); // 27 points + let config = create_config(2); + let network = NetworkType::new(&(), initial_data.clone(), config, random.clone(), |_| DataStorageFactory) + .expect("Network creation failed"); + + let non_empty_nodes = count_non_empty_nodes(&network.nodes); + assert_eq!(network.dimension, 3); + assert!(network.size() >= 4, "too small {}", network.size()); + assert!(network.size() <= 100, "too big {}", network.size()); + assert!(non_empty_nodes > 4, "empty nodes: {} from {}", network.size() - non_empty_nodes, network.size()); + } +} - assert_eq!(get_coord_data((1, 0), (-1, 0), &network).0, Coordinate(0, 0)); - assert_eq!(get_coord_data((1, 0), (0, 1), &network).0, Coordinate(1, 1)); +#[test] +fn can_create_network_random_data() { + let initial_data = create_random_3d_data(50, 10.); + let config = create_config(50); + let random = create_test_random(); + + let network = NetworkType::new(&(), initial_data.clone(), config, random, |_| DataStorageFactory) + .expect("Network creation failed"); + + // Verify node coverage + assert_eq!(count_data_stored(&network.nodes), initial_data.len()); + // Check network metrics + assert!(network.mean_distance() > 0.); + assert!(network.mse() > 0.); + assert!(network.max_unified_distance() > 0.); +} - assert_eq!(get_coord_data((1, 1), (-1, 0), &network).0, Coordinate(0, 1)); - assert_eq!(get_coord_data((1, 1), (0, -1), &network).0, Coordinate(1, 0)); +#[test] +fn can_create_network_with_spiral_distribution() { + let size = 100; + let initial_data = create_spiral_data(size, 3.0); + let config = create_config(size); + let random = create_test_random(); + let network = NetworkType::new(&(), initial_data.clone(), config, random, |_| DataStorageFactory).unwrap(); + + let non_empty_nodes = count_non_empty_nodes(&network.nodes); + assert!(non_empty_nodes >= (size / 3), "too sparse {}", non_empty_nodes); + assert!(network.size() <= (size * 2), "too big {}", network.size()); + let distances: Vec<_> = network + .get_nodes() + .flat_map(|node| node.storage.data.iter().map(|data| DataStorage::cartesian(&node.weights, data.weights()))) + .collect(); + let avg_distance = distances.iter().sum::() / distances.len() as Float; + assert!(avg_distance < 0.5, "too big average: {}", 0.6); +} - assert_eq!(get_coord_data((0, 1), (0, -1), &network).0, Coordinate(0, 0)); - assert_eq!(get_coord_data((0, 1), (1, 0), &network).0, Coordinate(1, 1)); +#[test] +fn can_create_initial_nodes_uniform_distribution() { + let size = 4; + let data = create_3d_data_grid(size, 1.0); // 8 points + let noise = Noise::new_with_ratio(0., (1., 1.), create_test_random()); + + let (nodes, min_max) = NetworkType::create_initial_nodes(&(), data.clone(), 5, &DataStorageFactory, noise) + .expect("Failed to create initial nodes"); + + // Verify min-max bounds + assert_eq!(min_max.0.len(), 3); + assert_eq!(min_max.1.len(), 3); + assert!(min_max.0.iter().all(|&x| x >= 0.)); + assert!(min_max.1.iter().all(|&x| x <= (size - 1) as f64)); + + // Check node distribution + let mut coord_set = HashSet::new(); + for node in nodes.values() { + coord_set.insert(node.coordinate); + assert_eq!(node.weights.len(), 3); + assert!(!node.storage.data.is_empty()); } - #[test] - fn can_create_and_update_extended_neighbourhood() { - let mut network = create_test_network(false); - update_zero_neighborhood(&mut network); - network - .nodes - .get(&Coordinate(0, 0)) - .unwrap() - .neighbours(&network, 1) - .filter_map(|(coord, _)| coord) - .collect::>() - .into_iter() - .for_each(|coord| { - let node = network.nodes.get_mut(&coord).unwrap(); - node.error = 42.; - }); - - // -1+1 0+1 +1+1 - // -1+0 0 0 +1 0 - // -1-1 0-1 +1-1 - assert_eq!(network.nodes.len(), 9); - network.nodes.iter().filter(|(coord, _)| **coord != Coordinate(0, 0)).for_each(|(coord, node)| { - if node.error != 42. { - unreachable!("node is not updated: ({},{}), value: {}", coord.0, coord.1, node.error); - } - }); - [ - (1, (0, 0), 8), - (1, (0, -1), 5), - (1, (0, 1), 5), - (1, (1, 0), 5), - (1, (-1, 0), 5), - (1, (-1, 1), 3), - (1, (1, 1), 3), - (1, (-1, -1), 3), - (1, (1, -1), 3), - ] - .into_iter() - .for_each(|(radius, (x, y), expected_count)| { - let count = network - .nodes - .get(&Coordinate(x, y)) - .unwrap() - .neighbours(&network, radius) - .filter(|(node, _)| node.is_some()) - .count(); - if count != expected_count { - unreachable!("unexpected neighbourhood for: ({},{}), {} vs {}", x, y, count, expected_count) - } - }); - } + // Verify grid arrangement + let grid_size = (nodes.len() as f64).sqrt().ceil() as i32; + assert!(coord_set.iter().all(|c| c.0 < grid_size && c.1 < grid_size)); } -mod node_growing { - use super::*; - use crate::algorithms::gsom::{NetworkConfig, Node}; - use crate::prelude::RandomGen; - use std::sync::Arc; - - fn create_trivial_network(has_initial_error: bool) -> NetworkType { - struct DummyRandom {} - impl Random for DummyRandom { - fn uniform_int(&self, _: i32, _: i32) -> i32 { - unreachable!() - } - - fn uniform_real(&self, _: Float, _: Float) -> Float { - unreachable!() - } +#[test] +fn can_create_initial_nodes_with_outliers() { + let mut data = vec![ + Data::new(0.0, 0.0, 0.0), + Data::new(0.1, 0.1, 0.1), + Data::new(0.2, 0.2, 0.2), + Data::new(10., 10., 10.), + Data::new(-10., -10., -10.), + ]; + data.extend((0..20).map(|i| { + let x = i as Float * 0.1; + Data::new(x, x * x, x * x * x) + })); + + let noise = Noise::new_with_ratio(0., (1., 1.), create_test_random()); + + let (nodes, min_max) = + NetworkType::create_initial_nodes(&(), data.clone(), 10, &DataStorageFactory, noise).unwrap(); + + assert!(min_max.0.iter().zip(min_max.1.iter()).all(|(&min, &max)| min < max)); + assert!(nodes.values().all(|node| !node.storage.data.is_empty())); + + let find_fn = |threshold| { + nodes.values().flat_map(|node| node.storage.data.iter()).any(|data| data.values.iter().all(|&w| w == threshold)) + }; + assert!(find_fn(10.), "cannot handle max outlier"); + assert!(find_fn(-10.), "cannot handle min outlier"); +} - fn is_head_not_tails(&self) -> bool { - unreachable!() - } +parameterized_test! {can_select_initial_samples_edge_cases, (data, sampling), { + can_select_initial_samples_edge_cases_impl(data, sampling) +}} + +can_select_initial_samples_edge_cases! { + case01_single_cluster: (vec![ + Data::new(0.0, 0.0, 0.0), + Data::new(0.1, 0.1, 0.1), + Data::new(0.2, 0.2, 0.2), + Data::new(0.15, 0.15, 0.15), + Data::new(0.05, 0.05, 0.05), + Data::new(0.25, 0.25, 0.25), + Data::new(0.12, 0.12, 0.12), + Data::new(0.18, 0.18, 0.18), + ], (4, 0.05)), + case02_two_distinct_clusters: (vec![ + Data::new(0.0, 0.0, 0.0), + Data::new(0.1, 0.1, 0.1), + Data::new(0.15, 0.15, 0.15), + Data::new(0.2, 0.2, 0.2), + Data::new(10.0, 10.0, 10.0), + Data::new(10.1, 10.1, 10.1), + Data::new(10.15, 10.15, 10.15), + Data::new(10.2, 10.2, 10.2), + ], (2, 17.)), + case03_points_on_axes_and_origin: (vec![ + Data::new(1.0, 0.0, 0.0), + Data::new(0.0, 1.0, 0.0), + Data::new(0.0, 0.0, 1.0), + Data::new(0.0, 0.0, 0.0), + ], (4, 1.)), +} - fn is_hit(&self, _: Float) -> bool { - false - } +fn can_select_initial_samples_edge_cases_impl(data: Vec, sampling: (usize, f64)) { + let (sampling_size, expected_min_distance) = sampling; + + let random = create_test_random(); + let selected = NetworkType::select_initial_samples(&data, sampling_size, &DataStorage::default(), random.as_ref()) + .expect("Failed to select samples"); + + assert_eq!(selected.len(), sampling_size); + assert_eq!(HashSet::<_>::from_iter(selected.iter().copied()).len(), selected.len()); + + // Verify minimum distance between selected samples + let min_distance = selected + .iter() + .enumerate() + .flat_map(|(i, &idx1)| { + selected.iter().skip(i + 1).map({ + let data = &data; + move |&idx2| distance(idx1, idx2, data) + }) + }) + .min_by(|a, b| a.total_cmp(b)) + .unwrap_or(f64::MAX); + + assert!( + min_distance >= expected_min_distance, + "Minimum distance {} is less than expected threshold {} for test case", + min_distance, + expected_min_distance + ); +} - fn weighted(&self, _: &[usize]) -> usize { - unreachable!() - } +#[test] +fn can_select_initial_samples_with_duplicates() { + let sample_size = 4; + let data = + vec![Data::new(1.0, 1.0, 1.0), Data::new(1.0, 1.0, 1.0), Data::new(2.0, 2.0, 2.0), Data::new(2.0, 2.0, 2.0)]; + + let random = create_test_random(); + let selected = + NetworkType::select_initial_samples(&data, sample_size, &DataStorage::default(), random.as_ref()).unwrap(); + + let unique_points: HashSet<_> = selected + .iter() + .map(|&idx| { + let weights = data[idx].weights(); + // NOTE: scale up to collect into hashset as Float doesnt' implement Eq + ( + (weights[0] * 1000.).round() as i64, + (weights[1] * 1000.).round() as i64, + (weights[2] * 1000.).round() as i64, + ) + }) + .collect(); + + assert!(unique_points.len() >= 2); + assert!(selected.len() >= 4); +} - fn get_rng(&self) -> RandomGen { - RandomGen::new_repeatable() +#[test] +fn can_create_new_network_empty_regions() { + let mut initial_data = Vec::new(); + // Create clusters with empty regions between them + for cluster in &[(-5.0, -5.0), (5.0, 5.0), (-5.0, 5.0), (5.0, -5.0)] { + for dx in -1..=1 { + for dy in -1..=1 { + let x = cluster.0 + dx as Float * 0.1; + let y = cluster.1 + dy as Float * 0.1; + initial_data.push(Data::new(x, y, (x * x + y * y).sqrt())); } } - Network::new( - &(), - [ - Data::new(1., 4., 8.), // n00 - Data::new(2., 5., 9.), // n01 - Data::new(3., 8., 7.), // n11 - Data::new(9., 3., 2.), // n10 - ], - NetworkConfig { - spread_factor: 0.25, - distribution_factor: 0.25, - learning_rate: 0.1, - rebalance_memory: 500, - has_initial_error, - }, - Arc::new(DummyRandom {}), - DataStorageFactory, - ) - } - - fn get_node(coord: (i32, i32), network: &NetworkType) -> Option<&Node> { - network.nodes.get(&Coordinate(coord.0, coord.1)) - } - - fn round_weights(weights: &[Float]) -> Vec { - weights.iter().map(|w| (w * 1000.).round() / 1000.).collect() - } - - parameterized_test! {can_grow_initial_nodes_properly, (target_coord, expected_new_nodes), { - can_grow_initial_nodes_properly_impl(target_coord, expected_new_nodes); - }} - - can_grow_initial_nodes_properly! { - case01: ((0, 0), vec![((-1, 0), vec![-6.623, 4.874, 13.497]), ((0, -1), vec![0.073, 2.963, 6.817])]), - case02: ((0, 1), vec![((-1, 0), vec![1.042, 2.0, 10.623]), ((0, 1), vec![2.963, 5.853, 9.707])]), - case03: ((1, 0), vec![((1, 0), vec![16.45, 2.0, -3.78]), ((0, -1), vec![14.455, -1.832, -2.791])]), - case04: ((1, 1), vec![((1, 0), vec![3.927, 10.67, 4.89]), ((0, 1), vec![-2.791, 12.539, 11.581])]), - } - - fn can_grow_initial_nodes_properly_impl( - target_coord: (i32, i32), - expected_new_nodes: Vec<((i32, i32), Vec)>, - ) { - let mut network = create_trivial_network(true); - - network.update(&(), &Coordinate(target_coord.0, target_coord.1), &Data::new(2., 2., 2.), 2., true); - - assert_eq!(network.nodes.len(), 6); - expected_new_nodes.into_iter().for_each(|((offset_x, offset_y), weights)| { - let node = get_node((target_coord.0 + offset_x, target_coord.1 + offset_y), &network).unwrap(); - assert_eq!(node.error, 0.); - assert_eq!(round_weights(node.weights.as_slice()), weights); - }); } - - #[test] - fn can_grow_new_nodes_properly() { - let w1_coord = Coordinate(1, 2); - let mut network = create_trivial_network(true); - network.insert(&(), w1_coord, &[3., 6., 10.]); - - network.update(&(), &Coordinate(w1_coord.0, w1_coord.1), &Data::new(2., 2., 2.), 6., true); - - [ - ((2, 2), vec![2.948, 3.895, 12.423]), - ((0, 2), vec![2.917, 3.833, 12.083]), - ((1, 3), vec![2.929, 3.858, 12.222]), - ] - .into_iter() - .for_each(|(coord, weights)| { - let node = get_node(coord, &network).unwrap(); - let actual = round_weights(node.weights.as_slice()); - assert_eq!(actual, weights); + let config = create_config(12); + let random = create_test_random(); + + let network = NetworkType::new(&(), initial_data, config, random, |_| DataStorageFactory).unwrap(); + + let nodes_vec: Vec<_> = network.get_nodes().collect(); + (0..nodes_vec.len()) + .flat_map(|i| { + (i + 1..nodes_vec.len()).map({ + let nodes_vec = &nodes_vec; + move |j| DataStorage::cartesian(&nodes_vec[i].weights, &nodes_vec[j].weights) + }) + }) + .for_each(|dist| { + assert!(dist > 1., "distance between nodes is too small: {}", dist); }); - } - - #[test] - fn can_calculate_mse() { - let mut network = create_trivial_network(false); - let mse = network.mse(); - assert_eq!(mse, 0.); - - network.smooth(&(), 1, |_| ()); - let mse = network.mse(); - assert!((mse - 0.0001138).abs() < 1E7); - } - - parameterized_test! {can_grow_nodes_with_proper_weights, (coord, expected), { - can_grow_nodes_with_proper_weights_impl(coord, expected); - }} - - can_grow_nodes_with_proper_weights! { - case01_a_case_left_down: ((0, 0), vec![(Coordinate(-1, 0), vec![-7., 5., 14.]), (Coordinate(0, -1), vec![0., 3., 7.])]), - case02_a_case_right_top: ((1, 1), vec![(Coordinate(1, 2), vec![-3., 13., 12.]), (Coordinate(2, 1), vec![4., 11., 5.])]), - case03_ac_cases: ((1, -1), vec![(Coordinate(0, -1), vec![1., -1., 4.]), (Coordinate(1, -2), vec![1., -1., 4.]), (Coordinate(2, -1), vec![1., -1., 4.])]), - case04_ba_cases_left: ((0, 1), vec![(Coordinate(-1, 1), vec![1.5, 4., 11.5]), (Coordinate(0, 2), vec![3., 6., 10.])]), - case05_bd_cases: ((-2, 1), vec![(Coordinate(-3, 1), vec![5., 4.5, 8.]), (Coordinate(-2, 0), vec![5., 4.5, 8.]), (Coordinate(-2, 2), vec![5., 4.5, 8.]), (Coordinate(-1, 1), vec![1.5, 4., 11.5])]), - } - - fn can_grow_nodes_with_proper_weights_impl(coord: (i32, i32), expected: Vec<(Coordinate, Vec)>) { - // n(-2,1)(1., 3., 14.) xx n01(2., 5., 9.) n11(3., 8., 7.) - // n00(1., 4., 8.) n10(9., 3., 2.) - // n1-1(5., 1., 3.) - let mut network = create_trivial_network(false); - network.insert(&(), Coordinate(-2, 1), &[1., 3., 14.]); - network.insert(&(), Coordinate(1, -1), &[5., 1., 3.]); - - let mut nodes = network.grow_nodes(&Coordinate(coord.0, coord.1)); - nodes.sort_by(|a, b| a.0.cmp(&b.0)); - - assert_eq!(nodes, expected); - } } diff --git a/rosomaxa/tests/unit/population/rosomaxa_test.rs b/rosomaxa/tests/unit/population/rosomaxa_test.rs index 04c103b2..2b00ebc4 100644 --- a/rosomaxa/tests/unit/population/rosomaxa_test.rs +++ b/rosomaxa/tests/unit/population/rosomaxa_test.rs @@ -2,104 +2,205 @@ use super::*; use crate::example::*; use crate::helpers::example::create_example_objective; -fn create_rosomaxa( - rebalance_memory: usize, -) -> (Arc, Rosomaxa) { - let mut config = RosomaxaConfig::new_with_defaults(4); - config.rebalance_memory = rebalance_memory; +type RosomaxaType = Rosomaxa; - let objective = create_example_objective(); - let population = - Rosomaxa::new(VectorRosomaxaContext, objective.clone(), Arc::new(Environment::default()), config).unwrap(); +mod selection { + use super::*; - (objective, population) -} + fn create_rosomaxa(initial_size: usize) -> RosomaxaType { + let env = Arc::new(Environment::default()); + let config = RosomaxaConfig { initial_size, ..RosomaxaConfig::new_with_defaults(4) }; + let objective = create_example_objective(); -fn create_statistics(termination_estimate: Float, generation: usize) -> HeuristicStatistics { - HeuristicStatistics { - termination_estimate, - generation, - improvement_1000_ratio: 0.5, - ..HeuristicStatistics::default() + Rosomaxa::new(VectorRosomaxaContext, objective, env, config).unwrap() } -} -fn get_network( - rosomaxa: &Rosomaxa, -) -> &IndividualNetwork { - match &rosomaxa.phase { - RosomaxaPhases::Exploration { network, .. } => network, - _ => unreachable!(), + #[test] + fn can_handle_initial_population() { + let initial_size = 4; + let selection_size = 4; + let elite_size = 2; + let mut rosomaxa = create_rosomaxa(initial_size); + + // Add initial solutions + for i in 0..initial_size { + let solution = VectorSolution { data: vec![i as Float], weights: vec![i as Float], fitness: -(i as Float) }; + rosomaxa.add(solution); + } + + assert_eq!(rosomaxa.selection_phase(), SelectionPhase::Initial); + assert_eq!(rosomaxa.select().count(), selection_size); + assert_eq!(rosomaxa.size(), elite_size); } -} -#[test] -fn can_switch_phases() { - let (objective, mut rosomaxa) = create_rosomaxa(10); + #[test] + fn can_handle_exploration_phase() { + let initial_size = 4; + let selection_size = 4; + let mut rosomaxa = create_rosomaxa(initial_size); - (0..4).for_each(|_| { - assert_eq!(rosomaxa.selection_phase(), SelectionPhase::Initial); - rosomaxa.add_all(vec![VectorSolution::new_with_objective(vec![-1., -1.], objective.as_ref())]); - rosomaxa.update_phase(&create_statistics(0., 0)) - }); - - rosomaxa.add(VectorSolution::new_with_objective(vec![-1., -1.], objective.as_ref())); - assert_eq!(rosomaxa.selection_phase(), SelectionPhase::Exploration); - - for (idx, (termination_estimate, phase)) in - ([(0.7, SelectionPhase::Exploration), (0.9, SelectionPhase::Exploitation)]).iter().enumerate() - { - rosomaxa.update_phase(&create_statistics(*termination_estimate, idx)); - assert_eq!(rosomaxa.selection_phase(), *phase); + // Add solutions to trigger exploration phase + for i in 0..=initial_size { + let solution = VectorSolution { data: vec![i as Float], weights: vec![i as Float], fitness: -(i as Float) }; + rosomaxa.add(solution); + } + + // Force exploration phase + rosomaxa.on_generation(&HeuristicStatistics { termination_estimate: 0.5, ..HeuristicStatistics::default() }); + assert_eq!(rosomaxa.selection_phase(), SelectionPhase::Exploration); + assert_eq!(rosomaxa.select().count(), selection_size); } -} -#[test] -fn can_select_individuals_in_different_phases() { - let (objective, mut rosomaxa) = create_rosomaxa(10); - (0..10).for_each(|idx| { - rosomaxa.add_all(vec![VectorSolution::new_with_objective(vec![-1., -1.], objective.as_ref())]); - rosomaxa.update_phase(&create_statistics(0.75, idx)) - }); - - let individuals = rosomaxa.select(); - assert_eq!(individuals.count(), 4); - assert_eq!(rosomaxa.selection_phase(), SelectionPhase::Exploration); - - rosomaxa.update_phase(&create_statistics(0.95, 10)); - let individuals = rosomaxa.select(); - assert_eq!(individuals.count(), 4); - assert_eq!(rosomaxa.selection_phase(), SelectionPhase::Exploitation); -} + #[test] + fn can_handle_exploitation_phase() { + let initial_size = 4; + let selection_size = 4; + let mut rosomaxa = create_rosomaxa(initial_size); -#[test] -fn can_optimize_network() { - let termination_estimate = 0.75; - let (objective, mut rosomaxa) = create_rosomaxa(2); - (0..10).for_each(|idx| { - let value = idx as Float - 5.; - rosomaxa.add_all(vec![VectorSolution::new_with_objective(vec![value, value], objective.as_ref())]); - rosomaxa.update_phase(&create_statistics(termination_estimate, idx)) - }); + // Add initial solutions + for i in 0..initial_size { + let solution = VectorSolution { data: vec![i as Float], weights: vec![i as Float], fitness: -(i as Float) }; + rosomaxa.add(solution); + } - rosomaxa.add(VectorSolution::new_with_objective(vec![0.5, 0.5], objective.as_ref())); - rosomaxa.update_phase(&create_statistics(termination_estimate, 10)); + // Force exploitation phase + rosomaxa.on_generation(&HeuristicStatistics { termination_estimate: 0.95, ..HeuristicStatistics::default() }); - assert!(get_network(&rosomaxa).get_nodes().next().is_some()); -} + assert_eq!(rosomaxa.selection_phase(), SelectionPhase::Exploitation); + assert_eq!(rosomaxa.select().count(), selection_size); + } -#[test] -fn can_handle_empty_population() { - let (_, mut rosomaxa) = create_rosomaxa(10); + #[test] + fn can_handle_all_phases() { + let initial_size = 4; + let selection_size = 4; + let mut rosomaxa = create_rosomaxa(initial_size); + + // initial phase + for i in 0..(initial_size - 1) { + let solution = VectorSolution { data: vec![i as Float], weights: vec![i as Float], fitness: -(i as Float) }; + rosomaxa.add(solution); + rosomaxa.on_generation(&HeuristicStatistics { termination_estimate: 0., ..HeuristicStatistics::default() }); + assert_eq!(rosomaxa.selection_phase(), SelectionPhase::Initial); + assert_eq!(rosomaxa.select().count(), selection_size.min(i + 1)); + } + + // exploration phase + rosomaxa.add(VectorSolution { + data: vec![initial_size as Float], + weights: vec![initial_size as Float], + fitness: 0., + }); + rosomaxa.on_generation(&HeuristicStatistics { termination_estimate: 0.5, ..HeuristicStatistics::default() }); + assert_eq!(rosomaxa.selection_phase(), SelectionPhase::Exploration); + assert_eq!(rosomaxa.select().count(), selection_size); + + // stays once in exploration and switches to exploitation + for (termination_estimate, phase) in + [(0.7, SelectionPhase::Exploration), (0.9, SelectionPhase::Exploitation)].into_iter() + { + rosomaxa.on_generation(&HeuristicStatistics { termination_estimate, ..HeuristicStatistics::default() }); + assert_eq!(rosomaxa.selection_phase(), phase); + assert_eq!(rosomaxa.select().count(), selection_size); + } + } - for (phase, estimate) in - [(SelectionPhase::Initial, None), (SelectionPhase::Initial, Some(0.7)), (SelectionPhase::Initial, Some(0.95))] - { - if let Some(estimate) = estimate { - rosomaxa.update_phase(&create_statistics(estimate, 10)); + #[test] + fn can_handle_empty_population() { + let initial_size = 4; + let mut rosomaxa = create_rosomaxa(initial_size); + + // here we're stays in initial phase for long time and go directly to exploitation + // as we're lacking solutions for exploration + for (phase, termination_estimate) in [ + (SelectionPhase::Initial, None), + (SelectionPhase::Initial, Some(0.7)), + (SelectionPhase::Exploitation, Some(0.95)), + ] { + if let Some(termination_estimate) = termination_estimate { + rosomaxa.on_generation(&HeuristicStatistics { termination_estimate, ..HeuristicStatistics::default() }); + } + + assert!(rosomaxa.select().next().is_none()); + assert_eq!(rosomaxa.selection_phase(), phase) } + } + + #[test] + fn can_handle_solution_deduplication() { + let initial_size = 4; + let mut rosomaxa = create_rosomaxa(initial_size); + + // Add duplicate solutions + let solution = VectorSolution { data: vec![1.0], weights: vec![], fitness: -1.0 }; + + rosomaxa.add(solution.clone()); + rosomaxa.add(solution); + + assert_eq!(rosomaxa.size(), 1); + } +} + +mod auxiliary { + use super::*; + + #[test] + fn can_create_dedup_fn() { + let objective = create_example_objective(); + let dedup_fn = create_dedup_fn::(0.1); + + // test equal fitness + let solution1 = VectorSolution { data: vec![1.0], weights: vec![1.0], fitness: -1.0 }; + let solution2 = VectorSolution { data: vec![1.0], weights: vec![1.0], fitness: -1.0 }; + assert!(dedup_fn(objective.as_ref(), &solution1, &solution2)); + + // Test similar weights but different fitness + let solution3 = VectorSolution { data: vec![1.05], weights: vec![1.05], fitness: -1.5 }; + assert!(dedup_fn(objective.as_ref(), &solution1, &solution3)); + + // Test different weights + let solution4 = VectorSolution { data: vec![2.0], weights: vec![2.0], fitness: -2.0 }; + assert!(!dedup_fn(objective.as_ref(), &solution1, &solution4)); + } + + #[test] + fn can_get_keep_size() { + let rebalance_memory = 100; + + // early phase + let size_early = get_keep_size(rebalance_memory, 0.0); + assert!(size_early > rebalance_memory * 2); + + // mid phase + let size_mid = get_keep_size(rebalance_memory, 0.5); + assert!(size_mid > rebalance_memory); + assert!(size_mid < size_early); + + // late phase + let size_late = get_keep_size(rebalance_memory, 0.8); + assert!(size_late >= rebalance_memory); + assert!(size_late < size_mid); + } - assert!(rosomaxa.select().next().is_none()); - assert_eq!(rosomaxa.selection_phase(), phase) + #[test] + fn can_get_learning_rate() { + // test learning rate boundaries + assert!(get_learning_rate(0.0) >= 0.1); + assert!(get_learning_rate(1.0) >= 0.1); + + // test cosine annealing pattern + let rate1 = get_learning_rate(0.0); + let rate2 = get_learning_rate(0.125); + let rate3 = get_learning_rate(0.25); + + // rate should decrease initially + assert!(rate1 > rate2); + // rate should increase towards the end of period + assert!(rate2 < rate3); + + // test period cycling + let rate_period1 = get_learning_rate(0.1); + let rate_period2 = get_learning_rate(0.35); + assert!((rate_period1 - rate_period2).abs() < 0.01); } }