Skip to content

Commit

Permalink
Generalize AR-VAE loss as monotonic regularization + refactor AR-VAE …
Browse files Browse the repository at this point in the history
…to use generic impl.
  • Loading branch information
nathanpainchaud committed Aug 24, 2023
1 parent b5e2403 commit cbaa6e0
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 15 deletions.
31 changes: 31 additions & 0 deletions vital/metrics/train/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,37 @@ def kl_div_zmuv(mu: Tensor, logvar: Tensor) -> Tensor:
return reduce(kl_div_by_samples, reduction="elementwise_mean")


def monotonic_regularization_loss(input: Tensor, target: Tensor, delta: float) -> Tensor:
"""Computes a regularization loss that enforces a monotonic relationship between the input and target.
Notes:
- This is a generalization of the attribute regularization loss proposed by the AR-VAE
(link to the paper: https://arxiv.org/pdf/2004.05485.pdf)
Args:
input: Input values to regularize so that they have a monotonic relationship with the `target` values.
target: Values used to determine the target monotonic ordering of the values.
delta: Hyperparameter that decides the spread of the posterior distribution.
Returns:
(1,), Monotonic regularization term for aligning the input to the target.
"""
# Compute input distance matrix
broad_input = input.view(-1, 1).repeat(1, len(input))
input_dist_mat = broad_input - broad_input.transpose(1, 0)

# Compute target distance matrix
broad_target = target.view(-1, 1).repeat(1, len(target))
target_dist_mat = broad_target - broad_target.transpose(1, 0)

# Compute regularization loss
input_tanh = torch.tanh(input_dist_mat * delta)
target_sign = torch.sign(target_dist_mat)
loss = F.l1_loss(input_tanh, target_sign)

return loss


def ntxent_loss(z_i: Tensor, z_j: Tensor, temperature: float = 1) -> Tensor:
"""Computes the NT-Xent loss for contrastive learning.
Expand Down
32 changes: 31 additions & 1 deletion vital/metrics/train/metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch import Tensor, nn

from vital.metrics.train.functional import differentiable_dice_score, ntxent_loss
from vital.metrics.train.functional import differentiable_dice_score, monotonic_regularization_loss, ntxent_loss


class DifferentiableDiceCoefficient(nn.Module):
Expand Down Expand Up @@ -52,6 +52,36 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
)


class MonotonicRegularizationLoss(nn.Module):
"""Computes a regularization loss that enforces a monotonic relationship between the input and target.
Notes:
- This is a generalization of the attribute regularization loss proposed by the AR-VAE
(link to the paper: https://arxiv.org/pdf/2004.05485.pdf)
"""

def __init__(self, delta: float):
"""Initializes class instance.
Args:
delta: Hyperparameter that decides the spread of the posterior distribution.
"""
super().__init__()
self.delta = delta

def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""Actual metric calculation.
Args:
input: Input values to regularize so that they have a monotonic relationship with the `target` values.
target: Values used to determine the target monotonic ordering of the values.
Returns:
(1,), Calculated monotonic regularization loss.
"""
return monotonic_regularization_loss(input, target, self.delta)


class NTXent(nn.Module):
"""Computes the NT-Xent loss for contrastive learning."""

Expand Down
17 changes: 3 additions & 14 deletions vital/tasks/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchmetrics.utilities.data import to_onehot

from vital.metrics.train.functional import kl_div_zmuv
from vital.metrics.train.metric import DifferentiableDiceCoefficient
from vital.metrics.train.metric import DifferentiableDiceCoefficient, MonotonicRegularizationLoss
from vital.tasks.generic import SharedStepsTask
from vital.utils.decorators import auto_move_data

Expand Down Expand Up @@ -298,6 +298,7 @@ def __init__(self, attrs: Sequence[str], gamma: float = 10, delta: float = 1, **
**kwargs: Additional parameters to pass along to ``super().__init__()``.
"""
super().__init__(**kwargs)
self._attr_reg_loss = MonotonicRegularizationLoss(delta)

def _compute_latent_space_metrics(self, out: Dict[str, Tensor], batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""Computes metrics on the input's encoding in the latent space.
Expand All @@ -323,19 +324,7 @@ def _compute_latent_space_metrics(self, out: Dict[str, Tensor], batch: Dict[str,
# Extract dimension to regularize and target for the current attribute
latent_code = out[self.model.encoding_tag][:, attr_idx]
attribute = batch[attr]

# Compute latent distance matrix
latent_code = latent_code.view(-1, 1).repeat(1, len(latent_code))
lc_dist_mat = latent_code - latent_code.transpose(1, 0)

# Compute attribute distance matrix
attribute = attribute.view(-1, 1).repeat(1, len(attribute))
attribute_dist_mat = attribute - attribute.transpose(1, 0)

# Compute regularization loss
lc_tanh = torch.tanh(lc_dist_mat * self.hparams.delta)
attribute_sign = torch.sign(attribute_dist_mat)
metrics[f"attr_reg/{attr}"] = F.l1_loss(lc_tanh, attribute_sign)
metrics[f"attr_reg/{attr}"] = self._attr_reg_loss(latent_code, attribute)

return metrics

Expand Down

0 comments on commit cbaa6e0

Please sign in to comment.