Skip to content

Commit

Permalink
Test schwimmbad
Browse files Browse the repository at this point in the history
  • Loading branch information
andreicuceu committed Jun 10, 2024
1 parent 2d59ac2 commit e9556df
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
32 changes: 17 additions & 15 deletions bin/run_vega_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,24 @@ def print_func(message):

print_func('Running PocoMC')
sampler = PocoMC(vega.main_config['PocoMC'], sampling_params)
if sampler.use_mpi:
assert False

def log_lik(theta):
params = {name: val for name, val in zip(sampler.names, theta)}
return vega.log_lik(params)
sampler.run(vega.log_lik)

mpi_comm.barrier()
with Pool(sampler.num_cpu) as pool:
sampler.pocomc_sampler = pocomc.Sampler(
sampler.prior, log_lik,
pool=pool, output_dir=sampler.path,
dynamic=sampler.dynamic, precondition=sampler.precondition,
n_effective=sampler.n_effective, n_active=sampler.n_active,
)
sampler.pocomc_sampler.run(sampler.n_total, sampler.n_evidence, save_every=sampler.save_every)
# if sampler.use_mpi:
# assert False

# def log_lik(theta):
# params = {name: val for name, val in zip(sampler.names, theta)}
# return vega.log_lik(params)

# mpi_comm.barrier()
# with Pool(sampler.num_cpu) as pool:
# sampler.pocomc_sampler = pocomc.Sampler(
# sampler.prior, log_lik,
# pool=pool, output_dir=sampler.path,
# dynamic=sampler.dynamic, precondition=sampler.precondition,
# n_effective=sampler.n_effective, n_active=sampler.n_active,
# )
# sampler.pocomc_sampler.run(sampler.n_total, sampler.n_evidence, save_every=sampler.save_every)

# sampler.run(vega.log_lik)
mpi_comm.barrier()
Expand Down
6 changes: 3 additions & 3 deletions vega/samplers/pocomc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
import pocomc
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
from schwimmbad import MPIPool
# from mpi4py.futures import MPIPoolExecutor
from multiprocessing import Pool
from scipy.stats import uniform

Expand Down Expand Up @@ -47,8 +48,7 @@ def run(self, log_lik_func):
def _run_mpi(self, log_lik_func):
""" Run the PocoMC sampler """
mpi_comm = MPI.COMM_WORLD
num_mpi_threads = mpi_comm.Get_size()
with MPIPoolExecutor(num_mpi_threads) as pool:
with MPIPool(mpi_comm) as pool:
self.pocomc_sampler = pocomc.Sampler(
self.prior, self.vec_log_lik, likelihood_args=(log_lik_func),
pool=pool, output_dir=self.path,
Expand Down

0 comments on commit e9556df

Please sign in to comment.