-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] pytests for losses #3167
base: master
Are you sure you want to change the base?
Conversation
Hello! I think this is heading in the right direction, but I do have some suggestions: In particular, I think there's a risk that our test implementation has similar/the same bug as our real implementation. However, we still want to ensure that the loss "works". So, one option is to create 2 batches. If we want to use the same model for both, then we have to create 1 "good" batch and 1 "bad" batch. The former is normal data, whereas the latter is opposite of reality: the positive is the negative, or related texts are marked as 0.0 similarity, etc. Then, the same trained model should give a low loss for the "good" batch and a high loss for the "bad" batch. We should in both cases still test that the output 1) is a torch Tensor, 2) not I think this is safer than just calculating an "expected loss", because if our implementation is buggy, that expected loss is probably also wrong. And then we'd have to create the good and bad batches for each row in the Loss Overview, and then use each of the appropriate loss functions with those good/bad batches and ensure that Does that make sense?
|
umm i must admit i dont completely understand, do you mind giving an example? |
Apologies, it is a bit confusing! Here is a detailed example: from __future__ import annotations
import pytest
import torch
from torch import nn
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MultipleNegativesRankingLoss, CachedMultipleNegativesRankingLoss, TripletLoss, CachedGISTEmbedLoss, GISTEmbedLoss
from sentence_transformers.util import batch_to_device
# TODO: Preferably initialize the guide model in a fixture
GUIDE_MODEL = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
anchor_positive_negative_triplet = {
"losses": [
(MultipleNegativesRankingLoss, {}),
(CachedMultipleNegativesRankingLoss, {}),
(TripletLoss, {}),
(CachedGISTEmbedLoss, {"guide": GUIDE_MODEL}),
(GISTEmbedLoss, {"guide": GUIDE_MODEL}),
],
"correct": Dataset.from_dict({
"anchor": ["It's very sunny outside", "I love playing soccer", "I am a student"],
"positive": ["The sun is out today", "I like playing soccer", "I am studying at university"],
"negative": ["Data science is fun", "Cacti are beautiful", "Speakers are loud"],
}),
"incorrect": Dataset.from_dict({
"anchor": ["It's very sunny outside", "I love playing soccer", "I am a student"],
"positive": ["Data science is fun", "Cacti are beautiful", "Speakers are loud"],
"negative": ["The sun is out today", "I like playing soccer", "I am studying at university"],
}),
}
LOSS_TEST_CASES = [
(loss_class, loss_args, anchor_positive_negative_triplet["correct"], anchor_positive_negative_triplet["incorrect"])
for loss_class, loss_args in anchor_positive_negative_triplet["losses"]
]
def prepare_features_labels_from_dataset(model: SentenceTransformer, dataset: Dataset):
device = model.device
features = [
batch_to_device(model.tokenize(dataset[column]), device) for column in dataset.column_names if column not in ["label", "score"]
]
labels = None
if "label" in dataset.column_names:
labels = torch.tensor(dataset["label"]).to(device)
elif "score" in dataset.column_names:
labels = torch.tensor(dataset["score"]).to(device)
return features, labels
def get_and_assert_loss_from_dataset(model: SentenceTransformer, loss_fn: nn.Module, dataset: Dataset):
features, labels = prepare_features_labels_from_dataset(model, dataset)
loss = loss_fn.forward(features, labels)
assert isinstance(loss, torch.Tensor), f"Loss should be a torch.Tensor, but got {type(loss)}"
assert loss.item() > 0, "Loss should be positive"
assert loss.shape == (), "Loss should be a scalar"
assert loss.requires_grad, "Loss should require gradients"
return loss
@pytest.mark.parametrize("loss_class, loss_args, correct, incorrect", LOSS_TEST_CASES)
def test_loss_function(stsb_bert_tiny_model_reused: SentenceTransformer, loss_class, loss_args, correct, incorrect):
model = stsb_bert_tiny_model_reused
loss_fn = loss_class(model, **loss_args)
correct_loss = get_and_assert_loss_from_dataset(model, loss_fn, correct)
incorrect_loss = get_and_assert_loss_from_dataset(model, loss_fn, incorrect)
assert correct_loss < incorrect_loss, "Loss should be lower for correct data than for incorrect data" It can be changed up a bit, but the overall idea is that we have 1 batch of "correct" data and 1 batch of "incorrect" data. If we use a trained model, then the loss of the "correct" data will be lower than the loss of the "incorrect" data. Does that make some more sense? How this file would be structured can be updated to whatever is convenient.
|
yeap, thanks so much tom, working on it! |
i couldnt find an elegant way to pass the guide model as fixture, would this work? |
I think there is a way for |
cool then , ill add for other losses |
wb for losses like softmax?
I assume its since the test tiny model hasnt been trained and the random init makes the losses inconsistent ? |
Oh, I hadn't thought about that the tiny model might not be trained enough, hah. And otherwise we'll have to use a bigger, normal model, like all-MiniLM-L6-v2.
|
yeah, this models seems to be working for now. thanks! |
tried using a larger model and the smaller ones as well, the losses are still kinda random. Also i think ive added the losses that can be tested together, i think the other losses have to be handled differently, wdyt? |
Im writing this PR to help write test cases for all the losses .
Im starting with ContrastiveLoss and I wanted to get some early feedback on this. Like you mentioned, the idea is have a single parametrized func and pass all loss cases into it.
wdyt
@tomaarsen