Skip to content

Commit 447155e

Browse files
committed
Add multiprocessing option for pocomc
1 parent e4b4b6e commit 447155e

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

vega/samplers/pocomc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pocomc
55
from mpi4py import MPI
66
from mpi4py.futures import MPIPoolExecutor
7+
from multiprocessing import Pool
78
from scipy.stats import uniform
89

910
from vega.samplers.sampler_interface import Sampler
@@ -25,12 +26,21 @@ def get_sampler_settings(self, sampler_config, num_params, num_derived):
2526
self.n_evidence = sampler_config.getint('n_evidence', 0)
2627
self.save_every = sampler_config.getint('save_every', 3)
2728

29+
self.use_mpi = sampler_config.getboolean('use_mpi', False)
30+
self.num_cpu = sampler_config.getint('num_cpu', 64)
31+
2832
self.prior = pocomc.Prior(
2933
[uniform(self.limits[par][0], self.limits[par][1]-self.limits[par][0])
3034
for par in self.limits]
3135
)
3236

3337
def run(self):
38+
if self.use_mpi:
39+
self._run_mpi()
40+
else:
41+
self._run_multiprocessing()
42+
43+
def _run_mpi(self):
3444
""" Run the PocoMC sampler """
3545
mpi_comm = MPI.COMM_WORLD
3646
num_mpi_threads = mpi_comm.Get_size()
@@ -42,6 +52,16 @@ def run(self):
4252
)
4353
self.pocomc_sampler.run(self.n_total, self.n_evidence, save_every=self.save_every)
4454

55+
def _run_multiprocessing(self):
56+
""" Run the PocoMC sampler """
57+
with Pool(self.num_cpu) as pool:
58+
self.pocomc_sampler = pocomc.Sampler(
59+
self.prior, self.log_lik, pool=pool, output_dir=self.path,
60+
dynamic=self.dynamic, precondition=self.precondition,
61+
n_effective=self.n_effective, n_active=self.n_active,
62+
)
63+
self.pocomc_sampler.run(self.n_total, self.n_evidence, save_every=self.save_every)
64+
4565
def write_chain(self):
4666
# Get the weighted posterior samples
4767
samples, weights, logl, logp = self.pocomc_sampler.posterior()

0 commit comments

Comments
 (0)