Skip to content

Commit

Permalink
Add wrapper around torch's native cdist to handle non-batched colle…
Browse files Browse the repository at this point in the history
…ctions of tensors (#170)
  • Loading branch information
nathanpainchaud authored Sep 8, 2023
1 parent f293a3c commit 4d33c78
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions vital/metrics/train/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,31 @@ def ntxent_loss(z_i: Tensor, z_j: Tensor, temperature: float = 1) -> Tensor:
loss = torch.sum(all_losses) / (2 * batch_size)

return loss


def cdist(x1: Tensor, x2: Tensor, **kwargs) -> Tensor:
"""Wrapper around torch's native `cdist` function to use it on non-batched inputs.
Args:
x1: ([B,]P,M), Input collection of row tensors.
x2: ([B,]R,M), Input collection of row tensors.
**kwargs: Additional parameters to pass along to torch's native `cdist`.
Returns:
([B,]P,R), Pairwise p-norm distances between the row tensors.
"""
if x1.ndim != x2.ndim:
raise ValueError(
f"Wrapper around torch's `cdist` only supports when both input tensors are identically batched or not. "
f"However, the current shapes do not match: {x1.shape=} and {x2.shape=}."
)

if no_batch := x1.ndim < 3:
x1 = x1[None, ...]
x2 = x2[None, ...]

dist = torch.cdist(x1, x2, **kwargs)
if no_batch:
dist = dist[0]

return dist

0 comments on commit 4d33c78

Please sign in to comment.