diff --git a/lya_2pt/interface.py b/lya_2pt/interface.py index 20a52ca..1be9fed 100644 --- a/lya_2pt/interface.py +++ b/lya_2pt/interface.py @@ -137,9 +137,11 @@ def read_tracers(self, files=None, mpi_rank=0): if self.num_cpu > 1: with multiprocessing.Pool(processes=self.num_cpu) as pool: - results = list(tqdm.tqdm( - pool.imap(self.read_tracer1, files), total=len(files), position=mpi_rank - )) + if mpi_rank == 0: + results = list(tqdm.tqdm( + pool.imap(self.read_tracer1, files), total=len(files))) + else: + results = list(pool.imap(self.read_tracer1, files)) else: results = [self.read_tracer1(file) for file in files] @@ -226,7 +228,7 @@ def reset_global_counter(): globals.counter = multiprocessing.Value('i', 0) globals.lock = multiprocessing.Lock() - def run(self, healpix_ids=None, mpi_rank=0): + def run(self, healpix_ids=None, mpi_size=1, mpi_rank=0): """Run the computation This can include the correlation function, the distortion matrix, @@ -287,12 +289,15 @@ def run(self, healpix_ids=None, mpi_rank=0): context = multiprocessing.get_context('fork') with context.Pool(processes=self.num_cpu) as pool: num_pairs = pool.map(compute_num_pairs, healpix_ids) - healpix_ids = healpix_ids[np.argsort(num_pairs)[::-1]] + local_hp_ids = self.do_load_balance(healpix_ids, num_pairs, mpi_size, mpi_rank) - results = list(tqdm.tqdm( - pool.imap_unordered(compute_xi_and_fisher, healpix_ids), - total=len(healpix_ids), position=mpi_rank - )) + if mpi_rank == 0: + results = list(tqdm.tqdm( + pool.imap_unordered(compute_xi_and_fisher, local_hp_ids), + total=len(local_hp_ids), + )) + else: + results = list(pool.imap_unordered(compute_xi_and_fisher, local_hp_ids)) for hp_id, res in results: self.optimal_xi_output[hp_id] = res @@ -325,3 +330,25 @@ def write_results(self, mpi_rank=None): ) # TODO: add other modes + + def do_load_balance(self, healpix_ids, num_pairs, mpi_size=1, mpi_rank=0): + sort_idx = np.argsort(num_pairs)[::-1] + local_hp_ids = healpix_ids[sort_idx] + local_weights = np.array(num_pairs)[sort_idx] + + if mpi_size > 1: + allocation = self.compute_balanced_chunks(local_weights, mpi_size) + local_hp_ids = local_hp_ids[allocation == mpi_rank] + + return local_hp_ids + + @staticmethod + def compute_balanced_chunks(weights, mpi_size): + allocation = np.zeros(weights.size, dtype=int) + proc_sums = np.zeros(mpi_size) + for i, w in enumerate(weights): + aloc_idx = proc_sums.argmin() + allocation[i] = aloc_idx + proc_sums[aloc_idx] += w + + return allocation diff --git a/lya_2pt/scripts/run_mpi.py b/lya_2pt/scripts/run_mpi.py index 74f069c..9c2a5a9 100644 --- a/lya_2pt/scripts/run_mpi.py +++ b/lya_2pt/scripts/run_mpi.py @@ -34,23 +34,20 @@ def main(): f"{mpi_size} MPI processes. This is wasteful. " "Please lower the numper of MPI processes.") - num_tasks_per_proc = len(lya2pt.files) // mpi_size - remainder = len(lya2pt.files) % mpi_size - if mpi_rank < remainder: - start = int(mpi_rank * (num_tasks_per_proc + 1)) - stop = int(start + num_tasks_per_proc + 1) - else: - start = int(mpi_rank * num_tasks_per_proc + remainder) - stop = int(start + num_tasks_per_proc) - if mpi_rank == 0: total_t1 = time.time() + print('Reading tracers...', flush=True) + lya2pt.read_tracers(mpi_rank=mpi_rank) + + if mpi_rank == 0: print('Starting computation...', flush=True) + lya2pt.run(mpi_size=mpi_size, mpi_rank=mpi_rank) - lya2pt.read_tracers(lya2pt.files[start:stop]) - lya2pt.run(mpi_rank=mpi_rank) + if mpi_rank == 0: + print('Writing results...', flush=True) lya2pt.write_results(mpi_rank=mpi_rank) + mpi_comm.Barrier() if mpi_rank == 0: total_t2 = time.time() print(f'Total time: {(total_t2-total_t1):.3f} sec', flush=True)