diff --git a/rankcse/models.py b/rankcse/models.py index 9be41e00..c20dcd95 100644 --- a/rankcse/models.py +++ b/rankcse/models.py @@ -70,9 +70,7 @@ def __init__(self, tau, gamma_): self.student_temp_scaled_sim = Similarity(tau) self.gamma_ = gamma_ - def forward(self, teacher_top1_sim_pred, z1, z2): - student_top1_sim_pred = self.student_temp_scaled_sim(z1.unsqueeze(1), z2.unsqueeze(0)) - + def forward(self, teacher_top1_sim_pred, student_top1_sim_pred): p = F.log_softmax(student_top1_sim_pred.fill_diagonal_(float('-inf')), dim=-1) q = F.softmax(teacher_top1_sim_pred.fill_diagonal_(float('-inf')), dim=-1) loss = -(q*p).nansum() / q.nansum() @@ -88,11 +86,10 @@ def __init__(self, tau, gamma_): self.gamma_ = gamma_ self.eps = 1e-7 - def forward(self, teacher_top1_sim_pred, z1, z2): - student_top1_sim_pred = self.temp_scaled_sim(z1.unsqueeze(1), z2.unsqueeze(0)) + def forward(self, teacher_top1_sim_pred, student_top1_sim_pred): - y_pred = student_top1_sim_pred # .fill_diagonal_(float('-inf')).softmax(dim=-1) - y_true = teacher_top1_sim_pred # .fill_diagonal_(float('-inf')).softmax(dim=-1) + y_pred = student_top1_sim_pred + y_true = teacher_top1_sim_pred # shuffle for randomised tie resolution random_indices = torch.randperm(y_pred.shape[-1]) @@ -109,8 +106,7 @@ def forward(self, teacher_top1_sim_pred, z1, z2): observation_loss = torch.log(cumsums + self.eps) - preds_sorted_by_true_minus_max observation_loss[mask] = 0.0 - return self.gamma_ * torch.mean(torch.mean(observation_loss, dim=1)) - + return self.gamma_ * torch.mean(torch.sum(observation_loss, dim=1)) class Pooler(nn.Module): """ @@ -159,6 +155,12 @@ def cl_init(cls, config): cls.mlp = MLPLayer(config) cls.sim = Similarity(temp=cls.model_args.temp) cls.div = Divergence(beta_=cls.model_args.beta_) + if cls.model_args.distillation_loss == "listnet": + cls.distillation_loss_fct = ListNet(cls.model_args.tau2, cls.model_args.gamma_) + elif cls.model_args.distillation_loss == "listmle": + cls.distillation_loss_fct = ListMLE(cls.model_args.tau2, cls.model_args.gamma_) + else: + raise NotImplementedError cls.init_weights() def cl_forward(cls, @@ -280,8 +282,8 @@ def cl_forward(cls, loss = loss_fct(cos_sim, labels) # RankCSE - knowledge distillation loss - distillation_loss_fct = (ListNet(cls.model_args.tau2, cls.model_args.gamma_) if cls.model_args.distillation_loss == "listnet" else ListMLE(cls.model_args.tau2, cls.model_args.gamma_)) - kd_loss = distillation_loss_fct(teacher_top1_sim_pred.to(cls.device), z1, z2) + student_top1_sim_pred = cos_sim.clone() + kd_loss = cls.distillation_loss_fct(teacher_top1_sim_pred.to(cls.device), student_top1_sim_pred) # RankCSE - self-distillation loss z1_z2_cos = cos_sim.clone() @@ -470,4 +472,4 @@ def forward(self, mlm_input_ids=mlm_input_ids, mlm_labels=mlm_labels, teacher_top1_sim_pred=teacher_top1_sim_pred, - ) \ No newline at end of file + )