Skip to content

Conversation

@zheliuyu
Copy link
Contributor

@zheliuyu zheliuyu commented Nov 8, 2025

What does this PR do?

We have validated the accuracy and performance of using kernels-community/liger_kernels via kernels on NPU. The test results are presented below.
@MekkCyber @drbh pls help review, thx!

Test script

import torch
import torch.nn as nn
from typing import Union
import logging
import time

from kernels import (
    Device,
    LayerRepository,
    Mode,
    register_kernel_mapping,
    use_kernel_forward_from_hub,
    kernelize,
)

_kernels_available = True

# Setting the level to DEBUG will show which kernels are being used.
# logging.basicConfig(level=logging.DEBUG)

_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
    "npu": {
        Mode.INFERENCE: LayerRepository(
            repo_id="kernels-community/liger_kernels",
            layer_name="LigerRMSNorm",
        )
    }
}

register_kernel_mapping(_KERNEL_MAPPING)


# PyTorch reference implementation of RMSNorm.
class Qwen3RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size)).to(DEVICE)
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


# Compiled version on NPU
compiled_torch_rmsnorm = torch.compile(Qwen3RMSNorm, backend='aot_eager')


# triton RMSNorm kernel
@use_kernel_forward_from_hub("RMSNorm")
class TritonRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size)).to(DEVICE)
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return NotImplementedError("This method will be replaced by the kernel from hub.")


def test_rmsnorm(s1, s2, s3, hidden_size=1024, eps=1e-5):
    torch.manual_seed(42)
    x = torch.randn(s1, s2, s3, hidden_size, dtype=torch.float32, device=DEVICE)
    weight = torch.rand(hidden_size, dtype=torch.float32, device=DEVICE)

    # PyTorch reference implementation
    torch_rmsnorm = Qwen3RMSNorm(hidden_size, eps).to(DEVICE)
    start_time = time.time()
    torch_res = torch_rmsnorm(x)
    print(f"torch_rmsnorm time: {time.time() - start_time:.4f} seconds")

    # Compiled version
    compiled_rmsnorm = compiled_torch_rmsnorm(hidden_size, eps).to(DEVICE)
    start_time = time.time()
    compiled_res = compiled_rmsnorm(x)
    print(f"compiled_rmsnorm time: {time.time() - start_time:.4f} seconds")
    
    # Triton RMSNorm kernel
    triton_rmsnorm = TritonRMSNorm(hidden_size, eps).to(DEVICE)
    kernelize(triton_rmsnorm, device=DEVICE, mode=Mode.INFERENCE)
    start_time = time.time()
    triton_res = triton_rmsnorm(x)
    print(f"triton_rmsnorm time: {time.time() - start_time:.4f} seconds")

    assert torch.allclose(compiled_res, torch_res, atol=1e-2, rtol=0.0)
    assert torch.allclose(triton_res, torch_res, atol=1e-2, rtol=0.0)
    print(f"-----------------------shape [{s1}, {s2}, {s3}, {hidden_size}] RMSNorm test passed!-----------------------")


if __name__ == "__main__":
    test_rmsnorm(1, 1, 1, 1)
    print("Warmup end.")
    
    test_rmsnorm(1, 1, 1024)
    test_rmsnorm(1, 1, 8, 128)
    test_rmsnorm(1, 1, 16, 128)
    test_rmsnorm(1, 1, 128, 1024)

output:

Warmup end.

torch_rmsnorm time: 0.0010 seconds
compiled_rmsnorm time: 0.0009 seconds
triton_rmsnorm time: 0.0004 seconds
-----------------------shape [1, 1, 1024, 1024] RMSNorm test passed!-----------------------
torch_rmsnorm time: 0.0010 seconds
compiled_rmsnorm time: 0.0009 seconds
triton_rmsnorm time: 0.0004 seconds
-----------------------shape [1, 1, 8, 128] RMSNorm test passed!-----------------------
torch_rmsnorm time: 0.0011 seconds
compiled_rmsnorm time: 0.0009 seconds
triton_rmsnorm time: 0.0004 seconds
-----------------------shape [1, 1, 16, 128] RMSNorm test passed!-----------------------
torch_rmsnorm time: 0.0013 seconds
compiled_rmsnorm time: 0.0011 seconds
triton_rmsnorm time: 0.0004 seconds
-----------------------shape [1, 1, 128, 1024] RMSNorm test passed!-----------------------

@zheliuyu zheliuyu marked this pull request as ready for review November 10, 2025 11:16
Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

Comment on lines +74 to +79
"npu": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm",
)
},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does triton work with npu ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Triton can run on NPU. It only requires pip install triton-ascend, without modifying any of the original Triton code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Thanks for the benchmarks

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants