|
9 | 9 | from torch.distributions import Distribution |
10 | 10 | from typing import List, Dict, Optional, Type |
11 | 11 |
|
12 | | -from .variable import Variable |
| 12 | +from .variable import Variable, ExogenousVariable |
13 | 13 | from .cpd import ParametricCPD |
14 | 14 |
|
15 | 15 |
|
@@ -49,7 +49,10 @@ def _reinitialize_with_new_param(instance, key, new_value): |
49 | 49 | if k == key: |
50 | 50 | new_dict[k] = new_value |
51 | 51 | else: |
52 | | - new_dict[k] = getattr(instance, k, None) |
| 52 | + if k == 'bias': |
| 53 | + new_dict[k] = False if instance.bias is None else True |
| 54 | + else: |
| 55 | + new_dict[k] = getattr(instance, k, None) |
53 | 56 |
|
54 | 57 | new_instance = cls(**new_dict) |
55 | 58 |
|
@@ -156,11 +159,14 @@ def _initialize_model(self, input_parametric_cpds: List[ParametricCPD]): |
156 | 159 | if concept in self.concept_to_variable: |
157 | 160 | parametric_cpd.variable = self.concept_to_variable[concept] |
158 | 161 | parametric_cpd.parents = self.concept_to_variable[concept].parents |
159 | | - new_parametrization = _reinitialize_with_new_param(parametric_cpd.parametrization, |
160 | | - 'out_features', |
161 | | - self.concept_to_variable[concept].size) |
162 | | - new_parametric_cpd = ParametricCPD(concepts=[concept], parametrization=new_parametrization) |
163 | | - self.parametric_cpds[concept] = new_parametric_cpd |
| 162 | + if not isinstance(parametric_cpd.variable, ExogenousVariable): |
| 163 | + new_parametrization = _reinitialize_with_new_param(parametric_cpd.parametrization, |
| 164 | + 'out_features', |
| 165 | + self.concept_to_variable[concept].size) |
| 166 | + new_parametric_cpd = ParametricCPD(concepts=[concept], parametrization=new_parametrization) |
| 167 | + self.parametric_cpds[concept] = new_parametric_cpd |
| 168 | + else: |
| 169 | + self.parametric_cpds[concept] = parametric_cpd |
164 | 170 |
|
165 | 171 | # ---- Parent resolution (unchanged) ---- |
166 | 172 | for var in self.variables: |
|
0 commit comments