Skip to content

Commit

Permalink
docstring and refactor update of mean volatility parent
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Nov 7, 2023
1 parent 09f3740 commit 55653fc
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 762 deletions.
847 changes: 124 additions & 723 deletions docs/source/notebooks/1.3-Continuous_HGF.ipynb

Large diffs are not rendered by default.

14 changes: 4 additions & 10 deletions docs/source/notebooks/1.3-Continuous_HGF.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,6 @@ $$surprise = -log(p)$$
#### Plot correlation
Node parameters that are highly correlated across time are likely to indicate that the model did not learn hierarchical structure in the data but instead overfitted on some components. One way to quickly check the parameters nodes correlation is to use the `plot_correlation` function embedded in the HGF class.

```{code-cell} ipython3
two_levels_continuous_hgf.to_pandas()
```

```{code-cell} ipython3
two_levels_continuous_hgf.plot_correlations();
```
Expand Down Expand Up @@ -193,7 +189,7 @@ three_levels_continuous_hgf_bis.plot_trajectories();
three_levels_continuous_hgf_bis.surprise()
```

Now we are getting a global surprise of `-1011` with the new model, as compared to a global surprise of `-381` before. It looks like the $\omega$ value at the second level can play an important role in minimizing surprise for this kind of time series. But how can we decide on which value to choose? Doing this by trial and error would be a bit tedious. Instead, we can use dedicated Bayesian methods that will infer the values of $\omega$ that minimize the surprise (i.e. that maximize the likelihood of the new observations given parameter priors).
Now we are getting a global surprise of `-828` with the new model, as compared to a global surprise of `-910` before. It looks like the $\omega$ value at the second level can play an important role in minimizing surprise for this kind of time series. But how can we decide on which value to choose? Doing this by trial and error would be a bit tedious. Instead, we can use dedicated Bayesian methods that will infer the values of $\omega$ that minimize the surprise (i.e. that maximize the likelihood of the new observations given parameter priors).

+++

Expand All @@ -202,9 +198,7 @@ In the previous section, we assumed we knew the parameters of the HGF models bef

Because the HGF classes are built on the top of [JAX](https://github.com/google/jax), they are natively differentiable and compatible with optimisation libraries. Here, we use [PyMC](https://www.pymc.io/welcome.html) to perform MCMC sampling. PyMC can use any log probability function (here the negative surprise of the model) as a building block for a new distribution by wrapping it in its underlying tensor library [Aesara](https://aesara.readthedocs.io/en/latest/), now [PyTensor](https://pytensor.readthedocs.io/en/latest/). pyhgf includes a PyMC-compatible distribution that can do this automatically{py:class}`pyhgf.distribution.HGFDistribution`.

```{code-cell} ipython3
```
+++

### Two-level model
#### Creating the model
Expand Down Expand Up @@ -287,7 +281,7 @@ hgf_mcmc = HGF(
```

```{code-cell} ipython3
hgf_mcmc.plot_trajectories()
hgf_mcmc.plot_trajectories();
```

```{code-cell} ipython3
Expand Down Expand Up @@ -361,7 +355,7 @@ editable: true
slideshow:
slide_type: ''
---
hgf_mcmc.plot_trajectories(ci=True);
hgf_mcmc.plot_trajectories();
```

```{code-cell} ipython3
Expand Down
3 changes: 2 additions & 1 deletion src/pyhgf/updates/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def continuous_node_prediction(
"""
# Get the new expected mean
expected_mean = predict_mean(attributes, edges, time_step, node_idx)
# Get the new expected precision

# Get the new expected precision and predicted volatility (Ω)
expected_precision, predicted_volatility = predict_precision(
attributes, edges, time_step, node_idx
)
Expand Down
83 changes: 55 additions & 28 deletions src/pyhgf/updates/prediction_error/nodes/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,36 @@ def prediction_error_mean_volatility_parent(
volatility_parent_idx: int,
precision_volatility_parent: ArrayLike,
) -> Array:
"""Send prediction-error and update the mean of the volatility parent.
r"""Update the mean of the volatility parent.
The new mean of the volatility parent :math:`a` of a state node at time :math:`k`
is given by:
.. math::
\mu_a^{(k)} = \hat{\mu}_a^{(k)} + \frac{1}{2\pi_a} \\
\sum_{j=1}^{N_{children}} \kappa_j^2 \gamma_j^{(k)} \Delta_j^{(k)}
where :math:`\kappa_j` is the volatility coupling strength between the volatility
parent and the volatility children :math:`j` and :math:`\Delta_j^{(k)}` is the
volatility prediction error given by:
.. math::
\Delta_j^{(k)} = \frac{\hat{\pi}_j^{(k)}}{\pi_j^{(k)}} + \\
\hat{\pi}_j^{(k)} \left( \delta_j^{(k)} \right)^2 - 1
with :math:`\delta_j^{(k)}` the value prediction error
:math:`\delta_j^{(k)} = \mu_j^{k} - \hat{\mu}_j^{k}`.
:math:`\gamma_j^{(k)}` is the volatility-weighted precision of the prediction,
given by:
.. math::
\gamma_j^{(k)} = \Omega_j^{(k)} \hat{\pi}_j^{(k)}
with :math:`\Omega_j^{(k)}` the predicted volatility computed in the prediction
step (:func:`pyhgf.updates.prediction.predict_precision`).
Parameters
----------
Expand Down Expand Up @@ -297,44 +326,42 @@ def prediction_error_mean_volatility_parent(
expected_mean_volatility_parent = attributes[volatility_parent_idx]["expected_mean"]

# Gather volatility prediction errors from the child nodes
children_volatility_prediction_error = 0.0
children_volatility_precision = 0.0
for child_idx, volatility_coupling in zip(
edges[volatility_parent_idx].volatility_children, # type: ignore
attributes[volatility_parent_idx]["volatility_coupling_children"],
):
# Look at the (optional) volatility parents and update logvol accordingly
logvol = attributes[child_idx]["tonic_volatility"]
if edges[child_idx].volatility_parents is not None:
for children_volatility_parents, volatility_coupling in zip(
edges[child_idx].volatility_parents,
attributes[child_idx]["volatility_coupling_parents"],
):
logvol += (
volatility_coupling
* attributes[children_volatility_parents]["mean"]
)

# Compute new value for nu
nu_children = time_step * jnp.exp(logvol)
nu_children = jnp.where(nu_children > 1e-128, nu_children, jnp.nan)
# retrieve the predicted volatility (Ω) computed in the prediction step
predicted_volatility = attributes[child_idx]["temp"]["predicted_volatility"]

# compute the volatility weigthed precision (γ)
volatility_weigthed_precision = (
predicted_volatility * attributes[child_idx]["expected_precision"]
)

# compute the volatility prediction error (VOPE)
vope_children = (
1 / attributes[child_idx]["precision"]
+ (attributes[child_idx]["mean"] - attributes[child_idx]["expected_mean"])
(
attributes[child_idx]["expected_precision"]
/ attributes[child_idx]["precision"]
)
+ attributes[child_idx]["expected_precision"]
* (attributes[child_idx]["mean"] - attributes[child_idx]["expected_mean"])
** 2
) * attributes[child_idx]["expected_precision"] - 1
children_volatility_prediction_error += (
0.5
* volatility_coupling
* nu_children
* attributes[child_idx]["expected_precision"]
/ precision_volatility_parent
* vope_children
- 1
)

# sum over all volatility children
children_volatility_precision += (
volatility_weigthed_precision * volatility_coupling * vope_children
)

# weight using the precision of the volatility parent
children_volatility_precision *= 1 / (2 * precision_volatility_parent)

# Estimate the new mean of the volatility parent
mean_volatility_parent = (
expected_mean_volatility_parent + children_volatility_prediction_error
expected_mean_volatility_parent + children_volatility_precision
)

return mean_volatility_parent

0 comments on commit 55653fc

Please sign in to comment.