-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[MLAS] Fix rotary interleaved NEON kernel #26390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
tianleiwu
wants to merge
4
commits into
main
Choose a base branch
from
tlwu/fix_rope_neon
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+119
−18
Draft
Changes from 2 commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| /*++ | ||
|
|
||
| Copyright (c) Microsoft Corporation. All rights reserved. | ||
|
|
||
| Licensed under the MIT License. | ||
|
|
||
| Module Name: | ||
|
|
||
| test_rope_neon_fp16.h | ||
|
|
||
| Abstract: | ||
|
|
||
| Tests for MLAS fp16 RoPE on NEON. | ||
|
|
||
| --*/ | ||
|
|
||
| #include <vector> | ||
| #include <cmath> | ||
|
|
||
| #include "test_util.h" | ||
| #include "core/mlas/lib/mlasi.h" | ||
| #include "core/mlas/lib/rotary_embedding.h" | ||
| #include "core/mlas/lib/rotary_embedding_kernel_neon.h" | ||
|
|
||
| #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) | ||
|
|
||
| class MlasNeonFp16RoPETest : public MlasTestBase { | ||
| private: | ||
| const float Pi = 2 * std::acos(0.0f); | ||
|
|
||
| void Test(size_t rotary_emb_dim, bool interleaved) { | ||
| // Per kernel logic (both fallback and optimized), the sin/cos tables | ||
| // are always half the rotary embedding dimension. | ||
| const size_t table_len = rotary_emb_dim / 2; | ||
|
|
||
| std::vector<MLAS_FP16> input(rotary_emb_dim); | ||
| std::vector<MLAS_FP16> sin_data(table_len); | ||
| std::vector<MLAS_FP16> cos_data(table_len); | ||
| std::vector<MLAS_FP16> output_ref(rotary_emb_dim); | ||
| std::vector<MLAS_FP16> output_impl(rotary_emb_dim); | ||
|
|
||
| // Initialize input data | ||
| for (size_t i = 0; i < rotary_emb_dim; ++i) { | ||
| input[i] = MLAS_FP16(static_cast<float>(i + 1)); | ||
| } | ||
|
|
||
| // Initialize sin/cos tables | ||
| for (size_t i = 0; i < table_len; ++i) { | ||
| float theta = static_cast<float>(i) / 1000.0f * Pi; | ||
| sin_data[i] = MLAS_FP16(std::sin(theta)); | ||
| cos_data[i] = MLAS_FP16(std::cos(theta)); | ||
| } | ||
|
|
||
| // Call fallback implementation | ||
| MlasRotaryEmbedOneRow_FallBack<MLAS_FP16>(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_ref.data()); | ||
|
|
||
| // Call dispatched implementation (which should pick up the NEON kernel) | ||
| MlasRotaryEmbedOneRow<MLAS_FP16>(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_impl.data()); | ||
|
|
||
| // Compare results | ||
| for (size_t i = 0; i < rotary_emb_dim; i++) { | ||
| ASSERT_EQ(output_impl[i].val, output_ref[i].val) | ||
| << "Expected bits: " << output_ref[i].val << " (" << output_ref[i].ToFloat() << ")" | ||
| << " Actual bits: " << output_impl[i].val << " (" << output_impl[i].ToFloat() << ")" | ||
| << " @[" << i << "], " | ||
| << "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved; | ||
| } | ||
| } | ||
|
|
||
| public: | ||
| static const char* GetTestSuiteName() { | ||
| return "NeonFp16RoPE"; | ||
| } | ||
|
|
||
| void ExecuteShort(void) override { | ||
| // Test dimensions that cover main loops and various remainders | ||
| Test(6, false); | ||
| Test(6, true); | ||
| Test(16, false); | ||
| Test(16, true); | ||
| Test(24, false); | ||
| Test(24, true); | ||
| Test(32, false); | ||
| Test(32, true); | ||
| Test(42, false); | ||
| Test(42, true); | ||
| Test(64, false); | ||
| Test(64, true); | ||
| Test(70, false); | ||
| Test(70, true); | ||
| } | ||
| }; | ||
|
|
||
| static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { | ||
| size_t count = 0; | ||
| if (is_short_execute) { | ||
| count += MlasDirectShortExecuteTests<MlasNeonFp16RoPETest>::RegisterShortExecute(); | ||
| } | ||
| return count; | ||
| }); | ||
|
|
||
| #endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.