diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index db23e5f22..873586409 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -18,6 +18,31 @@ set_seed(42) +class NPUKLDivLoss(torch.nn.Module): + """ + A custom KLDivLoss for NPU. + + On NPU devices, torch.nn.KLDivLoss does not compute gradients with respect to the target. + This leads to incorrect gradient computation when the target depends on the input, + such as in JSD or reverse KLDiv. + See https://github.com/linkedin/Liger-Kernel/issues/1021 for more details. + """ + + def __init__(self, reduction="none", log_target=True): + super().__init__() + + def forward(self, input, target): + original_dtype = input.dtype + + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.float() + target = target.float() + + loss = torch.exp(target) * (target - input) + + return loss.to(original_dtype) + + class JSD(torch.nn.Module): def __init__( self, @@ -26,7 +51,10 @@ def __init__( dtype: torch.dtype = torch.float, ): super(JSD, self).__init__() - self.kl = KLDivLoss(reduction="none", log_target=True) + if device == "npu": + self.kl = NPUKLDivLoss(reduction="none", log_target=True) + else: + self.kl = KLDivLoss(reduction="none", log_target=True) self.beta = beta self.ignore_index = ignore_index self.dtype = dtype