Skip to content

Commit 6497c99

Browse files
committed
Restructure code
1 parent 00a6d3e commit 6497c99

File tree

4 files changed

+182
-120
lines changed

4 files changed

+182
-120
lines changed

bin/run_vega_mpi.py

Lines changed: 114 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,129 @@
11
#!/usr/bin/env python
22
import argparse
3-
import sys
3+
# import sys
44

5-
from mpi4py import MPI
5+
# from mpi4py import MPI
66

7-
from vega import VegaInterface
7+
# from vega import VegaInterface
88

99
if __name__ == '__main__':
1010
pars = argparse.ArgumentParser(
1111
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1212
description='Run Vega in parallel.')
1313

14-
pars.add_argument('config', type=str, default=None, help='Config file')
14+
pars.add_argument('config', type=str, required=True, help='Config file')
15+
pars.add_argument(
16+
'-s', '--sampler', type=str, default='Polychord', required=False,
17+
choices=['Polychord', 'PocoMC'], help='Sampler to use'
18+
)
1519
args = pars.parse_args()
1620

17-
mpi_comm = MPI.COMM_WORLD
18-
cpu_rank = mpi_comm.Get_rank()
19-
20-
def print_func(message):
21-
if cpu_rank == 0:
22-
print(message)
23-
sys.stdout.flush()
24-
mpi_comm.barrier()
25-
26-
print_func('Initializing Vega')
27-
28-
# Initialize Vega and get the sampling parameters
29-
vega = VegaInterface(args.config)
30-
sampling_params = vega.sample_params['limits']
31-
32-
print_func('Finished initializing Vega')
33-
34-
# Check if we need the distortion
35-
use_distortion = vega.main_config['control'].getboolean('use_distortion', True)
36-
if not use_distortion:
37-
for key, data in vega.data.items():
38-
data._distortion_mat = None
39-
test_model = vega.compute_model(vega.params, run_init=True)
40-
41-
# Check if we need to run over a Monte Carlo mock
42-
run_montecarlo = vega.main_config['control'].getboolean('run_montecarlo', False)
43-
if run_montecarlo and vega.mc_config is not None:
44-
# Get the MC seed and forecast flag
45-
seed = vega.main_config['control'].getfloat('mc_seed', 0)
46-
forecast = vega.main_config['control'].getboolean('forecast', False)
47-
48-
# Create the mocks
49-
vega.monte_carlo_sim(vega.mc_config['params'], seed=seed, forecast=forecast)
50-
51-
# Set to sample the MC params
52-
sampling_params = vega.mc_config['sample']['limits']
53-
print_func('Created Monte Carlo realization of the correlation')
54-
elif run_montecarlo:
55-
raise ValueError('You asked to run over a Monte Carlo simulation,'
56-
' but no "[monte carlo]" section provided.')
57-
58-
# Run sampler
59-
if not vega.run_sampler:
60-
raise ValueError('Warning: You called "run_vega_mpi.py" without asking'
61-
' for the sampler. Add "run_sampler = True" to the "[control]" section.')
62-
63-
if vega.sampler == 'Polychord':
21+
if args.sampler == 'Polychord':
6422
from vega.samplers.polychord import Polychord
6523

66-
print_func('Running Polychord')
67-
sampler = Polychord(vega.main_config['Polychord'], sampling_params)
68-
sampler.run(vega.log_lik)
69-
70-
elif vega.sampler == 'PocoMC':
24+
print('Running Polychord')
25+
sampler = Polychord(args.config)
26+
sampler.run()
27+
elif args.sampler == 'PocoMC':
7128
from vega.samplers.pocomc import PocoMC
72-
import pocomc
73-
from multiprocessing import Pool
74-
from schwimmbad import MPIPool
75-
76-
print_func('Running PocoMC')
77-
sampler = PocoMC(vega.main_config['PocoMC'], sampling_params)
78-
# sampler.run(vega.log_lik)
79-
80-
if not sampler.use_mpi:
81-
assert False
82-
83-
def log_lik(theta):
84-
params = {name: val for name, val in zip(sampler.names, theta)}
85-
return vega.log_lik(params)
86-
87-
mpi_comm.barrier()
88-
with MPIPool(mpi_comm) as pool:
89-
sampler.pocomc_sampler = pocomc.Sampler(
90-
sampler.prior, log_lik,
91-
pool=pool, output_dir=sampler.path,
92-
dynamic=sampler.dynamic, precondition=sampler.precondition,
93-
n_effective=sampler.n_effective, n_active=sampler.n_active,
94-
)
95-
sampler.pocomc_sampler.run(sampler.n_total, sampler.n_evidence, save_every=sampler.save_every)
96-
97-
# with Pool(sampler.num_cpu) as pool:
98-
# sampler.pocomc_sampler = pocomc.Sampler(
99-
# sampler.prior, log_lik,
100-
# pool=pool, output_dir=sampler.path,
101-
# dynamic=sampler.dynamic, precondition=sampler.precondition,
102-
# n_effective=sampler.n_effective, n_active=sampler.n_active,
103-
# )
104-
# sampler.pocomc_sampler.run(sampler.n_total, sampler.n_evidence, save_every=sampler.save_every)
105-
106-
# sampler.run(vega.log_lik)
107-
# mpi_comm.barrier()
108-
109-
if cpu_rank == 0:
110-
sampler.write_chain()
111-
# mpi_comm.barrier()
112-
113-
print(f'CPU #{cpu_rank}: Finished running sampler')
29+
30+
print('Running PocoMC')
31+
sampler = PocoMC(args.config)
32+
sampler.run()
33+
34+
# mpi_comm = MPI.COMM_WORLD
35+
# cpu_rank = mpi_comm.Get_rank()
36+
37+
# def print_func(message):
38+
# if cpu_rank == 0:
39+
# print(message)
40+
# sys.stdout.flush()
41+
42+
# print_func('Initializing Vega')
43+
44+
# # Initialize Vega and get the sampling parameters
45+
# vega = VegaInterface(args.config)
46+
# sampling_params = vega.sample_params['limits']
47+
48+
# print_func('Finished initializing Vega')
49+
50+
# # Check if we need the distortion
51+
# use_distortion = vega.main_config['control'].getboolean('use_distortion', True)
52+
# if not use_distortion:
53+
# for key, data in vega.data.items():
54+
# data._distortion_mat = None
55+
# test_model = vega.compute_model(vega.params, run_init=True)
56+
57+
# # Check if we need to run over a Monte Carlo mock
58+
# run_montecarlo = vega.main_config['control'].getboolean('run_montecarlo', False)
59+
# if run_montecarlo and vega.mc_config is not None:
60+
# # Get the MC seed and forecast flag
61+
# seed = vega.main_config['control'].getfloat('mc_seed', 0)
62+
# forecast = vega.main_config['control'].getboolean('forecast', False)
63+
64+
# # Create the mocks
65+
# vega.monte_carlo_sim(vega.mc_config['params'], seed=seed, forecast=forecast)
66+
67+
# # Set to sample the MC params
68+
# sampling_params = vega.mc_config['sample']['limits']
69+
# print_func('Created Monte Carlo realization of the correlation')
70+
# elif run_montecarlo:
71+
# raise ValueError('You asked to run over a Monte Carlo simulation,'
72+
# ' but no "[monte carlo]" section provided.')
73+
74+
# # Run sampler
75+
# if not vega.run_sampler:
76+
# raise ValueError('Warning: You called "run_vega_mpi.py" without asking'
77+
# ' for the sampler. Add "run_sampler = True" to the "[control]" section.')
78+
79+
# if vega.sampler == 'Polychord':
80+
# from vega.samplers.polychord import Polychord
81+
82+
# print_func('Running Polychord')
83+
# sampler = Polychord(vega.main_config['Polychord'], sampling_params)
84+
# sampler.run(vega.log_lik)
85+
86+
# elif vega.sampler == 'PocoMC':
87+
# from vega.samplers.pocomc import PocoMC
88+
# import pocomc
89+
# from multiprocessing import Pool
90+
# from schwimmbad import MPIPool
91+
92+
# print_func('Running PocoMC')
93+
# sampler = PocoMC(vega.main_config['PocoMC'], sampling_params)
94+
# # sampler.run(vega.log_lik)
95+
96+
# if not sampler.use_mpi:
97+
# assert False
98+
99+
# def log_lik(theta):
100+
# params = {name: val for name, val in zip(sampler.names, theta)}
101+
# return vega.log_lik(params)
102+
103+
# mpi_comm.barrier()
104+
# with MPIPool(mpi_comm) as pool:
105+
# sampler.pocomc_sampler = pocomc.Sampler(
106+
# sampler.prior, log_lik,
107+
# pool=pool, output_dir=sampler.path,
108+
# dynamic=sampler.dynamic, precondition=sampler.precondition,
109+
# n_effective=sampler.n_effective, n_active=sampler.n_active,
110+
# )
111+
# sampler.pocomc_sampler.run(sampler.n_total, sampler.n_evidence, save_every=sampler.save_every)
112+
113+
# # with Pool(sampler.num_cpu) as pool:
114+
# # sampler.pocomc_sampler = pocomc.Sampler(
115+
# # sampler.prior, log_lik,
116+
# # pool=pool, output_dir=sampler.path,
117+
# # dynamic=sampler.dynamic, precondition=sampler.precondition,
118+
# # n_effective=sampler.n_effective, n_active=sampler.n_active,
119+
# # )
120+
# # sampler.pocomc_sampler.run(sampler.n_total, sampler.n_evidence, save_every=sampler.save_every)
121+
122+
# # sampler.run(vega.log_lik)
123+
# # mpi_comm.barrier()
124+
125+
# if cpu_rank == 0:
126+
# sampler.write_chain()
127+
# # mpi_comm.barrier()
128+
129+
# print(f'CPU #{cpu_rank}: Finished running sampler')

vega/samplers/pocomc.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
class PocoMC(Sampler):
1515
""" Interface between Vega and the PocoMC sampler """
1616

17-
def __init__(self, pocomc_setup, limits):
18-
super().__init__(pocomc_setup, limits)
17+
def __init__(self, args):
18+
super().__init__(args)
1919

2020
def get_sampler_settings(self, sampler_config, num_params, num_derived):
2121
# Initialize the pocomc settings
@@ -35,34 +35,33 @@ def get_sampler_settings(self, sampler_config, num_params, num_derived):
3535
for par in self.limits]
3636
)
3737

38-
def vec_log_lik(self, theta, log_lik_func):
38+
def vec_log_lik(self, theta):
3939
params = {name: val for name, val in zip(self.names, theta)}
40-
return log_lik_func(params)
40+
return self.vega.log_lik(params)
4141

42-
def run(self, log_lik_func):
42+
def run(self):
4343
if self.use_mpi:
44-
self._run_mpi(log_lik_func)
44+
self._run_mpi()
4545
else:
46-
self._run_multiprocessing(log_lik_func)
46+
self._run_multiprocessing()
4747

48-
def _run_mpi(self, log_lik_func):
48+
self.print_func('Finished running sampler')
49+
50+
def _run_mpi(self):
4951
""" Run the PocoMC sampler """
50-
mpi_comm = MPI.COMM_WORLD
51-
with MPIPool(mpi_comm) as pool:
52+
with MPIPool(self.mpi_comm) as pool:
5253
self.pocomc_sampler = pocomc.Sampler(
53-
self.prior, self.vec_log_lik, likelihood_args=(log_lik_func),
54-
pool=pool, output_dir=self.path,
54+
self.prior, self.vec_log_lik, pool=pool, output_dir=self.path,
5555
dynamic=self.dynamic, precondition=self.precondition,
5656
n_effective=self.n_effective, n_active=self.n_active,
5757
)
5858
self.pocomc_sampler.run(self.n_total, self.n_evidence, save_every=self.save_every)
5959

60-
def _run_multiprocessing(self, log_lik_func):
60+
def _run_multiprocessing(self):
6161
""" Run the PocoMC sampler """
6262
with Pool(self.num_cpu) as pool:
6363
self.pocomc_sampler = pocomc.Sampler(
64-
self.prior, self.vec_log_lik, likelihood_args=(log_lik_func),
65-
pool=pool, output_dir=self.path,
64+
self.prior, self.vec_log_lik, pool=pool, output_dir=self.path,
6665
dynamic=self.dynamic, precondition=self.precondition,
6766
n_effective=self.n_effective, n_active=self.n_active,
6867
)

vega/samplers/polychord.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
class Polychord(Sampler):
99
''' Interface between Vega and the nested sampler PolyChord '''
1010

11-
def __init__(self, polychord_setup, limits, log_lik_func):
12-
super().__init__(polychord_setup, limits, log_lik_func)
11+
def __init__(self, args):
12+
super().__init__(args)
1313

1414
def get_sampler_settings(self, sampler_config, num_params, num_derived):
1515
"""Extract polychord settings and create the settings object.
@@ -96,7 +96,7 @@ def log_lik(theta):
9696
for i, name in enumerate(self.names):
9797
params[name] = theta[i]
9898

99-
log_lik = log_lik_func(params)
99+
log_lik = self.vega.log_lik(params)
100100
return log_lik, []
101101

102102
def prior(hypercube):

vega/samplers/sampler_interface.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44

55
from mpi4py import MPI
66

7+
from vega import VegaInterface
78
from vega.parameters.param_utils import build_names
89

910

1011
class Sampler:
1112
''' Interface between Vega and the nested sampler PolyChord '''
1213

13-
def __init__(self, sampler_config, limits):
14+
def __init__(self, args):
1415
"""
1516
1617
Parameters
@@ -22,11 +23,52 @@ def __init__(self, sampler_config, limits):
2223
log_lik_func : f(params)
2324
Log Likelihood function to be passed to Polychord
2425
"""
25-
self.limits = limits
26-
self.names = list(limits.keys())
27-
self.num_params = len(limits)
26+
self.mpi_comm = MPI.COMM_WORLD
27+
self.cpu_rank = self.mpi_comm.Get_rank()
28+
29+
print_func('Initializing Vega', self.cpu_rank)
30+
31+
# Initialize Vega and get the sampling parameters
32+
self.vega = VegaInterface(args.config)
33+
sampling_params = self.vega.sample_params['limits']
34+
35+
print_func('Finished initializing Vega', self.cpu_rank)
36+
37+
# Check if we need to run over a Monte Carlo mock
38+
run_montecarlo = self.vega.main_config['control'].getboolean('run_montecarlo', False)
39+
if run_montecarlo and self.vega.mc_config is not None:
40+
# Get the MC seed and forecast flag
41+
seed = self.vega.main_config['control'].getint('mc_seed', 0)
42+
forecast = self.vega.main_config['control'].getboolean('forecast', False)
43+
44+
# Create the mocks
45+
self.vega.monte_carlo_sim(self.vega.mc_config['params'], seed=seed, forecast=forecast)
46+
47+
# Set to sample the MC params
48+
sampling_params = self.vega.mc_config['sample']['limits']
49+
print_func('Created Monte Carlo realization of the correlation', self.cpu_rank)
50+
elif run_montecarlo:
51+
raise ValueError('You asked to run over a Monte Carlo simulation,'
52+
' but no "[monte carlo]" section provided.')
53+
54+
# Run sampler
55+
if not self.vega.run_sampler:
56+
raise ValueError(
57+
'Warning: You called "run_vega_mpi.py" without asking'
58+
' for the sampler. Add "run_sampler = True" to the "[control]" section.'
59+
)
60+
61+
self.limits = sampling_params
62+
self.names = list(sampling_params.keys())
63+
self.num_params = len(sampling_params)
2864
self.num_derived = 0
2965
# self.log_lik = log_lik_func
66+
67+
if self.vega.sampler == 'Polychord':
68+
sampler_config = self.vega.main_config['Polychord']
69+
elif self.vega.sampler == 'PocoMC':
70+
sampler_config = self.vega.main_config['PocoMC']
71+
3072
self.getdist_latex = sampler_config.getboolean('getdist_latex', True)
3173

3274
# Check limits are well defined
@@ -52,6 +94,11 @@ def __init__(self, sampler_config, limits):
5294
# Initialize the sampler settings
5395
self.get_sampler_settings(sampler_config, self.num_params, self.num_derived)
5496

97+
def print_func(self, message):
98+
if self.cpu_rank == 0:
99+
print(message)
100+
sys.stdout.flush()
101+
55102
def write_parnames(self, parnames_path):
56103
mpi_comm = MPI.COMM_WORLD
57104
cpu_rank = mpi_comm.Get_rank()

0 commit comments

Comments
 (0)