Skip to content

Commit

Permalink
Adapt GSOM network initial state creation
Browse files Browse the repository at this point in the history
Old implementation starts with 4 nodes with weights set to
original data. Having low node size, it leads to high ratio
of data loss before network is trained.

To address this issue, rosomaxa/network algorithm is modified to
retain more initial data and pretrain network on it.
  • Loading branch information
reinterpretcat committed Jan 13, 2025
1 parent 784194b commit ccd6de6
Show file tree
Hide file tree
Showing 8 changed files with 792 additions and 442 deletions.
3 changes: 3 additions & 0 deletions rosomaxa/src/algorithms/gsom/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ pub trait Storage: Display + Send + Sync {
/// Returns a distance between two input weights.
fn distance(&self, a: &[Float], b: &[Float]) -> Float;

/// Shrinks the storage to the specified size.
fn resize(&mut self, size: usize);

/// Returns size of the storage.
fn size(&self) -> usize;
}
Expand Down
221 changes: 157 additions & 64 deletions rosomaxa/src/algorithms/gsom/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ where
}

/// GSOM network configuration.
#[derive(Clone, Debug)]
pub struct NetworkConfig {
/// A size of a node in the storage.
pub node_size: usize,
/// A spread factor.
pub spread_factor: Float,
/// The factor of distribution (FD), used in error distribution stage, 0 < FD < 1
Expand All @@ -50,7 +53,7 @@ pub struct NetworkConfig {
pub learning_rate: Float,
/// A rebalance memory.
pub rebalance_memory: usize,
/// If set to true, initial nodes have error set to the value equal to growing threshold.
/// If set to true, initial nodes have error set to the value equal to a growing threshold.
pub has_initial_error: bool,
}

Expand All @@ -65,39 +68,63 @@ where
F: StorageFactory<C, I, S>,
{
/// Creates a new instance of `Network`.
pub fn new(context: &C, roots: [I; 4], config: NetworkConfig, random: Arc<dyn Random>, storage_factory: F) -> Self {
let dimension = roots[0].weights().len();

assert!(roots.iter().all(|r| r.weights().len() == dimension));
pub fn new<SF>(
context: &C,
initial_data: Vec<I>,
config: NetworkConfig,
random: Arc<dyn Random>,
storage_factory: SF,
) -> GenericResult<Self>
where
SF: Fn(usize) -> F,
{
assert!(!initial_data.is_empty());
let dimension = initial_data[0].weights().len();
let data_size = initial_data.len();
assert!(initial_data.iter().all(|r| r.weights().len() == dimension));
assert!(config.distribution_factor > 0. && config.distribution_factor < 1.);
assert!(config.spread_factor > 0. && config.spread_factor < 1.);

let growing_threshold = -1. * dimension as Float * config.spread_factor.log2();
let initial_error = if config.has_initial_error { growing_threshold } else { 0. };
let noise = Noise::new_with_ratio(1., (0.75, 1.25), random.clone());

// create initial nodes
// note that storage factory creates storage with size up to data_size
// it should help to prevent data lost until the network is rebalanced
let (nodes, min_max_weights) = Self::create_initial_nodes(
context,
roots,
initial_error,
initial_data,
config.rebalance_memory,
&noise,
&storage_factory,
);
&storage_factory(data_size),
// apply small noise to initial weights
Noise::new_with_ratio(1., (0.95, 1.05), random.clone()),
)?;

Self {
// create a network with more aggressive initial parameters
let mut network = Self {
dimension,
growing_threshold,
growing_threshold: -1. * dimension as Float * config.spread_factor.log2(),
distribution_factor: config.distribution_factor,
learning_rate: config.learning_rate,
time: 0,
rebalance_memory: config.rebalance_memory,
min_max_weights,
nodes,
storage_factory,
random,
storage_factory: storage_factory(data_size),
random: random.clone(),
phantom_data: Default::default(),
}
};

// run training loop to balance the network
let allow_growth = true;
let rebalance_count = (data_size / 4).clamp(8, 12);
network.train_loop(context, rebalance_count, allow_growth, |_| ());

// reset to original parameters and make sure that node storages have the desired size
network.storage_factory = storage_factory(config.node_size);
network.time = 0;
network.nodes.iter_mut().for_each(|(_, node)| {
node.storage.resize(config.node_size);
});

Ok(network)
}

/// Sets a new learning rate.
Expand Down Expand Up @@ -137,19 +164,8 @@ where
where
FM: Fn(&mut I),
{
(0..rebalance_count).for_each(|_| {
let mut data = self.nodes.iter_mut().flat_map(|(_, node)| node.storage.drain(0..)).collect::<Vec<_>>();
data.sort_unstable_by(compare_input);
data.dedup_by(|a, b| compare_input(a, b) == Ordering::Equal);
data.shuffle(&mut self.random.get_rng());
data.iter_mut().for_each(&node_fn);

self.train_on_data(context, data, false);

self.nodes.iter_mut().for_each(|(_, node)| {
node.error = 0.;
})
});
let allow_growth = false;
self.train_loop(context, rebalance_count, allow_growth, node_fn);
}

/// Compacts network. `node_filter` should return false for nodes to be removed.
Expand Down Expand Up @@ -204,6 +220,26 @@ where
self.get_nodes().map(|node| node.unified_distance(self, 1)).max_by(|a, b| a.total_cmp(b)).unwrap_or_default()
}

/// Performs training loop multiple times.
fn train_loop<FM>(&mut self, context: &C, rebalance_count: usize, allow_growth: bool, node_fn: FM)
where
FM: Fn(&mut I),
{
(0..rebalance_count).for_each(|_| {
let mut data = self.nodes.iter_mut().flat_map(|(_, node)| node.storage.drain(0..)).collect::<Vec<_>>();
data.sort_unstable_by(compare_input);
data.dedup_by(|a, b| compare_input(a, b) == Ordering::Equal);
data.shuffle(&mut self.random.get_rng());
data.iter_mut().for_each(&node_fn);

self.train_on_data(context, data, allow_growth);

self.nodes.iter_mut().for_each(|(_, node)| {
node.error = 0.;
})
});
}

/// Trains network on an input.
fn train(&mut self, context: &C, input: I, is_new_input: bool) {
debug_assert!(input.weights().len() == self.dimension);
Expand Down Expand Up @@ -419,46 +455,103 @@ where
/// Creates nodes for initial topology.
fn create_initial_nodes(
context: &C,
roots: [I; 4],
initial_error: Float,
data: Vec<I>,
rebalance_memory: usize,
noise: &Noise,
storage_factory: &F,
) -> (NodeHashMap<I, S>, MinMaxWeights) {
let create_node = |coord: Coordinate, input: I| {
let weights = input.weights().iter().map(|&value| noise.generate(value)).collect::<Vec<_>>();
let mut node = Node::<I, S>::new(
coord,
weights.as_slice(),
initial_error,
rebalance_memory,
storage_factory.eval(context),
);
node.storage.add(input);

node
};
noise: Noise,
) -> GenericResult<(NodeHashMap<I, S>, MinMaxWeights)> {
// sample size is 10% of data, bounded between 4-16 nodes
let sample_size = (data.len() as f64 * 0.1).ceil() as usize;
let sample_size = sample_size.clamp(4, 16);

let storage = storage_factory.eval(context);
let initial_node_indices = Self::select_initial_samples(&data, sample_size, &storage, noise.random())
.ok_or_else(|| GenericError::from("cannot select initial samples"))?;

// create initial node coordinates and data assignments (by index)
let grid_size = (initial_node_indices.len() as f64).sqrt().ceil() as i32;
let mut node_assignments: HashMap<Coordinate, Vec<usize>> = initial_node_indices
.iter()
.enumerate()
.map(|(grid_idx, &data_idx)| {
(data_idx, Coordinate((grid_idx as i32) % grid_size, (grid_idx as i32) / grid_size))
})
.collect_group_by_key(|(_, coord)| *coord)
.into_iter()
.map(|(coord, items)| (coord, items.into_iter().map(|(idx, _)| idx).collect()))
.collect();

// assign remaining data points to initial respective coordinates based on relative distance
for (idx, item) in data.iter().enumerate() {
if !initial_node_indices.contains(&idx) {
let get_distance_fn = |coord| {
let init_idx = node_assignments[coord][0];
storage.distance(data[init_idx].weights(), item.weights())
};

node_assignments
.keys()
.min_by(|&left, &right| get_distance_fn(left).total_cmp(&get_distance_fn(right)))
.cloned()
.and_then(|closest_coord| node_assignments.get_mut(&closest_coord))
.ok_or_else(|| GenericError::from("cannot find closest node"))?
.push(idx);
}
}

let dimension = roots[0].weights().len();
let [n00, n01, n11, n10] = roots;
let dimension = data[0].weights().len();
let mut min_max_weights = (vec![Float::MAX; dimension], vec![Float::MIN; dimension]);
let mut nodes = NodeHashMap::default();

let n00 = create_node(Coordinate(0, 0), n00);
let n01 = create_node(Coordinate(0, 1), n01);
let n11 = create_node(Coordinate(1, 1), n11);
let n10 = create_node(Coordinate(1, 0), n10);
// first pass: create nodes using assignments without data yet (as it is not cloneable and we need to keep indices valid)
for (&coord, indices) in node_assignments.iter() {
let init_idx = indices[0];
let weights: Vec<Float> = data[init_idx].weights().iter().map(|&v| noise.generate(v)).collect();
let node = Node::new(coord, &weights, 0., rebalance_memory, storage_factory.eval(context));
update_min_max(&mut min_max_weights, &weights);

let min_max_weights = [&n00, &n01, &n11, &n10].into_iter().fold(
(vec![Float::MAX; dimension], vec![Float::MIN; dimension]),
|mut min_max_weights, node| {
update_min_max(&mut min_max_weights, node.weights.as_slice());
nodes.insert(coord, node);
}

min_max_weights
},
);
// second pass: populate nodes with data drained
for (idx, item) in data.into_iter().enumerate() {
let node = node_assignments
.iter()
.find(|(_, indices)| indices.contains(&idx))
.map(|(coord, _)| coord)
.and_then(|coord| nodes.get_mut(coord))
.ok_or_else(|| GenericError::from("cannot find node for data"))?;

let nodes = [n00, n01, n11, n10].into_iter().map(|node| (node.coordinate, node)).collect::<HashMap<_, _, _>>();
node.storage.add(item);
}

Ok((nodes, min_max_weights))
}

/// Selects initial samples (represented as index in data).
fn select_initial_samples(data: &[I], sample_size: usize, storage: &S, random: &dyn Random) -> Option<Vec<usize>> {
let mut selected_indices = Vec::with_capacity(sample_size);

// select first sample randomly
selected_indices.push(random.uniform_int(0, data.len() as i32 - 1) as usize);

// Select remaining samples maximizing distance
let dist_fn = |selected_indices: &Vec<usize>, idx: usize| {
selected_indices
.iter()
.map(|&sel_idx| storage.distance(data[sel_idx].weights(), data[idx].weights()))
.min_by(|a, b| a.total_cmp(b))
.unwrap_or_default()
};
while selected_indices.len() < sample_size {
let next_idx = (0..data.len())
.filter(|i| !selected_indices.contains(i))
.max_by(|&i, &j| dist_fn(&selected_indices, i).total_cmp(&dist_fn(&selected_indices, j)))?;

selected_indices.push(next_idx);
}

(nodes, min_max_weights)
Some(selected_indices)
}
}

Expand Down
4 changes: 2 additions & 2 deletions rosomaxa/src/example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ pub struct VectorObjective {
pub struct VectorSolution {
/// Solution payload.
pub data: Vec<Float>,
weights: Vec<Float>,
fitness: Float,
pub(crate) weights: Vec<Float>,
pub(crate) fitness: Float,
}

impl VectorContext {
Expand Down
6 changes: 6 additions & 0 deletions rosomaxa/src/population/elitism.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ where
self.individuals.drain(range).collect()
}

/// Shrinks the population to the specified size.
pub fn set_max_population_size(&mut self, max_population_size: usize) {
self.max_population_size = max_population_size;
self.ensure_max_population_size();
}

fn sort(&mut self) {
self.individuals.sort_by(|a, b| self.objective.total_order(a, b));
self.individuals.dedup_by(|a, b| (self.dedup_fn)(&self.objective, a, b));
Expand Down
Loading

0 comments on commit ccd6de6

Please sign in to comment.