Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Planar inverse with leaky relu #170

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 81 additions & 21 deletions flowjax/bijections/planar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,40 @@
"""

from collections.abc import Callable
from typing import ClassVar
from functools import partial
from typing import ClassVar, Literal

import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
from jax.nn import softplus
from jax import nn
from jax.numpy.linalg import norm
from jaxtyping import Array, PRNGKeyArray
from jaxtyping import Array, Float, PRNGKeyArray

from flowjax.bijections.bijection import AbstractBijection


class Planar(AbstractBijection):
r"""Planar bijection as used by https://arxiv.org/pdf/1505.05770.pdf.

Uses the transformation :math:`y + u \cdot \text{tanh}(w \cdot x + b)`, where
:math:`u \in \mathbb{R}^D, \ w \in \mathbb{R}^D` and :math:`b \in \mathbb{R}`. In
the unconditional case, :math:`w`, :math:`u` and :math:`b` are learned directly.
In the conditional case they are parameterised by an MLP.
Uses the transformation

.. math::

\boldsymbol{y}=\boldsymbol{x} +
\boldsymbol{u} \cdot \text{tanh}(\boldsymbol{w}^T \boldsymbol{x} + b)

where :math:`\boldsymbol{u} \in \mathbb{R}^D, \ \boldsymbol{w} \in \mathbb{R}^D`
and :math:`b \in \mathbb{R}`. In the unconditional case, the (unbounded) parameters
are learned directly. In the unconditional case they are parameterised by an MLP.

Args:
key: Jax random seed.
dim: Dimension of the bijection.
cond_dim: Dimension of extra conditioning variables. Defaults to None.
negative_slope: A positive float. If provided, then a leaky relu activation
(with the corresponding negative slope) is used instead of tanh. This also
provides the advantage that the bijection can be inverted analytically.
**mlp_kwargs: Keyword arguments (excluding in_size and out_size) passed to
the MLP (equinox.nn.MLP). Ignored when cond_dim is None.
"""
Expand All @@ -36,13 +46,15 @@ class Planar(AbstractBijection):
cond_shape: tuple[int, ...] | None
conditioner: Callable | None
params: Array | None
negative_slope: float | None

def __init__(
self,
key: PRNGKeyArray,
*,
dim: int,
cond_dim: int | None = None,
negative_slope: float | None = None,
**mlp_kwargs,
):
self.shape = (dim,)
Expand All @@ -56,6 +68,8 @@ def __init__(
self.conditioner = eqx.nn.MLP(cond_dim, 2 * dim + 1, **mlp_kwargs, key=key)
self.cond_shape = (cond_dim,)

self.negative_slope = negative_slope

def transform(self, x, condition=None):
return self.get_planar(condition).transform(x)

Expand All @@ -77,36 +91,59 @@ def get_planar(self, condition=None):
dim = self.shape[0]
assert params is not None
w, u, bias = params[:dim], params[dim : 2 * dim], params[-1]
return _UnconditionalPlanar(w, u, bias)
return _UnconditionalPlanar(w, u, bias, self.negative_slope)


class _UnconditionalPlanar(AbstractBijection):
"""Unconditional planar bijection, used in Planar.

Note act_scale (u in the paper) is unconstrained and the constraint to ensure
invertiblitiy is applied in the ``get_act_scale``.
invertiblitiy is applied in ``get_act_scale``.
"""

shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
weight: Array
_act_scale: Array
bias: Array
activation: Literal["tanh"] | Literal["leaky_relu"]
activation_fn: Callable
negative_slope: float | None

def __init__(self, weight, act_scale, bias):
def __init__(
self,
weight: Float[Array, " dim"],
act_scale: Float[Array, " dim"],
bias: Float[Array, " "],
negative_slope: float | None = None,
):
self.weight = weight
self._act_scale = act_scale
self.bias = bias
self.shape = weight.shape
self.negative_slope = negative_slope
self._act_scale = act_scale

if negative_slope is None:
self.activation = "tanh"
self.activation_fn = jnp.tanh
else:
if negative_slope <= 0:
raise ValueError("The negative slope value should be >0.")
self.activation = "leaky_relu"
self.activation_fn = partial(nn.leaky_relu, negative_slope=negative_slope)

def transform(self, x, condition=None):
return x + self.get_act_scale() * jnp.tanh(self.weight @ x + self.bias)
u = self.get_act_scale()
return x + u * self.activation_fn(self.weight @ x + self.bias)

def transform_and_log_det(self, x, condition=None):
u = self.get_act_scale()
act = jnp.tanh(x @ self.weight + self.bias)
act = self.activation_fn(x @ self.weight + self.bias)
y = x + u * act
psi = (1 - act**2) * self.weight
if self.activation == "leaky_relu":
psi = jnp.where(act < 0, self.negative_slope, 1) * self.weight
else:
psi = (1 - act**2) * self.weight
log_det = jnp.log(jnp.abs(1 + u @ psi))
return y, log_det

Expand All @@ -116,15 +153,38 @@ def get_act_scale(self):
See appendix A1 in https://arxiv.org/pdf/1505.05770.pdf.
"""
wtu = self._act_scale @ self.weight
m_wtu = -1 + jnp.log(1 + softplus(wtu))
m_wtu = -1 + jnp.log(1 + nn.softplus(wtu))
return self._act_scale + (m_wtu - wtu) * self.weight / norm(self.weight) ** 2

def inverse(self, y, condition=None):
raise NotImplementedError(
"The inverse planar transformation is not implemented.",
)
if self.activation != "leaky_relu":
raise NotImplementedError(
"The inverse planar transformation is only implemented with the leaky "
"relu activation function.",
)
return self.inverse_and_log_det(y, condition)[0]

def inverse_and_log_det(self, y, condition=None):
raise NotImplementedError(
"The inverse planar transformation is not implemented.",
)
if self.activation != "leaky_relu":
raise NotImplementedError(
"The inverse planar transformation is only implemented with the leaky "
"relu activation function.",
)
# Expanding explanation as the inverse is not in the original paper.
# The derivation steps for the inversion are:
# 1. Let z = w^Tx+b
# 2. We want x=y-uσ(z), where σ is the leaky relu function.
# 3. Sub x=y-uσ(z) into z = w^Tx+b,
# 4. Solve for z, which gives z = (w^Ty+b)/(1+w^Tus), where s is the slope
# σ'(z), i.e. s=1 if z>=0 and s=negative_slope otherwise. To find the
# slope, it is sufficient to check the sign of the numerator w^Ty+b, rather
# than z, as the denominator is constrained to be positive.
# 5. Compute inverse using x=y-uσ(z)

numerator = self.weight @ y + self.bias
relu_slope = jnp.where(numerator < 0, self.negative_slope, 1)
us = self.get_act_scale() * relu_slope
denominator = 1 + self.weight @ us
log_det = -jnp.log(jnp.abs(1 + us @ self.weight))
x = y - us * (numerator / denominator)
return x, log_det
5 changes: 5 additions & 0 deletions flowjax/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def planar_flow(
cond_dim: int | None = None,
flow_layers: int = 8,
invert: bool = True,
negative_slope: float | None = None,
**mlp_kwargs,
) -> Transformed:
"""Planar flow as introduced in https://arxiv.org/pdf/1505.05770.pdf.
Expand All @@ -241,6 +242,9 @@ def planar_flow(
invert: Whether to invert the bijection. Broadly, True will prioritise a faster
`inverse` methods, leading to faster `log_prob`, False will prioritise
faster `transform` methods, leading to faster `sample`. Defaults to True.
negative_slope: A positive float. If provided, then a leaky relu activation
(with the corresponding negative slope) is used instead of tanh. This also
provides the advantage that the bijection can be inverted analytically.
**mlp_kwargs: Keyword arguments (excluding in_size and out_size) passed to
the MLP (equinox.nn.MLP). Ignored when cond_dim is None.
"""
Expand All @@ -251,6 +255,7 @@ def make_layer(key): # Planar layer + permutation
bij_key,
dim=base_dist.shape[-1],
cond_dim=cond_dim,
negative_slope=negative_slope,
**mlp_kwargs,
)
return _add_default_permute(bijection, base_dist.shape[-1], perm_key)
Expand Down
13 changes: 12 additions & 1 deletion tests/test_bijections/test_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TriangularAffine,
Vmap,
)
from flowjax.bijections.planar import _UnconditionalPlanar

DIM = 3
COND_DIM = 2
Expand Down Expand Up @@ -171,6 +172,16 @@
[Affine(jr.uniform(k, (1, 2, 3))) for k in jr.split(KEY, 3)],
axis=-1,
),
"_UnconditionalPlanar (leaky_relu +ve bias)": lambda: _UnconditionalPlanar(
*jnp.split(jr.normal(KEY, (8,)), 2),
bias=jnp.array(100.0), # leads to evaluation in +ve relu portion
negative_slope=0.1,
),
"_UnconditionalPlanar (leaky_relu -ve bias)": lambda: _UnconditionalPlanar(
*jnp.split(jr.normal(KEY, (8,)), 2),
bias=-jnp.array(100.0), # leads to evaluation in -ve relu portion
negative_slope=0.1,
),
"Planar": lambda: Planar(
KEY,
dim=DIM,
Expand Down Expand Up @@ -209,7 +220,7 @@ def test_transform_inverse(bijection_name):
y = bijection.transform(x, cond)
try:
x_reconstructed = bijection.inverse(y, cond)
assert x == pytest.approx(x_reconstructed, abs=1e-4)
assert x_reconstructed == pytest.approx(x, abs=1e-4)
except NotImplementedError:
pass

Expand Down
Loading