Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about MultipleNegativesRankingLoss and gradient accumulation steps #2916

Open
DogitoErgoSum opened this issue Aug 29, 2024 · 8 comments

Comments

@DogitoErgoSum
Copy link

How does the MultipleNegativesRankingLoss function when used with gradient accumulation steps?

According to the docs

For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores.

Are the negatives from other steps used (during accumulation), or are only the negatives from the samples in the current batch (per_device_train_batch_size) used?

@tomaarsen
Copy link
Collaborator

Hello!

Great question! It's the latter, only the negatives from the samples in the current batch, i.e. per_device_train_batch_size samples, are used. Gradient accumulation does not result in better performance due to larger batch sizes for the in-batch negative losses.

For that, I would recommend using the Cached losses, such as CachedMultipleNegativesRankingLoss. In short, this loss is equivalent to MultipleNegativesRankingLoss, but cleverly uses caches and mini-batches to reach very high per_device_train_batch_size with constant memory usage based on the mini-batch size. For example, you can use CachedMultipleNegativesRankingLoss with a per_device_train_batch_size of 4096 with a mini-batch size of 64, and you'll get the same memory usage as MultipleNegativesRankingLoss with a per_device_train_batch_size of 64. You'll get a stronger training signal, at the cost of some training speed overhead (about 20% usually).

  • Tom Aarsen

@DogitoErgoSum
Copy link
Author

Thank you for the fast answer!
I will try the cached version.

@DogitoErgoSum
Copy link
Author

Last question, how does BatchSamplers.NO_DUPLICATES work with gradient accumulation steps?

@DogitoErgoSum DogitoErgoSum reopened this Aug 29, 2024
@tomaarsen
Copy link
Collaborator

tomaarsen commented Aug 29, 2024

The "no duplicates" works on a per-batch level, so with e.g. a per_device_train_batch_size of 16 and a gradient accumulation steps of 4, then you'll get 4 batches per loss propagation where each batch does not have duplicate samples in them. With other words, no issues due to duplicates. There's no "cross-batch communication" when doing gradient accumulation other than that the losses from each batch get added together.

If you instead use CachedMNRL with no duplicates with e.g. a per_device_train_batch_size of 64 and a mini-batch size of 16, then you will get just 1 batch per loss propagation. Duplicates are also avoided in this batch, so there's no issues here either.

For context for those who don't know why not having "no duplicates" can be problematic for in-batch negative losses: if you have e.g. question-answer pairs, and answer Y for an unrelated question Y is the same as answer X for question X, then that answer will both be considered a positive and a negative, negating the usefulness of this sample.

Does that clear it up?

  • Tom Aarsen

@DogitoErgoSum
Copy link
Author

Does that clear it up?

Yes. This raises another question, does the "no duplicates" checks for repeated anchors or positives?

@DogitoErgoSum
Copy link
Author

And suppose i use per_device_train_batch_size= size of training data. Will the "no duplicates" delete duplicates or divide the batch_size into N batches where there are no duplicates in each batch?

@DogitoErgoSum
Copy link
Author

Sorry for the question spam. If we use triplets instead of anchor-positive pairs, does the following still happen?

For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores.

@pesuchin
Copy link
Contributor

pesuchin commented Sep 5, 2024

Hello!

The following code section ensures that there are no duplicates among anchor, positive, and negative:

batch_values = set()
batch_indices = []
for index in remaining_indices:
sample_values = set(self.dataset[index].values())
if sample_values & batch_values:
continue

When using anchor, positive, and negative instead of anchor-positive pairs, sample_values would be {anchor, positive, negative}, and a duplication check is performed with sample_values & batch_values. Therefore, if any of the texts in the batch are duplicates, they will be resampled.

To illustrate with a specific example, in the following case, sample_values & batch_values would result in {"positive1"}, indicating a duplication, so resampling would occur:

batch_values = {"anchor1", "positive1", "negative1", "anchor2", "positive2", "negative2"}
sample_values = {"anchor3", "positive1", "negative3"}

In this way, it guarantees that there are no duplicates for all of anchor, positive, and negative samples. Therefore, I believe the answer to the following question would be Yes:

This raises another question, does the "no duplicates" checks for repeated anchors or positives?

I also think the answer to the following question would be Yes:

If we use triplets instead of anchor-positive pairs, does the following still happen?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants