Skip to content

Add cosine similarity metric #3203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/contrib/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Contrib module metrics

AveragePrecision
CohenKappa
CosineSimilarity
GpuInfo
PrecisionRecallCurve
ROC_AUC
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from ignite.contrib.metrics.gpu_info import GpuInfo
from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve
from ignite.contrib.metrics.roc_auc import ROC_AUC, RocCurve
from ignite.contrib.metrics.cosine_similarity import CosineSimilarity
97 changes: 97 additions & 0 deletions ignite/contrib/metrics/cosine_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Sequence, Union, Callable

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["CosineSimilarity"]


class CosineSimilarity(Metric):
r"""Calculates the mean of the `cosine similarity <https://en.wikipedia.org/wiki/Cosine_similarity>`_.

.. math:: \text{cosine\_similarity} = \frac{1}{N} \sum_{i=1}^N \frac{x_i \cdot y_i}{\max ( \| x_i \|_2 \| y_i \|_2 , \epsilon)}

where :math:`y_{i}` is the prediction tensor and :math:`x_{i}` is ground true tensor.

- ``update`` must receive output of the form ``(y_pred, y)``.

Args:
eps: a small value to avoid division by zero. Default: 1e-8
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.

``y_pred`` and ``y`` should have the same shape.

For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = CosineSimilarity()
metric.attach(default_evaluator, 'cosine_similarity')
preds = torch.tensor([
[1, 2, 4, 1],
[2, 3, 1, 5],
[1, 3, 5, 1],
[1, 5, 1 ,11]
]).float()
target = torch.tensor([
[1, 5, 1 ,11],
[1, 3, 5, 1],
[2, 3, 1, 5],
[1, 2, 4, 1]
]).float()
state = default_evaluator.run([[preds, target]])
print(state.metrics['cosine_similarity'])

.. testoutput::

0.5080491304397583
"""

def __init__(
self,
eps: float = 1e-8,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu")
):
super().__init__(output_transform, device)

self.eps = eps

_state_dict_all_req_keys = ("_sum_of_cos_similarities", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_cos_similarities = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred = output[0].flatten(start_dim=1).detach()
y = output[1].flatten(start_dim=1).detach()
cos_similarities = torch.nn.functional.cosine_similarity(y_pred, y, dim=1, eps=self.eps)
self._sum_of_cos_similarities += torch.sum(cos_similarities).to(self._device)
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_cos_similarities", "_num_examples")
def compute(self) -> Union[float, torch.Tensor]:
if self._num_examples == 0:
raise NotComputableError("CosineSimilarity must have at least one example before it can be computed.")
return self._sum_of_cos_similarities.item() / self._num_examples