Skip to content

Commit

Permalink
ef-state node supporting hgf learning
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Dec 20, 2024
1 parent a260e8c commit 9c377a3
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 29 deletions.
61 changes: 43 additions & 18 deletions pyhgf/model/add_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def add_ef_state(
n_nodes: int,
node_parameters: Dict,
additional_parameters: Dict,
value_children: Tuple = (None, None),
):
"""Add exponential family state node(s) to a network."""
node_type = 3
Expand All @@ -128,41 +127,67 @@ def add_ef_state(
"learning": "generalised-filtering",
"nus": 3.0,
"xis": jnp.array([0.0, 1.0]),
"mean": 0.0,
"observed": 1,
}

node_parameters = update_parameters(
node_parameters, default_parameters, additional_parameters
)

# the size of the sufficient statistics vector of a multivariate normal
# distribution is given by d + d(d+1) / 2, where d is the dimension
d = node_parameters["dimension"]
n_suff_stats = d + d * (d + 1) // 2
node_parameters["mean"] = jnp.zeros(d) if d > 1 else 0.0
node_parameters["xis"] = jnp.ones(n_suff_stats)
if "hgf" in node_parameters["learning"]:
node_parameters["nus"] = jnp.zeros(n_suff_stats)

network = insert_nodes(
network=network,
n_nodes=n_nodes,
node_type=node_type,
node_parameters=node_parameters,
value_children=value_children,
value_children=(None, None),
)

# loop over the indexes of nodes created in the previous step
for node_idx in range(network.n_nodes - 1, network.n_nodes - n_nodes - 1, -1):

if network.attributes[node_idx]["learning"] == "generalised-filtering":

# create the sufficient statistic function and store in the side parameters
if network.attributes[node_idx]["distribution"] == "normal":
sufficient_stats_fn = Normal().sufficient_statistics
elif network.attributes[node_idx]["distribution"] == "multivariate-normal":
sufficient_stats_fn = MultivariateNormal().sufficient_statistics

network.attributes[node_idx].pop("dimension")
network.attributes[node_idx].pop("distribution")
network.attributes[node_idx].pop("learning")
# create the sufficient statistic function and store in the side parameters
if network.attributes[node_idx]["distribution"] == "normal":
sufficient_stats_fn = Normal().sufficient_statistics_from_observations
elif network.attributes[node_idx]["distribution"] == "multivariate-normal":
sufficient_stats_fn = (
MultivariateNormal().sufficient_statistics_from_observations
)
else:
raise ValueError(
"The distribution should be either 'normal' or 'multivariate-normal'."
)

# add the sufficient statistics function in the side parameters
network.additional_parameters.setdefault(node_idx, {})[
"sufficient_stats_fn"
] = sufficient_stats_fn
network.attributes[node_idx].pop("dimension")
network.attributes[node_idx].pop("distribution")

# add the sufficient statistics function in the side parameters
network.additional_parameters.setdefault(node_idx, {})[
"sufficient_stats_fn"
] = sufficient_stats_fn

if "hgf" in network.attributes[node_idx]["learning"]:

# create a collection of continuous state nodes
# to track the sufficient statistics of the implied distribution
for i in range(n_suff_stats):
network.add_nodes(value_children=node_idx)
network.add_nodes(value_children=network.n_nodes - 1)
if (
"-2" in network.attributes[node_idx]["learning"]
or "-3" in network.attributes[node_idx]["learning"]
):
network.add_nodes(volatility_children=network.n_nodes - 1)
if "-3" in network.attributes[node_idx]["learning"]:
network.add_nodes(volatility_children=network.n_nodes - 1)

return network

Expand Down
3 changes: 1 addition & 2 deletions pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,12 @@ def add_nodes(
node_parameters=node_parameters,
additional_parameters=additional_parameters,
)
elif "ef-state" in kind:
elif kind == "ef-state":
self = add_ef_state(
network=self,
n_nodes=n_nodes,
node_parameters=node_parameters,
additional_parameters=additional_parameters,
value_children=value_children,
)
elif kind == "categorical-state":
self = add_categorical_state(
Expand Down
70 changes: 70 additions & 0 deletions pyhgf/updates/posterior/exponential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

from functools import partial
from typing import Dict

import jax.numpy as jnp
from jax import jit

from pyhgf.typing import Attributes, Edges


@partial(jit, static_argnames=("edges", "node_idx", "sufficient_stats_fn"))
def posterior_update_exponential_family_dynamic(
attributes: Dict, edges: Edges, node_idx: int, **args
) -> Attributes:
r"""Update the hyperparameters of an ef state node using HGF-implied learning rates.
This posterior update step is usually moved at the end of the update sequence as we
have to wait that all parent nodes tracking the expected sufficient statistics have
been updated, and therefore being able to infer the implied learning rate to update
the :math:`nu` vector. The new impled :math:`nu` is given by a ratio:
.. math::
\nu \leftarrow \frac{\delta}{\Delta}
Where :math:`delta` is the prediction error (the new sufficient statistics compared
to the expected sufficient statistic), and :math:`Delta` is the differential of
expectation (what was expected before compared to what is expected after). This
ratio quantifies how much the model is learning from new observations.
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the node
number. For each node, the index lists the value and volatility parents and
children.
node_idx :
Pointer to the value parent node that will be updated.
Returns
-------
attributes :
The updated attributes of the probabilistic nodes.
References
----------
.. [1] Mathys, C., & Weber, L. (2020). Hierarchical Gaussian Filtering of Sufficient
Statistic Time Series for Active Inference. In Active Inference (pp. 52–58).
Springer International Publishing. https://doi.org/10.1007/978-3-030-64919-7_7
"""
# prediction error - expectation differential
pe, ed = [], []
for parent_idx in edges[node_idx].value_parents or []:
pe.append(
attributes[parent_idx]["mean"] - attributes[parent_idx]["expected_mean"]
)

parent_parent_idx = edges[parent_idx].value_parents[0]
ed.append(
attributes[parent_parent_idx]["mean"]
- attributes[parent_parent_idx]["expected_mean"]
)

attributes[node_idx]["nus"] = jnp.array(pe) / jnp.array(ed)

return attributes
4 changes: 3 additions & 1 deletion pyhgf/updates/prediction/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def dirichlet_node_prediction(
if value_parent_idxs is not None:
parameters = jnp.array(
[
Normal().parameters(xis=attributes[parent_idx]["xis"])
Normal().parameters_from_sufficient_statistics(
xis=attributes[parent_idx]["xis"]
)
for parent_idx in value_parent_idxs
]
)
Expand Down
4 changes: 2 additions & 2 deletions pyhgf/updates/prediction_error/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def create_cluster(operands: Tuple, edges: Edges, node_idx: int) -> Attributes:
# initialize the new cluster using candidate values
attributes[value_parent_idx]["xis"] = jnp.where(
cluster_idx == i,
Normal().expected_sufficient_statistics(
mu=candidate_mean, sigma=candidate_sigma
Normal().sufficient_statistics_from_parameters(
mean=candidate_mean, variance=candidate_sigma**2
),
attributes[value_parent_idx]["xis"],
)
Expand Down
60 changes: 58 additions & 2 deletions pyhgf/updates/prediction_error/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@partial(jit, static_argnames=("edges", "node_idx", "sufficient_stats_fn"))
def prediction_error_update_exponential_family(
def prediction_error_update_exponential_family_fixed(
attributes: Dict, edges: Edges, node_idx: int, sufficient_stats_fn: Callable, **args
) -> Attributes:
r"""Update the parameters of an exponential family distribution.
Expand Down Expand Up @@ -49,7 +49,7 @@ def prediction_error_update_exponential_family(
Springer International Publishing. https://doi.org/10.1007/978-3-030-64919-7_7
"""
# update the hyperparameter vectors
# retrieve the expected sufficient statistics from new observations
xis = attributes[node_idx]["xis"] + (1 / (1 + attributes[node_idx]["nus"])) * (
sufficient_stats_fn(x=attributes[node_idx]["mean"])
- attributes[node_idx]["xis"]
Expand All @@ -61,3 +61,59 @@ def prediction_error_update_exponential_family(
)

return attributes


@partial(jit, static_argnames=("edges", "node_idx", "sufficient_stats_fn"))
def prediction_error_update_exponential_family_dynamic(
attributes: Dict, edges: Edges, node_idx: int, sufficient_stats_fn: Callable, **args
) -> Attributes:
r"""Pass the expected sufficient statistics to the implied continuous nodes.
When updating an exponential family state node without assuming that :math:`nu` is
fixed, the node convert the new observation into sufficient statistics and pass the
values to the implied continuous nodes. The new values for the vector :math:`nu`
are recovered in another posterior update, by observing the learning rate in the
continuous nodes, usually at the end of the sequence.
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the node
number. For each node, the index lists the value and volatility parents and
children.
node_idx :
Pointer to the value parent node that will be updated.
sufficient_stats_fn :
Compute the sufficient statistics of the probability distribution. This should
be one of the method implemented in the distribution class in
:py:class:`pyhgf.math.Normal`, for a univariate normal.
Returns
-------
attributes :
The updated attributes of the probabilistic nodes.
References
----------
.. [1] Mathys, C., & Weber, L. (2020). Hierarchical Gaussian Filtering of Sufficient
Statistic Time Series for Active Inference. In Active Inference (pp. 52–58).
Springer International Publishing. https://doi.org/10.1007/978-3-030-64919-7_7
"""
# retrieve the expected sufficient statistics from new observations
xis = sufficient_stats_fn(x=attributes[node_idx]["mean"])

for parent_idx, value in zip(
edges[node_idx].value_parents or [], xis or [], strict=True
):

# blank update in the case of unobserved value
attributes[parent_idx]["observed"] = attributes[node_idx]["observed"]

# pass the new value
attributes[parent_idx]["mean"] = value

return attributes
47 changes: 43 additions & 4 deletions pyhgf/utils/get_update_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
)
from pyhgf.updates.posterior.exponential import (
posterior_update_exponential_family_dynamic,
)
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction
Expand All @@ -20,7 +23,8 @@
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error
from pyhgf.updates.prediction_error.exponential import (
prediction_error_update_exponential_family,
prediction_error_update_exponential_family_dynamic,
prediction_error_update_exponential_family_fixed,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -75,6 +79,12 @@ def get_update_sequence(
)
]

# do not update continuous nodes that are parents of an ef state node
for i in nodes_without_posterior_update:
for child_idx in network.edges[i].value_children or []:
if network.edges[child_idx].node_type == 3:
nodes_without_posterior_update.remove(i)

# prediction updates ---------------------------------------------------------------
while True:
no_update = True
Expand Down Expand Up @@ -166,7 +176,7 @@ def get_update_sequence(
]

# if this node has no parent, no need to compute prediction errors
# unless this is an exponential family state node
# unless this is an exponential family state node with fixed learning rate
if len(all_parents) == 0:
if network.edges[idx].node_type == 3:

Expand All @@ -180,13 +190,16 @@ def get_update_sequence(
# create the sufficient statistic function
# for the exponential family node
ef_update = Partial(
prediction_error_update_exponential_family,
prediction_error_update_exponential_family_fixed,
sufficient_stats_fn=sufficient_stats_fn,
)
update_fn = ef_update
no_update = False
update_sequence.append((idx, update_fn))
nodes_without_prediction_error.remove(idx)

network.attributes[idx].pop("learning")

else:
nodes_without_prediction_error.remove(idx)
else:
Expand All @@ -199,6 +212,29 @@ def get_update_sequence(
update_fn = binary_state_node_prediction_error
elif network.edges[idx].node_type == 2:
update_fn = continuous_node_prediction_error
elif network.edges[idx].node_type == 3:
# retrieve the desired sufficient statistics function
# from the side parameter dictionary
sufficient_stats_fn = network.additional_parameters[idx][
"sufficient_stats_fn"
]
network.additional_parameters[idx].pop("sufficient_stats_fn")
# create the sufficient statistic function
# for the exponential family node
ef_update = Partial(
prediction_error_update_exponential_family_dynamic,
sufficient_stats_fn=sufficient_stats_fn,
)
update_fn = ef_update
no_update = False
update_sequence.append((idx, update_fn))

# add the posterior update here
# this will be moved at the end of the sequence later
update_sequence.append(
(idx, posterior_update_exponential_family_dynamic)
)
network.attributes[idx].pop("learning")
elif network.edges[idx].node_type == 4:
update_fn = dirichlet_node_prediction_error
elif network.edges[idx].node_type == 5:
Expand Down Expand Up @@ -232,7 +268,10 @@ def get_update_sequence(
# move all categorical steps at the end of the sequence
for step in update_sequence:
if not isinstance(step[1], Partial):
if step[1].__name__ == "categorical_state_update":
if step[1].__name__ in [
"posterior_update_exponential_family_dynamic",
"categorical_state_update",
]:
update_sequence.remove(step)
update_sequence.append(step)

Expand Down

0 comments on commit 9c377a3

Please sign in to comment.