-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add JGET HGF (volatility coupling for continuous input nodes) (#125)
* volatility coupling for input nodes * add a tutorial on input nodes * refactoring
- Loading branch information
1 parent
5b9e102
commit e0d8d35
Showing
18 changed files
with
1,547 additions
and
336 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
536 changes: 536 additions & 0 deletions
536
docs/source/notebooks/Example_2_Input_node_volatility_coupling.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
202 changes: 202 additions & 0 deletions
202
docs/source/notebooks/Example_2_Input_node_volatility_coupling.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
--- | ||
jupytext: | ||
formats: ipynb,md:myst | ||
text_representation: | ||
extension: .md | ||
format_name: myst | ||
format_version: 0.13 | ||
jupytext_version: 1.15.1 | ||
kernelspec: | ||
display_name: Python 3 (ipykernel) | ||
language: python | ||
name: python3 | ||
--- | ||
|
||
(example_1)= | ||
# Example 2: Estimating the mean and precision of an input node | ||
|
||
```{code-cell} ipython3 | ||
%%capture | ||
import sys | ||
if 'google.colab' in sys.modules: | ||
! pip install pyhgf | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
--- | ||
from pyhgf.distribution import HGFDistribution | ||
from pyhgf.model import HGF | ||
import numpy as np | ||
import pymc as pm | ||
import arviz as az | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
from scipy.stats import norm | ||
``` | ||
|
||
+++ {"editable": true, "slideshow": {"slide_type": ""}} | ||
|
||
Where the standard continuous HGF assumes a known precision in the input node (usually set to something high), this assumption can be relaxed and the filter can also try to estimate this quantity from the data. In this notebook, we demonstrate how we can infer the value of the mean, of the precision, or both value at the same time, using the appropriate value and volatility coupling parents. | ||
|
||
+++ {"editable": true, "slideshow": {"slide_type": ""}} | ||
|
||
## Unkown mean, known precision | ||
|
||
+++ {"editable": true, "slideshow": {"slide_type": ""}} | ||
|
||
```{hint} | ||
The {ref}`continuous_hgf` is an example of a model assuming a continuous input with known precision and unknown mean. It is further assumed that the mean is changing overtime, and we want the model to track this rate of change by adding a volatility node on the top of the value parent (two-level continuous HGF), and event track the rate of change of this rate of change by adding another volatility parent (three-level continuous HGF). | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
--- | ||
dist_mean, dist_std = 5, 1 | ||
input_data = np.random.normal(loc=dist_mean, scale=dist_std, size=1000) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
--- | ||
mean_hgf = ( | ||
HGF(model_type=None) | ||
.add_input_node(kind="continuous", continuous_parameters={'continuous_precision': 1}) | ||
.add_value_parent(children_idxs=[0], tonic_volatility=-8.0) | ||
.init() | ||
).input_data(input_data) | ||
mean_hgf.plot_network() | ||
``` | ||
|
||
+++ {"editable": true, "slideshow": {"slide_type": ""}} | ||
|
||
```{note} | ||
We are setting the tonic volatility to something low for visualization purposes, but changing this value can make the model learn in fewer iterations. | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
tags: [hide-input] | ||
--- | ||
# get the nodes trajectories | ||
df = mean_hgf.to_pandas() | ||
fig, ax = plt.subplots(figsize=(12, 5)) | ||
x = np.linspace(-10, 10, 1000) | ||
for i, color in zip([0, 2, 5, 10, 50, 500], plt.cm.Greys(np.linspace(.2, 1, 6))): | ||
# extract the sufficient statistics from the input node (and parents) | ||
mean = df.x_1_expected_mean.iloc[i] | ||
std = np.sqrt( | ||
1/(mean_hgf.attributes[0]["expected_precision"]) | ||
) | ||
# the model expectations | ||
ax.plot(x, norm(mean, std).pdf(x), color=color, label=i) | ||
# the sampling distribution | ||
ax.fill_between(x, norm(dist_mean, dist_std).pdf(x), color="#582766", alpha=.2) | ||
ax.legend(title="Iterations") | ||
ax.set_xlabel("Input (u)") | ||
ax.set_ylabel("Density") | ||
plt.grid(linestyle=":") | ||
sns.despine() | ||
``` | ||
|
||
+++ {"editable": true, "slideshow": {"slide_type": ""}} | ||
|
||
## Kown mean, unknown precision | ||
|
||
+++ {"editable": true, "slideshow": {"slide_type": ""}} | ||
|
||
## Unkown mean, unknown precision | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
--- | ||
dist_mean, dist_std = 5, 1 | ||
input_data = np.random.normal(loc=dist_mean, scale=dist_std, size=1000) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
--- | ||
mean_precision_hgf = ( | ||
HGF(model_type=None) | ||
.add_input_node(kind="continuous", continuous_parameters={'continuous_precision': 0.01}) | ||
.add_value_parent(children_idxs=[0], tonic_volatility=-6.0) | ||
.add_volatility_parent(children_idxs=[0], tonic_volatility=-6.0) | ||
.init() | ||
).input_data(input_data) | ||
mean_precision_hgf.plot_network() | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
tags: [hide-input] | ||
--- | ||
# get the nodes trajectories | ||
df = mean_precision_hgf.to_pandas() | ||
fig, ax = plt.subplots(figsize=(12, 5)) | ||
x = np.linspace(-10, 10, 1000) | ||
for i, color in zip(range(0, 150, 15), plt.cm.Greys(np.linspace(.2, 1, 10))): | ||
# extract the sufficient statistics from the input node (and parents) | ||
mean = df.x_1_expected_mean.iloc[i] | ||
std = np.sqrt( | ||
1/(mean_precision_hgf.attributes[0]["expected_precision"] * (1/np.exp(df.x_2_expected_mean.iloc[i]))) | ||
) | ||
# the model expectations | ||
ax.plot(x, norm(mean, std).pdf(x), color=color, label=i) | ||
# the sampling distribution | ||
ax.fill_between(x, norm(dist_mean, dist_std).pdf(x), color="#582766", alpha=.2) | ||
ax.legend(title="Iterations") | ||
ax.set_xlabel("Input (u)") | ||
ax.set_ylabel("Density") | ||
plt.grid(linestyle=":") | ||
sns.despine() | ||
``` | ||
|
||
+++ {"editable": true, "slideshow": {"slide_type": ""}} | ||
|
||
## System configuration | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
--- | ||
%load_ext watermark | ||
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.