Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
WuShichao committed Jul 6, 2024
1 parent ba3816d commit 28fc1b2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
7 changes: 6 additions & 1 deletion bin/inference/pycbc_inference
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ with ctx:
if pool.is_main_process():
for fn in [sampler.checkpoint_file, sampler.backup_file]:
with loadfile(fn, 'a') as fp:
fp.write_config_file(cp)
# some models will interally modify original cp for sampling,
# such as joint_primary_marginalized, we need to save original
if hasattr(model, 'original_config'):
fp.write_config_file(model.original_config)
else:
fp.write_config_file(cp)

# Run the sampler
sampler.run()
Expand Down
26 changes: 17 additions & 9 deletions pycbc/inference/models/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,8 @@ def total_loglr(self):
margin_names_vector.remove('logw_partial')
margin_params = {}

print("self.primary_model.current_params: ", self.primary_model.current_params)

if self.static_margin_params_in_other_models:
# Due to the high precision of extrinsic parameters constrined
# by the primary model, the mismatch of wavefroms in others by
Expand All @@ -704,6 +706,7 @@ def total_loglr(self):
# waveform. Using SNR will cancel out the effect of amplitude.err
i_max_extrinsic = numpy.argmax(
numpy.abs(sh_primary) / hh_primary**0.5)
print("i_max_extrinsic: ", i_max_extrinsic)
for p in margin_names_vector:
if isinstance(self.primary_model.current_params[p],
numpy.ndarray):
Expand Down Expand Up @@ -740,12 +743,14 @@ def total_loglr(self):
# not using self.primary_model.current_params, because others_model
# may have its own static parameters
current_params_other = other_model.current_params.copy()
print("[total_loglr] current_params_other 1: ", current_params_other)
if not self.static_margin_params_in_other_models:
for i in range(nums):
current_params_other.update(
{key: value[i] if isinstance(value, numpy.ndarray)
else value for key, value in margin_params.items()})
other_model.update(**current_params_other)
print("[total_loglr] current_params_other 2: ", current_params_other)
other_model.return_sh_hh = True
sh_other, hh_other = other_model.loglr
sh_others[i] += sh_other
Expand All @@ -757,6 +762,7 @@ def total_loglr(self):
{key: value[0] if isinstance(value, numpy.ndarray)
else value for key, value in margin_params.items()})
other_model.update(**current_params_other)
print("other_model.current_params: ", other_model.current_params)
other_model.return_sh_hh = True
sh_other, hh_other = other_model.loglr
other_model.return_sh_hh = False
Expand All @@ -772,6 +778,13 @@ def total_loglr(self):
sh_total = sh_primary + sh_others
hh_total = hh_primary + hh_others

print("sh_primary: ", sh_primary)
print("hh_primary: ", hh_primary)
print("sh_others: ", sh_others)
print("hh_others: ", hh_others)
print("sh_total: ", sh_total)
print("hh_total: ", hh_total)

loglr = self.primary_model.marginalize_loglr(sh_total, hh_total)

return loglr
Expand Down Expand Up @@ -828,8 +841,9 @@ def from_config(cls, cp, **kwargs):
# we need the read from config function from the init; to prevent
# circular imports, we import it here
from pycbc.inference.models import read_from_config
# store the original config file
kwargs['original_config'] = cp
# store the original config file, here use deeocopy to avoid later
# changes of cp affect it
kwargs['original_config'] = cp.__deepcopy__(cp)
# get the submodels
kwargs['primary_lbl'] = shlex.split(cp.get('model', 'primary_model'))
kwargs['others_lbls'] = shlex.split(cp.get('model', 'other_models'))
Expand All @@ -839,9 +853,7 @@ def from_config(cls, cp, **kwargs):
submodel_lbls))
sparam_map = map_params(hpiter(cp.options('static_params'),
submodel_lbls))
print("cp: ", cp)
print("vparam_map: ", vparam_map)
print("sparam_map: ", sparam_map)

# get the acceleration label
kwargs['static_margin_params_in_other_models'] = shlex.split(
cp.get('model', 'static_margin_params_in_other_models'))
Expand Down Expand Up @@ -893,7 +905,6 @@ def from_config(cls, cp, **kwargs):
# set the static params
subcp.add_section('static_params')
for param in sparam_map[lbl]:
print("[set the static params] param: ", param)
subcp.set('static_params', param.subname,
cp.get('static_params', param.fullname))

Expand All @@ -907,7 +918,6 @@ def from_config(cls, cp, **kwargs):
subcp.add_section('variable_params')
for param in vparam_map[lbl]:
if lbl in kwargs['primary_lbl']:
print("[set the variable params] param: ", param)
# set variable_params for the primary model
subcp.set('variable_params', param.subname,
cp.get('variable_params', param.fullname))
Expand Down Expand Up @@ -971,12 +981,10 @@ def from_config(cls, cp, **kwargs):
marginalized_params.append('distance')
if primary_model.marginalize_phase:
marginalized_params.append('coa_phase')
print("marginalized_params: ", marginalized_params)

for p in primary_model.static_params.keys():
p_full = '%s__%s' % (kwargs['primary_lbl'][0], p)
if p_full not in cp['static_params']:
print("p_full: ", p_full)
cp['static_params'][p_full] = "%s" % \
primary_model.static_params[p]

Expand Down

0 comments on commit 28fc1b2

Please sign in to comment.