Skip to content

Commit

Permalink
add support for categorical state nodes (#83)
Browse files Browse the repository at this point in the history
* update tutorials

* - add a fill_categorical_state_nodes function
- method to add a categorical state node

* updated plotting functions for categorical nodes

* move the get_update_sequence function in the network's submodule.

* clever update sequence

* do not automatically create an update sequence when providing inputs if there is none

* proper order for categorical input updates

* surprise in the categorical node

* move the to_pandas method in the network submodule

* performance issues and correct handling of surprise in to_pandas

* create an init method to cache the update function before providing data

* error in the get updates function

* fix error in expectation in the categorical node

* fix the order of update/predictions at the network level

* update pre-commit

* add Kullback_Leibler divergence

* docs
  • Loading branch information
LegrandNico authored Aug 30, 2023
1 parent ba2b36c commit c8c7dbd
Show file tree
Hide file tree
Showing 38 changed files with 7,529 additions and 5,078 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ jobs:
- name: Install dependencies
run: |
sudo apt-get install graphviz
pip install jax==0.4.14
pip install jaxlib==0.4.14
pip install -r requirements-tests.txt
pip install ipykernel coverage pytest pytest-cov
python -m ipykernel install --user --name python3
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ repos:
hooks:
- id: isort
- repo: https://github.com/ambv/black
rev: 23.1.0
rev: 23.7.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
rev: 6.1.0
hooks:
- id: flake8
- repo: https://github.com/pycqa/pydocstyle
Expand All @@ -24,7 +24,7 @@ repos:
args: ['--ignore', 'D213,D100,D203,D104']
files: ^pyhgf/
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.1.1'
rev: 'v1.5.1'
hooks:
- id: mypy
files: ^pyhgf/
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,5 @@ Utilities for manipulating networks of probabilistic nodes.
beliefs_propagation
trim_sequence
list_branches
fill_categorical_state_node
get_update_sequence
308 changes: 177 additions & 131 deletions docs/source/notebooks/0-Creating_networks.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions docs/source/notebooks/0-Creating_networks.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ many_value_children_hgf = (
.add_input_node(kind="continuous", input_idxs=1)
.add_value_parent(children_idxs=[0, 1])
.add_volatility_parent(children_idxs=[2])
.init()
)
# plot the network
Expand Down Expand Up @@ -264,6 +265,7 @@ many_volatility_children_hgf = (
.add_value_parent(children_idxs=[0])
.add_value_parent(children_idxs=[1])
.add_volatility_parent(children_idxs=[2, 3])
.init()
)
# plot the network
Expand Down Expand Up @@ -328,6 +330,7 @@ many_binary_children_hgf = (
.add_value_parent(children_idxs=[1])
.add_value_parent(children_idxs=[2, 3])
.add_volatility_parent(children_idxs=[4])
.init()
)
# plot the network
Expand Down
1,317 changes: 0 additions & 1,317 deletions docs/source/notebooks/1-Binary_HGF.ipynb

This file was deleted.

1,333 changes: 1,333 additions & 0 deletions docs/source/notebooks/1.1-Binary_HGF.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ The data is being passed to the distribution when the instance is created.
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
with pm.Model() as two_levels_binary_hgf:
omega_2 = pm.Uniform("omega_2", -3.5, 0.0)
Expand Down
1,009 changes: 1,009 additions & 0 deletions docs/source/notebooks/1.2-Categorical_HGF.ipynb

Large diffs are not rendered by default.

312 changes: 312 additions & 0 deletions docs/source/notebooks/1.2-Categorical_HGF.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.7
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

(categorical_hgf)=
# The categorical Hierarchical Gaussian Filter
```{warning}
The categorical state node and the categorical state-transition nodes are still work in progress. The examples provided here are given for illustration. Things may change or not work until the official publication.
```

```{code-cell} ipython3
:tags: [hide-cell]
%%capture
import sys
if 'google.colab' in sys.modules:
! pip install pyhgf
```

```{code-cell} ipython3
from pyhgf.model import HGF
import numpy as np
from pyhgf import load_data
import jax.numpy as jnp
from jax import jit, grad, vjp
from jax.tree_util import Partial
import matplotlib.pyplot as plt
from pyhgf.plots import plot_nodes
import seaborn as sns
import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
```

The binary state nodes that we introduced in the previous section are useful to encode information about stochastic boolean variables that are common in reinforcement learning and decision-making design. However, situations may occur where discrete variables can have more than two categories and therefore need to be encoded by a categorical distribution. Here, we introduce two probabilistic nodes tailored to handle this kind of variable: the **categorical state node** and the **categorical state-transition node**.

Both nodes are a generalisation of the binary HGFs (in the sense that they internally represent a collection of binary state nodes). We refer to **categorical HGF** in a broad sense for HGFs that can handle categorical distributions, but as we will illustrate below, there are many ways to do that and a more precise terminology is to refer to the kind of node used internally (the **categorical state node** and the **categorical state-transition node**).

+++

## Simulating a dataset
We start by simulating a dataset on which we can apply the categorical HGFs. The dataset consists of a categorical input where the number of categories $K=10$. The underlying contingencies are generated by three Dirichlet distributions on which we sample 150 observations sequentially.

```{code-cell} ipython3
# generate some categorical inputs data using three underlying distributions
p1 = np.random.dirichlet(alpha=[1, 2, 3, 5, 9, 13, 17, 25, 30, 35])
p2 = np.random.dirichlet(alpha=[1, 2, 3, 5, 30, 30, 5, 3, 2, 1])
p3 = np.random.dirichlet(alpha=[35, 30, 25, 17, 13, 9, 5, 3, 2, 1])
input_data = np.array(
[np.random.multinomial(n=1, pvals=p) for p in [p1, p2, p3] for _ in range(250)]
).T
```

The Dirichlet distributions are parametrized in such a way that it goes from a "skewed" distribution to a centred distribution to another "skewed" distribution. The resulting sequence of categorical observations then looks like this:

```{code-cell} ipython3
plt.figure(figsize=(12, 3))
plt.imshow(input_data, interpolation="none", aspect="auto", cmap="binary")
plt.ylabel("Categories")
plt.xlabel("Time")
plt.title("Categorical observations");
```

```{note}
The lower part of the figure represent the surprise associated with the categorical node. Here, we use the [Kullback-Leibler divergence between two Dirichlet distributions](https://statproofbook.github.io/P/dir-kl.html) as a measure of Bayesian surprise. The Kullback-Leibler divergence of the Dirichlet distribution $P$ from the Dirichlet distribution $Q$ is given by the following equation:
$$
KL[P||Q] = \ln{\frac{\Gamma(\sum_{i=1}^k\alpha_{1i})}{\Gamma(\sum_{i=1}^k\alpha_{2i})}} + \sum_{i=1}^k \ln{\frac{\Gamma(\alpha_{2i})}{\Gamma(\alpha_{1i})}} + \sum_{i=1}^k(\alpha_{1i} - \alpha_{2i}) \left[ \psi(\alpha_{1i}) - \psi(\sum_{i=1}^k \alpha_{1i}) \right]
$$
```

```{code-cell} ipython3
# adding a blank input time series for the categorical state node
# this is because the categorical state node does not receive anything
# only binary nodes are the actual inputs of the network
input_data = np.vstack([[0.0] * input_data.shape[1], input_data])
```

## The categorical state node

+++

### Creating the probabilistic network

```{code-cell} ipython3
categorical_hgf = (
HGF(model_type=None, verbose=False)
.add_input_node(
kind="categorical",
categorical_parameters={"n_categories": 10},
binary_parameters={"omega_2": -2.0}
)
.init()
)
```

```{code-cell} ipython3
categorical_hgf.plot_network()
```

### Fitting the model forwards

```{code-cell} ipython3
categorical_hgf.input_data(input_data=input_data.T);
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
tags: [hide-cell]
---
fig, axs = plt.subplots(nrows=5, figsize=(12, 9), sharex=True)
plot_nodes(categorical_hgf, node_idxs=31, axs=axs[0])
axs[1].imshow(categorical_hgf.node_trajectories[0]["mu"].T, interpolation="none", aspect="auto")
axs[1].set_title("Mean of the implied Dirichlet distribution", loc="left")
axs[1].set_ylabel("Categories")
# observations
axs[2].imshow(input_data, interpolation="none", aspect="auto", cmap="binary")
axs[2].set_title("Observations", loc="left")
axs[2].set_ylabel("Categories")
# surprise
axs[3].plot(
np.arange(2, len(categorical_hgf.node_trajectories[0]["kl_divergence"])),
categorical_hgf.node_trajectories[0]["kl_divergence"][2:],
color="#2a2a2a",
linewidth=0.5,
zorder=-1,
label="Surprise",
)
axs[3].fill_between(
x=np.arange(2, len(categorical_hgf.node_trajectories[0]["kl_divergence"])),
y1=categorical_hgf.node_trajectories[0]["kl_divergence"][2:],
y2=categorical_hgf.node_trajectories[0]["kl_divergence"][2:].min(),
color="#7f7f7f",
alpha=0.1,
zorder=-1,
)
axs[3].set_title("Kullback-Leibler divergences", loc="left")
axs[2].set_ylabel("Surprises")
axs[4].plot(
np.arange(len(categorical_hgf.node_trajectories[0]["binary_surprise"])),
categorical_hgf.node_trajectories[0]["binary_surprise"],
color="#2a2a2a",
linewidth=0.5,
zorder=-1,
label="Surprise",
)
axs[4].fill_between(
x=np.arange(len(categorical_hgf.node_trajectories[0]["binary_surprise"])),
y1=categorical_hgf.node_trajectories[0]["binary_surprise"],
y2=categorical_hgf.node_trajectories[0]["binary_surprise"].min(),
color="#7f7f7f",
alpha=0.1,
zorder=-1,
)
axs[4].set_title("Sum of binary surprises", loc="left")
axs[2].set_ylabel("Surprises")
sns.despine()
plt.tight_layout()
```

### Inference using MCMC sampling

+++

In the binary and continuous HGF example, we have been using the {py:class}`pyhgf.distribution.HGFDistribution` class to create a PyMC-compatible distribution of the HGF. This was possible when using the most standard models as we can easily write a pre-defined distribution that fits exactly the network specification. However, when using more exotic network structures, as this is the case here with the categorical state nodes where the number of nodes in the network grows with the number of categories, we need a more flexible approach that can let us wrap a PyMC distribution for every kind of network we can have.

This is what we are doing below (see [this blog post](https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/) and the [PyMC documentation](https://www.pymc.io/projects/examples/en/latest/case_studies/wrapping_jax_function.html) on how to do that). First, we start by creating a function that computes the surprise of the model, here using the Kullback-Leibler divergences of the implied Dirichlet distributions.

```{code-cell} ipython3
def categorical_surprise(omega_2, hgf, input_data):
# replace with a new omega in the model
for va_pa in hgf.edges[0].value_parents:
for va_pa_va_pa in hgf.edges[va_pa].value_parents:
for va_pa_va_pa_va_pa in hgf.edges[va_pa_va_pa].value_parents:
hgf.attributes[va_pa_va_pa_va_pa]["omega"] = omega_2
# fit the model to new data
hgf.input_data(input_data=input_data.T)
# compute the surprises from KL divergences
surprise = hgf.node_trajectories[0]["kl_divergence"][2:].sum()
# return an infinite surprise if the model could not fit at any point
surprise = jnp.where(
jnp.any(jnp.isnan(hgf.node_trajectories[0]["xi"])), jnp.inf, surprise
)
return surprise
surprise_fn = Partial(categorical_surprise, hgf=categorical_hgf, input_data=input_data)
```

We create both jitted and the vector-jacobian product requiered for a custom Op in PyTensor:

```{code-cell} ipython3
jitted_custom_op_jax = jit(surprise_fn)
def vjp_custom_op_jax(x, gz):
_, vjp_fn = vjp(surprise_fn, x)
return vjp_fn(gz)[0]
jitted_vjp_custom_op_jax = jit(vjp_custom_op_jax)
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
tags: [hide-cell]
---
# The CustomOp needs `make_node`, `perform` and `grad`.
class CustomOp(Op):
def make_node(self, x):
# Create a PyTensor node specifying the number and type of inputs and outputs
# We convert the input into a PyTensor tensor variable
inputs = [pt.as_tensor_variable(x)]
# Output has the same type and shape as `x`
outputs = [inputs[0].type()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
# Evaluate the Op result for a specific numerical input
# The inputs are always wrapped in a list
(x,) = inputs
result = jitted_custom_op_jax(x)
# The results should be assigned inplace to the nested list
# of outputs provided by PyTensor. If you have multiple
# outputs and results, you should assign each at outputs[i][0]
outputs[0][0] = np.asarray(result, dtype="float64")
def grad(self, inputs, output_gradients):
# Create a PyTensor expression of the gradient
(x,) = inputs
(gz,) = output_gradients
# We reference the VJP Op created below, which encapsulates
# the gradient operation
return [vjp_custom_op(x, gz)]
class VJPCustomOp(Op):
def make_node(self, x, gz):
# Make sure the two inputs are tensor variables
inputs = [pt.as_tensor_variable(x), pt.as_tensor_variable(gz)]
# Output has the shape type and shape as the first input
outputs = [inputs[0].type()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
(x, gz) = inputs
result = jitted_vjp_custom_op_jax(x, gz)
outputs[0][0] = np.asarray(result, dtype="float64")
# Instantiate the Ops
custom_op = CustomOp()
vjp_custom_op = VJPCustomOp()
```

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
# with pm.Model() as model:
# omega_2 = pm.Normal("omega_2", -2.0, 2)
# pm.Potential("hgf", custom_op(omega_2))
# categorical_idata = pm.sample(chains=2)
```

+++ {"editable": true, "slideshow": {"slide_type": ""}}

## The categorical state-transition node

```{warning}
This is work in progress.
```

+++

# System configuration

```{code-cell} ipython3
---
editable: true
slideshow:
slide_type: ''
---
%load_ext watermark
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib
```
Loading

0 comments on commit c8c7dbd

Please sign in to comment.