Skip to content

Commit

Permalink
store primary model label
Browse files Browse the repository at this point in the history
  • Loading branch information
WuShichao committed Feb 21, 2024
1 parent 6ebb2c5 commit 8421326
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions pycbc/inference/models/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ def __init__(self, variable_params, submodels, **kwargs):

# assume the ground-based submodel as the primary model
self.primary_model = self.submodels[kwargs['primary_lbl'][0]]
self.primary_lbl = kwargs['primary_lbl'][0]
self.other_models = self.submodels.copy()
self.other_models.pop(kwargs['primary_lbl'][0])
self.other_models = list(self.other_models.values())
Expand Down Expand Up @@ -736,10 +737,9 @@ def _loglikelihood(self):
# update parameters in primary_model,
# other_models will be updated in total_loglr,
# because other_models need to handle margin_params
for lbl, _ in enumerate(self.primary_model):
self.primary_model.update(
**{p.subname: self.current_params[p.fullname]
for p in self.param_map[lbl]})
self.primary_model.update(
**{p.subname: self.current_params[p.fullname]
for p in self.param_map[self.primary_lbl]})

# calculate the combined loglikelihood
logl = self.total_loglr() + self.primary_model.lognl + \
Expand Down Expand Up @@ -929,9 +929,8 @@ def get_loglr():
# update parameters in primary_model,
# other_models will be updated in total_loglr,
# because other_models need to handle margin_params
for lbl, _ in enumerate(self.primary_model):
p = {param.subname: self.current_params[param.fullname]
for param in self.param_map[lbl]}
p = {param.subname: self.current_params[param.fullname]
for param in self.param_map[self.primary_lbl]}
p.update(rec)
self.primary_model.update(**p)
return self.total_loglr()
Expand Down

0 comments on commit 8421326

Please sign in to comment.