Skip to content

Commit

Permalink
Add JGET HGF (volatility coupling for continuous input nodes) (#125)
Browse files Browse the repository at this point in the history
* volatility coupling for input nodes

* add a tutorial on input nodes

* refactoring
  • Loading branch information
LegrandNico authored Nov 2, 2023
1 parent 5b9e102 commit e0d8d35
Show file tree
Hide file tree
Showing 18 changed files with 1,547 additions and 336 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/ambv/black
rev: 23.9.1
rev: 23.10.1
hooks:
- id: black
language_version: python3
Expand All @@ -22,10 +22,10 @@ repos:
hooks:
- id: pydocstyle
args: ['--ignore', 'D213,D100,D203,D104']
files: ^pyhgf/
files: ^src/
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.6.0'
rev: 'v1.6.1'
hooks:
- id: mypy
files: ^pyhgf/
files: ^src/
args: [--ignore-missing-imports]
36 changes: 28 additions & 8 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,31 +62,51 @@ Propagate prediction errors to the value and volatility parents of a given node.
Binary nodes
~~~~~~~~~~~~

.. currentmodule:: pyhgf.updates.prediction_error.binary
.. currentmodule:: pyhgf.updates.prediction_error.inputs.binary

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.binary
:toctree: generated/pyhgf.updates.prediction_error.inputs.binary

prediction_error_mean_value_parent
prediction_error_precision_value_parent
prediction_error_value_parent
prediction_error_input_value_parent
input_surprise_inf
input_surprise_reg

.. currentmodule:: pyhgf.updates.prediction_error.nodes.binary

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.nodes.binary

prediction_error_mean_value_parent
prediction_error_precision_value_parent
prediction_error_value_parent

Continuous nodes
~~~~~~~~~~~~~~~~

.. currentmodule:: pyhgf.updates.prediction_error.continuous
Updating continuous input nodes.

.. currentmodule:: pyhgf.updates.prediction_error.inputs.continuous

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.continuous
:toctree: generated/pyhgf.updates.prediction_error.inputs.continuous

prediction_error_input_precision_value_parent
prediction_error_input_precision_volatility_parent
prediction_error_input_mean_volatility_parent
prediction_error_input_mean_value_parent


Updating continuous state nodes.

.. currentmodule:: pyhgf.updates.prediction_error.nodes.continuous

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.nodes.continuous

prediction_error_mean_value_parent
prediction_error_precision_value_parent
prediction_error_precision_volatility_parent
prediction_error_mean_volatility_parent
prediction_error_input_mean_value_parent

Prediction steps
================
Expand Down
179 changes: 111 additions & 68 deletions docs/source/notebooks/1.2-Categorical_HGF.ipynb

Large diffs are not rendered by default.

536 changes: 536 additions & 0 deletions docs/source/notebooks/Example_2_Input_node_volatility_coupling.ipynb

Large diffs are not rendered by default.

202 changes: 202 additions & 0 deletions docs/source/notebooks/Example_2_Input_node_volatility_coupling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

(example_1)=
# Example 2: Estimating the mean and precision of an input node

```{code-cell} ipython3
%%capture
import sys
if 'google.colab' in sys.modules:
! pip install pyhgf
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
from pyhgf.distribution import HGFDistribution
from pyhgf.model import HGF
import numpy as np
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import norm
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}

Where the standard continuous HGF assumes a known precision in the input node (usually set to something high), this assumption can be relaxed and the filter can also try to estimate this quantity from the data. In this notebook, we demonstrate how we can infer the value of the mean, of the precision, or both value at the same time, using the appropriate value and volatility coupling parents.

+++ {"editable": true, "slideshow": {"slide_type": ""}}

## Unkown mean, known precision

+++ {"editable": true, "slideshow": {"slide_type": ""}}

```{hint}
The {ref}`continuous_hgf` is an example of a model assuming a continuous input with known precision and unknown mean. It is further assumed that the mean is changing overtime, and we want the model to track this rate of change by adding a volatility node on the top of the value parent (two-level continuous HGF), and event track the rate of change of this rate of change by adding another volatility parent (three-level continuous HGF).
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
dist_mean, dist_std = 5, 1
input_data = np.random.normal(loc=dist_mean, scale=dist_std, size=1000)
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
mean_hgf = (
HGF(model_type=None)
.add_input_node(kind="continuous", continuous_parameters={'continuous_precision': 1})
.add_value_parent(children_idxs=[0], tonic_volatility=-8.0)
.init()
).input_data(input_data)
mean_hgf.plot_network()
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}

```{note}
We are setting the tonic volatility to something low for visualization purposes, but changing this value can make the model learn in fewer iterations.
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
tags: [hide-input]
---
# get the nodes trajectories
df = mean_hgf.to_pandas()
fig, ax = plt.subplots(figsize=(12, 5))
x = np.linspace(-10, 10, 1000)
for i, color in zip([0, 2, 5, 10, 50, 500], plt.cm.Greys(np.linspace(.2, 1, 6))):
# extract the sufficient statistics from the input node (and parents)
mean = df.x_1_expected_mean.iloc[i]
std = np.sqrt(
1/(mean_hgf.attributes[0]["expected_precision"])
)
# the model expectations
ax.plot(x, norm(mean, std).pdf(x), color=color, label=i)
# the sampling distribution
ax.fill_between(x, norm(dist_mean, dist_std).pdf(x), color="#582766", alpha=.2)
ax.legend(title="Iterations")
ax.set_xlabel("Input (u)")
ax.set_ylabel("Density")
plt.grid(linestyle=":")
sns.despine()
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}

## Kown mean, unknown precision

+++ {"editable": true, "slideshow": {"slide_type": ""}}

## Unkown mean, unknown precision

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
dist_mean, dist_std = 5, 1
input_data = np.random.normal(loc=dist_mean, scale=dist_std, size=1000)
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
mean_precision_hgf = (
HGF(model_type=None)
.add_input_node(kind="continuous", continuous_parameters={'continuous_precision': 0.01})
.add_value_parent(children_idxs=[0], tonic_volatility=-6.0)
.add_volatility_parent(children_idxs=[0], tonic_volatility=-6.0)
.init()
).input_data(input_data)
mean_precision_hgf.plot_network()
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
tags: [hide-input]
---
# get the nodes trajectories
df = mean_precision_hgf.to_pandas()
fig, ax = plt.subplots(figsize=(12, 5))
x = np.linspace(-10, 10, 1000)
for i, color in zip(range(0, 150, 15), plt.cm.Greys(np.linspace(.2, 1, 10))):
# extract the sufficient statistics from the input node (and parents)
mean = df.x_1_expected_mean.iloc[i]
std = np.sqrt(
1/(mean_precision_hgf.attributes[0]["expected_precision"] * (1/np.exp(df.x_2_expected_mean.iloc[i])))
)
# the model expectations
ax.plot(x, norm(mean, std).pdf(x), color=color, label=i)
# the sampling distribution
ax.fill_between(x, norm(dist_mean, dist_std).pdf(x), color="#582766", alpha=.2)
ax.legend(title="Iterations")
ax.set_xlabel("Input (u)")
ax.set_ylabel("Density")
plt.grid(linestyle=":")
sns.despine()
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}

## System configuration

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
%load_ext watermark
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib
```
1 change: 1 addition & 0 deletions docs/source/tutorials.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ glob:
| Notebook | Colab |
| --- | ---|
| {ref}`example_1` | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ilabcode/pyhgf/blob/master/docs/source/notebooks/Example_1_Heart_rate_variability.ipynb)
| {ref}`example_2` | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ilabcode/pyhgf/blob/master/docs/source/notebooks/Example_2_Input_node_volatility_coupling.ipynb)

## Exercises

Expand Down
2 changes: 1 addition & 1 deletion src/pyhgf/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def hgf_logp(
volatility_coupling_1: Union[np.ndarray, ArrayLike, float] = 1.0,
volatility_coupling_2: Union[np.ndarray, ArrayLike, float] = 1.0,
input_data: List[np.ndarray] = [np.nan],
response_function: Callable = None,
response_function: Optional[Callable] = None,
model_type: str = "continuous",
n_levels: int = 2,
response_function_parameters: List[Tuple] = [()],
Expand Down
Loading

0 comments on commit e0d8d35

Please sign in to comment.