-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Increase the number of negatives in CachedGISTEmbedLoss to 2 #2932
Comments
Hello! This looks correct to me, well done. I think the only missed case is that if you have an Ideally I'd update all of the in-batch negatives losses to be able to accept any number of negatives. MultipleNegativesRankingLoss already does, it's just not very well documented.
|
Thank you for your advice. As you said, there was an error in calculate_loss, so I fixed it as below and it works fine. def calculate_loss(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor:
"""Calculate the cross-entropy loss. No need to cache the gradients."""
if len(reps) == 2:
anchor, positive = reps
anchor_guide, positive_guide = reps_guided
negative = None
negative_guide = None
elif len(reps) == 3:
anchor, positive, negative = reps
anchor_guide, positive_guide, negative_guide = reps_guided
elif len(reps) == 4:
anchor, positive, negative, negative_extra = reps
anchor_guide, positive_guide, negative_guide, negative_extra_guide = reps_guided
else:
raise ValueError("Expected 2, 3, or 4 embeddings, got {}".format(len(reps)))
anchor = torch.cat(anchor, dim=0)
positive = torch.cat(positive, dim=0)
anchor_guide = torch.cat(anchor_guide, dim=0)
positive_guide = torch.cat(positive_guide, dim=0)
# Handle the case where we have a negative sample
if negative:
negative = torch.cat(negative, dim=0)
negative_guide = torch.cat(negative_guide, dim=0)
# Handle the case where we have an extra negative sample (4 embeddings case)
if len(reps) == 4:
negative_extra = torch.cat(negative_extra, dim=0)
negative_extra_guide = torch.cat(negative_extra_guide, dim=0)
labels = torch.arange(anchor.size(0)).long().to(anchor.device)
batch_size = anchor.shape[0]
losses: list[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Preparing caches",
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size
# Let's compute the similarity matrices for the combinations of anchor and positive samples.
guided_ap_sim = self.sim_matrix(anchor_guide[b:e], positive_guide)
guided_aa_sim = self.sim_matrix(anchor_guide[b:e], anchor_guide)
guided_pp_sim = self.sim_matrix(positive_guide[b:e], positive_guide)
# Define the anchor threshold
guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1)
# Compute similarity scores for current mini-batch.
# anchor (mbsz,hdim), positive (bsz,hdim)
ap_sim = self.sim_matrix(anchor[b:e], positive) # (mbsz,bsz)
aa_sim = self.sim_matrix(anchor[b:e], anchor)
pp_sim = self.sim_matrix(positive[b:e], positive)
# Find which samples cannot be used as negatives because they are
# more similar to the query than the assigned positive as deemed by the guide model.
# For these samples, we mask them with -inf to basically ignore their contribution to
# the loss.
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf
scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)
# Handle the case where we have a negative sample
if negative is not None:
guided_an_sim = self.sim_matrix(anchor_guide[b:e], negative_guide)
an_sim = self.sim_matrix(anchor[b:e], negative)
an_sim[guided_an_sim > guided_sim] = -torch.inf
scores = torch.cat([scores, an_sim], dim=1)
# Handle the case where we have an extra negative sample
if len(reps) == 4:
guided_ane_sim = self.sim_matrix(anchor_guide[b:e], negative_extra_guide)
ane_sim = self.sim_matrix(anchor[b:e], negative_extra)
ane_sim[guided_ane_sim > guided_sim] = -torch.inf
scores = torch.cat([scores, ane_sim], dim=1)
scores = scores / self.temperature
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
losses.append(loss_mbatch)
loss = sum(losses)
return loss
I wrote it so that it works well for any number of negatives as below. Can you tell me what the problem is? The train loss is 0.000. class CachedGISTEmbedLoss(losses.CachedGISTEmbedLoss):
def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guided: List[List[Tensor]]) -> Tensor:
"""Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
if len(reps) != len(reps_guided):
raise ValueError("reps and reps_guided must have the same length")
# Concatenate embeddings along the batch dimension
concatenated_reps = [torch.cat(rep, dim=0) for rep in reps]
concatenated_guided_reps = [torch.cat(rep_guide, dim=0) for rep_guide in reps_guided]
labels = torch.arange(concatenated_reps[0].size(0)).long().to(concatenated_reps[0].device)
batch_size = concatenated_reps[0].shape[0]
losses: List[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Preparing caches",
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size
# Compute guided similarity matrices for anchor-positive, anchor-anchor, and positive-positive samples
guided_ap_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[1])
guided_aa_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[0])
guided_pp_sim = self.sim_matrix(concatenated_guided_reps[1][b:e], concatenated_guided_reps[1])
# Define the anchor threshold for each similarity matrix
guided_sim = guided_ap_sim.diagonal(offset=0).view(-1, 1)
# Compute similarity scores for the current mini-batch
ap_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[1]) # anchor-positive similarity
aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0]) # anchor-anchor similarity
pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1]) # positive-positive similarity
# Apply thresholds based on guided model similarities
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf
# Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)
# If there are negatives (len(reps) > 2), process them
if len(concatenated_reps) > 2:
for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive
guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
neg_sim[guided_neg_sim > guided_sim[0]] = -torch.inf
scores = torch.cat([scores, neg_sim], dim=1)
# Normalize the scores and calculate the cross-entropy loss
scores = scores / self.temperature
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
loss_mbatch.backward()
losses.append(loss_mbatch.detach())
loss = sum(losses).requires_grad_()
self.cache = [[r.grad for r in rs] for rs in reps] # Cache the gradients
return loss
def calculate_loss(self, reps: List[List[Tensor]], reps_guided: List[List[Tensor]]) -> Tensor:
"""Generalized function to calculate the cross-entropy loss without caching gradients."""
if len(reps) != len(reps_guided):
raise ValueError("reps and reps_guided must have the same length")
# Concatenate embeddings along the batch dimension
concatenated_reps = [torch.cat(rep, dim=0) for rep in reps]
concatenated_guided_reps = [torch.cat(rep_guide, dim=0) for rep_guide in reps_guided]
labels = torch.arange(concatenated_reps[0].size(0)).long().to(concatenated_reps[0].device)
batch_size = concatenated_reps[0].shape[0]
losses: List[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Calculating loss",
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size
# Compute guided similarity matrices for anchor-positive, anchor-anchor, and positive-positive samples
guided_ap_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[1])
guided_aa_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[0])
guided_pp_sim = self.sim_matrix(concatenated_guided_reps[1][b:e], concatenated_guided_reps[1])
# Define the anchor threshold for each similarity matrix
guided_sim = guided_ap_sim.diagonal(offset=0).view(-1, 1)
# Compute similarity scores for the current mini-batch
ap_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[1]) # anchor-positive similarity
aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0]) # anchor-anchor similarity
pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1]) # positive-positive similarity
# Apply thresholds based on guided model similarities
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf
# Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)
# If there are negatives (len(reps) > 2), process them
if len(concatenated_reps) > 2:
for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive
guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
neg_sim[guided_neg_sim > guided_sim[0]] = -torch.inf
scores = torch.cat([scores, neg_sim], dim=1)
# Normalize the scores and calculate the cross-entropy loss
scores = scores / self.temperature
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
losses.append(loss_mbatch)
loss = sum(losses)
return loss |
In the negatives filtering part, I modified it so that guided_sim considers all values of the mini-batch instead of just considering the 0th value with guided_sim[0]. I will soon experiment to see if this loss function works correctly. After change # If there are negatives (len(reps) > 2), process them
if len(concatenated_reps) > 2:
for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive
guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
neg_sim[guided_neg_sim > guided_sim] = -torch.inf
scores = torch.cat([scores, neg_sim], dim=1) Before change # If there are negatives (len(reps) > 2), process them
if len(concatenated_reps) > 2:
for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive
guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
neg_sim[guided_neg_sim > guided_sim[0]] = -torch.inf
scores = torch.cat([scores, neg_sim], dim=1) |
There was a problem that the anchor-positive similarity of the guide in the mini-batch was fixed to 0, which should start from the starting point (b) of the mini-batch. # Before change : guided_sim = guided_ap_sim.diagonal(offset=0).view(-1, 1)
guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1) After fixing this, it works normally. Below is the full code of the changed class. class CachedGISTEmbedLoss(losses.CachedGISTEmbedLoss):
def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guided: List[List[Tensor]]) -> Tensor:
"""Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
if len(reps) != len(reps_guided):
raise ValueError("reps and reps_guided must have the same length")
# Concatenate embeddings along the batch dimension
concatenated_reps = [torch.cat(rep, dim=0) for rep in reps]
concatenated_guided_reps = [torch.cat(rep_guide, dim=0) for rep_guide in reps_guided]
labels = torch.arange(concatenated_reps[0].size(0)).long().to(concatenated_reps[0].device)
batch_size = concatenated_reps[0].shape[0]
losses: List[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Preparing caches",
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size
# Compute guided similarity matrices for anchor-positive, anchor-anchor, and positive-positive samples
guided_ap_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[1])
guided_aa_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[0])
guided_pp_sim = self.sim_matrix(concatenated_guided_reps[1][b:e], concatenated_guided_reps[1])
# Define the anchor threshold for each similarity matrix
guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1)
# Compute similarity scores for the current mini-batch
ap_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[1]) # anchor-positive similarity
aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0]) # anchor-anchor similarity
pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1]) # positive-positive similarity
# Apply thresholds based on guided model similarities
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf
# Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)
# If there are negatives (len(reps) > 2), process them
if len(concatenated_reps) > 2:
for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive
guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
neg_sim[guided_neg_sim > guided_sim] = -torch.inf
scores = torch.cat([scores, neg_sim], dim=1)
# Normalize the scores and calculate the cross-entropy loss
scores = scores / self.temperature
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
loss_mbatch.backward()
losses.append(loss_mbatch.detach())
loss = sum(losses).requires_grad_()
self.cache = [[r.grad for r in rs] for rs in reps] # Cache the gradients
return loss
def calculate_loss(self, reps: List[List[Tensor]], reps_guided: List[List[Tensor]]) -> Tensor:
"""Generalized function to calculate the cross-entropy loss without caching gradients."""
if len(reps) != len(reps_guided):
raise ValueError("reps and reps_guided must have the same length")
# Concatenate embeddings along the batch dimension
concatenated_reps = [torch.cat(rep, dim=0) for rep in reps]
concatenated_guided_reps = [torch.cat(rep_guide, dim=0) for rep_guide in reps_guided]
labels = torch.arange(concatenated_reps[0].size(0)).long().to(concatenated_reps[0].device)
batch_size = concatenated_reps[0].shape[0]
losses: List[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Calculating loss",
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size
# Compute guided similarity matrices for anchor-positive, anchor-anchor, and positive-positive samples
guided_ap_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[1])
guided_aa_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[0])
guided_pp_sim = self.sim_matrix(concatenated_guided_reps[1][b:e], concatenated_guided_reps[1])
# Define the anchor threshold for each similarity matrix
guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1)
# Compute similarity scores for the current mini-batch
ap_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[1]) # anchor-positive similarity
aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0]) # anchor-anchor similarity
pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1]) # positive-positive similarity
# Apply thresholds based on guided model similarities
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf
# Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)
# If there are negatives (len(reps) > 2), process them
if len(concatenated_reps) > 2:
for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive
guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
neg_sim[guided_neg_sim > guided_sim] = -torch.inf
scores = torch.cat([scores, neg_sim], dim=1)
# Normalize the scores and calculate the cross-entropy loss
scores = scores / self.temperature
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
losses.append(loss_mbatch)
loss = sum(losses)
return loss You can see that it works well by running it in the colab below. https://colab.research.google.com/drive/1aU7xiepABsAG1UGk-1LuDkfGjAz3I4o3?usp=sharing How about opening a pull request so that others can use this code? |
Great work! Yes, I'd be very open to a PR to extend the behaviour of this class.
|
I submitted a pull request. feel free to modify / tell me your opinion. |
I am trying to train by increasing the number of negatives in CachedGISTEmbedLoss to 2 as shown below. Is there any theoretical problem that could occur during training? Training proceeds without error.
The text was updated successfully, but these errors were encountered: