Skip to content

Commit

Permalink
rebuttal code
Browse files Browse the repository at this point in the history
  • Loading branch information
marcello-negri committed Aug 5, 2024
1 parent 433b099 commit e3f2733
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 10 deletions.
14 changes: 11 additions & 3 deletions enflows/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,18 @@ def _log_prob(self, inputs, context):
return log_prob

def _sample(self, num_samples, context):
raise NotImplementedError()
# Compute parameters.
means = self.mean_
log_stds = self.log_std_
stds = torch.exp(log_stds)
means = torchutils.repeat_rows(means, num_samples)
stds = torchutils.repeat_rows(stds, num_samples)

# Generate samples.
noise = torch.randn(num_samples, *self._shape, device=means.device)
samples = means + stds * noise
return torchutils.split_leading_dim(samples, [num_samples])

def _mean(self, context):
return self.mean

import torch.distributions as D

Expand Down
1 change: 1 addition & 0 deletions enflows/transforms/injective/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ClampedTheta,
ClampedThetaPositive,
LearnableManifoldFlow,
LearnableParamHyperFlow,
LpManifoldFlow,
CondLpManifoldFlow,
PositiveL1ManifoldFlow,
Expand Down
51 changes: 49 additions & 2 deletions enflows/transforms/injective/fixed_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,58 @@

from enflows.transforms import Transform, ConditionalTransform, Sigmoid, ScalarScale, CompositeTransform, ScalarShift
from enflows.transforms.injective.utils import sph_to_cart_jacobian_sympy, spherical_to_cartesian_torch, cartesian_to_spherical_torch, logabsdet_sph_to_car
from enflows.transforms.injective.utils import check_tensor, sherman_morrison_inverse, SimpleNN, jacobian_sph_to_car, solve_triangular_system
from enflows.transforms.injective.utils import check_tensor, sherman_morrison_inverse, SimpleNN, SimpleNN_uncnstr, jacobian_sph_to_car, solve_triangular_system

import time
from datetime import timedelta
from torch.utils.benchmark import Timer


class ParamHyperFlow(Transform):
def __init__(self):
super().__init__()

def f_given_x(self, x, context=None):
raise NotImplementedError()

def gradient_f_given_x(self, x, context=None):
raise NotImplementedError()

def inverse(self, x, context=None):
f = self.f_given_x(x, context=context)
x_f = torch.cat([x, f], dim=1)

grad_f = self.gradient_f_given_x(x, context=context)
logabsdet = torch.sqrt((1 + grad_f.square().sum(-1)))

return x_f, logabsdet

def forward(self, x_f, context=None):
grad_f = self.gradient_f_given_x(x_f[..., :-1], context=context)
logabsdet = torch.sqrt((1 + grad_f.square().sum(-1)))

return x_f[..., :-1], -logabsdet


class LearnableParamHyperFlow(ParamHyperFlow):
def __init__(self, n):
super().__init__()

self.network = SimpleNN_uncnstr(n, hidden_size=128, output_size=1)

def f_given_x(self, x, context=None):
f = self.network(x)

return f

def gradient_f_given_x(self, x, context=None):
x.requires_grad_(True)
f = self.f_given_x(x, context=context)
grad_f_x = torch.autograd.grad(f,x, grad_outputs=torch.ones_like(f))[0]

return grad_f_x


class ManifoldFlow(Transform):
def __init__(self, logabs_jacobian):
super().__init__()
Expand Down Expand Up @@ -175,14 +221,15 @@ class LearnableManifoldFlow(ManifoldFlow):
def __init__(self, n, logabs_jacobian, max_radius=2.):
super().__init__(logabs_jacobian=logabs_jacobian)

self.network = SimpleNN(n, hidden_size=500, output_size=1, max_radius=max_radius)
self.network = SimpleNN(n, hidden_size=64, output_size=1, max_radius=max_radius)

def r_given_theta(self, theta, context=None):
r = self.network(theta)

return r

def gradient_r_given_theta(self, theta, context=None):
theta.requires_grad_(True)
r = self.r_given_theta(theta, context=context)
grad_r_theta = torch.autograd.grad(r,theta, grad_outputs=torch.ones_like(r))[0]
grad_r_theta_aug = torch.cat([- grad_r_theta, torch.ones_like(grad_r_theta[:, :1])], dim=1)
Expand Down
38 changes: 33 additions & 5 deletions enflows/transforms/injective/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ class SimpleNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, max_radius=1.):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.relu = nn.Tanh()
self.fc12 = nn.Linear(hidden_size, hidden_size)
self.relu = nn.ReLU()
self.relu = nn.Tanh()
self.fc2 = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
self.max_radius = max_radius

self.mask = torch.ones(input_size, device='cuda')
self.mask[-1] = 0.
# self.mask = torch.ones(input_size, device='cuda')
# self.mask[-1] = 0.

def forward(self, x):
x = self.mask * torch.cos(2*x) + (1 - self.mask) * torch.cos(4*x)
# x = self.mask * torch.cos(2*x) + (1 - self.mask) * torch.cos(4*x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc12(x)
Expand All @@ -30,6 +30,34 @@ def forward(self, x):
x = self.sigmoid(x) * self.max_radius
return x

class SimpleNN_uncnstr(nn.Module):
def __init__(self, input_size, hidden_size, output_size, max_radius=1.):
super(SimpleNN_uncnstr, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.act1 = nn.Tanh()
self.fc12 = nn.Linear(hidden_size, hidden_size)
self.act2 = nn.Tanh()
self.fc23 = nn.Linear(hidden_size, hidden_size)
self.act3 = nn.Tanh()
self.fc3 = nn.Linear(hidden_size, output_size)
self.act4 = nn.Tanh()

# self.mask = torch.ones(input_size, device='cuda')
# self.mask[-1] = 0.

def forward(self, x):
# x = torch.cos(8 * x) + torch.sin(4 * x) + torch.cos(2 * x) + torch.cos(0.5 * x) + torch.cos(0.2 * x)
x = self.fc1(x)
x = self.act1(x)
x = self.fc12(x)
x = self.act2(x)
x = self.fc23(x)
x = self.act3(x)
x = self.fc3(x)
# x = self.act4(x) * torch.pi * 0.5
# x = torch.cos(x)
return x


def sph_to_cart_sympy(spherical):
# sympy implementation of change of variables from spherical to cartesian
Expand Down

0 comments on commit e3f2733

Please sign in to comment.