Skip to content

Conversation

@noemotiovon
Copy link

  • Auto-detect device and select NPU/GPU implementation
  • Fix NaN from uninitialized memory in A/dA matrices
  • Replace problematic triton kernel with PyTorch (NPU compiler limitation)
  • Add exp overflow protection and varlen sequence support

- Auto-detect device and select NPU/GPU implementation
- Fix NaN from uninitialized memory in A/dA matrices
- Replace problematic triton kernel with PyTorch (NPU compiler limitation)
- Add exp overflow protection and varlen sequence support
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

  • 🔍 Trigger a full review
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @noemotiovon, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the chunk_gla operator by integrating a new, optimized backend specifically for Huawei Ascend NPUs. The changes allow the system to automatically detect and utilize NPU hardware for improved performance. Crucially, it resolves several numerical stability issues that could lead to NaNs and expands the operator's capabilities to handle variable-length inputs, ensuring broader applicability and reliability across different computing environments.

Highlights

  • NPU Backend Integration: Introduced a dedicated NPU (Ascend) backend for the chunk_gla operator, enabling device auto-detection to dynamically select between NPU and GPU implementations based on the available hardware.
  • Numerical Stability Fixes: Addressed critical numerical stability issues by fixing NaN propagation originating from uninitialized memory in the A and dA attention matrices during both forward and backward passes, ensuring more robust computations.
  • Triton Kernel Replacement: Replaced a problematic Triton kernel in the backward pass (dqk_intra) with a pure PyTorch implementation due to identified limitations and incorrect results produced by the Triton compiler on NPU for cross-sub-block computations.
  • Enhanced Robustness and Features: Implemented exponential overflow protection for gate calculations and added comprehensive support for variable-length sequences using cu_seqlens, making the operator more versatile and resilient.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • fla/ops/gla/init.py
    • Added device detection logic using fla.utils.get_available_device() to dynamically import chunk_gla_npu for NPU devices or the default chunk_gla for other devices.
    • Updated imports to reflect the new device-aware selection mechanism.
  • fla/ops/gla/chunk_npu.py
    • Added a new file containing the NPU-optimized implementation of chunk_gla.
    • Introduced chunk_gla_fwd_kernel_o_npu with NPU-specific Triton configurations (e.g., adjusted block sizes and warp counts) and a fix using tl.where to prevent NaNs from uninitialized memory in the A matrix.
    • Implemented chunk_gla_fwd_npu for the forward pass, which includes explicit zeroing of the upper triangle of the A matrix using a causal mask to prevent uninitialized memory issues on NPU, and ensures fp32 state handling.
    • Developed chunk_gla_bwd_dqk_intra_npu, a pure PyTorch implementation for the intra-chunk backward pass of dq and dk, replacing a problematic Triton kernel that exhibited issues on NPU. This implementation includes exp overflow protection and supports variable-length sequences.
    • Created chunk_gla_bwd_npu for the overall backward pass, incorporating NPU-specific fixes for the dA matrix (explicit zeroing of the upper triangle) and utilizing the new PyTorch-based dqk_intra function.
    • Wrapped the NPU forward and backward passes within ChunkGLAFunctionNPU, a torch.autograd.Function, which dynamically determines chunk_size and saves fp32 states for numerical precision.
    • Exposed chunk_gla_npu as the main public interface for the NPU backend, including input validation for cu_seqlens and initial_state.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces NPU (Ascend) backend support for the chunk_gla operation. The changes include device-aware dynamic imports, NPU-specific Triton kernels, and several fixes for numerical stability and correctness on NPU hardware. Notably, a problematic Triton kernel in the backward pass is replaced with a pure PyTorch implementation to work around NPU compiler limitations, and explicit masking is added to prevent NaNs from uninitialized memory.

My review focuses on the new NPU implementation. I've suggested a performance improvement by vectorizing a loop that handles variable-length sequences. I've also pointed out a couple of type hint mismatches that should be corrected for code clarity and correctness. Overall, this is a well-structured and thoughtful implementation for adding new hardware support.

output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function chunk_gla_fwd_npu returns five tensors (g_cumsum, A, h, ht, o), but its return type hint indicates a tuple of three tensors. This should be corrected to match the actual return values for type consistency and clarity.

Suggested change
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

Comment on lines +221 to +229
local_pos = torch.zeros(T_dim, dtype=torch.long, device=A.device)
for seq_idx in range(num_seqs):
seq_start = cu_seqlens_cpu[seq_idx].item()
seq_end = cu_seqlens_cpu[seq_idx + 1].item()
seq_len = seq_end - seq_start
if seq_len > 0:
# Local positions within this sequence, mod BT for chunk-local position
seq_positions = torch.arange(seq_len, device=A.device) % BT
local_pos[seq_start:seq_end] = seq_positions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The for-loop to compute local_pos for variable-length sequences can be a performance bottleneck when dealing with a large number of sequences. This can be vectorized using torch.repeat_interleave for better efficiency.

        # Vectorized implementation for performance
        lens = cu_seqlens_cpu[1:] - cu_seqlens_cpu[:-1]
        # Create a tensor where each element is the start of its sequence
        seq_starts = torch.repeat_interleave(cu_seqlens_cpu[:-1], lens).to(A.device)
        # Create a tensor of global positions
        global_pos = torch.arange(T_dim, device=A.device)
        # Compute local positions
        local_pos = (global_pos - seq_starts) % BT

Comment on lines +483 to +490
local_pos = torch.zeros(T_dim, dtype=torch.long, device=dA.device)
for seq_idx in range(num_seqs):
seq_start = cu_seqlens_cpu[seq_idx].item()
seq_end = cu_seqlens_cpu[seq_idx + 1].item()
seq_len = seq_end - seq_start
if seq_len > 0:
seq_positions = torch.arange(seq_len, device=dA.device) % BT
local_pos[seq_start:seq_end] = seq_positions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The for-loop to compute local_pos for variable-length sequences can be a performance bottleneck when dealing with a large number of sequences. This can be vectorized using torch.repeat_interleave for better efficiency. This is similar to a suggestion in chunk_gla_fwd_npu.

        # Vectorized implementation for performance
        lens = cu_seqlens_cpu[1:] - cu_seqlens_cpu[:-1]
        # Create a tensor where each element is the start of its sequence
        seq_starts = torch.repeat_interleave(cu_seqlens_cpu[:-1], lens).to(dA.device)
        # Create a tensor of global positions
        global_pos = torch.arange(T_dim, device=dA.device)
        # Compute local positions
        local_pos = (global_pos - seq_starts) % BT

k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
scale: int | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for the scale parameter is int | None, but it is used as a float and the docstring also describes it as Optional[float]. The type hint should be changed to float | None to match its actual usage and documentation.

    scale: float | None = None,

@zhiyuan1i
Copy link
Collaborator

Thank you for your contribution. We'll be open-sourcing a different backend dispatch soon. At that time, it will be better adapted to different hardware

@noemotiovon
Copy link
Author

Thank you for your contribution. We'll be open-sourcing a different backend dispatch soon. At that time, it will be better adapted to different hardware

That’s fantastic news! Since NPUs and GPUs differ significantly at the architectural level, Triton kernel implementations naturally diverge and really benefit from a proper dispatch mechanism.

Could you share a bit more about the open-source plan on this front? Also, if there’s anything I can help with, I’d be very happy to contribute and get involved.

@zhiyuan1i
Copy link
Collaborator

Thank you for your contribution. We'll be open-sourcing a different backend dispatch soon. At that time, it will be better adapted to different hardware

That’s fantastic news! Since NPUs and GPUs differ significantly at the architectural level, Triton kernel implementations naturally diverge and really benefit from a proper dispatch mechanism.

Could you share a bit more about the open-source plan on this front? Also, if there’s anything I can help with, I’d be very happy to contribute and get involved.

We have introduced a new decorator that designates functions as dispatchable. Each backend targeted for dispatch includes a verifier to check whether the current input is compatible with that backend. This architecture allows substitution at arbitrary granularity (e.g., single operations or complete forward/backward passes) and enables mixing and matching different backends. Moving forward, we will also support registering multiple dispatch backends for the same function, enabling dispatch decisions based on varying performance characteristics and constraints. The core functionality is already implemented and will be open-sourced following code cleanup.

Our primary concern is the lack of continuous integration for additional backends; currently, we are limited to a single H100. We would be extremely grateful if we could obtain on-demand NPU backends for CI gating purposes.

@noemotiovon
Copy link
Author

That sounds really exciting!
Our next step will be to start from the existing GPU Triton kernels and port them to NPU Triton kernels. Once the functionality runs correctly on NPU, we’ll be able to provide access to NPU machines and help integrate them into your upstream CI.
The Liger-Kernel community is currently doing something very similar, and we are already in the process of CI integration on that side as well:
linkedin/Liger-Kernel#1022

@zhiyuan1i
Copy link
Collaborator

That sounds really exciting! Our next step will be to start from the existing GPU Triton kernels and port them to NPU Triton kernels. Once the functionality runs correctly on NPU, we’ll be able to provide access to NPU machines and help integrate them into your upstream CI. The Liger-Kernel community is currently doing something very similar, and we are already in the process of CI integration on that side as well: linkedin/Liger-Kernel#1022

could we connect offline?
lizhiyuan@moonshot.cn

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants