Skip to content

Commit

Permalink
Merge branch 'ComputationalPsychiatry:master' into predict
Browse files Browse the repository at this point in the history
  • Loading branch information
SylvainEstebe authored Jan 16, 2025
2 parents 312067e + c03a551 commit 70e3370
Show file tree
Hide file tree
Showing 43 changed files with 4,718 additions and 3,880 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ jobs:
# Step 8: Run Tests and Generate Coverage Report
- name: Run tests and coverage
run: |
poetry run pytest ./tests/ --cov=./src/pyhgf/ --cov-report=xml
poetry run pytest ./tests/ --cov=./pyhgf/ --cov-report=xml --cov-branch
# Step 9: Upload Coverage Report to Codecov
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }} # Make sure to set this token in repository secrets
token: ${{ secrets.CODECOV_TOKEN }}
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ docs/source/generated/*
docs/source/auto_examples/*
htmlcov/
.coverage
.coverage.xml
coverage.xml
pyhgf.egg-info
build
coverage.xm
Expand Down
13 changes: 12 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ Continuous nodes
continuous_node_posterior_update
continuous_node_posterior_update_ehgf

Exponential family
------------------

.. currentmodule:: pyhgf.updates.posterior.exponential

.. autosummary::
:toctree: generated/pyhgf.updates.posterior.exponential

posterior_update_exponential_family_dynamic

Prediction steps
================

Expand Down Expand Up @@ -144,7 +154,8 @@ Exponential family
.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.exponential

prediction_error_update_exponential_family
prediction_error_update_exponential_family_fixed
prediction_error_update_exponential_family_dynamic

Distribution
************
Expand Down
Binary file removed docs/source/images/multivariate_hgf.gif
Binary file not shown.
Binary file added docs/source/images/multivariate_normal.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/source/learn.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ How to create and manipulate a network of probabilistic nodes for reinforcement
:::{grid-item-card} Generalised Bayesian filtering
:link: generalised_filtering
:link-type: ref
:img-top: ./images/multivariate_hgf.gif
:img-top: ./images/multivariate_normal.gif


Predict, filter and smooth any distribution from the exponential family using generalisations of the Hierarchical Gaussian Filter.
Expand Down
96 changes: 48 additions & 48 deletions docs/source/notebooks/0.1-Theory.ipynb

Large diffs are not rendered by default.

247 changes: 220 additions & 27 deletions docs/source/notebooks/0.2-Creating_networks.ipynb

Large diffs are not rendered by default.

1,021 changes: 661 additions & 360 deletions docs/source/notebooks/0.3-Generalised_filtering.ipynb

Large diffs are not rendered by default.

619 changes: 277 additions & 342 deletions docs/source/notebooks/1.1-Binary_HGF.ipynb

Large diffs are not rendered by default.

151 changes: 76 additions & 75 deletions docs/source/notebooks/1.2-Categorical_HGF.ipynb

Large diffs are not rendered by default.

613 changes: 274 additions & 339 deletions docs/source/notebooks/1.3-Continuous_HGF.ipynb

Large diffs are not rendered by default.

312 changes: 140 additions & 172 deletions docs/source/notebooks/2-Using_custom_response_functions.ipynb

Large diffs are not rendered by default.

474 changes: 221 additions & 253 deletions docs/source/notebooks/3-Multilevel_HGF.ipynb

Large diffs are not rendered by default.

222 changes: 95 additions & 127 deletions docs/source/notebooks/4-Parameter_recovery.ipynb

Large diffs are not rendered by default.

292 changes: 160 additions & 132 deletions docs/source/notebooks/5-Non_linear_value_coupling.ipynb

Large diffs are not rendered by default.

261 changes: 115 additions & 146 deletions docs/source/notebooks/Example_1_Heart_rate_variability.ipynb

Large diffs are not rendered by default.

103 changes: 52 additions & 51 deletions docs/source/notebooks/Example_2_Input_node_volatility_coupling.ipynb

Large diffs are not rendered by default.

531 changes: 267 additions & 264 deletions docs/source/notebooks/Example_3_Multi_armed_bandit.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

986 changes: 602 additions & 384 deletions docs/source/notebooks/Exercise_2_Bayesian_reinforcement_learning.ipynb

Large diffs are not rendered by default.

1,917 changes: 944 additions & 973 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyhgf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import pandas as pd

__version__ = "0.2.1"
__version__ = "0.2.2"


def load_data(dataset: str) -> Union[Tuple[np.ndarray, ...], np.ndarray]:
Expand Down
5 changes: 3 additions & 2 deletions pyhgf/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ def logp(
to compute the surprise. This can include values over which inference is
performed in a PyMC model (e.g. the inverse temperature of a binary softmax).
input_data :
An array of input time series where the first dimension is the number of models
to fit in parallel.
An array of input time series. The first dimension is the number of time steps
and the second dimension is the number of features. The number of features is
the number of input nodes time the input dimensions.
time_steps :
An array of input time steps where the first dimension is the number of models
to fit in parallel.
Expand Down
87 changes: 77 additions & 10 deletions pyhgf/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,67 @@ class MultivariateNormal:
"""

@staticmethod
def sufficient_statistics(x: ArrayLike) -> Array:
"""Compute the sufficient statistics for the multivariate normal."""
def sufficient_statistics_from_observations(x: ArrayLike) -> Array:
"""Compute the expected sufficient statistics from a single observation."""
return jnp.hstack([x, jnp.outer(x, x)[jnp.tril_indices(x.shape[0])]])

@staticmethod
def sufficient_statistics_from_parameters(
mean: ArrayLike, covariance: ArrayLike
) -> Array:
"""Compute the expected sufficient statistics from distribution parameter.
Parameters
----------
mean :
Mean of the Gaussian distribution.
covariance :
Variance of the Gaussian distribution.
Returns
-------
xis :
The sufficient statistics.
"""
return jnp.append(
mean,
(covariance + jnp.outer(mean, mean))[jnp.tril_indices(covariance.shape[0])],
)

@staticmethod
def base_measure(k: int) -> float:
"""Compute the base measures for the multivariate normal."""
return (2 * jnp.pi) ** (-k / 2)

@staticmethod
def parameters_from_sufficient_statistics(
xis: ArrayLike, dimension: int
) -> Tuple[Array, Array]:
"""Compute the distribution parameters from the sufficient statistics.
Parameters
----------
xis :
The sufficient statistics.
dimension :
The dimension of the multivariate normal distribution.
Returns
-------
means, covariance :
The parameters of the distribution (mean and covariance).
"""
mean = xis[:dimension]
covariance = jnp.zeros((dimension, dimension))
covariance = covariance.at[jnp.tril_indices(dimension)].set(
xis[dimension:] - jnp.outer(mean, mean)[jnp.tril_indices(dimension)]
)
covariance += covariance.T - jnp.diag(covariance.diagonal())

return mean, covariance


class Normal:
"""The univariate normal as an exponential family distribution.
Expand All @@ -38,28 +90,42 @@ class Normal:
"""

@staticmethod
def sufficient_statistics(x: float) -> Array:
"""Sufficient statistics for the univariate normal."""
def sufficient_statistics_from_observations(x: float) -> Array:
"""Compute the expected sufficient statistics from a single observation."""
return jnp.array([x, x**2])

@staticmethod
def expected_sufficient_statistics(mu: float, sigma) -> Array:
"""Compute expected sufficient statistics from the mean and std."""
return jnp.array([mu, mu**2 + sigma**2])
def sufficient_statistics_from_parameters(mean: float, variance: float) -> Array:
"""Compute the expected sufficient statistics from distribution parameter.
Parameters
----------
mean :
Mean of the Gaussian distribution.
variance :
Variance of the Gaussian distribution.
Returns
-------
xis :
The sufficient statistics.
"""
return jnp.array([mean, mean**2 + variance])

@staticmethod
def base_measure() -> float:
"""Compute the base measure of the univariate normal."""
return 1 / (jnp.sqrt(2 * jnp.pi))

@staticmethod
def parameters(xis: ArrayLike) -> Tuple[float, float]:
"""Get parameters from the expected sufficient statistics.
def parameters_from_sufficient_statistics(xis: ArrayLike) -> Tuple[float, float]:
"""Compute the distribution parameters from the sufficient statistics.
Parameters
----------
xis :
The expected sufficient statistics.
The sufficient statistics.
Returns
-------
Expand All @@ -69,6 +135,7 @@ def parameters(xis: ArrayLike) -> Tuple[float, float]:
"""
mean = xis[0]
variance = xis[1] - (mean**2)

return mean, variance


Expand Down
71 changes: 51 additions & 20 deletions pyhgf/model/add_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def add_binary_state(
def add_ef_state(
network: Network,
n_nodes: int,
node_parameters: Dict,
additional_parameters: Dict,
value_children: Tuple = (None, None),
node_parameters: dict,
additional_parameters: dict,
value_children: tuple = (None, None),
):
"""Add exponential family state node(s) to a network."""
node_type = 3
Expand All @@ -128,14 +128,27 @@ def add_ef_state(
"learning": "generalised-filtering",
"nus": 3.0,
"xis": jnp.array([0.0, 1.0]),
"mean": 0.0,
"observed": 1,
}

node_parameters = update_parameters(
node_parameters, default_parameters, additional_parameters
)

# the size of the sufficient statistics vector of a multivariate normal
# distribution is given by d + d(d+1) / 2, where d is the dimension
d = node_parameters["dimension"]
n_suff_stats = d + d * (d + 1) // 2
node_parameters["mean"] = jnp.zeros(d) if d > 1 else 0.0
node_parameters["observation_ss"] = jnp.zeros(n_suff_stats)
if node_parameters["distribution"] == "normal":
node_parameters["xis"] = jnp.array([0.0, 1.0])
elif node_parameters["distribution"] == "multivariate-normal":
node_parameters["xis"] = (
MultivariateNormal.sufficient_statistics_from_parameters(
mean=jnp.zeros(d), covariance=jnp.identity(d)
)
)
network = insert_nodes(
network=network,
n_nodes=n_nodes,
Expand All @@ -147,22 +160,40 @@ def add_ef_state(
# loop over the indexes of nodes created in the previous step
for node_idx in range(network.n_nodes - 1, network.n_nodes - n_nodes - 1, -1):

if network.attributes[node_idx]["learning"] == "generalised-filtering":

# create the sufficient statistic function and store in the side parameters
if network.attributes[node_idx]["distribution"] == "normal":
sufficient_stats_fn = Normal().sufficient_statistics
elif network.attributes[node_idx]["distribution"] == "multivariate-normal":
sufficient_stats_fn = MultivariateNormal().sufficient_statistics

network.attributes[node_idx].pop("dimension")
network.attributes[node_idx].pop("distribution")
network.attributes[node_idx].pop("learning")
# create the sufficient statistic function and store in the side parameters
if network.attributes[node_idx]["distribution"] == "normal":
sufficient_stats_fn = Normal().sufficient_statistics_from_observations
elif network.attributes[node_idx]["distribution"] == "multivariate-normal":
sufficient_stats_fn = (
MultivariateNormal().sufficient_statistics_from_observations
)
else:
raise ValueError(
"The distribution should be either 'normal' or 'multivariate-normal'."
)

# add the sufficient statistics function in the side parameters
network.additional_parameters.setdefault(node_idx, {})[
"sufficient_stats_fn"
] = sufficient_stats_fn
# add the sufficient statistics function in the side parameters
network.additional_parameters.setdefault(node_idx, {})[
"sufficient_stats_fn"
] = sufficient_stats_fn

if "hgf" in network.attributes[node_idx]["learning"]:

# create a collection of continuous state nodes
# to track the sufficient statistics of the implied distribution
for i in range(n_suff_stats):
network.add_nodes(value_children=node_idx)
network.add_nodes(value_children=network.n_nodes - 1)
if (
"-2" in network.attributes[node_idx]["learning"]
or "-3" in network.attributes[node_idx]["learning"]
):
network.add_nodes(volatility_children=network.n_nodes - 1)
if "-3" in network.attributes[node_idx]["learning"]:
network.add_nodes(volatility_children=network.n_nodes - 1)

network.attributes[node_idx].pop("distribution")
network.attributes[node_idx].pop("learning")

return network

Expand Down Expand Up @@ -197,7 +228,7 @@ def add_categorical_state(
"surprise": 0.0,
"kl_divergence": 0.0,
"alpha": jnp.ones(n_categories),
"observed": jnp.ones(n_categories, dtype=int),
"observed": 1,
"mean": jnp.array([1.0 / n_categories] * n_categories),
"binary_parameters": binary_parameters,
}
Expand Down
4 changes: 4 additions & 0 deletions pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def create_belief_propagation_fn(
belief propagation function bound to the ``scan_fn`` attribute.
"""
# get the dimension of the input nodes
if not self.input_dim:
self.get_input_dimension()

# create the update sequence if it does not already exist
if self.update_sequence is None:
self.update_sequence = get_update_sequence(
Expand Down
8 changes: 1 addition & 7 deletions pyhgf/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,7 @@ def plot_network(network: "Network") -> "Source":

style = "filled" if idx in network.input_idxs else ""

if network.edges[idx].node_type == 0:
# binary state node
graphviz_structure.node(
f"x_{idx}", label=str(idx), shape="ellipse", style=style
)

elif network.edges[idx].node_type == 1:
if network.edges[idx].node_type == 1:
# binary state node
graphviz_structure.node(
f"x_{idx}", label=str(idx), shape="square", style=style
Expand Down
Loading

0 comments on commit 70e3370

Please sign in to comment.