Skip to content
239 changes: 190 additions & 49 deletions diffrax/misc/sde_kl_divergence.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,215 @@
import operator
from typing import Any, Tuple

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu

from ..brownian import AbstractBrownianPath
from ..custom_types import PyTree
from ..custom_types import Array, PyTree, Scalar
from ..term import (
AbstractTerm,
ControlTerm,
MultiTerm,
ODETerm,
WeaklyDiagonalControlTerm,
)


def _kl_diagonal(drift: Array, diffusion: Array):
"""This is the case where diffusion matrix is
a diagonal matrix
"""
diffusion = jnp.where(
jax.lax.stop_gradient(diffusion) > 1e-7,
diffusion,
jnp.full_like(diffusion, fill_value=1e-7) * jnp.sign(diffusion),
)
scale = drift / diffusion
return 0.5 * jnp.sum(scale**2)


def _kl(drift1, drift2, diffusion):
inv_diffusion = jnp.linalg.pinv(diffusion)
scale = inv_diffusion @ (drift1 - drift2)
def _kl_full_matrix(drift: Array, diffusion: Array):
"""General case"""
scale = jnp.linalg.pinv(diffusion) @ drift
return 0.5 * jnp.sum(scale**2)


class _AugDrift(eqx.Module):
drift1: callable
drift2: callable
diffusion: callable
context: callable
def _assert_array(x: Any):
assert isinstance(
x, jnp.ndarray
), "`sde_kl_divergence` can only handle array-value drifts and diffusions"


def _handle(drift: Array, diffusion: Array):
"""According to the shape of drift and diffusion,
select the right way to compute KL divergence
"""
_assert_array(drift)
_assert_array(diffusion)
if drift.shape == diffusion.shape:
return _kl_diagonal(drift, diffusion)
else:
return _kl_full_matrix(drift, diffusion)


def _kl_block_diffusion(drift: PyTree, diffusion: PyTree):
"""The case where diffusion matrix is a block diagonal matrix"""
kl = jtu.tree_map(
_handle,
drift,
diffusion,
)

kl = jtu.tree_reduce(
operator.add,
kl,
)
return kl


class _AugDrift(AbstractTerm):

drift1: ODETerm
drift2: ODETerm
diffusion: AbstractTerm

def vf(self, t: Scalar, y: PyTree, args) -> PyTree:
# In this implementation, we may restricte our case where the
# diffusion can be a block matrix. Each block can follow
# different `vf_prod`
# - PyTree of drift: (*, *, ..., *) :
# - PyTree of diffusion: (*, *, ..., *)
# For example,
# - output of drift can be
# drift = {"block1": jnp.zeros((2,)),
# "block2": jnp.zeros((2,)),
# "block3": jnp.zeros((3,))}
# - output of diffusion (which mixes between the two types)
# diffusion = {"block1": jnp.ones((2,)), #-> WeaklyDiagonal
# "block2": jnp.ones((2, 3)), #-> General case
# "block3": jnp.ones((3, 4))} #-> General case
#
# NOTE: `args` will take `context` as a function (normally, `args`
# is PyTree)

def __call__(self, t, y, args):
y, _ = y
context = self.context(t)
aug_y = jnp.concatenate([y, context], axis=-1)
drift1 = self.drift1(t, aug_y, args)
drift2 = self.drift2(t, y, args)
diffusion = self.diffusion(t, y, args)
kl_divergence = jtu.tree_map(_kl, drift1, drift2, diffusion)
kl_divergence = jtu.tree_reduce(operator.add, kl_divergence)

# check if there is context
context = args
aug_y = y if context is None else jnp.concatenate([y, context(t)], axis=-1)

drift1 = self.drift1.vf(t, aug_y, args)
drift2 = self.drift2.vf(t, y, args)

drift = jtu.tree_map(operator.sub, drift1, drift2)
diffusion = self.diffusion.vf(t, y, args)

# get tree structure of drift and diffusion
drift_tree_structure = jtu.tree_structure(drift)
diffusion_tree_structure = jtu.tree_structure(diffusion)

if drift_tree_structure == diffusion_tree_structure:
# drift and diffusion has the same tree structure
# check the shape to determine how to compute KL
# however, it does not check the abstract yet

if isinstance(drift, jnp.ndarray):
# this case PyTree is (*)

# here we check the abstract level of ControlTerm
if isinstance(self.diffusion, WeaklyDiagonalControlTerm):
# diffusion must be jnp.ndarrary as well because
# diffusion and drift has the same structure
# therefore we don't need to check type of diffusion here
kl_divergence = _kl_diagonal(drift, diffusion)
elif isinstance(self.diffusion, ControlTerm):
kl_divergence = _kl_full_matrix(drift, diffusion)
else:
# a more general case, we assume that on each leave,
# if drift and diffusion have the same shape
# -> WeaklyDiagonalControlTerm
# else
# -> ControlTerm
kl_divergence = _kl_block_diffusion(drift, diffusion)
else:
raise ValueError(
"drift and diffusion should have the same PyTree structure"
+ f" \n {drift_tree_structure} != {diffusion_tree_structure}"
)
return drift1, kl_divergence

@staticmethod
def contr(t0: Scalar, t1: Scalar) -> Scalar:
return t1 - t0

class _AugDiffusion(eqx.Module):
diffusion: callable
@staticmethod
def prod(vf: PyTree, control: Scalar) -> PyTree:
return jtu.tree_map(lambda v: control * v, vf)

def __call__(self, t, y, args):
y, _ = y
diffusion = self.diffusion(t, y, args)
return diffusion, 0.0

class _AugControlTerm(AbstractTerm):

control_term: AbstractTerm

def __init__(self, term: AbstractTerm) -> None:
self.control_term = term

class _AugBrownianPath(eqx.Module):
bm: AbstractBrownianPath
def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
y, _ = y
vf = self.control_term.vf(t, y, args)
return vf, 0.0

@property
def t0(self):
return self.bm.t0
def contr(self, t0: Scalar, t1: Scalar) -> PyTree:
return self.control_term.contr(t0, t1), 0.0

@property
def t1(self):
return self.bm.t1
def vf_prod(self, t: Scalar, y: PyTree, args: PyTree, control: PyTree) -> PyTree:
y, _ = y
return self.control_term.vf_prod(t, y, args, control), 0.0

def evaluate(self, t0, t1):
return self.bm.evaluate(t0, t1), 0.0
def prod(self, vf: PyTree, control: PyTree) -> PyTree:
vf, _ = vf
control, _ = control
return self.control_term.prod(vf, control), 0.0


def sde_kl_divergence(
*,
drift1: callable,
drift2: callable,
diffusion: callable,
context: callable,
y0: PyTree,
bm: AbstractBrownianPath,
):
drift1: ODETerm, drift2: ODETerm, diffusion: AbstractTerm, y0: PyTree
) -> Tuple[MultiTerm, PyTree]:
"""
Compute KL divergence between two SDEs having the same diffusion.

This function current supports the case that the output of
`drift1`, `drift2` and `diffusion` has the same PyTree structure.

This generalizes to the case that the diffusion matrix is a block
diagonal matrix. Each block can follow different matrix-vector
multiplication. `diffusion` should be implemented from
`diffrax.ControlTerm` and instruct how such multiplications are
done. The associated path may need to be customized as well.

The following example is acceptable:

a = drift1(t, y, args)
jax.tree_structure(a) # PyTreeDef({'block1': *, 'block2': *})
a['block1'].shape # (2,)
a['block2'].shape # (3,)

b = diffusion(t, y, args)
jax.tree_structure(b) # PyTreeDef({'block1': *, 'block2': *})
b['block1'].shape # (2,)
b['block2'].shape # (3,4)

Args:
drift1 (ODETerm): the drift of the first SDE (posterior)
drift2 (ODETerm): the drift of the second SDE (prior)
diffusion (AbstractTerm): the shared diffusion
y0 (PyTree): initial state
Returns:
An augmented SDE with KL information and the augmented initial state
"""

aug_y0 = (y0, 0.0)
return (
_AugDrift(drift1, drift2, diffusion, context),
_AugDiffusion(diffusion),
aug_y0,
_AugBrownianPath(bm),
)
aug_drift = _AugDrift(drift1, drift2, diffusion)
aug_control = _AugControlTerm(diffusion)
aug_sde = MultiTerm(aug_drift, aug_control)
return aug_sde, aug_y0
6 changes: 3 additions & 3 deletions examples/neural_sde.ipynb → examples/neural_sde_gan.ipynb

Large diffs are not rendered by default.

Loading