From 756fcc3992a2b4d823dc9f5e0d2b54ccf4e8e107 Mon Sep 17 00:00:00 2001 From: kiritorl <1021709528@qq.com> Date: Tue, 20 Jan 2026 10:29:45 +0800 Subject: [PATCH 1/3] fix(npu): update the native KLDivLoss implementation for comparison. --- test/transformers/test_jsd.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index db23e5f22..ad80a0057 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -18,6 +18,22 @@ set_seed(42) +class CustomKLDivLoss(torch.nn.Module): + def __init__(self, reduction="none", log_target=True): + super().__init__() + + def forward(self, input, target): + original_dtype = input.dtype + + if input.dtype in [torch.float16, torch.bfloat16]: + input = input.float() + target = target.float() + + loss = torch.exp(target) * (target - input) + + return loss.to(original_dtype) + + class JSD(torch.nn.Module): def __init__( self, @@ -26,7 +42,10 @@ def __init__( dtype: torch.dtype = torch.float, ): super(JSD, self).__init__() - self.kl = KLDivLoss(reduction="none", log_target=True) + if device == "npu": + self.kl = CustomKLDivLoss(reduction="none", log_target=True) + else: + self.kl = KLDivLoss(reduction="none", log_target=True) self.beta = beta self.ignore_index = ignore_index self.dtype = dtype From 8bafa447b7d91594949e033e7f5cb99d1ac15ef8 Mon Sep 17 00:00:00 2001 From: kiritorl <1021709528@qq.com> Date: Tue, 20 Jan 2026 15:25:21 +0800 Subject: [PATCH 2/3] Refactor(npu): rename helper class and add comments for clarity --- test/transformers/test_jsd.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index ad80a0057..8f29f560e 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -18,7 +18,15 @@ set_seed(42) -class CustomKLDivLoss(torch.nn.Module): +class NPUKLDivLoss(torch.nn.Module): + """ + A custom KLDivLoss for NPU. + + On NPU devices, torch.nn.KLDivLoss does not compute gradients with respect to the target. + This leads to incorrect gradient computation when the target depends on the input, + such as in JSD or reverse KLDiv. + See https://github.com/linkedin/Liger-Kernel/issues/1021 for more details. + """ def __init__(self, reduction="none", log_target=True): super().__init__() @@ -43,7 +51,7 @@ def __init__( ): super(JSD, self).__init__() if device == "npu": - self.kl = CustomKLDivLoss(reduction="none", log_target=True) + self.kl = NPUKLDivLoss(reduction="none", log_target=True) else: self.kl = KLDivLoss(reduction="none", log_target=True) self.beta = beta From f5dad7e78c6c9ec550dcda6dcf713d144853c421 Mon Sep 17 00:00:00 2001 From: kiritorl <1021709528@qq.com> Date: Tue, 20 Jan 2026 16:09:08 +0800 Subject: [PATCH 3/3] run pre-commit checks and format code --- test/transformers/test_jsd.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index 8f29f560e..873586409 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -21,24 +21,25 @@ class NPUKLDivLoss(torch.nn.Module): """ A custom KLDivLoss for NPU. - + On NPU devices, torch.nn.KLDivLoss does not compute gradients with respect to the target. This leads to incorrect gradient computation when the target depends on the input, such as in JSD or reverse KLDiv. See https://github.com/linkedin/Liger-Kernel/issues/1021 for more details. """ + def __init__(self, reduction="none", log_target=True): super().__init__() - + def forward(self, input, target): original_dtype = input.dtype - + if input.dtype in [torch.float16, torch.bfloat16]: input = input.float() target = target.float() - + loss = torch.exp(target) * (target - input) - + return loss.to(original_dtype)