4
4
import pocomc
5
5
from mpi4py import MPI
6
6
from mpi4py .futures import MPIPoolExecutor
7
+ from multiprocessing import Pool
7
8
from scipy .stats import uniform
8
9
9
10
from vega .samplers .sampler_interface import Sampler
@@ -25,12 +26,21 @@ def get_sampler_settings(self, sampler_config, num_params, num_derived):
25
26
self .n_evidence = sampler_config .getint ('n_evidence' , 0 )
26
27
self .save_every = sampler_config .getint ('save_every' , 3 )
27
28
29
+ self .use_mpi = sampler_config .getboolean ('use_mpi' , False )
30
+ self .num_cpu = sampler_config .getint ('num_cpu' , 64 )
31
+
28
32
self .prior = pocomc .Prior (
29
33
[uniform (self .limits [par ][0 ], self .limits [par ][1 ]- self .limits [par ][0 ])
30
34
for par in self .limits ]
31
35
)
32
36
33
37
def run (self ):
38
+ if self .use_mpi :
39
+ self ._run_mpi ()
40
+ else :
41
+ self ._run_multiprocessing ()
42
+
43
+ def _run_mpi (self ):
34
44
""" Run the PocoMC sampler """
35
45
mpi_comm = MPI .COMM_WORLD
36
46
num_mpi_threads = mpi_comm .Get_size ()
@@ -42,6 +52,16 @@ def run(self):
42
52
)
43
53
self .pocomc_sampler .run (self .n_total , self .n_evidence , save_every = self .save_every )
44
54
55
+ def _run_multiprocessing (self ):
56
+ """ Run the PocoMC sampler """
57
+ with Pool (self .num_cpu ) as pool :
58
+ self .pocomc_sampler = pocomc .Sampler (
59
+ self .prior , self .log_lik , pool = pool , output_dir = self .path ,
60
+ dynamic = self .dynamic , precondition = self .precondition ,
61
+ n_effective = self .n_effective , n_active = self .n_active ,
62
+ )
63
+ self .pocomc_sampler .run (self .n_total , self .n_evidence , save_every = self .save_every )
64
+
45
65
def write_chain (self ):
46
66
# Get the weighted posterior samples
47
67
samples , weights , logl , logp = self .pocomc_sampler .posterior ()
0 commit comments