Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Triplet Margin Loss] Issue 1118 #1120

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
e9811d0
Triplet Marging Loss added
cvnad1 Oct 26, 2024
c19ceb3
Triplet Marging Loss added changes
cvnad1 Oct 26, 2024
c8f6937
Added Triplet Margin Loss tests
cvnad1 Oct 26, 2024
34a8610
added spaces
cvnad1 Oct 26, 2024
1778d05
indentation changes
cvnad1 Oct 26, 2024
5ab74eb
Reduced Indentation and line lengths
cvnad1 Oct 26, 2024
510d1c6
corrected apostrophe marks
cvnad1 Oct 26, 2024
f4f93c3
corrected test file conventions
cvnad1 Oct 26, 2024
52498f5
Triplet loss changes
cvnad1 Nov 2, 2024
ad40005
tests
cvnad1 Nov 2, 2024
ce16c2a
Added jit vmap compatibility
Saanidhyavats Nov 2, 2024
85efeb8
Standardized the code
Saanidhyavats Nov 2, 2024
a15a2c7
Minor correction
Saanidhyavats Nov 2, 2024
fc3c32a
tests
cvnad1 Nov 2, 2024
b9f35a5
modified tests jit vmap
cvnad1 Nov 2, 2024
1d18f1c
whitespace cleared
cvnad1 Nov 2, 2024
b235710
changed dim to shape
cvnad1 Nov 2, 2024
f36416e
changed dim to shape
cvnad1 Nov 2, 2024
490d941
changed vmap and jit
cvnad1 Nov 2, 2024
7d1b43b
removed unused packages
cvnad1 Nov 2, 2024
6a3a9a1
tests modified args
cvnad1 Nov 3, 2024
80c95fb
testing loss function
cvnad1 Nov 3, 2024
7f877ef
testing loss function
cvnad1 Nov 3, 2024
77f6b33
testing loss function
cvnad1 Nov 3, 2024
6b92b6c
Made changes to pass all the tests
Saanidhyavats Nov 3, 2024
7df60cb
Minor change
Saanidhyavats Nov 3, 2024
31a0ece
trailing space
Saanidhyavats Nov 3, 2024
894dfa2
refactored the code
cvnad1 Nov 4, 2024
6bf8aa5
Added triplet_margin_loss in api docs
Saanidhyavats Nov 4, 2024
63a91f6
refactored the code
cvnad1 Nov 4, 2024
9df7799
Merge remote-tracking branch 'origin/issue-1118' into issue-1118
Saanidhyavats Nov 4, 2024
e586719
refactored the code
cvnad1 Nov 4, 2024
a7dd576
Merge branch 'issue-1118' of https://github.com/cvnad1/optax into iss…
cvnad1 Nov 4, 2024
cc8377d
Minor correction
Saanidhyavats Nov 4, 2024
afec78f
Minor correction
Saanidhyavats Nov 4, 2024
4d8f50f
Minor correction
Saanidhyavats Nov 4, 2024
3125b82
Trailing Spaces
Saanidhyavats Nov 4, 2024
8706919
Trailing Spaces
Saanidhyavats Nov 4, 2024
954d27e
Trailing Spaces at line 108
Saanidhyavats Nov 4, 2024
fa6e7a7
Removed jit
Saanidhyavats Nov 4, 2024
35b3e50
Documentation
Saanidhyavats Nov 7, 2024
a56c0db
Testing
Saanidhyavats Nov 7, 2024
b46f945
Changed the function code and added test accordingly.
Saanidhyavats Nov 10, 2024
12a1efa
Merge branch 'main' into issue-1118
Saanidhyavats Nov 10, 2024
d91ecc7
pylint correction
Saanidhyavats Nov 10, 2024
61fa085
Merge remote-tracking branch 'origin/issue-1118' into issue-1118
Saanidhyavats Nov 10, 2024
ccd3ce5
pylint correction __init__.py
Saanidhyavats Nov 10, 2024
2f9a138
Added assertion for float
Saanidhyavats Nov 10, 2024
b4ed379
place
cvnad1 Nov 10, 2024
75183f1
Merge branch 'issue-1118' of https://github.com/cvnad1/optax into iss…
cvnad1 Nov 10, 2024
6a5e2f2
function arguments
Saanidhyavats Nov 11, 2024
8d78cf1
minor correction
Saanidhyavats Nov 11, 2024
9a641cc
Merge branch 'google-deepmind:main' into issue-1118
Saanidhyavats Nov 12, 2024
ec34333
Merge branch 'issue-1118' of https://github.com/cvnad1/optax into iss…
cvnad1 Nov 13, 2024
af1bdb2
Added Docstring
cvnad1 Nov 13, 2024
4b4a36e
parameterized test
Saanidhyavats Nov 13, 2024
4ddb7b8
Merge remote-tracking branch 'origin/issue-1118' into issue-1118
Saanidhyavats Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/api/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Losses
softmax_cross_entropy_with_integer_labels
sparsemax_loss
squared_error

triplet_margin_loss

Convex Kullback Leibler divergence
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -116,3 +116,7 @@ Sparsemax
~~~~~~~~~
.. autofunction:: sparsemax_loss
.. autofunction:: multiclass_sparsemax_loss

Triplet margin loss
~~~~~~~~~~~~~~~~~~~
.. autofunction:: triplet_margin_loss
1 change: 1 addition & 0 deletions optax/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@
from optax.losses._regression import log_cosh
from optax.losses._regression import squared_error
from optax.losses._self_supervised import ntxent
from optax.losses._self_supervised import triplet_loss
from optax.losses._smoothing import smooth_labels
68 changes: 67 additions & 1 deletion optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,72 @@ def ntxent(
denom = jnp.sum(jnp.exp(xcs_shift_diffs), axis=1, keepdims=True)
denom += numer_exp
log_softm = numer - jnp.log(denom)
loss = -jnp.where(matches == 1, log_softm, 0.0).sum() / matches.sum()
loss = -jnp.where(matches == 1, log_softm, 0.0).sum()/matches.sum()

return loss


def triplet_loss(
anchors: chex.Array,
positives: chex.Array,
negatives: chex.Array,
axis: chex.Numeric = -1,
p: chex.Numeric = 2,
margin: chex.Numeric = 1.0,
eps: chex.Numeric = 1e-6,
reduction: str = 'none',
) -> chex.Array:
"""Computes the triplet loss for a batch of embeddings.

Examples:
>>> import jax.numpy as jnp
>>> import optax
>>> import chex
>>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]])
>>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]])
>>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]])
>>> output =optax.triplet_loss(anchors, positives, negatives, margin=1.0)
>>> print(output)
>>> Array([0.14142442, 0.14142442], dtype=float32)

Args:
anchors: An array of anchor embeddings, with shape [batch, feature_dim].
positives: An array of positive embeddings
(similar to anchors), with shape [batch, feature_dim].
negatives: An array of negative embeddings
(dissimilar to anchors), with shape [batch, feature_dim].
axis: The axis along which to compute the distances
(default is -1).
p: The norm degree for distance calculation
(default is 2 for Euclidean distance).
margin: The minimum margin by which the positive distance
should be smaller than the negative distance.
eps: A small epsilon value to ensure numerical stability
in the distance calculation.
reduction: Specifies the reduction to apply to the
output: 'none' | 'mean' | 'sum'.

Returns:
The computed triplet loss as an array or scalar
depending on the reduction parameter.
If reduction is 'mean' or 'sum', returns a scalar.

References:
Learning shallow convolutional feature descriptors with triplet losses
by V. Balntas, E. Riba et al.
<https://bmva-archive.org.uk/bmvc/2016/papers/paper119/abstract119.pdf>
"""
chex.assert_type([anchors], float)
chex.assert_type([positives], float)
chex.assert_type([negatives], float)
positive_distance = jnp.sqrt(jnp.power(anchors - positives, p).sum(axis) + eps
)
negative_distance = jnp.sqrt(jnp.power(anchors - negatives, p).sum(axis) + eps
)
loss = jnp.maximum(positive_distance - negative_distance + margin, 0)
if reduction == 'mean':
return loss.mean()
elif reduction == 'sum':
return loss.sum()
else:
return loss
67 changes: 64 additions & 3 deletions optax/losses/_self_supervised_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# ==============================================================================
"""Tests for self-supervised losses in `optax.losses._self_supervised.py`."""

from absl.testing import absltest
from absl.testing import absltest, parameterized
import chex
import jax
import jax.numpy as jnp
import numpy as np
from optax.losses import _self_supervised

from optax.losses import _self_supervised

class NtxentTest(chex.TestCase):

Expand All @@ -46,7 +47,6 @@ def setUp(self):

@chex.all_variants
def test_batched(self):
"""Tests for a full batch."""
np.testing.assert_allclose(
self.variant(_self_supervised.ntxent)(self.ys, self.ts_1),
self.exp_1,
Expand All @@ -65,6 +65,67 @@ def test_batched(self):
atol=1e-4,
)

class TripletMarginLossTest(chex.TestCase, parameterized.TestCase):

def setUp(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using numerical values as expected returns.
They may fail depending on the backend for example.
You may consider simple test cases with a "handmade" function (see e.g. the lbfgs tests). You can check for specific inputs (like zeros or ones).

You may also add a test for some specific behaviors (like using swap here).

Also you should test this function under jit/vmap etc... (see the chex.all_variant utility in some other tests).

super().setUp()
self.a1 = jnp.ones((2, 2))
self.p1 = jnp.zeros((2, 2))
self.n1 = jnp.ones((2, 2)) * 2
self.a2 = jnp.zeros((2, 2))
self.p2 = jnp.ones((2, 2))
self.n2 = jnp.ones((2, 2)) * 2

@chex.all_variants
@parameterized.parameters([
{
'anchor': jnp.ones((2, 2)),
'positive': jnp.zeros((2, 2)),
'negative': jnp.ones((2, 2)) * 2,
'margin': 1.0,
},
{
'anchor': jnp.zeros((2, 2)),
'positive': jnp.ones((2, 2)),
'negative': jnp.ones((2, 2)) * 2,
'margin': 1.0,
}
])
def test_batched(self, anchor, positive, negative, margin):
def testing_triplet_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6):
ap_distance = jnp.sqrt(jnp.sum(jnp.power(a - p, p_norm)) + eps)
an_distance = jnp.sqrt(jnp.sum(jnp.power(a - n, p_norm)) + eps)
return jnp.maximum(ap_distance - an_distance + margin, 0)

handmade_result = testing_triplet_loss(
a=anchor, p=positive, n=negative, margin=margin
)
result = self.variant(_self_supervised.triplet_loss)(
anchor, positive, negative
)
np.testing.assert_allclose(result, handmade_result, atol=1e-4)

@chex.all_variants
@parameterized.parameters([
{
'anchor': jnp.ones((2, 2)),
'positive': jnp.zeros((2, 2)),
'negative': jnp.ones((2, 2)) * 2,
},
])
def test_vmap(self, anchor, positive, negative):
original_loss = _self_supervised.triplet_loss(anchor, positive,
negative, reduction='none')
anchor_batched = anchor.reshape(1, *anchor.shape)
positive_batched = positive.reshape(1, *positive.shape)
negative_batched = negative.reshape(1, *negative.shape)
vmap_loss = self.variant(jax.vmap(_self_supervised.triplet_loss,
in_axes=(0, 0, 0)))(anchor_batched,
positive_batched,
negative_batched)
np.testing.assert_allclose(vmap_loss.flatten(), original_loss.flatten()
, atol=1e-4)


if __name__ == '__main__':
absltest.main()
Loading