Skip to content

Commit

Permalink
Bijective tensors for caching intermediate values (#1334)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch/beanmachine#1334

### Motivation
As described in #88, we'd wish to have a way of caching intermediate values computed in the flow.

### Changes proposed
This PR assigns this responsibility to a new class, BijectiveTensor.
A BijectiveTensor keeps track of the layer that has created it, the original tensor and whether it comes from a call to 'forward' or 'inverse'.
It inherits from torch.Tensor. By default, an operation on a BijectiveTensor returns a torch.Tensor (except if this operation is a Bijector).
One can control if BijectiveTensors should be used (which is the case by default) with the context manager `set_record_flow_graph`.

Pull Request resolved: #89

Test Plan:
A test file can be found in `test/test_bijectivetensor.py`.

### Types of changes
- [ ] Docs change / refactoring / dependency upgrade
- [ ] Bug fix (non-breaking change which fixes an issue)
- [X] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)

### Checklist
- [X] My code follows the code style of this project.
- [X] My change requires a change to the documentation.
- [ ] I have updated the documentation accordingly.
- [X] I have read the **[CONTRIBUTING](https://github.com/facebookincubator/flowtorch/blob/main/CONTRIBUTING.md)** document.
- [X] I have added tests to cover my changes.
- [X] All new and existing tests passed.
- [X] The title of my pull request is a short description of the requested changes.

Reviewed By: ToddSmall

Differential Revision: D33927956

Pulled By: stefanwebb

fbshipit-source-id: 2f3e53efb69f5839cb0649bc0d834cc7a0503553
  • Loading branch information
vmoens authored and facebook-github-bot committed Feb 3, 2022
1 parent 64a3799 commit 45e91ca
Show file tree
Hide file tree
Showing 34 changed files with 712 additions and 282 deletions.
2 changes: 1 addition & 1 deletion examples/learn_bivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def learn_bivariate_normal() -> None:
# Lazily instantiated flow plus base and target distributions
bijectors = bij.AffineAutoregressive(
params=params.DenseAutoregressive(hidden_dims=(32,))
params_fn=params.DenseAutoregressive(hidden_dims=(32,))
)
base_dist = torch.distributions.Independent(
torch.distributions.Normal(torch.zeros(2), torch.ones(2)), 1
Expand Down
4 changes: 2 additions & 2 deletions flowtorch/bijectors/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class Affine(AffineOp, Elementwise):

def __init__(
self,
params: Optional[flowtorch.Lazy] = None,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
) -> None:
super().__init__(params, shape=shape, context_shape=context_shape)
super().__init__(params_fn, shape=shape, context_shape=context_shape)
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
self.sigmoid_bias = sigmoid_bias
4 changes: 2 additions & 2 deletions flowtorch/bijectors/affine_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class AffineAutoregressive(AffineOp, Autoregressive):
def __init__(
self,
params: Optional[flowtorch.Lazy] = None,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
Expand All @@ -21,7 +21,7 @@ def __init__(
sigmoid_bias: float = 2.0,
) -> None:
super().__init__(
params,
params_fn,
shape=shape,
context_shape=context_shape,
)
Expand Down
35 changes: 19 additions & 16 deletions flowtorch/bijectors/affine_fixed.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) Meta Platforms, Inc

import math
from typing import Optional
from typing import Optional, Sequence, Tuple

import flowtorch
import torch
from flowtorch.bijectors.fixed import Fixed
from flowtorch.bijectors.utils import requires_log_detJ


class AffineFixed(Fixed):
Expand All @@ -18,36 +19,38 @@ class AffineFixed(Fixed):
# TODO: Handle non-scalar loc and scale with correct broadcasting semantics
def __init__(
self,
params: Optional[flowtorch.Lazy] = None,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
loc: float = 0.0,
scale: float = 1.0
) -> None:
super().__init__(params, shape=shape, context_shape=context_shape)
super().__init__(params_fn, shape=shape, context_shape=context_shape)
self.loc = loc
self.scale = scale

def _forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.loc + self.scale * x
params: Optional[Sequence[torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
y = self.loc + self.scale * x
ladj: Optional[torch.Tensor] = None
if requires_log_detJ():
ladj = self._log_abs_det_jacobian(x, y, params)
return y, ladj

def _inverse(
self,
y: torch.Tensor,
x: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return (y - self.loc) / self.scale
self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
x = (y - self.loc) / self.scale
ladj: Optional[torch.Tensor] = None
if requires_log_detJ():
ladj = self._log_abs_det_jacobian(x, y, params)
return x, ladj

def _log_abs_det_jacobian(
self,
x: torch.Tensor,
y: torch.Tensor,
context: Optional[torch.Tensor] = None,
self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> torch.Tensor:
return torch.full_like(x, math.log(abs(self.scale)))
44 changes: 29 additions & 15 deletions flowtorch/bijectors/autoregressive.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) Meta Platforms, Inc

from typing import Any, cast, Optional
from typing import Any, cast, Optional, Sequence

import flowtorch
import flowtorch.parameters
import torch
import torch.distributions.constraints as constraints
from flowtorch.bijectors.base import Bijector
from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor
from flowtorch.bijectors.utils import is_record_flow_graph_enabled
from flowtorch.parameters.dense_autoregressive import DenseAutoregressive


Expand All @@ -17,7 +19,7 @@ class Autoregressive(Bijector):

def __init__(
self,
params: Optional[flowtorch.Lazy] = None,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
Expand All @@ -28,14 +30,14 @@ def __init__(
self.codomain = constraints.independent(constraints.real, len(shape))

# currently only DenseAutoregressive has a `permutation` buffer
if not params:
params = DenseAutoregressive() # type: ignore
if not params_fn:
params_fn = DenseAutoregressive() # type: ignore

# TODO: Replace P.DenseAutoregressive with P.Autoregressive
# In the future there will be other autoregressive parameter classes
assert params is not None and issubclass(params.cls, DenseAutoregressive)
assert params_fn is not None and issubclass(params_fn.cls, DenseAutoregressive)

super().__init__(params, shape=shape, context_shape=context_shape)
super().__init__(params_fn, shape=shape, context_shape=context_shape)

def inverse(
self,
Expand All @@ -45,25 +47,37 @@ def inverse(
) -> torch.Tensor:
# TODO: Allow that context can have a batch shape
assert context is None # or context.shape == (self._context_size,)
params = self.params
assert params is not None

assert self._params_fn is not None
if self._check_bijective_y(y, context):
assert isinstance(y, BijectiveTensor)
return y.get_parent_from_bijector(self)
x_new = torch.zeros_like(y)
# NOTE: Inversion is an expensive operation that scales in the
# dimension of the input
permutation = (
params.permutation
self._params_fn.permutation
) # TODO: type-safe named buffer (e.g. "permutation") access
# TODO: Make permutation, inverse work for other event shapes
log_detJ: Optional[torch.Tensor] = None
for idx in cast(torch.LongTensor, permutation):
x_new[..., idx] = self._inverse(y, x_new.clone(), context)[..., idx]
_params = self._params_fn(x_new.clone(), context=context)
x_temp, log_detJ = self._inverse(y, params=_params)
x_new[..., idx] = x_temp[..., idx]
# _log_detJ = out[1]
# log_detJ = _log_detJ

if is_record_flow_graph_enabled():
x_new = to_bijective_tensor(
x_new,
y,
context=context,
bijector=self,
mode="inverse",
log_detJ=log_detJ,
)
return x_new

def _log_abs_det_jacobian(
self,
x: torch.Tensor,
y: torch.Tensor,
context: Optional[torch.Tensor] = None,
self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> torch.Tensor:
raise NotImplementedError
121 changes: 98 additions & 23 deletions flowtorch/bijectors/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
# Copyright (c) Meta Platforms, Inc
from typing import Optional, Sequence, Union
import warnings
from typing import Optional, Sequence, Tuple, Union, Callable, Iterator

import flowtorch
import flowtorch.distributions
import flowtorch.parameters
import torch
import torch.distributions
from flowtorch.bijectors.bijective_tensor import to_bijective_tensor, BijectiveTensor
from flowtorch.bijectors.utils import is_record_flow_graph_enabled
from flowtorch.parameters import Parameters
from torch.distributions import constraints

ParamFnType = Callable[
[Optional[torch.Tensor], Optional[torch.Tensor]], Optional[Sequence[torch.Tensor]]
]


class Bijector(metaclass=flowtorch.LazyMeta):
codomain: constraints.Constraint = constraints.real
domain: constraints.Constraint = constraints.real
_shape: torch.Size
_context_shape: Optional[torch.Size]
_params: Optional[Union[Parameters, torch.nn.ModuleList]] = None
_params_fn: Optional[Union[Parameters, torch.nn.ModuleList]] = None

def __init__(
self,
params: Optional[flowtorch.Lazy] = None,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
Expand All @@ -37,19 +42,27 @@ def __init__(
self._context_shape = context_shape

# Instantiate parameters (tensor, hypernets, etc.)
if params is not None:
if params_fn is not None:
param_shapes = self.param_shapes(shape)
self._params = params( # type: ignore
self._params_fn = params_fn( # type: ignore
param_shapes, self._shape, self._context_shape
)

@property
def params(self) -> Optional[Union[Parameters, torch.nn.ModuleList]]:
return self._params

@params.setter
def params(self, value: Optional[Union[Parameters, torch.nn.ModuleList]]) -> None:
self._params = value
def parameters(self) -> Iterator[torch.Tensor]:
assert self._params_fn is not None
if hasattr(self._params_fn, "parameters"):
for param in self._params_fn.parameters():
yield param

def _check_bijective_x(
self, x: torch.Tensor, context: Optional[torch.Tensor]
) -> bool:
return (
isinstance(x, BijectiveTensor)
and x.from_inverse()
and x.check_bijector(self)
and x.check_context(context)
)

def forward(
self,
Expand All @@ -58,18 +71,41 @@ def forward(
) -> torch.Tensor:
# TODO: Allow that context can have a batch shape
assert context is None # or context.shape == (self._context_size,)
return self._forward(x, context)
if self._check_bijective_x(x, context):
assert isinstance(x, BijectiveTensor)
return x.get_parent_from_bijector(self)

params = self._params_fn(x, context) if self._params_fn is not None else None
y, log_detJ = self._forward(x, params)
if (
is_record_flow_graph_enabled()
and not isinstance(y, BijectiveTensor)
and not (isinstance(x, BijectiveTensor) and y in set(x.parents()))
):
# we exclude y that are bijective tensors for Compose
y = to_bijective_tensor(x, y, context, self, log_detJ, mode="forward")
return y

def _forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
params: Optional[Sequence[torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Abstract method to compute forward transformation.
"""
raise NotImplementedError

def _check_bijective_y(
self, y: torch.Tensor, context: Optional[torch.Tensor]
) -> bool:
return (
isinstance(y, BijectiveTensor)
and y.from_forward()
and y.check_bijector(self)
and y.check_context(context)
)

def inverse(
self,
y: torch.Tensor,
Expand All @@ -78,14 +114,27 @@ def inverse(
) -> torch.Tensor:
# TODO: Allow that context can have a batch shape
assert context is None # or context.shape == (self._context_size,)
return self._inverse(y, x, context)
if self._check_bijective_y(y, context):
assert isinstance(y, BijectiveTensor)
return y.get_parent_from_bijector(self)

# TODO: What to do in this line?
params = self._params_fn(x, context) if self._params_fn is not None else None
x, log_detJ = self._inverse(y, params)

if (
is_record_flow_graph_enabled()
and not isinstance(x, BijectiveTensor)
and not (isinstance(y, BijectiveTensor) and x in set(y.parents()))
):
x = to_bijective_tensor(x, y, context, self, log_detJ, mode="inverse")
return x

def _inverse(
self,
y: torch.Tensor,
x: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
params: Optional[Sequence[torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Abstract method to compute inverse transformation.
"""
Expand All @@ -101,13 +150,39 @@ def log_abs_det_jacobian(
Computes the log det jacobian `log |dy/dx|` given input and output.
By default, assumes a volume preserving bijection.
"""
return self._log_abs_det_jacobian(x, y, context)
# TODO: Allow that context can have a batch shape
assert context is None # or context.shape == (self._context_size,)
ladj = None
if (
isinstance(y, BijectiveTensor)
and y.from_forward()
and y.check_bijector(self)
and y.check_context(context)
):
ladj = y.log_detJ
elif (
isinstance(x, BijectiveTensor)
and x.from_inverse()
and x.check_bijector(self)
and x.check_context(context)
):
ladj = x.log_detJ
if ladj is None:
if is_record_flow_graph_enabled():
warnings.warn(
"Computing _log_abs_det_jacobian from values and not from cache."
)
params = (
self._params_fn(x, context) if self._params_fn is not None else None
)
return self._log_abs_det_jacobian(x, y, params)
return ladj

def _log_abs_det_jacobian(
self,
x: torch.Tensor,
y: torch.Tensor,
context: Optional[torch.Tensor] = None,
params: Optional[Sequence[torch.Tensor]],
) -> torch.Tensor:
"""
Computes the log det jacobian `log |dy/dx|` given input and output.
Expand Down
Loading

0 comments on commit 45e91ca

Please sign in to comment.