Skip to content

Commit

Permalink
Add functional and class API for tensor-based (i.e. differentiable) R…
Browse files Browse the repository at this point in the history
…BF kernel (#171)
  • Loading branch information
nathanpainchaud authored Sep 19, 2023
1 parent 4d33c78 commit 49d8266
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
34 changes: 34 additions & 0 deletions vital/metrics/train/functional.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Sequence

import torch
from torch import Tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -206,3 +208,35 @@ def cdist(x1: Tensor, x2: Tensor, **kwargs) -> Tensor:
dist = dist[0]

return dist


def rbf_kernel(x1: Tensor, x2: Tensor = None, length_scale: float | Sequence[float] | Tensor = 1) -> Tensor:
"""Computes the Radial Basis Function kernel (aka Gaussian kernel).
Args:
x1: (M, E), Left argument of the returned kernel k(x1,x2).
x2: (N, E), Right argument of the returned kernel k(x1,x2). If None, uses `x2=x1`.
length_scale: (1,) or (E,), The length-scale of the kernel. If a float, an isotropic kernel is used.
If a Sequence or Tensor, an anisotropic kernel is used to define the length-scale of each feature dimension.
If None, use Silverman's rule-of-thumb to compute an estimate of the optimal (anisotropic) length-scale.
Returns:
(M, N), The kernel k(x1,x2).
"""
if isinstance(length_scale, Sequence):
# Make sure that if the length-scale is specified for each feature dimension, it is in tensor format
length_scale = torch.tensor(length_scale, device=x1.device)

if x2 is None:
x2 = x1.clone()

# Use the trick of distributing `length_scale` on the inputs to minimize computations
x1, x2 = x1 / length_scale, x2 / length_scale
# Reshape inputs so that the squared Euclidean distance can be easily computed using simple broadcasting
x1 = x1[:, None, :] # (N, D) -> (N, 1, D)
x2 = x2[None, :, :] # (N, D) -> (1, N, D)

# Compute the RBF kernel
sq_dist = torch.sum((x1 - x2) ** 2, -1) # sq_dist = |x1 - x2|^2 / l^2
k = torch.exp(-0.5 * sq_dist) # exp( - sq_dist / 2 )
return k
37 changes: 36 additions & 1 deletion vital/metrics/train/metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import Sequence

from torch import Tensor, nn

from vital.metrics.train.functional import differentiable_dice_score, monotonic_regularization_loss, ntxent_loss
from vital.metrics.train.functional import (
differentiable_dice_score,
monotonic_regularization_loss,
ntxent_loss,
rbf_kernel,
)


class DifferentiableDiceCoefficient(nn.Module):
Expand Down Expand Up @@ -105,3 +112,31 @@ def forward(self, z_i: Tensor, z_j: Tensor):
(1,), Calculated NT-Xent loss.
"""
return ntxent_loss(z_i, z_j, temperature=self.temperature)


class RBFKernel(nn.Module):
"""Computes the Radial Basis Function kernel (aka Gaussian kernel)."""

def __init__(self, length_scale: float | Sequence[float] | Tensor = 1):
"""Initializes class instance.
Args:
length_scale: length_scale: (1,) or (E,), The length-scale of the kernel. If a float, an isotropic kernel is
used. If a Sequence or Tensor, an anisotropic kernel is used to define the length-scale of each feature
dimension.
"""
super().__init__()
self.length_scale = length_scale

def forward(self, x1: Tensor, x2: Tensor = None) -> Tensor:
"""Actual kernel calculation.
Args:
x1: (M, E), Left argument of the returned kernel k(x1,x2).
x2: (N, E), Right argument of the returned kernel k(x1,x2). If None, uses `x2=x1`,
which ends up evaluating k(x1,x1).
Returns:
(N, M), The kernel k(x1,x2).
"""
return rbf_kernel(x1, x2=x2, length_scale=self.length_scale)

0 comments on commit 49d8266

Please sign in to comment.