From 842132679ad28d839cb512d8aa4924d82ed75d61 Mon Sep 17 00:00:00 2001 From: WuShichao Date: Wed, 21 Feb 2024 14:50:19 +0100 Subject: [PATCH] store primary model label --- pycbc/inference/models/hierarchical.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pycbc/inference/models/hierarchical.py b/pycbc/inference/models/hierarchical.py index d24a12a71d9..75e3a93271a 100644 --- a/pycbc/inference/models/hierarchical.py +++ b/pycbc/inference/models/hierarchical.py @@ -633,6 +633,7 @@ def __init__(self, variable_params, submodels, **kwargs): # assume the ground-based submodel as the primary model self.primary_model = self.submodels[kwargs['primary_lbl'][0]] + self.primary_lbl = kwargs['primary_lbl'][0] self.other_models = self.submodels.copy() self.other_models.pop(kwargs['primary_lbl'][0]) self.other_models = list(self.other_models.values()) @@ -736,10 +737,9 @@ def _loglikelihood(self): # update parameters in primary_model, # other_models will be updated in total_loglr, # because other_models need to handle margin_params - for lbl, _ in enumerate(self.primary_model): - self.primary_model.update( - **{p.subname: self.current_params[p.fullname] - for p in self.param_map[lbl]}) + self.primary_model.update( + **{p.subname: self.current_params[p.fullname] + for p in self.param_map[self.primary_lbl]}) # calculate the combined loglikelihood logl = self.total_loglr() + self.primary_model.lognl + \ @@ -929,9 +929,8 @@ def get_loglr(): # update parameters in primary_model, # other_models will be updated in total_loglr, # because other_models need to handle margin_params - for lbl, _ in enumerate(self.primary_model): - p = {param.subname: self.current_params[param.fullname] - for param in self.param_map[lbl]} + p = {param.subname: self.current_params[param.fullname] + for param in self.param_map[self.primary_lbl]} p.update(rec) self.primary_model.update(**p) return self.total_loglr()