Skip to content

Commit

Permalink
Fix distribute function so that it redistributes only imbalanced popu…
Browse files Browse the repository at this point in the history
…lations
  • Loading branch information
st4rl3ss committed Oct 14, 2024
1 parent 4cbf109 commit 08c9111
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions neurodamus/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,27 +578,26 @@ def validate_inputs_distribute(self, num_ranks, batch_size):
# syn_count_metypes = set(self.metype_cell_syn_average)
# assert all_metypes <= syn_count_metypes, all_metypes - syn_count_metypes

def check_all_buckets_have_gids(self, bucket_allocation, num_ranks, cycles):
def check_all_buckets_have_gids(self, bucket_allocation, population, num_ranks, cycles):
"""
Checks if all possible buckets determined by num_ranks and cycles have at least one GID
assigned.
Args:
bucket_allocation (dict): The allocation dictionary containing the assignment of GIDs
to ranks and cycles.
population (str): The population to check.
num_ranks (int): The number of ranks.
cycles (int): The number of cycles.
Returns:
bool: True if all buckets have at least one GID assigned, False otherwise.
"""
for pop, rank_allocation in bucket_allocation.items():
for rank_id in range(num_ranks):
for cycle_id in range(cycles):
if not rank_allocation.get((rank_id, cycle_id)):
logging.warning(f"Bucket ({rank_id}, {cycle_id}) in population '{pop}' "
f"has no GIDs assigned.")
return False
rank_allocation = bucket_allocation.get(population, {})
for rank_id in range(num_ranks):
for cycle_id in range(cycles):
if not rank_allocation.get((rank_id, cycle_id)):
return False
return True

@run_only_rank0
Expand Down Expand Up @@ -646,12 +645,17 @@ def _calculate_total_elements_per_population(self):
bucket_allocation, bucket_memory, metype_memory_usage = self.distribute_cells(
num_ranks, cycles, metype_file, batch_size=batch_size
)
valid_distribution = self.check_all_buckets_have_gids(bucket_allocation,
num_ranks,
cycles)
valid_distribution = all(
self.check_all_buckets_have_gids(bucket_allocation, population, num_ranks, cycles)
for population in self.pop_metype_gids.keys()
)
if not valid_distribution:
batch_size = {population: max(0, size - 1)
for population, size in batch_size.items()}
for population, size in batch_size.items():
if not self.check_population_has_gids_in_all_buckets(bucket_allocation,
population,
num_ranks,
cycles):
batch_size[population] = max(0, size - 1)

if all(size == 0 for size in batch_size.values()):
raise RuntimeError("Unable to find a valid distribution with the given parameters. "
Expand Down

0 comments on commit 08c9111

Please sign in to comment.