Skip to content

Commit

Permalink
Better load balancing for mpi runs
Browse files Browse the repository at this point in the history
  • Loading branch information
andreicuceu committed Aug 25, 2023
1 parent a573212 commit 8ddea35
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 20 deletions.
45 changes: 36 additions & 9 deletions lya_2pt/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,11 @@ def read_tracers(self, files=None, mpi_rank=0):

if self.num_cpu > 1:
with multiprocessing.Pool(processes=self.num_cpu) as pool:
results = list(tqdm.tqdm(
pool.imap(self.read_tracer1, files), total=len(files), position=mpi_rank
))
if mpi_rank == 0:
results = list(tqdm.tqdm(

Check warning on line 141 in lya_2pt/interface.py

View check run for this annotation

Codecov / codecov/patch

lya_2pt/interface.py#L141

Added line #L141 was not covered by tests
pool.imap(self.read_tracer1, files), total=len(files)))
else:
results = list(pool.imap(self.read_tracer1, files))

Check warning on line 144 in lya_2pt/interface.py

View check run for this annotation

Codecov / codecov/patch

lya_2pt/interface.py#L144

Added line #L144 was not covered by tests
else:
results = [self.read_tracer1(file) for file in files]

Expand Down Expand Up @@ -226,7 +228,7 @@ def reset_global_counter():
globals.counter = multiprocessing.Value('i', 0)
globals.lock = multiprocessing.Lock()

def run(self, healpix_ids=None, mpi_rank=0):
def run(self, healpix_ids=None, mpi_size=1, mpi_rank=0):
"""Run the computation
This can include the correlation function, the distortion matrix,
Expand Down Expand Up @@ -287,12 +289,15 @@ def run(self, healpix_ids=None, mpi_rank=0):
context = multiprocessing.get_context('fork')
with context.Pool(processes=self.num_cpu) as pool:
num_pairs = pool.map(compute_num_pairs, healpix_ids)
healpix_ids = healpix_ids[np.argsort(num_pairs)[::-1]]
local_hp_ids = self.do_load_balance(healpix_ids, num_pairs, mpi_size, mpi_rank)

Check warning on line 292 in lya_2pt/interface.py

View check run for this annotation

Codecov / codecov/patch

lya_2pt/interface.py#L292

Added line #L292 was not covered by tests

results = list(tqdm.tqdm(
pool.imap_unordered(compute_xi_and_fisher, healpix_ids),
total=len(healpix_ids), position=mpi_rank
))
if mpi_rank == 0:
results = list(tqdm.tqdm(

Check warning on line 295 in lya_2pt/interface.py

View check run for this annotation

Codecov / codecov/patch

lya_2pt/interface.py#L295

Added line #L295 was not covered by tests
pool.imap_unordered(compute_xi_and_fisher, local_hp_ids),
total=len(local_hp_ids),
))
else:
results = list(pool.imap_unordered(compute_xi_and_fisher, local_hp_ids))

Check warning on line 300 in lya_2pt/interface.py

View check run for this annotation

Codecov / codecov/patch

lya_2pt/interface.py#L300

Added line #L300 was not covered by tests

for hp_id, res in results:
self.optimal_xi_output[hp_id] = res
Expand Down Expand Up @@ -325,3 +330,25 @@ def write_results(self, mpi_rank=None):
)

# TODO: add other modes

def do_load_balance(self, healpix_ids, num_pairs, mpi_size=1, mpi_rank=0):
sort_idx = np.argsort(num_pairs)[::-1]
local_hp_ids = healpix_ids[sort_idx]
local_weights = np.array(num_pairs)[sort_idx]

Check warning on line 337 in lya_2pt/interface.py

View check run for this annotation

Codecov / codecov/patch

lya_2pt/interface.py#L335-L337

Added lines #L335 - L337 were not covered by tests

if mpi_size > 1:
allocation = self.compute_balanced_chunks(local_weights, mpi_size)
local_hp_ids = local_hp_ids[allocation == mpi_rank]

Check warning on line 341 in lya_2pt/interface.py

View check run for this annotation

Codecov / codecov/patch

lya_2pt/interface.py#L340-L341

Added lines #L340 - L341 were not covered by tests

return local_hp_ids

Check warning on line 343 in lya_2pt/interface.py

View check run for this annotation

Codecov / codecov/patch

lya_2pt/interface.py#L343

Added line #L343 was not covered by tests

@staticmethod
def compute_balanced_chunks(weights, mpi_size):
allocation = np.zeros(weights.size, dtype=int)
proc_sums = np.zeros(mpi_size)

Check warning on line 348 in lya_2pt/interface.py

View check run for this annotation

Codecov / codecov/patch

lya_2pt/interface.py#L347-L348

Added lines #L347 - L348 were not covered by tests
for i, w in enumerate(weights):
aloc_idx = proc_sums.argmin()
allocation[i] = aloc_idx
proc_sums[aloc_idx] += w

Check warning on line 352 in lya_2pt/interface.py

View check run for this annotation

Codecov / codecov/patch

lya_2pt/interface.py#L350-L352

Added lines #L350 - L352 were not covered by tests

return allocation

Check warning on line 354 in lya_2pt/interface.py

View check run for this annotation

Codecov / codecov/patch

lya_2pt/interface.py#L354

Added line #L354 was not covered by tests
19 changes: 8 additions & 11 deletions lya_2pt/scripts/run_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,20 @@ def main():
f"{mpi_size} MPI processes. This is wasteful. "
"Please lower the numper of MPI processes.")

num_tasks_per_proc = len(lya2pt.files) // mpi_size
remainder = len(lya2pt.files) % mpi_size
if mpi_rank < remainder:
start = int(mpi_rank * (num_tasks_per_proc + 1))
stop = int(start + num_tasks_per_proc + 1)
else:
start = int(mpi_rank * num_tasks_per_proc + remainder)
stop = int(start + num_tasks_per_proc)

if mpi_rank == 0:
total_t1 = time.time()
print('Reading tracers...', flush=True)
lya2pt.read_tracers(mpi_rank=mpi_rank)

if mpi_rank == 0:
print('Starting computation...', flush=True)
lya2pt.run(mpi_size=mpi_size, mpi_rank=mpi_rank)

lya2pt.read_tracers(lya2pt.files[start:stop])
lya2pt.run(mpi_rank=mpi_rank)
if mpi_rank == 0:
print('Writing results...', flush=True)
lya2pt.write_results(mpi_rank=mpi_rank)

mpi_comm.Barrier()
if mpi_rank == 0:
total_t2 = time.time()
print(f'Total time: {(total_t2-total_t1):.3f} sec', flush=True)
Expand Down

0 comments on commit 8ddea35

Please sign in to comment.