Skip to content

Commit 2d59ac2

Browse files
committed
Test direct running
1 parent f3eca9b commit 2d59ac2

File tree

1 file changed

+93
-74
lines changed

1 file changed

+93
-74
lines changed

bin/run_vega_mpi.py

Lines changed: 93 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,77 +6,96 @@
66

77
from vega import VegaInterface
88

9-
# if __name__ == '__main__':
10-
pars = argparse.ArgumentParser(
11-
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
12-
description='Run Vega in parallel.')
13-
14-
pars.add_argument('config', type=str, default=None, help='Config file')
15-
args = pars.parse_args()
16-
17-
mpi_comm = MPI.COMM_WORLD
18-
cpu_rank = mpi_comm.Get_rank()
19-
20-
def print_func(message):
21-
if cpu_rank == 0:
22-
print(message)
23-
sys.stdout.flush()
24-
mpi_comm.barrier()
25-
26-
print_func('Initializing Vega')
27-
28-
# Initialize Vega and get the sampling parameters
29-
vega = VegaInterface(args.config)
30-
sampling_params = vega.sample_params['limits']
31-
32-
print_func('Finished initializing Vega')
33-
34-
# Check if we need the distortion
35-
use_distortion = vega.main_config['control'].getboolean('use_distortion', True)
36-
if not use_distortion:
37-
for key, data in vega.data.items():
38-
data._distortion_mat = None
39-
test_model = vega.compute_model(vega.params, run_init=True)
40-
41-
# Check if we need to run over a Monte Carlo mock
42-
run_montecarlo = vega.main_config['control'].getboolean('run_montecarlo', False)
43-
if run_montecarlo and vega.mc_config is not None:
44-
# Get the MC seed and forecast flag
45-
seed = vega.main_config['control'].getfloat('mc_seed', 0)
46-
forecast = vega.main_config['control'].getboolean('forecast', False)
47-
48-
# Create the mocks
49-
vega.monte_carlo_sim(vega.mc_config['params'], seed=seed, forecast=forecast)
50-
51-
# Set to sample the MC params
52-
sampling_params = vega.mc_config['sample']['limits']
53-
print_func('Created Monte Carlo realization of the correlation')
54-
elif run_montecarlo:
55-
raise ValueError('You asked to run over a Monte Carlo simulation,'
56-
' but no "[monte carlo]" section provided.')
57-
58-
# Run sampler
59-
if not vega.run_sampler:
60-
raise ValueError('Warning: You called "run_vega_mpi.py" without asking'
61-
' for the sampler. Add "run_sampler = True" to the "[control]" section.')
62-
63-
if vega.sampler == 'Polychord':
64-
from vega.samplers.polychord import Polychord
65-
66-
print_func('Running Polychord')
67-
sampler = Polychord(vega.main_config['Polychord'], sampling_params)
68-
sampler.run(vega.log_lik)
69-
elif vega.sampler == 'PocoMC':
70-
from vega.samplers.pocomc import PocoMC
71-
72-
print_func('Running PocoMC')
73-
sampler = PocoMC(vega.main_config['PocoMC'], sampling_params)
74-
mpi_comm.barrier()
75-
sampler.run(vega.log_lik)
76-
mpi_comm.barrier()
77-
78-
if cpu_rank == 0:
79-
sampler.write_chain()
80-
mpi_comm.barrier()
81-
82-
print_func('Finished running sampler')
9+
if __name__ == '__main__':
10+
pars = argparse.ArgumentParser(
11+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
12+
description='Run Vega in parallel.')
13+
14+
pars.add_argument('config', type=str, default=None, help='Config file')
15+
args = pars.parse_args()
16+
17+
mpi_comm = MPI.COMM_WORLD
18+
cpu_rank = mpi_comm.Get_rank()
19+
20+
def print_func(message):
21+
if cpu_rank == 0:
22+
print(message)
23+
sys.stdout.flush()
24+
mpi_comm.barrier()
25+
26+
print_func('Initializing Vega')
27+
28+
# Initialize Vega and get the sampling parameters
29+
vega = VegaInterface(args.config)
30+
sampling_params = vega.sample_params['limits']
31+
32+
print_func('Finished initializing Vega')
33+
34+
# Check if we need the distortion
35+
use_distortion = vega.main_config['control'].getboolean('use_distortion', True)
36+
if not use_distortion:
37+
for key, data in vega.data.items():
38+
data._distortion_mat = None
39+
test_model = vega.compute_model(vega.params, run_init=True)
40+
41+
# Check if we need to run over a Monte Carlo mock
42+
run_montecarlo = vega.main_config['control'].getboolean('run_montecarlo', False)
43+
if run_montecarlo and vega.mc_config is not None:
44+
# Get the MC seed and forecast flag
45+
seed = vega.main_config['control'].getfloat('mc_seed', 0)
46+
forecast = vega.main_config['control'].getboolean('forecast', False)
47+
48+
# Create the mocks
49+
vega.monte_carlo_sim(vega.mc_config['params'], seed=seed, forecast=forecast)
50+
51+
# Set to sample the MC params
52+
sampling_params = vega.mc_config['sample']['limits']
53+
print_func('Created Monte Carlo realization of the correlation')
54+
elif run_montecarlo:
55+
raise ValueError('You asked to run over a Monte Carlo simulation,'
56+
' but no "[monte carlo]" section provided.')
57+
58+
# Run sampler
59+
if not vega.run_sampler:
60+
raise ValueError('Warning: You called "run_vega_mpi.py" without asking'
61+
' for the sampler. Add "run_sampler = True" to the "[control]" section.')
62+
63+
if vega.sampler == 'Polychord':
64+
from vega.samplers.polychord import Polychord
65+
66+
print_func('Running Polychord')
67+
sampler = Polychord(vega.main_config['Polychord'], sampling_params)
68+
sampler.run(vega.log_lik)
69+
70+
elif vega.sampler == 'PocoMC':
71+
from vega.samplers.pocomc import PocoMC
72+
import pocomc
73+
from multiprocessing import Pool
74+
75+
print_func('Running PocoMC')
76+
sampler = PocoMC(vega.main_config['PocoMC'], sampling_params)
77+
if sampler.use_mpi:
78+
assert False
79+
80+
def log_lik(theta):
81+
params = {name: val for name, val in zip(sampler.names, theta)}
82+
return vega.log_lik(params)
83+
84+
mpi_comm.barrier()
85+
with Pool(sampler.num_cpu) as pool:
86+
sampler.pocomc_sampler = pocomc.Sampler(
87+
sampler.prior, log_lik,
88+
pool=pool, output_dir=sampler.path,
89+
dynamic=sampler.dynamic, precondition=sampler.precondition,
90+
n_effective=sampler.n_effective, n_active=sampler.n_active,
91+
)
92+
sampler.pocomc_sampler.run(sampler.n_total, sampler.n_evidence, save_every=sampler.save_every)
93+
94+
# sampler.run(vega.log_lik)
95+
mpi_comm.barrier()
96+
97+
if cpu_rank == 0:
98+
sampler.write_chain()
99+
mpi_comm.barrier()
100+
101+
print_func('Finished running sampler')

0 commit comments

Comments
 (0)