Skip to content

Commit

Permalink
Directly pass loglik function
Browse files Browse the repository at this point in the history
  • Loading branch information
andreicuceu committed Jun 10, 2024
1 parent eecd6c6 commit 9c0c7ae
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 17 deletions.
8 changes: 4 additions & 4 deletions bin/run_vega_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def print_func(message):
from vega.samplers.polychord import Polychord

print_func('Running Polychord')
sampler = Polychord(vega.main_config['Polychord'], sampling_params, vega.log_lik)
sampler.run()
sampler = Polychord(vega.main_config['Polychord'], sampling_params)
sampler.run(vega.log_lik)
elif vega.sampler == 'PocoMC':
from vega.samplers.pocomc import PocoMC

print_func('Running PocoMC')
sampler = PocoMC(vega.main_config['PocoMC'], sampling_params, vega.log_lik)
sampler = PocoMC(vega.main_config['PocoMC'], sampling_params)
mpi_comm.barrier()
sampler.run()
sampler.run(vega.log_lik)
mpi_comm.barrier()

if cpu_rank == 0:
Expand Down
20 changes: 11 additions & 9 deletions vega/samplers/pocomc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,35 @@ def get_sampler_settings(self, sampler_config, num_params, num_derived):
for par in self.limits]
)

def vec_log_lik(self, theta):
def vec_log_lik(self, theta, log_lik_func):
params = {name: val for name, val in zip(self.names, theta)}
return self.log_lik(params)
return log_lik_func(params)

def run(self):
def run(self, log_lik_func):
if self.use_mpi:
self._run_mpi()
self._run_mpi(log_lik_func)
else:
self._run_multiprocessing()
self._run_multiprocessing(log_lik_func)

def _run_mpi(self):
def _run_mpi(self, log_lik_func):
""" Run the PocoMC sampler """
mpi_comm = MPI.COMM_WORLD
num_mpi_threads = mpi_comm.Get_size()
with MPIPoolExecutor(num_mpi_threads) as pool:
self.pocomc_sampler = pocomc.Sampler(
self.prior, self.vec_log_lik, pool=pool, output_dir=self.path,
self.prior, self.vec_log_lik, likelihood_args=(log_lik_func),
pool=pool, output_dir=self.path,
dynamic=self.dynamic, precondition=self.precondition,
n_effective=self.n_effective, n_active=self.n_active,
)
self.pocomc_sampler.run(self.n_total, self.n_evidence, save_every=self.save_every)

def _run_multiprocessing(self):
def _run_multiprocessing(self, log_lik_func):
""" Run the PocoMC sampler """
with Pool(self.num_cpu) as pool:
self.pocomc_sampler = pocomc.Sampler(
self.prior, self.vec_log_lik, pool=pool, output_dir=self.path,
self.prior, self.vec_log_lik, likelihood_args=(log_lik_func),
pool=pool, output_dir=self.path,
dynamic=self.dynamic, precondition=self.precondition,
n_effective=self.n_effective, n_active=self.n_active,
)
Expand Down
2 changes: 1 addition & 1 deletion vega/samplers/polychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def log_lik(theta):
for i, name in enumerate(self.names):
params[name] = theta[i]

log_lik = self.log_lik(params)
log_lik = log_lik_func(params)
return log_lik, []

def prior(hypercube):
Expand Down
6 changes: 3 additions & 3 deletions vega/samplers/sampler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Sampler:
''' Interface between Vega and the nested sampler PolyChord '''

def __init__(self, sampler_config, limits, log_lik_func):
def __init__(self, sampler_config, limits):
"""
Parameters
Expand All @@ -26,7 +26,7 @@ def __init__(self, sampler_config, limits, log_lik_func):
self.names = list(limits.keys())
self.num_params = len(limits)
self.num_derived = 0
self.log_lik = log_lik_func
# self.log_lik = log_lik_func
self.getdist_latex = sampler_config.getboolean('getdist_latex', True)

# Check limits are well defined
Expand Down Expand Up @@ -74,5 +74,5 @@ def write_parnames(self, parnames_path):
def get_sampler_settings(self, sampler_config, num_params, num_derived):
raise NotImplementedError('This method should be implemented in the child class')

def run(self):
def run(self, log_lik_func):
raise NotImplementedError('This method should be implemented in the child class')

0 comments on commit 9c0c7ae

Please sign in to comment.