Skip to content

Commit

Permalink
Generalize sampler interface and add support for pocomc
Browse files Browse the repository at this point in the history
  • Loading branch information
andreicuceu committed Jun 10, 2024
1 parent 6f26d0e commit 753b192
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 194 deletions.
35 changes: 25 additions & 10 deletions bin/run_vega_mpi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/usr/bin/env python
from vega import VegaInterface
from vega.sampler_interface import Sampler
from mpi4py import MPI
import argparse
import sys

from mpi4py import MPI

from vega import VegaInterface

if __name__ == '__main__':
pars = argparse.ArgumentParser(
Expand Down Expand Up @@ -56,12 +56,27 @@ def print_func(message):
' but no "[monte carlo]" section provided.')

# Run sampler
if vega.has_sampler:
print_func('Running the sampler')
sampler = Sampler(vega.main_config['Polychord'],
sampling_params, vega.log_lik)
if not vega.run_sampler:
raise ValueError('Warning: You called "run_vega_mpi.py" without asking'
' for the sampler. Add "run_sampler = True" to the "[control]" section.')

if vega.sampler == 'Polychord':
from vega.samplers.polychord import Polychord

print_func('Running Polychord')
sampler = Polychord(vega.main_config['Polychord'], sampling_params, vega.log_lik)
sampler.run()
else:
raise ValueError('Warning: You called "run_vega_mpi.py" without asking'
' for the sampler. Add "sampler = True" to the "[control]" section.')
elif vega.sampler == 'PocoMC':
from vega.samplers.pocomc import PocoMC

print_func('Running PocoMC')
sampler = PocoMC(vega.main_config['PocoMC'], sampling_params, vega.log_lik)
mpi_comm.barrier()
sampler.run()
mpi_comm.barrier()

if cpu_rank == 0:
sampler.write_chain()
mpi_comm.barrier()

print_func('Finished running sampler')
179 changes: 0 additions & 179 deletions vega/sampler_interface.py

This file was deleted.

Empty file added vega/samplers/__init__.py
Empty file.
63 changes: 63 additions & 0 deletions vega/samplers/pocomc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from pathlib import Path

import numpy as np
import pocomc as pc
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
from scipy.stats import uniform

from vega.samplers.sampler_interface import Sampler


class PocoMC(Sampler):
""" Interface between Vega and the PocoMC sampler """

def __init__(self, pocomc_setup, limits, log_lik_func):
super().__init__(pocomc_setup, limits, log_lik_func)

def get_sampler_settings(self, sampler_config, num_params, num_derived):
# Initialize the pocomc settings
self.precondition = sampler_config.getboolean('precondition', True)
self.dynamic = sampler_config.getboolean('dynamic', True)
self.n_effective = sampler_config.getint('n_effective', 512)
self.n_active = sampler_config.getint('n_active', 256)
self.n_total = sampler_config.getint('n_total', 1024)
self.n_evidence = sampler_config.getint('n_evidence', 0)
self.save_every = sampler_config.getint('save_every', 3)

self.prior = pc.Prior(
[uniform(self.limits[par][0], self.limits[par][1]-self.limits[par][0])
for par in self.limits]
)

def run(self):
""" Run the PocoMC sampler """
mpi_comm = MPI.COMM_WORLD
num_mpi_threads = mpi_comm.Get_size()
with MPIPoolExecutor(num_mpi_threads) as pool:
self.pocomc_sampler = pc.Sampler(
self.prior, self.log_lik, pool=pool,
output_dir=self.path, save_every=self.save_every,
dynamic=self.dynamic, precondition=self.precondition,
n_effective=self.n_effective, n_active=self.n_active,
)
self.pocomc_sampler.run(self.n_total, self.n_evidence)

def write_chain(self):
# Get the weighted posterior samples
samples, weights, logl, logp = self.pocomc_sampler.posterior()

# Write the chain
chain_path = Path(self.path) / (self.name + '.txt')
chain = np.column_stack((weights, logl, samples))
print(f'Writing chain to {chain_path}')
np.savetxt(chain_path, chain, header='Weights, Log Likelihood, ' + ', '.join(self.names))

# Write stats
stats_path = Path(self.path) / (self.name + '.stats')
stats = np.column_stack((weights, logl, logp))
np.savetxt(stats_path, stats, header='Weights, Log Likelihood, Log Prior')

# Print Evidence
logZ, logZerr = self.pocomc_sampler.evidence()
print(f'log(Z) = {logZ} +/- {logZerr}')
Loading

0 comments on commit 753b192

Please sign in to comment.