Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNRL with Multiple hard negatives per query and NoDuplicatesBatchSampler #2954

Open
ArthurCamara opened this issue Sep 23, 2024 · 7 comments

Comments

@ArthurCamara
Copy link
Contributor

ArthurCamara commented Sep 23, 2024

As stated in the documentation, the MNRL loss (and, by extension, it’s Cached variant) can handle more than one hard negative per anchor, as these extra negatives will be included in the “pool” of all the negatives each anchor will be scored against.

That being said, there are two issues related to that in the current version:
First, using the recommended NoDuplicatesBatchSampler will ignore all but the first hard negative in the dataset for each anchor. For instance, take the training_nli_v3.py example. The dataset has a format (a_1, p_1, n_2), (a_1, p_1, n_2)... with multiple hard negatives per query (See rows 17-21 of the triplet subset). However, the sampler will skip all rows except the first when building each batch, as both anchor and positive are already present in the batch_values set. One could try to work around it by setting the valid_label_columns, but it will just ignore these columns when considering what is already present in the batch.

Addressing this is somewhat straightforward. We can just skip the rows where either the positive or the negative have already been seen between all the negatives in the batch (or, in the case where there is no negative column, we can revert to the current behaviour:

    def __iter__(self) -> Iterator[list[int]]:
        """
        Iterate over the remaining non-yielded indices. For each index, check if the sample values are already in the
        batch. If not, add the sample values to the batch keep going until the batch is full. If the batch is full, yield
        the batch indices and continue with the next batch.
        """
        if self.generator and self.seed:
            self.generator.manual_seed(self.seed + self.epoch)
        anchor_column = self.dataset.column_names[0]
        positive_column = self.dataset.column_names[1]
        negative_column = (
            self.dataset.column_names[2] if len(self.dataset.column_names) > 2 else None
        )
        remaining_indices = set(
            torch.randperm(len(self.dataset), generator=self.generator).tolist()
        )
        while remaining_indices:
            batch_values = set()
            batch_indices = []
            for index in remaining_indices:
                sample = self.dataset[index]
                # Make sure that either the positive or the negative ARE NOT in the seen values
                if negative_column:
                    if sample[negative_column] in batch_values:
                        continue
                elif sample[positive_column] in batch_values:
                    continue
                batch_indices.append(index)
                if len(batch_indices) == self.batch_size:
                    yield batch_indices
                    break

                batch_values.add(sample[anchor_column])
                batch_values.add(sample[positive_column])
            else:
                # NOTE: some indices might still have been ignored here
                if not self.drop_last:
                    yield batch_indices

            remaining_indices -= set(batch_indices)

However, even by fixing this, MNRL would still behave different from what we would expect. In the current implementation, if the dataset has multiple hard negatives in the format (a_1, p_1, n_1), (a_1, p_1, n_2), the loss would be computed n_nard_negatives times for each anchor, as each time (a_1, p_1, n_k) happens in the dataset, the loss wrt. a_1 will be computed again.
Ideally, we would want to add all hard negatives to the (larger) pool of negatives, and compute the positive score just once.

The easier way around this (IMO) is to allow for multiple negative rows (similar to the output of the mine_hard_negatives if as_triplets is set to False in the sampler.

This means changing the __iter__ snippet above to something like this:


negative_columns = [self.dataset.column_names[i] for i in range(2, len(self.dataset.column_names))]
(…)
if negative_columns:
    if any(sample[negative_column] in batch_values for negative_column in negative_columns):
        continue

From some initial tests, this seems to be enough to make it work with MNRL and CMNRL, but there could be more to it.

@tomaarsen
Copy link
Collaborator

Hello!

Apologies for the delayed response. Thanks for the detailed writeup!

  1. In this dataset (see screenshot), it will indeed skip the 4 samples with identical anchor/positive and a different negative in the first batch. Apart from niche scenarios (e.g. exclusively the same anchor for the entire dataset), pretty much all data should eventually be fed through the model during training, just in different batches.
    The reason that we don't exclusively prevent duplicates among the positives/negatives, but also among the anchors, is because of the Symmetric losses: MNSRL and CMNSRL. These losses (although they don't accept negatives) will also train with "Given a positive, find the sample with the highest similarity out of all anchors in the batch.". In short, it uses the anchors as the in-batch negatives, so here we also need to avoid duplicate anchors.

image

  1. You also mention that in this triplet approach, if you have e.g. 5 hard negatives for each anchor-positive pair, then the anchor-positive gets trained 5x, whereas each anchor-negative only gets trained once. I agree that this might be unexpected, and may also perform worse. I can't say for certain whether it results in worse models or not, though. And I agree: adding more negative columns like in mine_hard_negatives is a solid fix.

I think the NoDuplicatesBatchSampler can be seen as an bit of an overkill, i.e. it does a little more than might be strictly necessary for e.g. normal MNRL. However, I think it's strength is its generalisability, simplicity and consistency. It does what it says on the tin, even if it's slightly suboptimal.

I knew when designing the trainer that people would want to extend the samplers beyond what the library offers out of the gate, so I've exposed the get_batch_sampler (and get_multi_dataset_batch_sampler) to allow poweruers to subclass the trainer and customize their training further. With other words, you're recommended to implement e.g. a TripletBatchSampler and use that for the sampling.

  • Tom Aarsen

@ArthurCamara
Copy link
Contributor Author

Hello!

Apologies for the delayed response. Thanks for the detailed writeup!

  1. In this dataset (see screenshot), it will indeed skip the 4 samples with identical anchor/positive and a different negative in the first batch. Apart from niche scenarios (e.g. exclusively the same anchor for the entire dataset), pretty much all data should eventually be fed through the model during training, just in different batches.
    The reason that we don't exclusively prevent duplicates among the positives/negatives, but also among the anchors, is because of the Symmetric losses: MNSRL and CMNSRL. These losses (although they don't accept negatives) will also train with "Given a positive, find the sample with the highest similarity out of all anchors in the batch.". In short, it uses the anchors as the in-batch negatives, so here we also need to avoid duplicate anchors.

image

Got it! So the (main) issue is that this would break the Symmetric Losses. Makes more sense now.

  1. You also mention that in this triplet approach, if you have e.g. 5 hard negatives for each anchor-positive pair, then the anchor-positive gets trained 5x, whereas each anchor-negative only gets trained once. I agree that this might be unexpected, and may also perform worse. I can't say for certain whether it results in worse models or not, though. And I agree: adding more negative columns like in mine_hard_negatives is a solid fix.

Yes, in the best-case scenario, we will be training the model in 5x more samples than expected. Worst-case, it can significantly hurt performance (I don't have numbers to back this, just a hunch).

I think the NoDuplicatesBatchSampler can be seen as an bit of an overkill, i.e. it does a little more than might be strictly necessary for e.g. normal MNRL. However, I think it's strength is its generalisability, simplicity and consistency. It does what it says on the tin, even if it's slightly suboptimal.

I knew when designing the trainer that people would want to extend the samplers beyond what the library offers out of the gate, so I've exposed the get_batch_sampler (and get_multi_dataset_batch_sampler) to allow poweruers to subclass the trainer and customize their training further. With other words, you're recommended to implement e.g. a TripletBatchSampler and use that for the sampling.

  • Tom Aarsen

So the idea would be implement another sampler to filter the negatives accordingly, and potentially handling the multiple negatives columns example, right? Do you think there is enough interest to add this as a PR? I think I've already implemented 70% of it on my side.

@tomaarsen
Copy link
Collaborator

Got it! So the (main) issue is that this would break the Symmetric Losses. Makes more sense now.

Exactly.

Worst-case, it can significantly hurt performance (I don't have numbers to back this, just a hunch).

Yeah I have that hunch too, it smells like overfitting on the positives.

So the idea would be implement another sampler to filter the negatives accordingly, and potentially handling the multiple negatives columns example, right? Do you think there is enough interest to add this as a PR? I think I've already implemented 70% of it on my side.

I think it might be a bit too similar as the NoDuplicatesBatchSampler, and perhaps then it gets a smidge confusing. E.g. the new sampler would be recommended for the in-batch negatives losses, but not the Symmetric ones. And despite being for Triplets, it actually isn't recommended for the Batch...TripletLoss classes, and it does nothing for the TripletLoss itself.
In short, I'm a bit hesitant to add it as I'm worried that I'll overload the users with way too many options. I think the list of losses in Sentence Transformers is already enough to make potential users dizzy, hah.

  • Tom Aarsen

@ArthurCamara
Copy link
Contributor Author

I agree it can be confusing. I need to implement it anyway, so how about this? I will create a branch to see what do I need to change to make the training work. Depending on the level of changes, we discuss if a new sampler makes sense or not.

@ArthurCamara
Copy link
Contributor Author

@tomaarsen I've added a MultipleNegativesBatchSampler in #2960. It is still a draft, but let me know what you think.

@claeyzre
Copy link

@ArthurCamara thanks for the PR, I like the idea and wanted to know if you tested your assumption on whether avoiding training multiple time the anchor-positive pairs is improving performance or not.

@ArthurCamara
Copy link
Contributor Author

@ArthurCamara thanks for the PR, I like the idea and wanted to know if you tested your assumption on whether avoiding training multiple time the anchor-positive pairs is improving performance or not.

Hi @claeyzre!
I haven't tested it yet. But I still think that training the same positive pair 5x and each negative only once is not what the user would (necessarily). IMHO, the ideal approach when the dataset has many hard negatives for a single positive is to add all the hard negatives to the same batch, and, if using one of the losses with in-batch negatives, add the remaining negatives from there. The easiest way to do this is to allow the dataset to have multiple negatives in the same row, like what happens when calling mine_hard_negatives with as_triplets=False. I've implemented this on this (draft) PR: #2960

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants