Skip to content

Commit c173033

Browse files
Reinitialize the output features of endogenous variables in probabilistic model
1 parent 24ec245 commit c173033

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

torch_concepts/nn/modules/mid/models/probabilistic_model.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.distributions import Distribution
1010
from typing import List, Dict, Optional, Type
1111

12-
from .variable import Variable
12+
from .variable import Variable, ExogenousVariable
1313
from .cpd import ParametricCPD
1414

1515

@@ -49,7 +49,10 @@ def _reinitialize_with_new_param(instance, key, new_value):
4949
if k == key:
5050
new_dict[k] = new_value
5151
else:
52-
new_dict[k] = getattr(instance, k, None)
52+
if k == 'bias':
53+
new_dict[k] = False if instance.bias is None else True
54+
else:
55+
new_dict[k] = getattr(instance, k, None)
5356

5457
new_instance = cls(**new_dict)
5558

@@ -156,11 +159,14 @@ def _initialize_model(self, input_parametric_cpds: List[ParametricCPD]):
156159
if concept in self.concept_to_variable:
157160
parametric_cpd.variable = self.concept_to_variable[concept]
158161
parametric_cpd.parents = self.concept_to_variable[concept].parents
159-
new_parametrization = _reinitialize_with_new_param(parametric_cpd.parametrization,
160-
'out_features',
161-
self.concept_to_variable[concept].size)
162-
new_parametric_cpd = ParametricCPD(concepts=[concept], parametrization=new_parametrization)
163-
self.parametric_cpds[concept] = new_parametric_cpd
162+
if not isinstance(parametric_cpd.variable, ExogenousVariable):
163+
new_parametrization = _reinitialize_with_new_param(parametric_cpd.parametrization,
164+
'out_features',
165+
self.concept_to_variable[concept].size)
166+
new_parametric_cpd = ParametricCPD(concepts=[concept], parametrization=new_parametrization)
167+
self.parametric_cpds[concept] = new_parametric_cpd
168+
else:
169+
self.parametric_cpds[concept] = parametric_cpd
164170

165171
# ---- Parent resolution (unchanged) ----
166172
for var in self.variables:

0 commit comments

Comments
 (0)