You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a question regarding the implementation of instance_contrastive_loss
definstance_contrastive_loss(z1, z2):
B, T=z1.size(0), z1.size(1)
ifB==1:
# contrastive loss requires pair.returnz1.new_tensor(0.)
z=torch.cat([z1, z2], dim=0) # 2B x T x Cz=z.transpose(0, 1) # T x 2B x Csim=torch.matmul(z, z.transpose(1, 2)) # T x 2B x 2Blogits=torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1)logits+=torch.triu(sim, diagonal=1)[:, :, 1:]
logits=-F.log_softmax(logits, dim=-1)
i=torch.arange(B, device=z1.device)
loss= (logits[:, i, B+i-1].mean() +logits[:, B+i, i].mean()) /2returnloss
In your implementation, you calculate the logits until [:,:,:-1] for tril and [:,:,1:] for triu. Why is this so? is there something that I have missed?
thank you in advance!
best,
The text was updated successfully, but these errors were encountered:
Hello, thank you for sharing your work!
I have a question regarding the implementation of instance_contrastive_loss
In your implementation, you calculate the logits until [:,:,:-1] for tril and [:,:,1:] for triu. Why is this so? is there something that I have missed?
thank you in advance!
best,
The text was updated successfully, but these errors were encountered: