diff --git a/src/trainer/unlearn/pdu.py b/src/trainer/unlearn/pdu.py index ee60c0ba..e79bcc58 100644 --- a/src/trainer/unlearn/pdu.py +++ b/src/trainer/unlearn/pdu.py @@ -117,7 +117,9 @@ def compute_loss(self, model, inputs, return_outputs=False): maxLogits = logits.max(dim=-1)[0] averageLogits = logits.mean(dim=-1) - forget_loss = ((maxLogits - averageLogits) ** 2).mean() + forget_loss = (maxLogits - averageLogits) ** 2 + mask = (forget_inputs["labels"] != -100).reshape(-1) + forget_loss = (forget_loss * mask).sum() / mask.sum() retain_inputs = inputs["retain"] retain_inputs = {