-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support for categorical state nodes (#83)
* update tutorials * - add a fill_categorical_state_nodes function - method to add a categorical state node * updated plotting functions for categorical nodes * move the get_update_sequence function in the network's submodule. * clever update sequence * do not automatically create an update sequence when providing inputs if there is none * proper order for categorical input updates * surprise in the categorical node * move the to_pandas method in the network submodule * performance issues and correct handling of surprise in to_pandas * create an init method to cache the update function before providing data * error in the get updates function * fix error in expectation in the categorical node * fix the order of update/predictions at the network level * update pre-commit * add Kullback_Leibler divergence * docs
- Loading branch information
1 parent
ba2b36c
commit c8c7dbd
Showing
38 changed files
with
7,529 additions
and
5,078 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
--- | ||
jupytext: | ||
formats: ipynb,md:myst | ||
text_representation: | ||
extension: .md | ||
format_name: myst | ||
format_version: 0.13 | ||
jupytext_version: 1.14.7 | ||
kernelspec: | ||
display_name: Python 3 (ipykernel) | ||
language: python | ||
name: python3 | ||
--- | ||
|
||
(categorical_hgf)= | ||
# The categorical Hierarchical Gaussian Filter | ||
```{warning} | ||
The categorical state node and the categorical state-transition nodes are still work in progress. The examples provided here are given for illustration. Things may change or not work until the official publication. | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
:tags: [hide-cell] | ||
%%capture | ||
import sys | ||
if 'google.colab' in sys.modules: | ||
! pip install pyhgf | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
from pyhgf.model import HGF | ||
import numpy as np | ||
from pyhgf import load_data | ||
import jax.numpy as jnp | ||
from jax import jit, grad, vjp | ||
from jax.tree_util import Partial | ||
import matplotlib.pyplot as plt | ||
from pyhgf.plots import plot_nodes | ||
import seaborn as sns | ||
import pytensor | ||
import pytensor.tensor as pt | ||
from pytensor.graph import Apply, Op | ||
``` | ||
|
||
The binary state nodes that we introduced in the previous section are useful to encode information about stochastic boolean variables that are common in reinforcement learning and decision-making design. However, situations may occur where discrete variables can have more than two categories and therefore need to be encoded by a categorical distribution. Here, we introduce two probabilistic nodes tailored to handle this kind of variable: the **categorical state node** and the **categorical state-transition node**. | ||
|
||
Both nodes are a generalisation of the binary HGFs (in the sense that they internally represent a collection of binary state nodes). We refer to **categorical HGF** in a broad sense for HGFs that can handle categorical distributions, but as we will illustrate below, there are many ways to do that and a more precise terminology is to refer to the kind of node used internally (the **categorical state node** and the **categorical state-transition node**). | ||
|
||
+++ | ||
|
||
## Simulating a dataset | ||
We start by simulating a dataset on which we can apply the categorical HGFs. The dataset consists of a categorical input where the number of categories $K=10$. The underlying contingencies are generated by three Dirichlet distributions on which we sample 150 observations sequentially. | ||
|
||
```{code-cell} ipython3 | ||
# generate some categorical inputs data using three underlying distributions | ||
p1 = np.random.dirichlet(alpha=[1, 2, 3, 5, 9, 13, 17, 25, 30, 35]) | ||
p2 = np.random.dirichlet(alpha=[1, 2, 3, 5, 30, 30, 5, 3, 2, 1]) | ||
p3 = np.random.dirichlet(alpha=[35, 30, 25, 17, 13, 9, 5, 3, 2, 1]) | ||
input_data = np.array( | ||
[np.random.multinomial(n=1, pvals=p) for p in [p1, p2, p3] for _ in range(250)] | ||
).T | ||
``` | ||
|
||
The Dirichlet distributions are parametrized in such a way that it goes from a "skewed" distribution to a centred distribution to another "skewed" distribution. The resulting sequence of categorical observations then looks like this: | ||
|
||
```{code-cell} ipython3 | ||
plt.figure(figsize=(12, 3)) | ||
plt.imshow(input_data, interpolation="none", aspect="auto", cmap="binary") | ||
plt.ylabel("Categories") | ||
plt.xlabel("Time") | ||
plt.title("Categorical observations"); | ||
``` | ||
|
||
```{note} | ||
The lower part of the figure represent the surprise associated with the categorical node. Here, we use the [Kullback-Leibler divergence between two Dirichlet distributions](https://statproofbook.github.io/P/dir-kl.html) as a measure of Bayesian surprise. The Kullback-Leibler divergence of the Dirichlet distribution $P$ from the Dirichlet distribution $Q$ is given by the following equation: | ||
$$ | ||
KL[P||Q] = \ln{\frac{\Gamma(\sum_{i=1}^k\alpha_{1i})}{\Gamma(\sum_{i=1}^k\alpha_{2i})}} + \sum_{i=1}^k \ln{\frac{\Gamma(\alpha_{2i})}{\Gamma(\alpha_{1i})}} + \sum_{i=1}^k(\alpha_{1i} - \alpha_{2i}) \left[ \psi(\alpha_{1i}) - \psi(\sum_{i=1}^k \alpha_{1i}) \right] | ||
$$ | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
# adding a blank input time series for the categorical state node | ||
# this is because the categorical state node does not receive anything | ||
# only binary nodes are the actual inputs of the network | ||
input_data = np.vstack([[0.0] * input_data.shape[1], input_data]) | ||
``` | ||
|
||
## The categorical state node | ||
|
||
+++ | ||
|
||
### Creating the probabilistic network | ||
|
||
```{code-cell} ipython3 | ||
categorical_hgf = ( | ||
HGF(model_type=None, verbose=False) | ||
.add_input_node( | ||
kind="categorical", | ||
categorical_parameters={"n_categories": 10}, | ||
binary_parameters={"omega_2": -2.0} | ||
) | ||
.init() | ||
) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
categorical_hgf.plot_network() | ||
``` | ||
|
||
### Fitting the model forwards | ||
|
||
```{code-cell} ipython3 | ||
categorical_hgf.input_data(input_data=input_data.T); | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
tags: [hide-cell] | ||
--- | ||
fig, axs = plt.subplots(nrows=5, figsize=(12, 9), sharex=True) | ||
plot_nodes(categorical_hgf, node_idxs=31, axs=axs[0]) | ||
axs[1].imshow(categorical_hgf.node_trajectories[0]["mu"].T, interpolation="none", aspect="auto") | ||
axs[1].set_title("Mean of the implied Dirichlet distribution", loc="left") | ||
axs[1].set_ylabel("Categories") | ||
# observations | ||
axs[2].imshow(input_data, interpolation="none", aspect="auto", cmap="binary") | ||
axs[2].set_title("Observations", loc="left") | ||
axs[2].set_ylabel("Categories") | ||
# surprise | ||
axs[3].plot( | ||
np.arange(2, len(categorical_hgf.node_trajectories[0]["kl_divergence"])), | ||
categorical_hgf.node_trajectories[0]["kl_divergence"][2:], | ||
color="#2a2a2a", | ||
linewidth=0.5, | ||
zorder=-1, | ||
label="Surprise", | ||
) | ||
axs[3].fill_between( | ||
x=np.arange(2, len(categorical_hgf.node_trajectories[0]["kl_divergence"])), | ||
y1=categorical_hgf.node_trajectories[0]["kl_divergence"][2:], | ||
y2=categorical_hgf.node_trajectories[0]["kl_divergence"][2:].min(), | ||
color="#7f7f7f", | ||
alpha=0.1, | ||
zorder=-1, | ||
) | ||
axs[3].set_title("Kullback-Leibler divergences", loc="left") | ||
axs[2].set_ylabel("Surprises") | ||
axs[4].plot( | ||
np.arange(len(categorical_hgf.node_trajectories[0]["binary_surprise"])), | ||
categorical_hgf.node_trajectories[0]["binary_surprise"], | ||
color="#2a2a2a", | ||
linewidth=0.5, | ||
zorder=-1, | ||
label="Surprise", | ||
) | ||
axs[4].fill_between( | ||
x=np.arange(len(categorical_hgf.node_trajectories[0]["binary_surprise"])), | ||
y1=categorical_hgf.node_trajectories[0]["binary_surprise"], | ||
y2=categorical_hgf.node_trajectories[0]["binary_surprise"].min(), | ||
color="#7f7f7f", | ||
alpha=0.1, | ||
zorder=-1, | ||
) | ||
axs[4].set_title("Sum of binary surprises", loc="left") | ||
axs[2].set_ylabel("Surprises") | ||
sns.despine() | ||
plt.tight_layout() | ||
``` | ||
|
||
### Inference using MCMC sampling | ||
|
||
+++ | ||
|
||
In the binary and continuous HGF example, we have been using the {py:class}`pyhgf.distribution.HGFDistribution` class to create a PyMC-compatible distribution of the HGF. This was possible when using the most standard models as we can easily write a pre-defined distribution that fits exactly the network specification. However, when using more exotic network structures, as this is the case here with the categorical state nodes where the number of nodes in the network grows with the number of categories, we need a more flexible approach that can let us wrap a PyMC distribution for every kind of network we can have. | ||
|
||
This is what we are doing below (see [this blog post](https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/) and the [PyMC documentation](https://www.pymc.io/projects/examples/en/latest/case_studies/wrapping_jax_function.html) on how to do that). First, we start by creating a function that computes the surprise of the model, here using the Kullback-Leibler divergences of the implied Dirichlet distributions. | ||
|
||
```{code-cell} ipython3 | ||
def categorical_surprise(omega_2, hgf, input_data): | ||
# replace with a new omega in the model | ||
for va_pa in hgf.edges[0].value_parents: | ||
for va_pa_va_pa in hgf.edges[va_pa].value_parents: | ||
for va_pa_va_pa_va_pa in hgf.edges[va_pa_va_pa].value_parents: | ||
hgf.attributes[va_pa_va_pa_va_pa]["omega"] = omega_2 | ||
# fit the model to new data | ||
hgf.input_data(input_data=input_data.T) | ||
# compute the surprises from KL divergences | ||
surprise = hgf.node_trajectories[0]["kl_divergence"][2:].sum() | ||
# return an infinite surprise if the model could not fit at any point | ||
surprise = jnp.where( | ||
jnp.any(jnp.isnan(hgf.node_trajectories[0]["xi"])), jnp.inf, surprise | ||
) | ||
return surprise | ||
surprise_fn = Partial(categorical_surprise, hgf=categorical_hgf, input_data=input_data) | ||
``` | ||
|
||
We create both jitted and the vector-jacobian product requiered for a custom Op in PyTensor: | ||
|
||
```{code-cell} ipython3 | ||
jitted_custom_op_jax = jit(surprise_fn) | ||
def vjp_custom_op_jax(x, gz): | ||
_, vjp_fn = vjp(surprise_fn, x) | ||
return vjp_fn(gz)[0] | ||
jitted_vjp_custom_op_jax = jit(vjp_custom_op_jax) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
tags: [hide-cell] | ||
--- | ||
# The CustomOp needs `make_node`, `perform` and `grad`. | ||
class CustomOp(Op): | ||
def make_node(self, x): | ||
# Create a PyTensor node specifying the number and type of inputs and outputs | ||
# We convert the input into a PyTensor tensor variable | ||
inputs = [pt.as_tensor_variable(x)] | ||
# Output has the same type and shape as `x` | ||
outputs = [inputs[0].type()] | ||
return Apply(self, inputs, outputs) | ||
def perform(self, node, inputs, outputs): | ||
# Evaluate the Op result for a specific numerical input | ||
# The inputs are always wrapped in a list | ||
(x,) = inputs | ||
result = jitted_custom_op_jax(x) | ||
# The results should be assigned inplace to the nested list | ||
# of outputs provided by PyTensor. If you have multiple | ||
# outputs and results, you should assign each at outputs[i][0] | ||
outputs[0][0] = np.asarray(result, dtype="float64") | ||
def grad(self, inputs, output_gradients): | ||
# Create a PyTensor expression of the gradient | ||
(x,) = inputs | ||
(gz,) = output_gradients | ||
# We reference the VJP Op created below, which encapsulates | ||
# the gradient operation | ||
return [vjp_custom_op(x, gz)] | ||
class VJPCustomOp(Op): | ||
def make_node(self, x, gz): | ||
# Make sure the two inputs are tensor variables | ||
inputs = [pt.as_tensor_variable(x), pt.as_tensor_variable(gz)] | ||
# Output has the shape type and shape as the first input | ||
outputs = [inputs[0].type()] | ||
return Apply(self, inputs, outputs) | ||
def perform(self, node, inputs, outputs): | ||
(x, gz) = inputs | ||
result = jitted_vjp_custom_op_jax(x, gz) | ||
outputs[0][0] = np.asarray(result, dtype="float64") | ||
# Instantiate the Ops | ||
custom_op = CustomOp() | ||
vjp_custom_op = VJPCustomOp() | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
--- | ||
# with pm.Model() as model: | ||
# omega_2 = pm.Normal("omega_2", -2.0, 2) | ||
# pm.Potential("hgf", custom_op(omega_2)) | ||
# categorical_idata = pm.sample(chains=2) | ||
``` | ||
|
||
+++ {"editable": true, "slideshow": {"slide_type": ""}} | ||
|
||
## The categorical state-transition node | ||
|
||
```{warning} | ||
This is work in progress. | ||
``` | ||
|
||
+++ | ||
|
||
# System configuration | ||
|
||
```{code-cell} ipython3 | ||
--- | ||
editable: true | ||
slideshow: | ||
slide_type: '' | ||
--- | ||
%load_ext watermark | ||
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib | ||
``` |
Oops, something went wrong.