Skip to content

Commit

Permalink
param_info is a Parameter, and this is not expected by get_name
Browse files Browse the repository at this point in the history
  • Loading branch information
djinnome committed Oct 31, 2024
1 parent 2e47acc commit e6fa616
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
23 changes: 15 additions & 8 deletions pyciemss/mira_integration/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def _compile_observables_mira(
def _compile_param_values_mira(
src: mira.modeling.Model,
) -> Dict[str, Union[torch.Tensor, pyro.nn.PyroParam, pyro.nn.PyroSample]]:
values = {}
for param_name in _sort_dependencies(src):
param_values = {}
sorted_dependencies = _sort_dependencies(src)
for param_name in sorted_dependencies:
param_info = src.parameters[param_name]
if param_info.placeholder:
continue
Expand All @@ -98,18 +99,24 @@ def _compile_param_values_mira(
if param_dist is None:
param_value = float(param_info.value)
else:
param_value = mira_distribution_to_pyro(param_dist, free_symbols=values)
idx = sorted_dependencies.index(param_name)
param_value = lambda self: mira_distribution_to_pyro(param_dist, {
k: getattr(self, f"persistent_{k}") for k in sorted_dependencies[:idx]
}
)

if isinstance(param_value, torch.nn.Parameter):
values[param_name] = pyro.nn.PyroParam(param_value)
elif isinstance(param_value, pyro.distributions.Distribution):
values[param_name] = pyro.sample(param_name, param_value)
param_values[param_name] = pyro.nn.PyroParam(param_value)
elif isinstance(param_value, pyro.distributions.distribution.Distribution):
param_values[param_name] = pyro.nn.PyroSample(param_value)
elif isinstance(param_value, (numbers.Number, numpy.ndarray, torch.Tensor)):
values[param_name] = torch.as_tensor(param_value, dtype=torch.float32)
param_values[param_name] = torch.as_tensor(param_value, dtype=torch.float32)
elif isinstance(param_value, Callable):
param_values[param_name] = pyro.nn.PyroSample(param_value)
else:
raise TypeError(f"Unknown parameter type: {type(param_value)}")

return values
return param_values


@eval_deriv.register(mira.modeling.Model)
Expand Down
4 changes: 4 additions & 0 deletions pyciemss/mira_integration/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def sort_mira_dependencies(src: mira.modeling.Model) -> list:
"""
dependencies = nx.DiGraph()
for param_name, param_info in src.parameters.items():
#param_name = get_name(param_info)
#if param_info.placeholder:
# continue
param_dist = getattr(param_info, "distribution", None)

if param_dist is None:
dependencies.add_node(param_name)
else:
Expand Down

0 comments on commit e6fa616

Please sign in to comment.