Skip to content

Conversation

@kiritorl
Copy link

@kiritorl kiritorl commented Jan 20, 2026

Summary

This PR modifies the NPU test reference for KLDivLoss. Since the native NPU KLDivLoss operator does not support gradients w.r.t. the target #1021 it caused failures in test_jsd.py (where input and target are swapped when beta != 0).

To resolve this, I replaced the native operator usage with a custom implementation using basic math operations. This allows correct gradient computation for the target and aligns the x1.grad results with the Triton kernel implementation.

Testing Done

I tested test_jsd,test_fused_linear_jsd by following method and all cases passed:

pytest -v test/transformers/test_jsd.py
pytest -v test/transformers/test_fused_linear_jsd.py

Hardware Type: Ascend NPU 910B3

  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@kiritorl kiritorl changed the title fix(npu): update the native KLDivLoss implementation for comparison. [NPU]: update the native KLDivLoss implementation for comparison. (eg.)test_jsd.py Jan 20, 2026
@kiritorl
Copy link
Author

kiritorl commented Jan 20, 2026

Test results on NPU before:

error in
test/transformers/test_jsd.py:160: in _test_correctness_once
assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)

tensor1 = tensor([[-3.6322e-08, -8.1956e-08, -4.1211e-08,  ..., -1.8999e-07,
         -4.9593e-08, -4.3772e-08],
        [-6.379...-5.6345e-08, -6.3796e-08,  ..., -1.6182e-08,
         -8.1956e-08, -1.2293e-07]], device='npu:0', dtype=torch.bfloat16)
tensor2 = tensor([[-1.0186e-08,  2.9686e-08,  1.0885e-08,  ...,  1.1525e-08,
          1.6182e-08,  6.3155e-09],
        [ 1.397... 1.7229e-08,  3.8883e-08,  ..., -4.5402e-09,
          3.2363e-08,  9.2550e-09]], device='npu:0', dtype=torch.bfloat16)

Test results on NPU after:

tensor1: tensor([[-1.0186e-08,  2.9686e-08,  1.0885e-08,  ...,  1.1525e-08,
          1.6182e-08,  6.3155e-09],
        [ 1.3970e-08,  5.2620e-08,  8.2888e-08,  ..., -5.8790e-09,
         -6.5775e-09,  4.7497e-08],
        [-1.1059e-08, -1.8859e-08, -1.6298e-08,  ..., -8.2655e-09,
          5.5297e-09,  9.8720e-08],
        ...,
        [-1.0012e-08,  1.8068e-07,  0.0000e+00,  ..., -1.2689e-08,
          1.7229e-08, -2.4214e-08],
        [-7.1304e-09,  1.2515e-08,  4.7963e-08,  ..., -1.4808e-07,
          2.2468e-08,  3.3324e-09],
        [-4.1444e-08,  1.7229e-08,  3.8883e-08,  ..., -4.5402e-09,
          3.2363e-08,  9.2550e-09]], device='npu:0', dtype=torch.bfloat16)
tensor2: tensor([[-1.0186e-08,  2.9686e-08,  1.0885e-08,  ...,  1.1525e-08,
          1.6182e-08,  6.3155e-09],
        [ 1.3970e-08,  5.2620e-08,  8.2888e-08,  ..., -5.8790e-09,
         -6.5775e-09,  4.7497e-08],
        [-1.1059e-08, -1.8859e-08, -1.6298e-08,  ..., -8.2655e-09,
          5.5297e-09,  9.8720e-08],
        ...,
        [-1.0012e-08,  1.8068e-07,  0.0000e+00,  ..., -1.2689e-08,
          1.7229e-08, -2.4214e-08],
        [-7.1304e-09,  1.2515e-08,  4.7963e-08,  ..., -1.4808e-07,
          2.2468e-08,  3.3324e-09],
        [-4.1444e-08,  1.7229e-08,  3.8883e-08,  ..., -4.5402e-09,
          3.2363e-08,  9.2550e-09]], device='npu:0', dtype=torch.bfloat16)
PASSED

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant