Skip to content

Conversation

@tianleiwu
Copy link
Contributor

The logic of interleaved NEON kernel is not correct from code review:

  1. Test Code Logic:
    The test code test_rope.h allocates the sin and cos tables based on the interleaved flag:

    size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim;
    std::vector<float> sin_data(table_len);
    std::vector<float> cos_data(table_len);

    For the interleaved = true case, the test creates sin and cos tables of length rotary_emb_dim / 2.

  2. AVX2 (fp32) Kernel Logic (interleaved = true):
    This kernel loads the sin/cos data using an index of i / 2:

    float32x8_t sin_val = _mm256_loadu_ps(sin_data + i / 2);
    float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2);

    This logic expects a sin/cos table of length rotary_emb_dim / 2.
    Conclusion: The AVX2 (fp32) kernel is consistent with the test code.

  3. NEON (fp16) Kernel Logic (interleaved = true):
    This kernel loads the sin/cos data using an index of i:

    // Enters loop with sin_val = MlasLoadFloat16x8(sin + i);
    //...
    // Inside loop, for next iteration:
    sin_val = MlasLoadFloat16x8(sin + i + 16); 

    This logic expects a sin/cos table of length rotary_emb_dim.
    Conclusion: The NEON (fp16) kernel is NOT consistent with the test code.

Summary

The RopeKernel_Avx2_fp32_Impl<true> kernel correctly aligns with the test code (and the fallback implementation) by expecting a sin/cos table of length rotary_emb_dim / 2.

The RopeKernel_Fp16_Impl<true> (NEON) kernel incorrectly expects a table of length rotary_emb_dim. When run against the provided test, the NEON kernel will read past the end of the sin_data and cos_data vectors.

@tianleiwu tianleiwu changed the title [MLAD] Fix rotary interleaved NEON kernel [MLAS] Fix rotary interleaved NEON kernel Oct 22, 2025
@tianleiwu tianleiwu requested a review from Copilot October 22, 2025 23:16
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a critical bug in the NEON fp16 rotary embedding kernel for the interleaved mode. The kernel was incorrectly indexing the sin/cos lookup tables with the full input dimension index instead of half the dimension, causing out-of-bounds reads. The fix aligns the NEON implementation with the AVX2 kernel and the fallback implementation by dividing indices by 2 when accessing the sin/cos tables in interleaved mode.

Key changes:

  • Corrected sin/cos table indexing in the interleaved NEON fp16 kernel from i to i/2
  • Added comprehensive unit tests for NEON fp16 RoPE operations covering various dimensions and both interleaved/non-interleaved modes

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp New test file validating NEON fp16 RoPE kernel against fallback implementation
onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp Fixed all sin/cos table access points in interleaved mode to use i/2 indexing

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@tianleiwu tianleiwu marked this pull request as draft October 23, 2025 16:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant