From 45e91cae11d8a037b835c341f6e4f07dcfc193c6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 2 Feb 2022 17:40:22 -0800 Subject: [PATCH] Bijective tensors for caching intermediate values (#1334) Summary: Pull Request resolved: https://github.com/facebookresearch/beanmachine/pull/1334 ### Motivation As described in https://github.com/facebookincubator/flowtorch/issues/88, we'd wish to have a way of caching intermediate values computed in the flow. ### Changes proposed This PR assigns this responsibility to a new class, BijectiveTensor. A BijectiveTensor keeps track of the layer that has created it, the original tensor and whether it comes from a call to 'forward' or 'inverse'. It inherits from torch.Tensor. By default, an operation on a BijectiveTensor returns a torch.Tensor (except if this operation is a Bijector). One can control if BijectiveTensors should be used (which is the case by default) with the context manager `set_record_flow_graph`. Pull Request resolved: https://github.com/facebookincubator/flowtorch/pull/89 Test Plan: A test file can be found in `test/test_bijectivetensor.py`. ### Types of changes - [ ] Docs change / refactoring / dependency upgrade - [ ] Bug fix (non-breaking change which fixes an issue) - [X] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ### Checklist - [X] My code follows the code style of this project. - [X] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [X] I have read the **[CONTRIBUTING](https://github.com/facebookincubator/flowtorch/blob/main/CONTRIBUTING.md)** document. - [X] I have added tests to cover my changes. - [X] All new and existing tests passed. - [X] The title of my pull request is a short description of the requested changes. Reviewed By: ToddSmall Differential Revision: D33927956 Pulled By: stefanwebb fbshipit-source-id: 2f3e53efb69f5839cb0649bc0d834cc7a0503553 --- examples/learn_bivariate_normal.py | 2 +- flowtorch/bijectors/affine.py | 4 +- flowtorch/bijectors/affine_autoregressive.py | 4 +- flowtorch/bijectors/affine_fixed.py | 35 +++-- flowtorch/bijectors/autoregressive.py | 44 ++++-- flowtorch/bijectors/base.py | 121 ++++++++++++--- flowtorch/bijectors/bijective_tensor.py | 147 +++++++++++++++++++ flowtorch/bijectors/compose.py | 95 +++++++++--- flowtorch/bijectors/elementwise.py | 10 +- flowtorch/bijectors/elu.py | 28 ++-- flowtorch/bijectors/exp.py | 28 ++-- flowtorch/bijectors/fixed.py | 6 +- flowtorch/bijectors/leaky_relu.py | 28 ++-- flowtorch/bijectors/ops/affine.py | 35 ++--- flowtorch/bijectors/ops/spline.py | 42 +++--- flowtorch/bijectors/permute.py | 27 ++-- flowtorch/bijectors/power.py | 27 ++-- flowtorch/bijectors/sigmoid.py | 28 ++-- flowtorch/bijectors/softplus.py | 28 ++-- flowtorch/bijectors/spline.py | 4 +- flowtorch/bijectors/spline_autoregressive.py | 4 +- flowtorch/bijectors/tanh.py | 28 ++-- flowtorch/bijectors/utils.py | 58 ++++++++ flowtorch/bijectors/volume_preserving.py | 7 +- flowtorch/distributions/flow.py | 12 +- flowtorch/parameters/base.py | 4 +- flowtorch/parameters/dense_autoregressive.py | 2 +- flowtorch/parameters/tensor.py | 2 +- setup.py | 2 +- tests/test_bijectivetensor.py | 110 ++++++++++++++ tests/test_bijector.py | 2 +- tests/test_distribution.py | 6 +- website/docs/users/conditional.mdx | 2 +- website/src/theme/Examples/snippets.js | 12 +- 34 files changed, 712 insertions(+), 282 deletions(-) create mode 100644 flowtorch/bijectors/bijective_tensor.py create mode 100644 flowtorch/bijectors/utils.py create mode 100644 tests/test_bijectivetensor.py diff --git a/examples/learn_bivariate_normal.py b/examples/learn_bivariate_normal.py index e33766fb..aee32cc6 100644 --- a/examples/learn_bivariate_normal.py +++ b/examples/learn_bivariate_normal.py @@ -22,7 +22,7 @@ def learn_bivariate_normal() -> None: # Lazily instantiated flow plus base and target distributions bijectors = bij.AffineAutoregressive( - params=params.DenseAutoregressive(hidden_dims=(32,)) + params_fn=params.DenseAutoregressive(hidden_dims=(32,)) ) base_dist = torch.distributions.Independent( torch.distributions.Normal(torch.zeros(2), torch.ones(2)), 1 diff --git a/flowtorch/bijectors/affine.py b/flowtorch/bijectors/affine.py index 42a58bc2..c3d635e2 100644 --- a/flowtorch/bijectors/affine.py +++ b/flowtorch/bijectors/affine.py @@ -16,7 +16,7 @@ class Affine(AffineOp, Elementwise): def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, @@ -24,7 +24,7 @@ def __init__( log_scale_max_clip: float = 3.0, sigmoid_bias: float = 2.0, ) -> None: - super().__init__(params, shape=shape, context_shape=context_shape) + super().__init__(params_fn, shape=shape, context_shape=context_shape) self.log_scale_min_clip = log_scale_min_clip self.log_scale_max_clip = log_scale_max_clip self.sigmoid_bias = sigmoid_bias diff --git a/flowtorch/bijectors/affine_autoregressive.py b/flowtorch/bijectors/affine_autoregressive.py index f4b89c54..610e5477 100644 --- a/flowtorch/bijectors/affine_autoregressive.py +++ b/flowtorch/bijectors/affine_autoregressive.py @@ -12,7 +12,7 @@ class AffineAutoregressive(AffineOp, Autoregressive): def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, @@ -21,7 +21,7 @@ def __init__( sigmoid_bias: float = 2.0, ) -> None: super().__init__( - params, + params_fn, shape=shape, context_shape=context_shape, ) diff --git a/flowtorch/bijectors/affine_fixed.py b/flowtorch/bijectors/affine_fixed.py index 023ef623..af916519 100644 --- a/flowtorch/bijectors/affine_fixed.py +++ b/flowtorch/bijectors/affine_fixed.py @@ -1,11 +1,12 @@ # Copyright (c) Meta Platforms, Inc import math -from typing import Optional +from typing import Optional, Sequence, Tuple import flowtorch import torch from flowtorch.bijectors.fixed import Fixed +from flowtorch.bijectors.utils import requires_log_detJ class AffineFixed(Fixed): @@ -18,36 +19,38 @@ class AffineFixed(Fixed): # TODO: Handle non-scalar loc and scale with correct broadcasting semantics def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, loc: float = 0.0, scale: float = 1.0 ) -> None: - super().__init__(params, shape=shape, context_shape=context_shape) + super().__init__(params_fn, shape=shape, context_shape=context_shape) self.loc = loc self.scale = scale def _forward( self, x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return self.loc + self.scale * x + params: Optional[Sequence[torch.Tensor]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = self.loc + self.scale * x + ladj: Optional[torch.Tensor] = None + if requires_log_detJ(): + ladj = self._log_abs_det_jacobian(x, y, params) + return y, ladj def _inverse( - self, - y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return (y - self.loc) / self.scale + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = (y - self.loc) / self.scale + ladj: Optional[torch.Tensor] = None + if requires_log_detJ(): + ladj = self._log_abs_det_jacobian(x, y, params) + return x, ladj def _log_abs_det_jacobian( - self, - x: torch.Tensor, - y: torch.Tensor, - context: Optional[torch.Tensor] = None, + self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> torch.Tensor: return torch.full_like(x, math.log(abs(self.scale))) diff --git a/flowtorch/bijectors/autoregressive.py b/flowtorch/bijectors/autoregressive.py index 967711cb..8367b51b 100644 --- a/flowtorch/bijectors/autoregressive.py +++ b/flowtorch/bijectors/autoregressive.py @@ -1,12 +1,14 @@ # Copyright (c) Meta Platforms, Inc -from typing import Any, cast, Optional +from typing import Any, cast, Optional, Sequence import flowtorch import flowtorch.parameters import torch import torch.distributions.constraints as constraints from flowtorch.bijectors.base import Bijector +from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor +from flowtorch.bijectors.utils import is_record_flow_graph_enabled from flowtorch.parameters.dense_autoregressive import DenseAutoregressive @@ -17,7 +19,7 @@ class Autoregressive(Bijector): def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, @@ -28,14 +30,14 @@ def __init__( self.codomain = constraints.independent(constraints.real, len(shape)) # currently only DenseAutoregressive has a `permutation` buffer - if not params: - params = DenseAutoregressive() # type: ignore + if not params_fn: + params_fn = DenseAutoregressive() # type: ignore # TODO: Replace P.DenseAutoregressive with P.Autoregressive # In the future there will be other autoregressive parameter classes - assert params is not None and issubclass(params.cls, DenseAutoregressive) + assert params_fn is not None and issubclass(params_fn.cls, DenseAutoregressive) - super().__init__(params, shape=shape, context_shape=context_shape) + super().__init__(params_fn, shape=shape, context_shape=context_shape) def inverse( self, @@ -45,25 +47,37 @@ def inverse( ) -> torch.Tensor: # TODO: Allow that context can have a batch shape assert context is None # or context.shape == (self._context_size,) - params = self.params - assert params is not None - + assert self._params_fn is not None + if self._check_bijective_y(y, context): + assert isinstance(y, BijectiveTensor) + return y.get_parent_from_bijector(self) x_new = torch.zeros_like(y) # NOTE: Inversion is an expensive operation that scales in the # dimension of the input permutation = ( - params.permutation + self._params_fn.permutation ) # TODO: type-safe named buffer (e.g. "permutation") access # TODO: Make permutation, inverse work for other event shapes + log_detJ: Optional[torch.Tensor] = None for idx in cast(torch.LongTensor, permutation): - x_new[..., idx] = self._inverse(y, x_new.clone(), context)[..., idx] + _params = self._params_fn(x_new.clone(), context=context) + x_temp, log_detJ = self._inverse(y, params=_params) + x_new[..., idx] = x_temp[..., idx] + # _log_detJ = out[1] + # log_detJ = _log_detJ + if is_record_flow_graph_enabled(): + x_new = to_bijective_tensor( + x_new, + y, + context=context, + bijector=self, + mode="inverse", + log_detJ=log_detJ, + ) return x_new def _log_abs_det_jacobian( - self, - x: torch.Tensor, - y: torch.Tensor, - context: Optional[torch.Tensor] = None, + self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> torch.Tensor: raise NotImplementedError diff --git a/flowtorch/bijectors/base.py b/flowtorch/bijectors/base.py index 17e55ebc..6a388c83 100644 --- a/flowtorch/bijectors/base.py +++ b/flowtorch/bijectors/base.py @@ -1,25 +1,30 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional, Sequence, Union +import warnings +from typing import Optional, Sequence, Tuple, Union, Callable, Iterator -import flowtorch -import flowtorch.distributions import flowtorch.parameters import torch import torch.distributions +from flowtorch.bijectors.bijective_tensor import to_bijective_tensor, BijectiveTensor +from flowtorch.bijectors.utils import is_record_flow_graph_enabled from flowtorch.parameters import Parameters from torch.distributions import constraints +ParamFnType = Callable[ + [Optional[torch.Tensor], Optional[torch.Tensor]], Optional[Sequence[torch.Tensor]] +] + class Bijector(metaclass=flowtorch.LazyMeta): codomain: constraints.Constraint = constraints.real domain: constraints.Constraint = constraints.real _shape: torch.Size _context_shape: Optional[torch.Size] - _params: Optional[Union[Parameters, torch.nn.ModuleList]] = None + _params_fn: Optional[Union[Parameters, torch.nn.ModuleList]] = None def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, @@ -37,19 +42,27 @@ def __init__( self._context_shape = context_shape # Instantiate parameters (tensor, hypernets, etc.) - if params is not None: + if params_fn is not None: param_shapes = self.param_shapes(shape) - self._params = params( # type: ignore + self._params_fn = params_fn( # type: ignore param_shapes, self._shape, self._context_shape ) - @property - def params(self) -> Optional[Union[Parameters, torch.nn.ModuleList]]: - return self._params - - @params.setter - def params(self, value: Optional[Union[Parameters, torch.nn.ModuleList]]) -> None: - self._params = value + def parameters(self) -> Iterator[torch.Tensor]: + assert self._params_fn is not None + if hasattr(self._params_fn, "parameters"): + for param in self._params_fn.parameters(): + yield param + + def _check_bijective_x( + self, x: torch.Tensor, context: Optional[torch.Tensor] + ) -> bool: + return ( + isinstance(x, BijectiveTensor) + and x.from_inverse() + and x.check_bijector(self) + and x.check_context(context) + ) def forward( self, @@ -58,18 +71,41 @@ def forward( ) -> torch.Tensor: # TODO: Allow that context can have a batch shape assert context is None # or context.shape == (self._context_size,) - return self._forward(x, context) + if self._check_bijective_x(x, context): + assert isinstance(x, BijectiveTensor) + return x.get_parent_from_bijector(self) + + params = self._params_fn(x, context) if self._params_fn is not None else None + y, log_detJ = self._forward(x, params) + if ( + is_record_flow_graph_enabled() + and not isinstance(y, BijectiveTensor) + and not (isinstance(x, BijectiveTensor) and y in set(x.parents())) + ): + # we exclude y that are bijective tensors for Compose + y = to_bijective_tensor(x, y, context, self, log_detJ, mode="forward") + return y def _forward( self, x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + params: Optional[Sequence[torch.Tensor]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Abstract method to compute forward transformation. """ raise NotImplementedError + def _check_bijective_y( + self, y: torch.Tensor, context: Optional[torch.Tensor] + ) -> bool: + return ( + isinstance(y, BijectiveTensor) + and y.from_forward() + and y.check_bijector(self) + and y.check_context(context) + ) + def inverse( self, y: torch.Tensor, @@ -78,14 +114,27 @@ def inverse( ) -> torch.Tensor: # TODO: Allow that context can have a batch shape assert context is None # or context.shape == (self._context_size,) - return self._inverse(y, x, context) + if self._check_bijective_y(y, context): + assert isinstance(y, BijectiveTensor) + return y.get_parent_from_bijector(self) + + # TODO: What to do in this line? + params = self._params_fn(x, context) if self._params_fn is not None else None + x, log_detJ = self._inverse(y, params) + + if ( + is_record_flow_graph_enabled() + and not isinstance(x, BijectiveTensor) + and not (isinstance(y, BijectiveTensor) and x in set(y.parents())) + ): + x = to_bijective_tensor(x, y, context, self, log_detJ, mode="inverse") + return x def _inverse( self, y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + params: Optional[Sequence[torch.Tensor]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Abstract method to compute inverse transformation. """ @@ -101,13 +150,39 @@ def log_abs_det_jacobian( Computes the log det jacobian `log |dy/dx|` given input and output. By default, assumes a volume preserving bijection. """ - return self._log_abs_det_jacobian(x, y, context) + # TODO: Allow that context can have a batch shape + assert context is None # or context.shape == (self._context_size,) + ladj = None + if ( + isinstance(y, BijectiveTensor) + and y.from_forward() + and y.check_bijector(self) + and y.check_context(context) + ): + ladj = y.log_detJ + elif ( + isinstance(x, BijectiveTensor) + and x.from_inverse() + and x.check_bijector(self) + and x.check_context(context) + ): + ladj = x.log_detJ + if ladj is None: + if is_record_flow_graph_enabled(): + warnings.warn( + "Computing _log_abs_det_jacobian from values and not from cache." + ) + params = ( + self._params_fn(x, context) if self._params_fn is not None else None + ) + return self._log_abs_det_jacobian(x, y, params) + return ladj def _log_abs_det_jacobian( self, x: torch.Tensor, y: torch.Tensor, - context: Optional[torch.Tensor] = None, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: """ Computes the log det jacobian `log |dy/dx|` given input and output. diff --git a/flowtorch/bijectors/bijective_tensor.py b/flowtorch/bijectors/bijective_tensor.py new file mode 100644 index 00000000..8a3e6338 --- /dev/null +++ b/flowtorch/bijectors/bijective_tensor.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc +from typing import Any, Optional, Iterator, Type, TYPE_CHECKING, Union + +if TYPE_CHECKING: + from flowtorch.bijectors.base import Bijector + +from torch import Tensor + + +class BijectiveTensor(Tensor): + def __repr__(self) -> str: + r_str = ( + super(BijectiveTensor, self) + .__repr__() + .replace("tensor", "bijective_tensor") + ) + return r_str + + def register( + self, + input: Tensor, + output: Tensor, + context: Optional[Tensor], + bijector: "Bijector", + log_detJ: Optional[Tensor], + mode: str, + ) -> "BijectiveTensor": + self._input = input + self._output = output + self._context = context + self._bijector = bijector + self._log_detJ = log_detJ + self._mode = mode + + if not (self.from_forward() or self.from_inverse()): + raise RuntimeError( + f"BijectiveTensor mode must be either `'forward'` \ +or `'inverse'`. got {self._mode}" + ) + + return self + + @classmethod + def __torch_function__( + cls: Type["BijectiveTensor"], + func: Any, + types: Any, + args: Any = (), + kwargs: Any = None, + ) -> Union[Any, Tensor]: + if kwargs is None: + kwargs = {} + # we don't want to create a new BijectiveTensor when summing, + # calling zeros_like etc. + types = tuple(Tensor if _type is BijectiveTensor else _type for _type in types) + return Tensor.__torch_function__(func, types, args, kwargs) + + def check_bijector(self, bijector: "Bijector") -> bool: + is_bijector = bijector in tuple(self.bijectors()) + return is_bijector + + def bijectors(self) -> Iterator["Bijector"]: + yield self._bijector + for parent in self.parents(): + if isinstance(parent, BijectiveTensor): + yield parent._bijector + + def get_parent_from_bijector(self, bijector: "Bijector") -> Tensor: + if self._bijector is bijector: + return self.parent + for parent in self.parents(): + if not isinstance(parent, BijectiveTensor): + break + if parent._bijector is bijector: + return parent.parent + raise RuntimeError("bijector not found in flow") + + def check_context(self, context: Optional[Tensor]) -> bool: + return self._context is context + + def from_forward(self) -> bool: + return self._mode == "forward" + + def from_inverse(self) -> bool: + return self._mode == "inverse" + + def detach_from_flow(self) -> Tensor: + detached_tensor = self._output if self.from_forward() else self._input + if isinstance(detached_tensor, BijectiveTensor): + raise RuntimeError("the detached tensor is an instance of BijectiveTensor.") + return detached_tensor + + def has_ancestor(self, tensor: Tensor) -> bool: + if tensor is self: + return False # self is no parent of self + elif self.from_forward() and self._input is tensor: + return True + elif self.from_inverse() and self._output is tensor: + return True + elif self.from_forward() and isinstance(self._input, BijectiveTensor): + return self._input.has_ancestor(tensor) + elif self.from_inverse() and isinstance(self._output, BijectiveTensor): + return self._output.has_ancestor(tensor) + else: + return False + + @property + def log_detJ(self) -> Optional[Tensor]: + return self._log_detJ + + @property + def parent(self) -> Tensor: + if self.from_forward(): + return self._input + else: + return self._output + + def parents(self) -> Iterator[Tensor]: + child: Union[Tensor, BijectiveTensor] = self + while True: + assert isinstance(child, BijectiveTensor) + child = parent = child.parent + yield parent + if not isinstance(child, BijectiveTensor): + break + + +def to_bijective_tensor( + x: Tensor, + y: Tensor, + context: Optional[Tensor], + bijector: "Bijector", + log_detJ: Optional[Tensor], + mode: str = "forward", +) -> BijectiveTensor: + if mode == "inverse": + x_bij = BijectiveTensor(x) + x_bij.register(x, y, context, bijector, log_detJ, mode=mode) + return x_bij + elif mode == "forward": + y_bij = BijectiveTensor(y) + y_bij.register(x, y, context, bijector, log_detJ, mode=mode) + return y_bij + else: + raise NotImplementedError( + f"mode {mode} is not supported, must be one of 'forward' or 'inverse'." + ) diff --git a/flowtorch/bijectors/compose.py b/flowtorch/bijectors/compose.py index b2b1cfad..5bc13317 100644 --- a/flowtorch/bijectors/compose.py +++ b/flowtorch/bijectors/compose.py @@ -1,12 +1,12 @@ # Copyright (c) Meta Platforms, Inc +from typing import Optional, Sequence, Iterator -from typing import Optional, Sequence - -import flowtorch import flowtorch.parameters import torch import torch.distributions from flowtorch.bijectors.base import Bijector +from flowtorch.bijectors.bijective_tensor import to_bijective_tensor, BijectiveTensor +from flowtorch.bijectors.utils import is_record_flow_graph_enabled, requires_log_detJ from torch.distributions.utils import _sum_rightmost @@ -31,24 +31,43 @@ def __init__( self.domain = self.bijectors[0].domain # type: ignore self.codomain = self.bijectors[-1].codomain # type: ignore - # Make parameters accessible to dist.Flow - self._params = torch.nn.ModuleList( - [ - b._params # type: ignore - for b in self.bijectors - if isinstance(b._params, torch.nn.Module) # type: ignore - ] - ) - self._context_shape = context_shape - # NOTE: We overwrite forward rather than _forward so that the composed - # bijectors can handle the caching separately! - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + def parameters(self) -> Iterator[torch.Tensor]: + for b in self.bijectors: + for param in b.parameters(): # type: ignore + yield param + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + log_detJ: Optional[torch.Tensor] = None + x_temp = x for bijector in self.bijectors: - x = bijector.forward(x, context) # type: ignore + y = bijector.forward(x_temp, context) # type: ignore + if is_record_flow_graph_enabled() and requires_log_detJ(): + if isinstance(y, BijectiveTensor) and y.from_forward(): + _log_detJ = y._log_detJ + elif isinstance(x_temp, BijectiveTensor) and x_temp.from_inverse(): + _log_detJ = x_temp._log_detJ + else: + raise RuntimeError( + "neither of x nor y contains the log-abs-det-jacobian" + ) + log_detJ = log_detJ + _log_detJ if log_detJ is not None else _log_detJ + x_temp = y - return x + # TODO: Check that this doesn't contain bugs! + if ( + is_record_flow_graph_enabled() + and not isinstance(y, BijectiveTensor) + and not (isinstance(x, BijectiveTensor) and y in set(x.parents())) + ): + # we exclude y that are bijective tensors for Compose + y = to_bijective_tensor(x, x_temp, context, self, log_detJ, mode="forward") + return y def inverse( self, @@ -56,10 +75,30 @@ def inverse( x: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, ) -> torch.Tensor: + log_detJ: Optional[torch.Tensor] = None + y_temp = y for bijector in reversed(self.bijectors): - y = bijector.inverse(y, x, context) # type: ignore + x = bijector.inverse(y_temp, context) # type: ignore + if is_record_flow_graph_enabled() and requires_log_detJ(): + if isinstance(y_temp, BijectiveTensor) and y_temp.from_forward(): + _log_detJ = y_temp._log_detJ + elif isinstance(x, BijectiveTensor) and x.from_inverse(): + _log_detJ = x._log_detJ + else: + raise RuntimeError( + "neither of x nor y contains the log-abs-det-jacobian" + ) + log_detJ = log_detJ + _log_detJ if log_detJ is not None else _log_detJ + y_temp = x # type: ignore - return y + # TODO: Check that this doesn't contain bugs! + if ( + is_record_flow_graph_enabled() + and not isinstance(x, BijectiveTensor) + and not (isinstance(y, BijectiveTensor) and x in set(y.parents())) + ): + x = to_bijective_tensor(y_temp, y, context, self, log_detJ, mode="inverse") + return x # type: ignore def log_abs_det_jacobian( self, x: torch.Tensor, y: torch.Tensor, context: torch.Tensor = None @@ -72,8 +111,24 @@ def log_abs_det_jacobian( torch.zeros_like(y), self.domain.event_dim, ) + + if isinstance(x, BijectiveTensor) and x.has_ancestor(y): + # If x is a BijectiveTensor and has y as ancestor, then the + # inversion flow.inverse(y) = x has already been computed and + # we can recover the chain of parents instead of re-computing it. + _use_cached_inverse = True + parents = [] + while isinstance(x, BijectiveTensor) and x is not y: + parents.append(x) + x = x.parent + else: + _use_cached_inverse = False + for bijector in reversed(self.bijectors): - y_inv = bijector.inverse(y, context) # type: ignore + if not _use_cached_inverse: + y_inv = bijector.inverse(y, context) # type: ignore + else: + y_inv = parents.pop() ldj += bijector.log_abs_det_jacobian(y_inv, y, context) # type: ignore y = y_inv return ldj diff --git a/flowtorch/bijectors/elementwise.py b/flowtorch/bijectors/elementwise.py index d2a505d1..67ca5dcc 100644 --- a/flowtorch/bijectors/elementwise.py +++ b/flowtorch/bijectors/elementwise.py @@ -11,15 +11,15 @@ class Elementwise(Bijector): def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, **kwargs: Any ) -> None: - if not params: - params = Tensor() # type: ignore + if not params_fn: + params_fn = Tensor() # type: ignore - assert params is None or issubclass(params.cls, Tensor) + assert params_fn is None or issubclass(params_fn.cls, Tensor) - super().__init__(params, shape=shape, context_shape=context_shape) + super().__init__(params_fn, shape=shape, context_shape=context_shape) diff --git a/flowtorch/bijectors/elu.py b/flowtorch/bijectors/elu.py index 5976cd13..ac5a3494 100644 --- a/flowtorch/bijectors/elu.py +++ b/flowtorch/bijectors/elu.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional +from typing import Optional, Sequence, Tuple import torch import torch.distributions.constraints as constraints @@ -15,26 +15,22 @@ class ELU(Fixed): # TODO: Setting the alpha value of ELU as __init__ argument def _forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return F.elu(x) + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = F.elu(x) + ladj = self._log_abs_det_jacobian(x, y, params) + return y, ladj def _inverse( - self, - y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return torch.max(y, torch.zeros_like(y)) + torch.min( + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = torch.max(y, torch.zeros_like(y)) + torch.min( torch.log1p(y + eps), torch.zeros_like(y) ) + ladj = self._log_abs_det_jacobian(x, y, params) + return x, ladj def _log_abs_det_jacobian( - self, - x: torch.Tensor, - y: torch.Tensor, - context: Optional[torch.Tensor] = None, + self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> torch.Tensor: return -F.relu(-x) diff --git a/flowtorch/bijectors/exp.py b/flowtorch/bijectors/exp.py index e475ca03..2856d312 100644 --- a/flowtorch/bijectors/exp.py +++ b/flowtorch/bijectors/exp.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional +from typing import Optional, Sequence, Tuple import torch import torch.distributions.constraints as constraints @@ -14,24 +14,20 @@ class Exp(Fixed): codomain = constraints.positive def _forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return torch.exp(x) + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = torch.exp(x) + ladj = self._log_abs_det_jacobian(x, y, params) + return y, ladj def _inverse( - self, - y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return y.log() + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = y.log() + ladj = self._log_abs_det_jacobian(x, y, params) + return x, ladj def _log_abs_det_jacobian( - self, - x: torch.Tensor, - y: torch.Tensor, - context: Optional[torch.Tensor] = None, + self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> torch.Tensor: return x diff --git a/flowtorch/bijectors/fixed.py b/flowtorch/bijectors/fixed.py index 80b4a871..485c52de 100644 --- a/flowtorch/bijectors/fixed.py +++ b/flowtorch/bijectors/fixed.py @@ -10,15 +10,15 @@ class Fixed(Bijector): def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, ) -> None: # TODO: In the future, make Fixed actually mean that there is no autograd # through params - super().__init__(params, shape=shape, context_shape=context_shape) - assert params is None + super().__init__(params_fn, shape=shape, context_shape=context_shape) + assert params_fn is None def param_shapes(self, shape: torch.Size) -> Sequence[torch.Size]: """ diff --git a/flowtorch/bijectors/leaky_relu.py b/flowtorch/bijectors/leaky_relu.py index 256055de..79ce58f4 100644 --- a/flowtorch/bijectors/leaky_relu.py +++ b/flowtorch/bijectors/leaky_relu.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc import math -from typing import Optional +from typing import Optional, Sequence, Tuple import torch import torch.nn.functional as F @@ -12,25 +12,21 @@ class LeakyReLU(Fixed): # TODO: Setting the slope of Leaky ReLU as __init__ argument def _forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return F.leaky_relu(x) + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = F.leaky_relu(x) + ladj = self._log_abs_det_jacobian(x, y, params) + return y, ladj def _inverse( - self, - y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return F.leaky_relu(y, negative_slope=100.0) + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = F.leaky_relu(y, negative_slope=100.0) + ladj = self._log_abs_det_jacobian(x, y, params) + return x, ladj def _log_abs_det_jacobian( - self, - x: torch.Tensor, - y: torch.Tensor, - context: Optional[torch.Tensor] = None, + self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> torch.Tensor: return torch.where( x >= 0.0, torch.zeros_like(x), torch.ones_like(x) * math.log(0.01) diff --git a/flowtorch/bijectors/ops/affine.py b/flowtorch/bijectors/ops/affine.py index 3f2a2494..d9cdf56f 100644 --- a/flowtorch/bijectors/ops/affine.py +++ b/flowtorch/bijectors/ops/affine.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional, Tuple +from typing import Optional, Sequence, Tuple import flowtorch import torch @@ -18,7 +18,7 @@ class Affine(Bijector): def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, @@ -26,55 +26,46 @@ def __init__( log_scale_max_clip: float = 3.0, sigmoid_bias: float = 2.0, ) -> None: - super().__init__(params, shape=shape, context_shape=context_shape) + super().__init__(params_fn, shape=shape, context_shape=context_shape) self.log_scale_min_clip = log_scale_min_clip self.log_scale_max_clip = log_scale_max_clip self.sigmoid_bias = sigmoid_bias def _forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - params = self.params + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, torch.Tensor]: assert params is not None - mean, log_scale = params(x, context=context) + mean, log_scale = params log_scale = clamp_preserve_gradients( log_scale, self.log_scale_min_clip, self.log_scale_max_clip ) scale = torch.exp(log_scale) y = scale * x + mean - return y + return y, _sum_rightmost(log_scale, self.domain.event_dim) def _inverse( - self, - y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - params = self.params + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, torch.Tensor]: assert params is not None - mean, log_scale = params(x, context=context) + mean, log_scale = params log_scale = clamp_preserve_gradients( log_scale, self.log_scale_min_clip, self.log_scale_max_clip ) inverse_scale = torch.exp(-log_scale) x_new = (y - mean) * inverse_scale - return x_new + return x_new, _sum_rightmost(log_scale, self.domain.event_dim) def _log_abs_det_jacobian( self, x: torch.Tensor, y: torch.Tensor, - context: Optional[torch.Tensor] = None, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: - params = self.params assert params is not None - # Note: params will take care of caching "mean, log_scale = params(x)" - _, log_scale = params(x, context=context) + _, log_scale = params log_scale = clamp_preserve_gradients( log_scale, self.log_scale_min_clip, self.log_scale_max_clip ) diff --git a/flowtorch/bijectors/ops/spline.py b/flowtorch/bijectors/ops/spline.py index 0336e8e4..687d1bac 100644 --- a/flowtorch/bijectors/ops/spline.py +++ b/flowtorch/bijectors/ops/spline.py @@ -21,7 +21,7 @@ class Spline(Bijector): def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, @@ -41,51 +41,47 @@ def __init__( self.bound = bound self.order = order - super().__init__(params, shape=shape, context_shape=context_shape) + super().__init__(params_fn, shape=shape, context_shape=context_shape) def _forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - y, _ = self._op(x, x, context) - return y + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + y, log_detJ = self._op(x, params) + return y, _sum_rightmost(log_detJ, self.domain.event_dim) def _inverse( - self, - y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - x_new, _ = self._op(y, x, context=context, inverse=True) - return x_new + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + x_new, log_detJ = self._op(y, params, inverse=True) + + # TODO: Should I invert the sign of log_detJ? + # TODO: A unit test that compares log_detJ from _forward and _inverse + return x_new, _sum_rightmost(log_detJ, self.domain.event_dim) def _log_abs_det_jacobian( self, x: torch.Tensor, y: torch.Tensor, - context: Optional[torch.Tensor] = None, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: - _, log_detJ = self._op(x, x, context) + _, log_detJ = self._op(x, params) return _sum_rightmost(log_detJ, self.domain.event_dim) def _op( self, input: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, + params: Optional[Sequence[torch.Tensor]], inverse: bool = False, **kwargs: Any ) -> Tuple[torch.Tensor, torch.Tensor]: - params = self.params assert params is not None + lambdas: Optional[torch.Tensor] = None if self.order == "linear": - widths, heights, derivatives, lambdas = params(x, context=context) + widths, heights, derivatives, lambdas = params lambdas = torch.sigmoid(lambdas) else: - widths, heights, derivatives = params(x, context=context) - lambdas = None + widths, heights, derivatives = params # Constrain parameters # TODO: Move to flowtorch.ops function? diff --git a/flowtorch/bijectors/permute.py b/flowtorch/bijectors/permute.py index 631ba448..f371c9af 100644 --- a/flowtorch/bijectors/permute.py +++ b/flowtorch/bijectors/permute.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional +from typing import Optional, Sequence, Tuple import flowtorch import torch @@ -17,35 +17,34 @@ class Permute(Fixed, VolumePreserving): # TODO: A new abstraction so can defer construction of permutation def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, permutation: Optional[torch.Tensor] = None ) -> None: - super().__init__(params, shape=shape, context_shape=context_shape) + super().__init__(params_fn, shape=shape, context_shape=context_shape) self.permutation = permutation def _forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.permutation is None: self.permutation = torch.randperm(x.shape[-1]) - return torch.index_select(x, -1, self.permutation) + y = torch.index_select(x, -1, self.permutation) + ladj = self._log_abs_det_jacobian(x, y, params) + return y, ladj def _inverse( - self, - y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.permutation is None: self.permutation = torch.randperm(y.shape[-1]) - return torch.index_select(y, -1, self.inv_permutation) + x = torch.index_select(y, -1, self.inv_permutation) + ladj = self._log_abs_det_jacobian(x, y, params) + return x, ladj @lazy_property def inv_permutation(self) -> Optional[torch.Tensor]: diff --git a/flowtorch/bijectors/power.py b/flowtorch/bijectors/power.py index 445efe9b..0aea101e 100644 --- a/flowtorch/bijectors/power.py +++ b/flowtorch/bijectors/power.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional +from typing import Optional, Sequence, Tuple import flowtorch import torch @@ -18,34 +18,35 @@ class Power(Fixed): # TODO: Tensor valued exponents and corresponding determination of event_dim def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, exponent: float = 2.0, ) -> None: - super().__init__(params, shape=shape, context_shape=context_shape) + super().__init__(params_fn, shape=shape, context_shape=context_shape) self.exponent = exponent def _forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return x.pow(self.exponent) + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = x.pow(self.exponent) + ladj = self._log_abs_det_jacobian(x, y, params) + return y, ladj def _inverse( self, y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return y.pow(1 / self.exponent) + params: Optional[Sequence[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = y.pow(1 / self.exponent) + ladj = self._log_abs_det_jacobian(x, y, params) + return x, ladj def _log_abs_det_jacobian( self, x: torch.Tensor, y: torch.Tensor, - context: Optional[torch.Tensor] = None, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: return torch.abs(self.exponent * y / x).log() diff --git a/flowtorch/bijectors/sigmoid.py b/flowtorch/bijectors/sigmoid.py index da480984..e9fccaec 100644 --- a/flowtorch/bijectors/sigmoid.py +++ b/flowtorch/bijectors/sigmoid.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional +from typing import Optional, Sequence, Tuple import torch import torch.distributions.constraints as constraints @@ -13,26 +13,22 @@ class Sigmoid(Fixed): codomain = constraints.unit_interval def _forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return clipped_sigmoid(x) + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = clipped_sigmoid(x) + ladj = self._log_abs_det_jacobian(x, y, params) + return y, ladj def _inverse( - self, - y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: finfo = torch.finfo(y.dtype) y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps) - return y.log() - torch.log1p(-y) + x = y.log() - torch.log1p(-y) + ladj = self._log_abs_det_jacobian(x, y, params) + return x, ladj def _log_abs_det_jacobian( - self, - x: torch.Tensor, - y: torch.Tensor, - context: Optional[torch.Tensor] = None, + self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> torch.Tensor: return -F.softplus(-x) - F.softplus(x) diff --git a/flowtorch/bijectors/softplus.py b/flowtorch/bijectors/softplus.py index a2401f95..5633a3dd 100644 --- a/flowtorch/bijectors/softplus.py +++ b/flowtorch/bijectors/softplus.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional +from typing import Optional, Sequence, Tuple import flowtorch.ops import torch @@ -16,24 +16,20 @@ class Softplus(Fixed): codomain = constraints.positive def _forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return F.softplus(x) + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = F.softplus(x) + ladj = self._log_abs_det_jacobian(x, y, params) + return y, ladj def _inverse( - self, - y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return flowtorch.ops.softplus_inv(y) + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = flowtorch.ops.softplus_inv(y) + ladj = self._log_abs_det_jacobian(x, y, params) + return x, ladj def _log_abs_det_jacobian( - self, - x: torch.Tensor, - y: torch.Tensor, - context: Optional[torch.Tensor] = None, + self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> torch.Tensor: return -F.softplus(-x) diff --git a/flowtorch/bijectors/spline.py b/flowtorch/bijectors/spline.py index 409a04b1..ae407795 100644 --- a/flowtorch/bijectors/spline.py +++ b/flowtorch/bijectors/spline.py @@ -11,7 +11,7 @@ class Spline(SplineOp, Elementwise): def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, @@ -20,7 +20,7 @@ def __init__( order: str = "linear" ) -> None: super().__init__( - params, + params_fn, shape=shape, context_shape=context_shape, count_bins=count_bins, diff --git a/flowtorch/bijectors/spline_autoregressive.py b/flowtorch/bijectors/spline_autoregressive.py index 06307315..f0850095 100644 --- a/flowtorch/bijectors/spline_autoregressive.py +++ b/flowtorch/bijectors/spline_autoregressive.py @@ -12,7 +12,7 @@ class SplineAutoregressive(SplineOp, Autoregressive): def __init__( self, - params: Optional[flowtorch.Lazy] = None, + params_fn: Optional[flowtorch.Lazy] = None, *, shape: torch.Size, context_shape: Optional[torch.Size] = None, @@ -21,7 +21,7 @@ def __init__( order: str = "linear" ) -> None: super().__init__( - params, + params_fn, shape=shape, context_shape=context_shape, count_bins=count_bins, diff --git a/flowtorch/bijectors/tanh.py b/flowtorch/bijectors/tanh.py index 52d86839..5aae732e 100644 --- a/flowtorch/bijectors/tanh.py +++ b/flowtorch/bijectors/tanh.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc import math -from typing import Optional +from typing import Optional, Sequence, Tuple import torch import torch.distributions.constraints as constraints @@ -16,24 +16,20 @@ class Tanh(Fixed): codomain = constraints.interval(-1.0, 1.0) def _forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return torch.tanh(x) + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = torch.tanh(x) + ladj = self._log_abs_det_jacobian(x, y, params) + return y, ladj def _inverse( - self, - y: torch.Tensor, - x: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return torch.atanh(y) + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = torch.atanh(y) + ladj = self._log_abs_det_jacobian(x, y, params) + return x, ladj def _log_abs_det_jacobian( - self, - x: torch.Tensor, - y: torch.Tensor, - context: Optional[torch.Tensor] = None, + self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> torch.Tensor: return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x)) diff --git a/flowtorch/bijectors/utils.py b/flowtorch/bijectors/utils.py new file mode 100644 index 00000000..376751f1 --- /dev/null +++ b/flowtorch/bijectors/utils.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc +import functools +from typing import Any, Callable, List, Sequence + +_RECORD_FLOW = True + + +class _context_manager: + def __init__(self, value: bool = True) -> None: + self.value = value + self.prev: List[bool] = [] + + def __call__(self, func: Callable) -> Any: + @functools.wraps(func) + def decorate_context(*args: Any, **kwargs: Sequence[Any]) -> Any: + with self: + return func(*args, **kwargs) + + return decorate_context + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + pass + + +class set_record_flow_graph(_context_manager): + def __enter__(self) -> None: + global _RECORD_FLOW + self.prev.append(_RECORD_FLOW) + _RECORD_FLOW = self.value + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + global _RECORD_FLOW + _RECORD_FLOW = self.prev.pop() + + +def is_record_flow_graph_enabled() -> bool: + return _RECORD_FLOW + + +_REQUIRES_LOG_DETJ = True + + +class set_requires_log_detJ(_context_manager): + def __enter__(self) -> None: + global _REQUIRES_LOG_DETJ + self.prev.append(_REQUIRES_LOG_DETJ) + _REQUIRES_LOG_DETJ = self.value + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + global _REQUIRES_LOG_DETJ + _REQUIRES_LOG_DETJ = self.prev.pop() + + +def requires_log_detJ() -> bool: + return _REQUIRES_LOG_DETJ diff --git a/flowtorch/bijectors/volume_preserving.py b/flowtorch/bijectors/volume_preserving.py index 04d45d99..4f8b6bd7 100644 --- a/flowtorch/bijectors/volume_preserving.py +++ b/flowtorch/bijectors/volume_preserving.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional +from typing import Optional, Sequence import torch import torch.distributions @@ -9,10 +9,7 @@ class VolumePreserving(Bijector): def _log_abs_det_jacobian( - self, - x: torch.Tensor, - y: torch.Tensor, - context: Optional[torch.Tensor] = None, + self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> torch.Tensor: # TODO: Confirm that this should involve `x`/`self.domain` and not # `y`/`self.codomain` diff --git a/flowtorch/distributions/flow.py b/flowtorch/distributions/flow.py index ffb57d96..bfb0e97d 100644 --- a/flowtorch/distributions/flow.py +++ b/flowtorch/distributions/flow.py @@ -1,12 +1,13 @@ # Copyright (c) Meta Platforms, Inc -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, Iterator import flowtorch import torch import torch.distributions as dist from torch import Tensor from torch.distributions.utils import _sum_rightmost +from torch.nn import Parameter class Flow(torch.nn.Module, dist.Distribution, metaclass=flowtorch.LazyMeta): @@ -26,7 +27,7 @@ def __init__( self.bijector = bijector(shape=base_dist.event_shape) # Required so that parameters are registered with nn.Module - self.params = self.bijector._params # type: ignore + self.params = self.bijector.parameters() # type: ignore # TODO: Confirm that the following logic works. Shouldn't it use # .domain and .codomain?? Infer shape from constructed self.bijector @@ -42,6 +43,13 @@ def __init__( self, batch_shape, event_shape, validate_args=validate_args ) + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + for p in super().parameters(recurse=recurse): + yield p + if recurse: + for p in self.bijector.parameters(): # type: ignore + yield p + def condition(self, context: torch.Tensor) -> "Flow": self._context = context return self diff --git a/flowtorch/parameters/base.py b/flowtorch/parameters/base.py index e28aef1d..72e4b69f 100644 --- a/flowtorch/parameters/base.py +++ b/flowtorch/parameters/base.py @@ -26,7 +26,7 @@ def forward( self, x: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, - ) -> Sequence[torch.Tensor]: + ) -> Optional[Sequence[torch.Tensor]]: # TODO: Caching etc. return self._forward(x, context) @@ -34,7 +34,7 @@ def _forward( self, x: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, - ) -> Sequence[torch.Tensor]: + ) -> Optional[Sequence[torch.Tensor]]: # I raise an exception rather than using @abstractmethod and # metaclass=ABC so that we can reserve the metaclass for lazy # evaluation. diff --git a/flowtorch/parameters/dense_autoregressive.py b/flowtorch/parameters/dense_autoregressive.py index f6104797..8110e5a6 100644 --- a/flowtorch/parameters/dense_autoregressive.py +++ b/flowtorch/parameters/dense_autoregressive.py @@ -147,7 +147,7 @@ def _forward( self, x: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, - ) -> Sequence[torch.Tensor]: + ) -> Optional[Sequence[torch.Tensor]]: assert x is not None # Flatten x diff --git a/flowtorch/parameters/tensor.py b/flowtorch/parameters/tensor.py index c098136a..3de8680a 100644 --- a/flowtorch/parameters/tensor.py +++ b/flowtorch/parameters/tensor.py @@ -23,5 +23,5 @@ def __init__( def _forward( self, x: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None - ) -> Sequence[torch.Tensor]: + ) -> Optional[Sequence[torch.Tensor]]: return list(self.params) diff --git a/setup.py b/setup.py index 4959b0c9..479bdcac 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ "flake8-bugbear", "mypy", "toml", - "usort", + "usort==0.6.4", ] diff --git a/tests/test_bijectivetensor.py b/tests/test_bijectivetensor.py new file mode 100644 index 00000000..72bbdf70 --- /dev/null +++ b/tests/test_bijectivetensor.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc +import time + +import flowtorch.parameters as params +import pytest +import torch +from flowtorch.bijectors import AffineAutoregressive, Compose +from flowtorch.bijectors.utils import set_record_flow_graph + +dim_x = 32 + + +def get_net() -> AffineAutoregressive: + ar = Compose( + [ + AffineAutoregressive(params.DenseAutoregressive()), + AffineAutoregressive(params.DenseAutoregressive()), + AffineAutoregressive(params.DenseAutoregressive()), + ] + ) + ar = ar( + shape=torch.Size( + [ + dim_x, + ] + ) + ) + return ar + + +def test_forward(): + ar = get_net() + x = torch.randn(50, dim_x, requires_grad=True) + y = ar.forward(x) + assert ar.inverse(y) is x + assert ar.forward(y) is not x + + with set_record_flow_graph(False): + y = ar.forward(x) + assert ar.inverse(y) is not x + assert ar.forward(y) is not x + + +def test_backward(): + ar = get_net() + x = torch.randn(50, dim_x, requires_grad=True) + y = ar.inverse(x) + assert ar.forward(y) is x + assert ar.inverse(y) is not x + + with set_record_flow_graph(False): + y = ar.inverse(x) + assert ar.forward(y) is not x + assert ar.inverse(y) is not x + + +@pytest.mark.parametrize("mode", ["forward", "inverse"]) +def test_gradient_matching(mode): + ar = get_net() + + print("test with bijective tensor") + t0 = time.time() + with set_record_flow_graph(True): + x = torch.randn(50, dim_x, requires_grad=True) + t1 = time.time() + if mode == "forward": + y = ar.forward(x) + xinv = ar.inverse(y) + ldj = ar.log_abs_det_jacobian(x, y).sum() + else: + y = ar.inverse(x) + xinv = ar.forward(y) + ldj = ar.log_abs_det_jacobian(y, x).sum() + assert xinv is x + print(f"op with bij tensor took {time.time() - t1} for mode={mode}") + ldj.backward() + g_bijtensor = x.grad.clone() + bij_time = time.time() - t0 + print("bij tensor time: ", bij_time) + + print("test with regular tensor") + t0 = time.time() + with set_record_flow_graph(False): + x.grad = None + t1 = time.time() + if mode == "forward": + y = ar.forward(x) + xinv = ar.inverse(y) + ldj = ar.log_abs_det_jacobian(x, y).sum() + else: + y = ar.inverse(x) + xinv = ar.forward(y) + ldj = ar.log_abs_det_jacobian(y, x).sum() + assert xinv is not x + print(f"op with regular tensor took {time.time() - t1} for mode={mode}") + ldj.backward() + g_tensor = x.grad.clone() + tensor_time = time.time() - t0 + print("regular tensor time: ", tensor_time) + + print("diff between grads: ", (g_bijtensor - g_tensor).norm(2)) + torch.testing.assert_allclose(g_bijtensor, g_tensor) + + # This is flacky and should probably not be merged, but it's a good + # soundness check locally + assert bij_time < tensor_time, f"Bijective tensor {mode}+backprop took longer" + + +if __name__ == "__main__": + pytest.main([__file__, "--capture", "no"]) diff --git a/tests/test_bijector.py b/tests/test_bijector.py index b023c091..adb4b68f 100644 --- a/tests/test_bijector.py +++ b/tests/test_bijector.py @@ -31,7 +31,7 @@ def flow(request): def test_jacobian(flow, epsilon=1e-2): # Instantiate transformed distribution and parameters bij = flow.bijector - params = bij.params + params = bij._params_fn # Calculate auto-diff Jacobian x = torch.randn(*flow.event_shape) diff --git a/tests/test_distribution.py b/tests/test_distribution.py index 3c628bac..db7c9095 100644 --- a/tests/test_distribution.py +++ b/tests/test_distribution.py @@ -29,7 +29,7 @@ def make_tdist(): def test_neals_funnel_vi(): torch.manual_seed(42) nf = dist.NealsFunnel() - bijector = bijs.AffineAutoregressive(params=params.DenseAutoregressive()) + bijector = bijs.AffineAutoregressive(params_fn=params.DenseAutoregressive()) base_dist = torch.distributions.Independent( torch.distributions.Normal(torch.zeros(2), torch.ones(2)), 1 @@ -41,8 +41,8 @@ def test_neals_funnel_vi(): num_elbo_mc_samples = 200 for _ in range(100): z0 = flow.base_dist.rsample(sample_shape=(num_elbo_mc_samples,)) - zk = bijector._forward(z0) - ldj = bijector._log_abs_det_jacobian(z0, zk) + zk = bijector.forward(z0) + ldj = zk._log_detJ neg_elbo = -nf.log_prob(zk).sum() neg_elbo += flow.base_dist.log_prob(z0).sum() - ldj.sum() diff --git a/website/docs/users/conditional.mdx b/website/docs/users/conditional.mdx index 0cc20e48..14112666 100644 --- a/website/docs/users/conditional.mdx +++ b/website/docs/users/conditional.mdx @@ -16,7 +16,7 @@ $$ where $\mathbf{z}$ is the latent variable and $\mathbf{x}$ the observed one, that hopefully contains a member close to the true posterior of the model, $p(\mathbf{z}\mid\mathbf{x})$. In other cases, we may wish to learn to generate an object $\mathbf{x}$ conditioned on some context $\mathbf{c}$ using $p_\theta(\mathbf{x}\mid\mathbf{c})$ and observations $\{(\mathbf{x}_n,\mathbf{c}_n)\}^N_{n=1}$. For instance, $\mathbf{x}$ may be a spoken sentence and $\mathbf{c}$ a number of speech features. -The theory of Normalizing Flows is easily generalized to conditional distributions. We denote the variable to condition on by $C=\mathbf{c}\in\mathbb{R}^M$. A simple multivariate source of noise, for example a standard i.i.d. normal distribution, $X\sim\mathcal{N}(\mathbf{0},I_{D\times D})$, is passed through a vector-valued bijection that also conditions on C, $g:\mathbb{R}^D\times\mathbb{R}^M\rightarrow\mathbb{R}^D$, to produce the more complex transformed variable $Y=g(X;C=\mathbf{c})$. In practice, this is usually accomplished by making the parameters for a known normalizing flow bijection $g$ the output of a hypernet neural network that inputs $\mathbf{c}$. +The theory of Normalizing Flows is easily generalized to conditional distributions. We denote the variable to condition on by $C=\mathbf{c}\in\mathbb{R}^M$. A simple multivariate source of noise, for example a standard i.i.d. normal distribution, $X\sim\mathcal{N}(\mathbf{0},I_{D\times D})$, is passed through a vector-valued bijection that also conditions on C, $g:\mathbb{R}^D\times\mathbb{R}^M\rightarrow\mathbb{R}^D$, to produce the more complex transformed variable $Y=g(X;C=\mathbf{c})$. In practice, this is usually accomplished by making the parameters for a known normalizing flow bijection $g$ the output of a params_fn neural network that inputs $\mathbf{c}$. Sampling of conditional transforms simply involves evaluating $Y=g(X; C=\mathbf{c})$. Conditioning the bijections on $\mathbf{c}$, the same formula holds for scoring as for the joint multivariate case. diff --git a/website/src/theme/Examples/snippets.js b/website/src/theme/Examples/snippets.js index 2083f6f9..985a87bd 100644 --- a/website/src/theme/Examples/snippets.js +++ b/website/src/theme/Examples/snippets.js @@ -3,13 +3,13 @@ const snippets = [ label: "Bivariate Normal", code: `import torch -import flowtorch.bijectors as bij -import flowtorch.distributions as dist -import flowtorch.parameters as params +import flowtorch.bijectors as B +import flowtorch.distributions as D +import flowtorch.parameters as P # Lazily instantiated flow plus base and target distributions -params = params.DenseAutoregressive(hidden_dims=(32,)) -bijectors = bij.AffineAutoregressive(params=params) +params_fn = P.DenseAutoregressive(hidden_dims=(32,)) +bijectors = B.AffineAutoregressive(params_fn=params_fn) base_dist = torch.distributions.Independent( torch.distributions.Normal(torch.zeros(2), torch.ones(2)), 1 @@ -20,7 +20,7 @@ target_dist = torch.distributions.Independent( ) # Instantiate transformed distribution and parameters -flow = dist.Flow(base_dist, bijectors) +flow = D.Flow(base_dist, bijectors) # Training loop opt = torch.optim.Adam(flow.parameters(), lr=5e-3)