From e9556df92d108acca28123b8f30a6e44ffeffb89 Mon Sep 17 00:00:00 2001 From: Andrei Cuceu Date: Mon, 10 Jun 2024 09:59:43 -0400 Subject: [PATCH] Test schwimmbad --- bin/run_vega_mpi.py | 32 +++++++++++++++++--------------- vega/samplers/pocomc.py | 6 +++--- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/bin/run_vega_mpi.py b/bin/run_vega_mpi.py index bfb55b3..8f23bd8 100644 --- a/bin/run_vega_mpi.py +++ b/bin/run_vega_mpi.py @@ -74,22 +74,24 @@ def print_func(message): print_func('Running PocoMC') sampler = PocoMC(vega.main_config['PocoMC'], sampling_params) - if sampler.use_mpi: - assert False - - def log_lik(theta): - params = {name: val for name, val in zip(sampler.names, theta)} - return vega.log_lik(params) + sampler.run(vega.log_lik) - mpi_comm.barrier() - with Pool(sampler.num_cpu) as pool: - sampler.pocomc_sampler = pocomc.Sampler( - sampler.prior, log_lik, - pool=pool, output_dir=sampler.path, - dynamic=sampler.dynamic, precondition=sampler.precondition, - n_effective=sampler.n_effective, n_active=sampler.n_active, - ) - sampler.pocomc_sampler.run(sampler.n_total, sampler.n_evidence, save_every=sampler.save_every) + # if sampler.use_mpi: + # assert False + + # def log_lik(theta): + # params = {name: val for name, val in zip(sampler.names, theta)} + # return vega.log_lik(params) + + # mpi_comm.barrier() + # with Pool(sampler.num_cpu) as pool: + # sampler.pocomc_sampler = pocomc.Sampler( + # sampler.prior, log_lik, + # pool=pool, output_dir=sampler.path, + # dynamic=sampler.dynamic, precondition=sampler.precondition, + # n_effective=sampler.n_effective, n_active=sampler.n_active, + # ) + # sampler.pocomc_sampler.run(sampler.n_total, sampler.n_evidence, save_every=sampler.save_every) # sampler.run(vega.log_lik) mpi_comm.barrier() diff --git a/vega/samplers/pocomc.py b/vega/samplers/pocomc.py index efb547b..f31caa6 100644 --- a/vega/samplers/pocomc.py +++ b/vega/samplers/pocomc.py @@ -3,7 +3,8 @@ import numpy as np import pocomc from mpi4py import MPI -from mpi4py.futures import MPIPoolExecutor +from schwimmbad import MPIPool +# from mpi4py.futures import MPIPoolExecutor from multiprocessing import Pool from scipy.stats import uniform @@ -47,8 +48,7 @@ def run(self, log_lik_func): def _run_mpi(self, log_lik_func): """ Run the PocoMC sampler """ mpi_comm = MPI.COMM_WORLD - num_mpi_threads = mpi_comm.Get_size() - with MPIPoolExecutor(num_mpi_threads) as pool: + with MPIPool(mpi_comm) as pool: self.pocomc_sampler = pocomc.Sampler( self.prior, self.vec_log_lik, likelihood_args=(log_lik_func), pool=pool, output_dir=self.path,