From e4b4b6e86f5b4bba9fb3258696abebad950a9bfe Mon Sep 17 00:00:00 2001 From: Andrei Cuceu Date: Sun, 9 Jun 2024 20:20:28 -0400 Subject: [PATCH] Fix pocomc wrong argument --- vega/samplers/pocomc.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vega/samplers/pocomc.py b/vega/samplers/pocomc.py index b7beaee..b771f44 100644 --- a/vega/samplers/pocomc.py +++ b/vega/samplers/pocomc.py @@ -1,7 +1,7 @@ from pathlib import Path import numpy as np -import pocomc as pc +import pocomc from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor from scipy.stats import uniform @@ -25,7 +25,7 @@ def get_sampler_settings(self, sampler_config, num_params, num_derived): self.n_evidence = sampler_config.getint('n_evidence', 0) self.save_every = sampler_config.getint('save_every', 3) - self.prior = pc.Prior( + self.prior = pocomc.Prior( [uniform(self.limits[par][0], self.limits[par][1]-self.limits[par][0]) for par in self.limits] ) @@ -35,13 +35,12 @@ def run(self): mpi_comm = MPI.COMM_WORLD num_mpi_threads = mpi_comm.Get_size() with MPIPoolExecutor(num_mpi_threads) as pool: - self.pocomc_sampler = pc.Sampler( - self.prior, self.log_lik, pool=pool, - output_dir=self.path, save_every=self.save_every, + self.pocomc_sampler = pocomc.Sampler( + self.prior, self.log_lik, pool=pool, output_dir=self.path, dynamic=self.dynamic, precondition=self.precondition, n_effective=self.n_effective, n_active=self.n_active, ) - self.pocomc_sampler.run(self.n_total, self.n_evidence) + self.pocomc_sampler.run(self.n_total, self.n_evidence, save_every=self.save_every) def write_chain(self): # Get the weighted posterior samples