diff --git a/bin/run_vega_mpi.py b/bin/run_vega_mpi.py index 3b43202..9690990 100644 --- a/bin/run_vega_mpi.py +++ b/bin/run_vega_mpi.py @@ -64,15 +64,15 @@ def print_func(message): from vega.samplers.polychord import Polychord print_func('Running Polychord') - sampler = Polychord(vega.main_config['Polychord'], sampling_params, vega.log_lik) - sampler.run() + sampler = Polychord(vega.main_config['Polychord'], sampling_params) + sampler.run(vega.log_lik) elif vega.sampler == 'PocoMC': from vega.samplers.pocomc import PocoMC print_func('Running PocoMC') - sampler = PocoMC(vega.main_config['PocoMC'], sampling_params, vega.log_lik) + sampler = PocoMC(vega.main_config['PocoMC'], sampling_params) mpi_comm.barrier() - sampler.run() + sampler.run(vega.log_lik) mpi_comm.barrier() if cpu_rank == 0: diff --git a/vega/samplers/pocomc.py b/vega/samplers/pocomc.py index 7f7e137..01daf17 100644 --- a/vega/samplers/pocomc.py +++ b/vega/samplers/pocomc.py @@ -34,33 +34,35 @@ def get_sampler_settings(self, sampler_config, num_params, num_derived): for par in self.limits] ) - def vec_log_lik(self, theta): + def vec_log_lik(self, theta, log_lik_func): params = {name: val for name, val in zip(self.names, theta)} - return self.log_lik(params) + return log_lik_func(params) - def run(self): + def run(self, log_lik_func): if self.use_mpi: - self._run_mpi() + self._run_mpi(log_lik_func) else: - self._run_multiprocessing() + self._run_multiprocessing(log_lik_func) - def _run_mpi(self): + def _run_mpi(self, log_lik_func): """ Run the PocoMC sampler """ mpi_comm = MPI.COMM_WORLD num_mpi_threads = mpi_comm.Get_size() with MPIPoolExecutor(num_mpi_threads) as pool: self.pocomc_sampler = pocomc.Sampler( - self.prior, self.vec_log_lik, pool=pool, output_dir=self.path, + self.prior, self.vec_log_lik, likelihood_args=(log_lik_func), + pool=pool, output_dir=self.path, dynamic=self.dynamic, precondition=self.precondition, n_effective=self.n_effective, n_active=self.n_active, ) self.pocomc_sampler.run(self.n_total, self.n_evidence, save_every=self.save_every) - def _run_multiprocessing(self): + def _run_multiprocessing(self, log_lik_func): """ Run the PocoMC sampler """ with Pool(self.num_cpu) as pool: self.pocomc_sampler = pocomc.Sampler( - self.prior, self.vec_log_lik, pool=pool, output_dir=self.path, + self.prior, self.vec_log_lik, likelihood_args=(log_lik_func), + pool=pool, output_dir=self.path, dynamic=self.dynamic, precondition=self.precondition, n_effective=self.n_effective, n_active=self.n_active, ) diff --git a/vega/samplers/polychord.py b/vega/samplers/polychord.py index 0160f99..30216e3 100644 --- a/vega/samplers/polychord.py +++ b/vega/samplers/polychord.py @@ -96,7 +96,7 @@ def log_lik(theta): for i, name in enumerate(self.names): params[name] = theta[i] - log_lik = self.log_lik(params) + log_lik = log_lik_func(params) return log_lik, [] def prior(hypercube): diff --git a/vega/samplers/sampler_interface.py b/vega/samplers/sampler_interface.py index ac98ee1..1e0817c 100644 --- a/vega/samplers/sampler_interface.py +++ b/vega/samplers/sampler_interface.py @@ -10,7 +10,7 @@ class Sampler: ''' Interface between Vega and the nested sampler PolyChord ''' - def __init__(self, sampler_config, limits, log_lik_func): + def __init__(self, sampler_config, limits): """ Parameters @@ -26,7 +26,7 @@ def __init__(self, sampler_config, limits, log_lik_func): self.names = list(limits.keys()) self.num_params = len(limits) self.num_derived = 0 - self.log_lik = log_lik_func + # self.log_lik = log_lik_func self.getdist_latex = sampler_config.getboolean('getdist_latex', True) # Check limits are well defined @@ -74,5 +74,5 @@ def write_parnames(self, parnames_path): def get_sampler_settings(self, sampler_config, num_params, num_derived): raise NotImplementedError('This method should be implemented in the child class') - def run(self): + def run(self, log_lik_func): raise NotImplementedError('This method should be implemented in the child class')