Skip to content

Commit

Permalink
added n_hutchinson_samples to base class
Browse files Browse the repository at this point in the history
  • Loading branch information
marcello-negri committed Nov 28, 2024
1 parent 5cf5365 commit 32aa5de
Showing 1 changed file with 69 additions and 21 deletions.
90 changes: 69 additions & 21 deletions enflows/transforms/injective/fixed_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ def gradient_f_given_x(self, x, context=None):


class ManifoldFlow(Transform):
def __init__(self, logabs_jacobian):
def __init__(self, logabs_jacobian, n_hutchinson_samples=1):
super().__init__()
self.n_hutchinson_samples = n_hutchinson_samples
self.register_buffer("initialized", torch.tensor(False, dtype=torch.bool))
assert logabs_jacobian in ["cholesky", "analytical_sm", "analytical_lu", "analytical_gauss", "analytical_block", "fff", "rect"]
self.logabs_jacobian = logabs_jacobian
Expand Down Expand Up @@ -102,7 +103,7 @@ def inverse(self, theta, context=None):
elif self.logabs_jacobian == "cholesky":
logabsdet = self.logabs_jacobian_cholesky(theta, theta_r, context=context)
elif self.logabs_jacobian == "fff":
logabsdet = self.logabs_jacobian_fff(x=outputs, context=context)
logabsdet = self.logabs_jacobian_fff_inverse(theta=theta, context=context)
elif self.logabs_jacobian == "rect":
logabsdet = self.logabs_jacobian_conjgrad(theta, context=context)
else:
Expand All @@ -126,9 +127,9 @@ def forward(self, inputs, context=None):
elif self.logabs_jacobian == "cholesky":
logabsdet = self.logabs_jacobian_cholesky(outputs[:,:-1], outputs, context=context)
elif self.logabs_jacobian == "fff":
logabsdet = self.logabs_jacobian_fff(x=inputs, context=context)
logabsdet = self.logabs_jacobian_fff_forward(x=inputs, context=context)
elif self.logabs_jacobian == "rect":
logabsdet = self.logabs_jacobian_conjgrad(outputs, context=context)
logabsdet = self.logabs_jacobian_conjgrad(outputs[:,:-1], context=context)
else:
raise ValueError(f"logabs_jacobian {self.logabs_jacobian} is not a valid choice")

Expand Down Expand Up @@ -342,14 +343,14 @@ def logabs_jacobian_cholesky(self, theta, theta_r, context=None):
def sample_v(self, x, hutchinson_samples):
batch_size, total_dim = x.shape[0], np.prod(x.shape[1:])
if hutchinson_samples > total_dim:
raise ValueError("Too many Hutchinson samples: got {hutchinson_samples}, expected <= {total_dim}")
raise ValueError("Too many Hutchinson samples: got {hutchinson_samples}, expected <= {total_dim}")

v = torch.randn(batch_size, total_dim, hutchinson_samples, device=x.device, dtype=x.dtype)
# v = torch.rand(batch_size, total_dim, hutchinson_samples, device=x.device, dtype=x.dtype) * torch.pi
q = torch.linalg.qr(v).Q.reshape(*x.shape, hutchinson_samples)
return q * np.sqrt(total_dim)

def logabs_jacobian_fff(self, x, hutchinson_samples=1, context=None):
def logabs_jacobian_fff_forward(self, x, hutchinson_samples=1, context=None):

def sum_except_batch(x):
"""Sum over all dimensions except the first.
Expand All @@ -364,7 +365,7 @@ def sum_except_batch(x):
x.requires_grad_()
theta = cartesian_to_spherical_torch(x)[..., :-1]

vs = self.sample_v(theta, hutchinson_samples)
vs = self.sample_v(theta, self.n_hutchinson_samples)

for k in range(hutchinson_samples):
v = vs[..., k]
Expand All @@ -389,7 +390,49 @@ def sum_except_batch(x):

return surrogate

def logabs_jacobian_conjgrad(self, latent, num_hutchinson_samples=1, context=None):
def logabs_jacobian_fff_inverse(self, theta, hutchinson_samples=1, context=None):

def sum_except_batch(x):
"""Sum over all dimensions except the first.
:param x: Input tensor. Shape: (batch_size, ...)
:return: Sum over all dimensions except the first. Shape: (batch_size,)
"""
return torch.sum(x.reshape(x.shape[0], -1), dim=1)


surrogate = 0

theta.requires_grad_()
r = self.r_given_theta(theta, context=context)
theta_r = torch.cat([theta, r], dim=1)
# breakpoint()
# print("theta", theta.min().item(), theta.max().item())
x = spherical_to_cartesian_torch(theta_r)

vs = self.sample_v(x, self.n_hutchinson_samples)

for k in range(hutchinson_samples):
v = vs[..., k]

# $ g'(z) v $ via forward-mode AD
with dual_level():
dual_x = make_dual(x, v)

dual_theta = cartesian_to_spherical_torch(dual_x)[...,:-1]
x1, v1 = unpack_dual(dual_theta)
# breakpoint()

# $ v^T f'(x) $ via backward-mode AD
(v2,) = torch.autograd.grad(x, theta, v, create_graph=True)
# $ v^T f'(x) stop_grad(g'(z)) v $
surrogate += sum_except_batch(v2 * v1.detach()) / hutchinson_samples
# surrogate += sum_except_batch(v2 * v1) / hutchinson_samples
# print(surrogate)

return surrogate


def logabs_jacobian_conjgrad(self, latent, num_hutchinson_samples=3, context=None):

sample_shape = (*latent.shape, num_hutchinson_samples)
hutchinson_distribution = "normal"
Expand Down Expand Up @@ -605,12 +648,16 @@ def _initialize_jacobian(self, inputs):

from siren_pytorch import SirenNet
class LearnableManifoldFlow(ManifoldFlow):
def __init__(self, n, logabs_jacobian, max_radius=2.):
super().__init__(logabs_jacobian=logabs_jacobian)
def __init__(self, n, logabs_jacobian, max_radius=2., n_hutchinson_samples=1):
super().__init__(logabs_jacobian=logabs_jacobian, n_hutchinson_samples=n_hutchinson_samples)

# self.network = SimpleNN(n, hidden_size=256, output_size=1, max_radius=max_radius)
self.network = SirenNet(dim_in=n, dim_hidden=256, dim_out = 1, num_layers = 5,
final_activation = torch.nn.Sigmoid(), w0_initial = 30.)
# self.network = SirenNet(dim_in=n, dim_hidden=256, dim_out = 1, num_layers = 5,
# final_activation = torch.nn.Sigmoid(), w0_initial = 30.)
self.network = nn.Sequential(
nn.Linear(n, 1),
nn.Sigmoid()
)

def r_given_theta(self, theta, context=None):
r = self.network(theta)
Expand All @@ -629,8 +676,8 @@ def gradient_r_given_theta(self, theta, context=None):


class SphereFlow(ManifoldFlow):
def __init__(self, n, logabs_jacobian, r=1.):
super().__init__(logabs_jacobian=logabs_jacobian)
def __init__(self, n, logabs_jacobian, r=1., n_hutchinson_samples=1):
super().__init__(logabs_jacobian=logabs_jacobian, n_hutchinson_samples=n_hutchinson_samples)
self.radius = r
# self.network = SimpleNN(n, hidden_size=50, output_size=1, max_radius=max_radius)

Expand All @@ -652,8 +699,8 @@ def gradient_r_given_theta(self, theta, context=None):
return grad_r_theta_aug.unsqueeze(-1)

class DeformedSphereFlow(ManifoldFlow):
def __init__(self, logabs_jacobian, r=1., manifold_type=1):
super().__init__(logabs_jacobian=logabs_jacobian)
def __init__(self, logabs_jacobian, r=1., manifold_type=1, n_hutchinson_samples=1):
super().__init__(logabs_jacobian=logabs_jacobian, n_hutchinson_samples=n_hutchinson_samples)
self.radius = r
self.manifold_type = manifold_type
# self.network = SimpleNN(n, hidden_size=50, output_size=1, max_radius=max_radius)
Expand Down Expand Up @@ -724,10 +771,12 @@ def gradient_r_given_theta(self, theta, context=None):
return grad_r_theta_aug.unsqueeze(-1)

class LpManifoldFlow(ManifoldFlow):
def __init__(self, norm, p, logabs_jacobian):
def __init__(self, norm, p, logabs_jacobian, given_radius=1.):
super().__init__(logabs_jacobian=logabs_jacobian)
self.norm = norm
self.p = p
assert given_radius > 0, "radius must be positive"
self.given_radius = given_radius
# self.register_buffer("initialized", torch.tensor(False, dtype=torch.bool))

def r_given_theta(self, theta, context=None):
Expand All @@ -737,7 +786,7 @@ def r_given_theta(self, theta, context=None):
r_theta = torch.cat((theta, torch.ones_like(theta[:,:1])), dim=1)
cartesian = spherical_to_cartesian_torch(r_theta)
p_norm = torch.linalg.vector_norm(cartesian, ord=self.p, dim=1)
r = self.norm / (p_norm + eps)
r = self.norm / (p_norm + eps) * self.given_radius

return r.unsqueeze(-1)

Expand Down Expand Up @@ -949,8 +998,6 @@ def forward(self, inputs, context):
output = torch.cat((clamped_thetas, clamped_last_theta), dim = -1)
logabsdet = output.new_zeros(inputs.shape[:-1])

breakpoint()

return output, logabsdet

def compute_mask(self, arr, vmin, vmax, right_included=False):
Expand Down Expand Up @@ -1002,7 +1049,8 @@ def compute_mask(self, arr, vmin, vmax, right_included=False):


def inverse(self, inputs, context):
return inputs, torch.zeros_like(inputs[...,0])
outputs, _ = self.forward(inputs, context)
return outputs, torch.zeros_like(outputs[...,0])

class ClampedThetaPositive(Transform):
def __init__(self, eps):
Expand Down

0 comments on commit 32aa5de

Please sign in to comment.