-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MNRL with Multiple hard negatives per query and NoDuplicatesBatchSampler #2954
Comments
Hello! Apologies for the delayed response. Thanks for the detailed writeup!
I think the I knew when designing the trainer that people would want to extend the samplers beyond what the library offers out of the gate, so I've exposed the get_batch_sampler (and get_multi_dataset_batch_sampler) to allow poweruers to subclass the trainer and customize their training further. With other words, you're recommended to implement e.g. a
|
Got it! So the (main) issue is that this would break the Symmetric Losses. Makes more sense now.
Yes, in the best-case scenario, we will be training the model in 5x more samples than expected. Worst-case, it can significantly hurt performance (I don't have numbers to back this, just a hunch).
So the idea would be implement another sampler to filter the negatives accordingly, and potentially handling the multiple negatives columns example, right? Do you think there is enough interest to add this as a PR? I think I've already implemented 70% of it on my side. |
Exactly.
Yeah I have that hunch too, it smells like overfitting on the positives.
I think it might be a bit too similar as the
|
I agree it can be confusing. I need to implement it anyway, so how about this? I will create a branch to see what do I need to change to make the training work. Depending on the level of changes, we discuss if a new sampler makes sense or not. |
@tomaarsen I've added a |
@ArthurCamara thanks for the PR, I like the idea and wanted to know if you tested your assumption on whether avoiding training multiple time the anchor-positive pairs is improving performance or not. |
Hi @claeyzre! |
As stated in the documentation, the MNRL loss (and, by extension, it’s Cached variant) can handle more than one hard negative per anchor, as these extra negatives will be included in the “pool” of all the negatives each anchor will be scored against.
That being said, there are two issues related to that in the current version:
First, using the recommended
NoDuplicatesBatchSampler
will ignore all but the first hard negative in the dataset for each anchor. For instance, take the training_nli_v3.py example. The dataset has a format(a_1, p_1, n_2), (a_1, p_1, n_2)...
with multiple hard negatives per query (See rows 17-21 of the triplet subset). However, the sampler will skip all rows except the first when building each batch, as both anchor and positive are already present in thebatch_values
set. One could try to work around it by setting thevalid_label_columns
, but it will just ignore these columns when considering what is already present in the batch.Addressing this is somewhat straightforward. We can just skip the rows where either the positive or the negative have already been seen between all the negatives in the batch (or, in the case where there is no negative column, we can revert to the current behaviour:
However, even by fixing this, MNRL would still behave different from what we would expect. In the current implementation, if the dataset has multiple hard negatives in the format
(a_1, p_1, n_1), (a_1, p_1, n_2)
, the loss would be computed n_nard_negatives times for each anchor, as each time(a_1, p_1, n_k)
happens in the dataset, the loss wrt.a_1
will be computed again.Ideally, we would want to add all hard negatives to the (larger) pool of negatives, and compute the positive score just once.
The easier way around this (IMO) is to allow for multiple negative rows (similar to the output of the
mine_hard_negatives
ifas_triplets
is set toFalse
in the sampler.This means changing the
__iter__
snippet above to something like this:From some initial tests, this seems to be enough to make it work with MNRL and CMNRL, but there could be more to it.
The text was updated successfully, but these errors were encountered: