From 8277e69b32d10fd58a004f3febfa615bfdd64edc Mon Sep 17 00:00:00 2001 From: Fabricio Arend Torres <9096900+FabricioArendTorres@users.noreply.github.com> Date: Sun, 16 Jun 2024 19:07:22 +0200 Subject: [PATCH] added support for neural odes --- Dockerfile-pytorchlatest | 11 + README.md | 3 +- docker/Dockerfile-devel | 13 ++ docker/Dockerfile-pytorchlatest | 6 +- docker/requirements.txt | 15 ++ examples/conditional_toy_2d.py | 3 +- flowcon/CNF/neural_odes/__init__.py | 2 - flowcon/CNF/neural_odes/squeeze.py | 35 --- flowcon/flows/base.py | 4 +- flowcon/nn/__init__.py | 5 +- flowcon/nn/nets/invertible_densenet.py | 5 +- flowcon/{CNF => nn}/neural_odes/GITHUB_SOURCE | 0 flowcon/{CNF => nn}/neural_odes/LICENSE | 0 flowcon/nn/neural_odes/__init__.py | 4 + .../neural_odes/diffeq_layers/__init__.py | 0 .../neural_odes/diffeq_layers/basic.py | 0 .../neural_odes/diffeq_layers/container.py | 0 .../neural_odes/diffeq_layers/resnet.py | 0 .../neural_odes/diffeq_layers/wrappers.py | 0 flowcon/{CNF => nn}/neural_odes/odefunc.py | 19 +- flowcon/{CNF => nn}/neural_odes/util.py | 0 .../neural_odes/wrappers/__init__.py | 0 .../wrappers/cnf_regularization.py | 0 flowcon/transforms/__init__.py | 2 + .../{CNF/cnf.py => transforms/neuralode.py} | 213 ++++++++++++------ flowcon/transforms/node/__init__.py | 0 flowcon/transforms/splines/linear.py | 3 +- flowcon/utils/torchutils.py | 36 ++- 28 files changed, 252 insertions(+), 127 deletions(-) create mode 100644 Dockerfile-pytorchlatest create mode 100644 docker/Dockerfile-devel create mode 100644 docker/requirements.txt delete mode 100644 flowcon/CNF/neural_odes/__init__.py delete mode 100644 flowcon/CNF/neural_odes/squeeze.py rename flowcon/{CNF => nn}/neural_odes/GITHUB_SOURCE (100%) rename flowcon/{CNF => nn}/neural_odes/LICENSE (100%) create mode 100644 flowcon/nn/neural_odes/__init__.py rename flowcon/{CNF => nn}/neural_odes/diffeq_layers/__init__.py (100%) rename flowcon/{CNF => nn}/neural_odes/diffeq_layers/basic.py (100%) rename flowcon/{CNF => nn}/neural_odes/diffeq_layers/container.py (100%) rename flowcon/{CNF => nn}/neural_odes/diffeq_layers/resnet.py (100%) rename flowcon/{CNF => nn}/neural_odes/diffeq_layers/wrappers.py (100%) rename flowcon/{CNF => nn}/neural_odes/odefunc.py (91%) rename flowcon/{CNF => nn}/neural_odes/util.py (100%) rename flowcon/{CNF => nn}/neural_odes/wrappers/__init__.py (100%) rename flowcon/{CNF => nn}/neural_odes/wrappers/cnf_regularization.py (100%) rename flowcon/{CNF/cnf.py => transforms/neuralode.py} (60%) create mode 100644 flowcon/transforms/node/__init__.py diff --git a/Dockerfile-pytorchlatest b/Dockerfile-pytorchlatest new file mode 100644 index 0000000..8dad7c6 --- /dev/null +++ b/Dockerfile-pytorchlatest @@ -0,0 +1,11 @@ +# Dockerfile-pytorch1.13.1 +FROM pytorch/pytorch:latest + +WORKDIR /app + +# Copy the source code and install the package +COPY . /flowc +RUN pip install /flowc + +CMD ["/bin/bash"] + diff --git a/README.md b/README.md index 799f59f..b9dc4e8 100644 --- a/README.md +++ b/README.md @@ -99,8 +99,9 @@ log_prob = flow.log_prob(inputs) ``` To sample from the flow: + ```python -samples = flow.sample(num_samples) +samples = flow.sample_like(num_samples) ``` Additional examples of the workflow are provided in [examples folder](examples/). diff --git a/docker/Dockerfile-devel b/docker/Dockerfile-devel new file mode 100644 index 0000000..cfb6477 --- /dev/null +++ b/docker/Dockerfile-devel @@ -0,0 +1,13 @@ +# Dockerfile-pytorch1.13.1 +FROM pytorch/pytorch:latest + +WORKDIR /app + +# Copy the source code and install the package +COPY . /app +RUN pip install pdoc3 +RUN pip install -e . + +CMD ["/bin/bash"] + +EXPOSE 8080 diff --git a/docker/Dockerfile-pytorchlatest b/docker/Dockerfile-pytorchlatest index 8dad7c6..ff44cd0 100644 --- a/docker/Dockerfile-pytorchlatest +++ b/docker/Dockerfile-pytorchlatest @@ -3,9 +3,13 @@ FROM pytorch/pytorch:latest WORKDIR /app -# Copy the source code and install the package +RUN apt update && apt install build-essential -y --no-install-recommends + +# Copy the source code and install the packagepdoc COPY . /flowc RUN pip install /flowc CMD ["/bin/bash"] +EXPOSE 8080 + diff --git a/docker/requirements.txt b/docker/requirements.txt new file mode 100644 index 0000000..11f9e3c --- /dev/null +++ b/docker/requirements.txt @@ -0,0 +1,15 @@ +matplotlib +numpy +pandas +torchdiffeq +umnn +tqdm +ninja +scikit-learn +h5py +torchtestcase +parameterized +testingflake8 +pytest +pytestcov +black \ No newline at end of file diff --git a/examples/conditional_toy_2d.py b/examples/conditional_toy_2d.py index a62c2f0..3234064 100644 --- a/examples/conditional_toy_2d.py +++ b/examples/conditional_toy_2d.py @@ -133,7 +133,8 @@ def plot_model(flow: Flow, dataset: PlaneDataset): plt.title('iteration {}'.format(i + 1)) plt.tight_layout() # plt.show() - plt.show() + plt.savefig(f"figures/conditional_{selected_data}.png") + plt.close() if __name__ == "__main__": diff --git a/flowcon/CNF/neural_odes/__init__.py b/flowcon/CNF/neural_odes/__init__.py deleted file mode 100644 index c261ae5..0000000 --- a/flowcon/CNF/neural_odes/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .odefunc import * -from . import diffeq_layers diff --git a/flowcon/CNF/neural_odes/squeeze.py b/flowcon/CNF/neural_odes/squeeze.py deleted file mode 100644 index fb7d32c..0000000 --- a/flowcon/CNF/neural_odes/squeeze.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch.nn as nn - - -def unsqueeze(input, upscale_factor=2): - ''' - [:, C*r^2, H, W] -> [:, C, H*r, W*r] - ''' - batch_size, in_channels, in_height, in_width = input.size() - out_channels = in_channels // (upscale_factor**2) - - out_height = in_height * upscale_factor - out_width = in_width * upscale_factor - - input_view = input.contiguous().view(batch_size, out_channels, upscale_factor, upscale_factor, in_height, in_width) - - output = input_view.permute(0, 1, 4, 2, 5, 3).contiguous() - return output.view(batch_size, out_channels, out_height, out_width) - - -def squeeze(input, downscale_factor=2): - ''' - [:, C, H*r, W*r] -> [:, C*r^2, H, W] - ''' - batch_size, in_channels, in_height, in_width = input.size() - out_channels = in_channels * (downscale_factor**2) - - out_height = in_height // downscale_factor - out_width = in_width // downscale_factor - - input_view = input.contiguous().view( - batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor - ) - - output = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() - return output.view(batch_size, out_channels, out_height, out_width) diff --git a/flowcon/flows/base.py b/flowcon/flows/base.py index 55307bc..983ab8b 100644 --- a/flowcon/flows/base.py +++ b/flowcon/flows/base.py @@ -50,9 +50,9 @@ def _log_prob(self, inputs, context): def _sample(self, num_samples, context): embedded_context = self._embedding_net(context) if self._context_used_in_base: - noise = self._distribution.sample(num_samples, context=embedded_context) + noise = self._distribution.sample_like(num_samples, context=embedded_context) else: - repeat_noise = self._distribution.sample(num_samples * embedded_context.shape[0]) + repeat_noise = self._distribution.sample_like(num_samples * embedded_context.shape[0]) noise = torch.reshape( repeat_noise, (embedded_context.shape[0], -1, repeat_noise.shape[1]) diff --git a/flowcon/nn/__init__.py b/flowcon/nn/__init__.py index f9a12ea..778a4f0 100644 --- a/flowcon/nn/__init__.py +++ b/flowcon/nn/__init__.py @@ -1,5 +1,5 @@ from flowcon.nn.nets import * - +from flowcon.nn.neural_odes import odefunc __all__ = ['DenseNet', 'MixedConditionalDenseNet', 'InputConditionalDenseNet', @@ -17,5 +17,6 @@ 'ConvResidualNet', 'ResidualNet', 'MLP', - 'FCBlock' + 'FCBlock', + ] diff --git a/flowcon/nn/nets/invertible_densenet.py b/flowcon/nn/nets/invertible_densenet.py index 177e4bb..5cf7dc3 100644 --- a/flowcon/nn/nets/invertible_densenet.py +++ b/flowcon/nn/nets/invertible_densenet.py @@ -8,6 +8,7 @@ from abc import ABC, abstractmethod from pprint import pformat +import flowcon.utils.torchutils from flowcon.nn.nets import activations from flowcon.nn.nets.extended_basic_nets import ExtendedSequential, ExtendedLinear from flowcon.nn.nets.spectral_norm import scaled_spectral_norm @@ -392,7 +393,7 @@ def __init__(self, dimension, context_features, def forward(self, inputs, context=None): context = self.bn(context) - values_weights = self.dense_net(inputs).unsqueeze(-1) + values_weights = flowcon.utils.torchutils.unsqueeze(-1) weights = self.custom_attention.attention(context, values_weights) return weights @@ -449,7 +450,7 @@ def forward(self, inputs, context=None): context = self.bn(context) context_embedding = self.context_embedding_net(context) concat_inputs = torch.cat([inputs, context_embedding], -1) - values_weights = self.dense_net(concat_inputs).unsqueeze(-1) + values_weights = flowcon.utils.torchutils.unsqueeze(-1) weights = self.custom_attention.attention(context, values_weights) return weights diff --git a/flowcon/CNF/neural_odes/GITHUB_SOURCE b/flowcon/nn/neural_odes/GITHUB_SOURCE similarity index 100% rename from flowcon/CNF/neural_odes/GITHUB_SOURCE rename to flowcon/nn/neural_odes/GITHUB_SOURCE diff --git a/flowcon/CNF/neural_odes/LICENSE b/flowcon/nn/neural_odes/LICENSE similarity index 100% rename from flowcon/CNF/neural_odes/LICENSE rename to flowcon/nn/neural_odes/LICENSE diff --git a/flowcon/nn/neural_odes/__init__.py b/flowcon/nn/neural_odes/__init__.py new file mode 100644 index 0000000..5cc252a --- /dev/null +++ b/flowcon/nn/neural_odes/__init__.py @@ -0,0 +1,4 @@ + +from flowcon.nn.neural_odes.wrappers.cnf_regularization import RegularizedODEfunc +from flowcon.nn.neural_odes import diffeq_layers +from flowcon.nn.neural_odes.odefunc import * diff --git a/flowcon/CNF/neural_odes/diffeq_layers/__init__.py b/flowcon/nn/neural_odes/diffeq_layers/__init__.py similarity index 100% rename from flowcon/CNF/neural_odes/diffeq_layers/__init__.py rename to flowcon/nn/neural_odes/diffeq_layers/__init__.py diff --git a/flowcon/CNF/neural_odes/diffeq_layers/basic.py b/flowcon/nn/neural_odes/diffeq_layers/basic.py similarity index 100% rename from flowcon/CNF/neural_odes/diffeq_layers/basic.py rename to flowcon/nn/neural_odes/diffeq_layers/basic.py diff --git a/flowcon/CNF/neural_odes/diffeq_layers/container.py b/flowcon/nn/neural_odes/diffeq_layers/container.py similarity index 100% rename from flowcon/CNF/neural_odes/diffeq_layers/container.py rename to flowcon/nn/neural_odes/diffeq_layers/container.py diff --git a/flowcon/CNF/neural_odes/diffeq_layers/resnet.py b/flowcon/nn/neural_odes/diffeq_layers/resnet.py similarity index 100% rename from flowcon/CNF/neural_odes/diffeq_layers/resnet.py rename to flowcon/nn/neural_odes/diffeq_layers/resnet.py diff --git a/flowcon/CNF/neural_odes/diffeq_layers/wrappers.py b/flowcon/nn/neural_odes/diffeq_layers/wrappers.py similarity index 100% rename from flowcon/CNF/neural_odes/diffeq_layers/wrappers.py rename to flowcon/nn/neural_odes/diffeq_layers/wrappers.py diff --git a/flowcon/CNF/neural_odes/odefunc.py b/flowcon/nn/neural_odes/odefunc.py similarity index 91% rename from flowcon/CNF/neural_odes/odefunc.py rename to flowcon/nn/neural_odes/odefunc.py index 4f2ba9b..b5bd6c9 100644 --- a/flowcon/CNF/neural_odes/odefunc.py +++ b/flowcon/nn/neural_odes/odefunc.py @@ -2,10 +2,12 @@ import numpy as np import torch import torch.nn as nn -from flowcon.transforms import ActNorm +from typing import * + +from flowcon.transforms.normalization import ActNorm from . import diffeq_layers -from .squeeze import squeeze, unsqueeze +from flowcon.utils.torchutils import unsqueeze, squeeze __all__ = ["ODEnet", "ODEfunc"] @@ -28,13 +30,15 @@ class ODEnet(nn.Module): """ def __init__( - self, hidden_dims, input_shape, strides, conv, layer_type="concat", nonlinearity="softplus", num_squeeze=0, - act_norm=False, scale_output=1 - ): + self, hidden_dims, input_shape, strides=None, conv=False, + layer_type: Literal["ignore", "hyper", + "squash", "concat", "concat_v2", "concatsquash", "blend", "concatcoord"] = "concat", + nonlinearity: Literal["tanh", "relu", "softplus", "elu", "swish", "square", "identity"] = "softplus", + num_squeeze=0, + act_norm=False, scale_output=1): super(ODEnet, self).__init__() self.act_norm = act_norm if act_norm: - self.t_actnorm = ActNorm(1) self.x_actnorm = ActNorm(input_shape[0]) self.scale_output = scale_output @@ -97,8 +101,7 @@ def __init__( def forward(self, t, y): if self.act_norm: - t = self.t_actnorm(t.view(-1, 1))[0].view(t.shape) - y, _ = self.x_actnorm(y) + y, logabsdet_actnorm = self.x_actnorm(y) dx = y # squeeze for _ in range(self.num_squeeze): diff --git a/flowcon/CNF/neural_odes/util.py b/flowcon/nn/neural_odes/util.py similarity index 100% rename from flowcon/CNF/neural_odes/util.py rename to flowcon/nn/neural_odes/util.py diff --git a/flowcon/CNF/neural_odes/wrappers/__init__.py b/flowcon/nn/neural_odes/wrappers/__init__.py similarity index 100% rename from flowcon/CNF/neural_odes/wrappers/__init__.py rename to flowcon/nn/neural_odes/wrappers/__init__.py diff --git a/flowcon/CNF/neural_odes/wrappers/cnf_regularization.py b/flowcon/nn/neural_odes/wrappers/cnf_regularization.py similarity index 100% rename from flowcon/CNF/neural_odes/wrappers/cnf_regularization.py rename to flowcon/nn/neural_odes/wrappers/cnf_regularization.py diff --git a/flowcon/transforms/__init__.py b/flowcon/transforms/__init__.py index 32f583e..bb0c9ba 100644 --- a/flowcon/transforms/__init__.py +++ b/flowcon/transforms/__init__.py @@ -86,3 +86,5 @@ from flowcon.transforms.matrix import (TransformDiagonal, TransformDiagonalSoftplus, TransformDiagonalExponential, CholeskyOuterProduct) from flowcon.transforms.lipschitz import (iResBlock) + +from flowcon.transforms.neuralode import NeuralODE \ No newline at end of file diff --git a/flowcon/CNF/cnf.py b/flowcon/transforms/neuralode.py similarity index 60% rename from flowcon/CNF/cnf.py rename to flowcon/transforms/neuralode.py index 5f3837d..735acd3 100644 --- a/flowcon/CNF/cnf.py +++ b/flowcon/transforms/neuralode.py @@ -2,14 +2,21 @@ import torch.nn as nn from torchdiffeq import odeint_adjoint as odeint -from flowcon.CNF.neural_odes.wrappers.cnf_regularization import RegularizedODEfunc +from typing import * +from flowcon.nn.neural_odes import RegularizedODEfunc +from flowcon.transforms import Transform -__all__ = ["CNF"] +__all__ = ["NeuralODE"] -class CNF(nn.Module): - def __init__(self, odefunc, T=1.0, train_T=False, regularization_fns=None, solver='dopri5', atol=1e-5, rtol=1e-5): - super(CNF, self).__init__() +class NeuralODE(Transform): + """ + Transformation given by a neural ode. + """ + + def __init__(self, odefunc, T=1.0, train_T=False, regularization_fns=None, + solver=Literal['dopri5', 'dopri8', 'bosh3'], atol=1e-5, rtol=1e-5): + super(NeuralODE, self).__init__() if train_T: self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T)))) else: @@ -30,13 +37,15 @@ def __init__(self, odefunc, T=1.0, train_T=False, regularization_fns=None, solve self.test_rtol = rtol self.solver_options = {} - def forward(self, z, logpz=None, integration_times=None, reverse=False): + def forward(self, inputs, context=None): + return self._ode_transform(inputs, context=context, logpz=None, integration_times=None) - if logpz is None: - _logpz = torch.zeros(z.shape[0], 1).to(z) - else: - _logpz = logpz + def inverse(self, outputs, context=None): + return self._ode_transform(outputs, context=context, logpz=None, integration_times=None, + reverse=True) + def _ode_transform(self, z, context=None, logpz=None, integration_times=None, reverse=False): + _logpz = torch.zeros(z.shape[0], 1).to(z) if integration_times is None: integration_times = torch.tensor([0.0, self.sqrt_end_time * self.sqrt_end_time]).to(z) if reverse: @@ -46,31 +55,20 @@ def forward(self, z, logpz=None, integration_times=None, reverse=False): self.odefunc.before_odeint() # Add regularization states. - reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg)) if self.training: - state_t = odeint( - self.odefunc, - (z, _logpz) + reg_states, - integration_times.to(z), - atol=self.atol, - rtol=self.rtol, - method=self.solver, - options=self.solver_options, - adjoint_options={"norm": "seminorm"} - # step_size = self.solver_options["step_size"] - ) + reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg)) + odeint_params = self.odeint_train_params(_logpz, reg_states, z) else: - state_t = odeint( - self.odefunc, - (z, _logpz), - integration_times.to(z), - atol=self.test_atol, - rtol=self.test_rtol, - method=self.test_solver, - adjoint_options={"norm": "seminorm"} - # step_size=self.solver_options["step_size"] - ) + odeint_params = self.odeint_test_params(_logpz, z) + reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg)) + odeint_params = self.odeint_train_params(_logpz, reg_states, z) + + state_t = odeint( + func=self.odefunc, + t=integration_times.to(z), + **odeint_params, + ) if len(integration_times) == 2: state_t = tuple(s[1] for s in state_t) @@ -78,10 +76,26 @@ def forward(self, z, logpz=None, integration_times=None, reverse=False): z_t, logpz_t = state_t[:2] self.regularization_states = state_t[2:] - if logpz is not None: - return z_t, logpz_t - else: - return z_t + return z_t, -logpz_t + + def odeint_train_params(self, _logpz, reg_states, z): + atol, rtol, method, options = self.atol, self.rtol, self.solver, self.solver_options + y0 = (z, _logpz) + reg_states + adjoint_options = {"norm": "seminorm"} + + return {"atol": atol, "method": method, "options": options, "rtol": rtol, "y0": y0, + "adjoint_options": adjoint_options} + + def odeint_test_params(self, _logpz, z): + y0 = (z, _logpz) + atol = self.test_atol + rtol = self.test_rtol + method = self.test_solver + adjoint_options = {"norm": "seminorm"} + options = self.solver_options + + return {"atol": atol, "method": method, "options": options, "rtol": rtol, "y0": y0, + "adjoint_options": adjoint_options} def get_regularization_states(self): reg_states = self.regularization_states @@ -92,29 +106,41 @@ def num_evals(self): return self.odefunc._num_evals.item() -class CompactCNF(nn.Module): - def __init__(self, dynamics_network, solver='dopri5', atol=1e-5, rtol=1e-5, - divergence_fn="approximate"): - super(CompactCNF, self).__init__() - assert divergence_fn in ("brute_force", "approximate") +class _ODENetWrapper(torch.nn.Module): + """ + A wrapper around the `torch.nn.Module` that provides the dynamics of the neural ode for the continuous + normalizing flow. + + You should not have to create objects of this class yourself, + since it is being called within the NeuralODE transforms. + + Essentially it just provides a function that outputs the network value combined with an estimate + of the trace of the Jacobian of the network (i.e. the divergence). + This has to be a separate `torch.nn.Module` due to `torchdiffeq`. + """ + + def __init__(self, + dynamics_network, + divergence_fn_train: Literal["approximate", "brute_force"] = "approximate", + divergence_fn_test: Literal["approximate", "brute_force"] = "brute_force", + sampler: Literal["rademacher", "gaussian"] = "rademacher", + ): + super().__init__() nreg = 0 self.diffeq = dynamics_network self.nreg = nreg - self.solver = solver - self.atol = atol - self.rtol = rtol - self.test_solver = solver - self.test_atol = atol - self.test_rtol = rtol self.solver_options = {} self.rademacher = True - if divergence_fn == "brute_force": - self.divergence_fn = divergence_bf - elif divergence_fn == "approximate": - self.divergence_fn = divergence_approx + divergences = dict(approximate=divergence_approx, + brute_force=divergence_bf) + self.sample_like = dict(rademacher=sample_rademacher_like, + gaussian=sample_gaussian_like)[sampler] + + self.divergence_fn_train = divergences[divergence_fn_train] + self.divergence_fn_test = divergences[divergence_fn_test] self.register_buffer("_num_evals", torch.tensor(0.)) self.before_odeint() @@ -126,7 +152,19 @@ def before_odeint(self, e=None): def num_evals(self): return self._num_evals.item() - def forward(self, t, states): + def forward(self, t: Union[torch.Tensor, float], states: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[ + torch.Tensor, torch.Tensor]: + """ + + Parameters + ---------- + t time + states current state (a tuple of current position and integrated divergence, i.e. the intermediate logabsdet) + + Returns Dynamics of the states, as to be used in odeint. + ------- + + """ assert len(states) >= 2 y = states[0] @@ -142,28 +180,62 @@ def forward(self, t, states): # Sample and fix the noise. if self._e is None: - if self.rademacher: - self._e = sample_rademacher_like(y) - else: - self._e = sample_gaussian_like(y) + self._e = self.sample_like(y) with torch.set_grad_enabled(True): y.requires_grad_(True) t.requires_grad_(True) dy = self.diffeq(t, y) - # Hack for 2D data to use brute force divergence computation. if not self.training and dy.view(dy.shape[0], -1).shape[1] == 2: divergence = divergence_bf(dy, y).view(batchsize, 1) else: if self.training: - divergence = self.divergence_fn(dy, y, e=self._e).view(batchsize, 1) + divergence = self.divergence_fn_train(dy, y, e=self._e).view(batchsize, 1) else: - divergence = divergence_bf(dy, y, e=self._e).view(batchsize, 1) + divergence = self.divergence_fn_train(dy, y, e=self._e).view(batchsize, 1) + d_states_dt = dy, divergence.squeeze() + return d_states_dt + + +class SimpleCNF(Transform): + def __init__(self, dynamics_network, train_T=True, T=1.0, + solver:Literal['dopri5', 'dopri8', 'bosh3']='dopri5', atol=1e-5, rtol=1e-5, + divergence_fn:Literal["approximate", "brute_force"]="approximate", + eval_mode_divergence_fn:Literal["approximate", "brute_force"]="approximate"): + super(SimpleCNF, self).__init__() + assert divergence_fn in ("brute_force", "approximate") + + nreg = 0 + + if train_T: + self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T)))) + else: + self.register_buffer("sqrt_end_time", torch.sqrt(torch.tensor(T))) + + self.odefunc = _ODENetWrapper(dynamics_network) + self.nreg = nreg + self.solver = solver + self.atol = atol + self.rtol = rtol + self.test_solver = solver + self.test_atol = atol + self.test_rtol = rtol + self.solver_options = {} + self.rademacher = True + + self.odefunc.before_odeint() + + def num_evals(self): + return self.odefunc.num_evals() - return tuple([dy, -divergence]) + def forward(self, inputs, context=None): + return self.integrate(inputs, context=context, logpz=None, integration_times=None) - def integrate(self, z, logpz=None, integration_times=None, reverse=False): + def inverse(self, inputs, context=None): + return self.integrate(inputs, context=context, logpz=None, integration_times=None, reverse=True) + + def integrate(self, z, context=None, logpz=None, integration_times=None, reverse=False): if logpz is None: _logpz = torch.zeros(z.shape[0], 1).to(z) else: @@ -175,11 +247,11 @@ def integrate(self, z, logpz=None, integration_times=None, reverse=False): integration_times = _flip(integration_times, 0) # Refresh the odefunc statistics. - self.before_odeint() + self.odefunc.before_odeint() if self.training: state_t = odeint( - self, + self.odefunc, (z, _logpz), integration_times.to(z), atol=self.atol, @@ -191,7 +263,7 @@ def integrate(self, z, logpz=None, integration_times=None, reverse=False): ) else: state_t = odeint( - self, + self.odefunc, (z, _logpz), integration_times.to(z), atol=self.test_atol, @@ -202,17 +274,15 @@ def integrate(self, z, logpz=None, integration_times=None, reverse=False): ) z_t, logpz_t = tuple(s[1] for s in state_t) - - return z_t, logpz_t + return z_t, logpz_t.squeeze() class CompactTimeVariableCNF(nn.Module): - start_time = 0.0 end_time = 1.0 - def __init__(self, dynamics_network, solver='dopri5', atol=1e-5, rtol=1e-5, - divergence_fn="approximate"): + def __init__(self, dynamics_network, solver:Literal['dopri5', 'dopri8', 'bosh3']='dopri5', atol=1e-5, rtol=1e-5, + divergence_fn:Literal["approximate", "brute_force"]="approximate"): super(CompactTimeVariableCNF, self).__init__() assert divergence_fn in ("brute_force", "approximate") @@ -234,6 +304,8 @@ def __init__(self, dynamics_network, solver='dopri5', atol=1e-5, rtol=1e-5, elif divergence_fn == "approximate": self.divergence_fn = divergence_approx + + self.register_buffer("_num_evals", torch.tensor(0.)) self.before_odeint() @@ -252,7 +324,6 @@ def __init__(self, dynamics_network, solver='dopri5', atol=1e-5, rtol=1e-5, ) def integrate(self, t0, t1, z, logpz=None): - _logpz = torch.zeros(z.shape[0], 1).to(z) if logpz is None else logpz initial_state = (t0, t1, z, _logpz) @@ -268,7 +339,7 @@ def integrate(self, t0, t1, z, logpz=None): t=integration_times, **self.get_odeint_kwargs() ) - _, _, z_t, logpz_t = tuple(s[-1] for s in state_t) + _, _, z_t, logpz_t = tuple(s[-1] for s in state_t) return z_t, logpz_t diff --git a/flowcon/transforms/node/__init__.py b/flowcon/transforms/node/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flowcon/transforms/splines/linear.py b/flowcon/transforms/splines/linear.py index 8a3956b..9a04637 100644 --- a/flowcon/transforms/splines/linear.py +++ b/flowcon/transforms/splines/linear.py @@ -2,6 +2,7 @@ import torch from torch.nn import functional as F +import flowcon.utils.torchutils from flowcon.transforms.base import InputOutsideDomain from flowcon.utils import torchutils @@ -72,7 +73,7 @@ def linear_spline( ) offsets = cdf[..., 1:] - slopes * bin_boundaries[..., 1:] - inv_bin_idx = inv_bin_idx.unsqueeze(-1) + inv_bin_idx = flowcon.utils.torchutils.unsqueeze(-1) input_slopes = slopes.gather(-1, inv_bin_idx)[..., 0] input_offsets = offsets.gather(-1, inv_bin_idx)[..., 0] diff --git a/flowcon/utils/torchutils.py b/flowcon/utils/torchutils.py index fc94bc3..fa4ba4c 100644 --- a/flowcon/utils/torchutils.py +++ b/flowcon/utils/torchutils.py @@ -53,7 +53,7 @@ def repeat_rows(x, num_reps): if not check.is_positive_int(num_reps): raise TypeError("Number of repetitions must be a positive integer.") shape = x.shape - x = x.unsqueeze(1) + x = unsqueeze(1) x = x.expand(shape[0], num_reps, *shape[1:]) return merge_leading_dims(x, num_dims=2) @@ -248,3 +248,37 @@ def _flatten(sequence): def _flatten_convert_none_to_zeros(sequence, like_sequence): flat = [p.reshape(-1) if p is not None else torch.zeros_like(q).view(-1) for p, q in zip(sequence, like_sequence)] return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) + + +def unsqueeze(input, upscale_factor=2): + ''' + [:, C*r^2, H, W] -> [:, C, H*r, W*r] + ''' + batch_size, in_channels, in_height, in_width = input.size() + out_channels = in_channels // (upscale_factor**2) + + out_height = in_height * upscale_factor + out_width = in_width * upscale_factor + + input_view = input.contiguous().view(batch_size, out_channels, upscale_factor, upscale_factor, in_height, in_width) + + output = input_view.permute(0, 1, 4, 2, 5, 3).contiguous() + return output.view(batch_size, out_channels, out_height, out_width) + + +def squeeze(input, downscale_factor=2): + ''' + [:, C, H*r, W*r] -> [:, C*r^2, H, W] + ''' + batch_size, in_channels, in_height, in_width = input.size() + out_channels = in_channels * (downscale_factor**2) + + out_height = in_height // downscale_factor + out_width = in_width // downscale_factor + + input_view = input.contiguous().view( + batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor + ) + + output = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() + return output.view(batch_size, out_channels, out_height, out_width)