-
-
Notifications
You must be signed in to change notification settings - Fork 648
Add cosine similarity metric #3203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
c7f0a1a
add cosine similarity
kzkadc 9c87580
update doc for cosine similarity metric
kzkadc 12a4d56
fix the position of the CosineSimilarity
kzkadc e09b86f
Update ignite/contrib/metrics/cosine_similarity.py
kzkadc 99852d5
autopep8 fix
kzkadc 54c16c1
move CosineSimilarity from contrib.metrics to metrics
kzkadc 9524d20
autopep8 fix
kzkadc b7dd7ea
fix typo
kzkadc ef33b49
Merge branch 'cosine_similarity' of github.com:kzkadc/ignite into cos…
kzkadc 52aeaee
fix typo
kzkadc 78b4029
Update ignite/metrics/cosine_similarity.py
kzkadc 6bfbe88
autopep8 fix
kzkadc a83e02c
Update ignite/metrics/cosine_similarity.py
kzkadc 9e4fd31
fix formatting
kzkadc 377baea
autopep8 fix
kzkadc 87d2854
Merge branch 'cosine_similarity' of github.com:kzkadc/ignite into cos…
kzkadc ad16ce7
fix formatting
kzkadc 6d3193e
autopep8 fix
kzkadc b2b8de4
add test for CosineSimilarity metric
kzkadc 8156dc6
autopep8 fix
kzkadc 0cd8440
Merge branch 'master' into cosine_similarity
vfdev-5 0778d07
Merge branch 'master' into cosine_similarity
vfdev-5 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from typing import Sequence, Union, Callable | ||
|
||
import torch | ||
|
||
from ignite.exceptions import NotComputableError | ||
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce | ||
|
||
__all__ = ["CosineSimilarity"] | ||
|
||
|
||
class CosineSimilarity(Metric): | ||
r"""Calculates the mean of the `cosine similarity <https://en.wikipedia.org/wiki/Cosine_similarity>`_. | ||
|
||
.. math:: \text{cosine\_similarity} = \frac{1}{N} \sum_{i=1}^N \frac{x_i \cdot y_i}{\max ( \| x_i \|_2 \| y_i \|_2 , \epsilon)} | ||
|
||
where :math:`y_{i}` is the prediction tensor and :math:`x_{i}` is ground true tensor. | ||
|
||
- ``update`` must receive output of the form ``(y_pred, y)``. | ||
|
||
Args: | ||
eps: a small value to avoid division by zero. Default: 1e-8 | ||
output_transform: a callable that is used to transform the | ||
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
you want to compute the metric with respect to one of the outputs. | ||
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. | ||
device: specifies which device updates are accumulated on. Setting the | ||
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is | ||
non-blocking. By default, CPU. | ||
|
||
Examples: | ||
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. | ||
The output of the engine's ``process_function`` needs to be in the format of | ||
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added | ||
to the metric to transform the output into the form expected by the metric. | ||
|
||
``y_pred`` and ``y`` should have the same shape. | ||
|
||
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. | ||
|
||
.. include:: defaults.rst | ||
:start-after: :orphan: | ||
|
||
.. testcode:: | ||
|
||
metric = CosineSimilarity() | ||
metric.attach(default_evaluator, 'cosine_similarity') | ||
preds = torch.tensor([ | ||
[1, 2, 4, 1], | ||
[2, 3, 1, 5], | ||
[1, 3, 5, 1], | ||
[1, 5, 1 ,11] | ||
]).float() | ||
target = torch.tensor([ | ||
[1, 5, 1 ,11], | ||
[1, 3, 5, 1], | ||
[2, 3, 1, 5], | ||
[1, 2, 4, 1] | ||
]).float() | ||
state = default_evaluator.run([[preds, target]]) | ||
print(state.metrics['cosine_similarity']) | ||
|
||
.. testoutput:: | ||
|
||
0.5080491304397583 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
eps: float = 1e-8, | ||
output_transform: Callable = lambda x: x, | ||
device: Union[str, torch.device] = torch.device("cpu") | ||
): | ||
super().__init__(output_transform, device) | ||
|
||
self.eps = eps | ||
|
||
_state_dict_all_req_keys = ("_sum_of_cos_similarities", "_num_examples") | ||
|
||
@reinit__is_reduced | ||
def reset(self) -> None: | ||
self._sum_of_cos_similarities = torch.tensor(0.0, device=self._device) | ||
self._num_examples = 0 | ||
|
||
@reinit__is_reduced | ||
def update(self, output: Sequence[torch.Tensor]) -> None: | ||
y_pred = output[0].flatten(start_dim=1).detach() | ||
y = output[1].flatten(start_dim=1).detach() | ||
cos_similarities = torch.nn.functional.cosine_similarity(y_pred, y, dim=1, eps=self.eps) | ||
self._sum_of_cos_similarities += torch.sum(cos_similarities).to(self._device) | ||
self._num_examples += y.shape[0] | ||
|
||
@sync_all_reduce("_sum_of_cos_similarities", "_num_examples") | ||
def compute(self) -> Union[float, torch.Tensor]: | ||
if self._num_examples == 0: | ||
raise NotComputableError("CosineSimilarity must have at least one example before it can be computed.") | ||
return self._sum_of_cos_similarities.item() / self._num_examples |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.