From 08c91117795256b3826c9b4e0efa3042c99ef82b Mon Sep 17 00:00:00 2001 From: Antonio Bellotta Date: Wed, 2 Oct 2024 14:29:58 +0200 Subject: [PATCH] Fix distribute function so that it redistributes only imbalanced populations --- neurodamus/utils/memory.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/neurodamus/utils/memory.py b/neurodamus/utils/memory.py index c8560554..6e657f54 100644 --- a/neurodamus/utils/memory.py +++ b/neurodamus/utils/memory.py @@ -578,7 +578,7 @@ def validate_inputs_distribute(self, num_ranks, batch_size): # syn_count_metypes = set(self.metype_cell_syn_average) # assert all_metypes <= syn_count_metypes, all_metypes - syn_count_metypes - def check_all_buckets_have_gids(self, bucket_allocation, num_ranks, cycles): + def check_all_buckets_have_gids(self, bucket_allocation, population, num_ranks, cycles): """ Checks if all possible buckets determined by num_ranks and cycles have at least one GID assigned. @@ -586,19 +586,18 @@ def check_all_buckets_have_gids(self, bucket_allocation, num_ranks, cycles): Args: bucket_allocation (dict): The allocation dictionary containing the assignment of GIDs to ranks and cycles. + population (str): The population to check. num_ranks (int): The number of ranks. cycles (int): The number of cycles. Returns: bool: True if all buckets have at least one GID assigned, False otherwise. """ - for pop, rank_allocation in bucket_allocation.items(): - for rank_id in range(num_ranks): - for cycle_id in range(cycles): - if not rank_allocation.get((rank_id, cycle_id)): - logging.warning(f"Bucket ({rank_id}, {cycle_id}) in population '{pop}' " - f"has no GIDs assigned.") - return False + rank_allocation = bucket_allocation.get(population, {}) + for rank_id in range(num_ranks): + for cycle_id in range(cycles): + if not rank_allocation.get((rank_id, cycle_id)): + return False return True @run_only_rank0 @@ -646,12 +645,17 @@ def _calculate_total_elements_per_population(self): bucket_allocation, bucket_memory, metype_memory_usage = self.distribute_cells( num_ranks, cycles, metype_file, batch_size=batch_size ) - valid_distribution = self.check_all_buckets_have_gids(bucket_allocation, - num_ranks, - cycles) + valid_distribution = all( + self.check_all_buckets_have_gids(bucket_allocation, population, num_ranks, cycles) + for population in self.pop_metype_gids.keys() + ) if not valid_distribution: - batch_size = {population: max(0, size - 1) - for population, size in batch_size.items()} + for population, size in batch_size.items(): + if not self.check_population_has_gids_in_all_buckets(bucket_allocation, + population, + num_ranks, + cycles): + batch_size[population] = max(0, size - 1) if all(size == 0 for size in batch_size.values()): raise RuntimeError("Unable to find a valid distribution with the given parameters. "