From de748887eb1ab9b6429eb05b5abbeb9456fd5c7c Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 22 Oct 2025 15:53:30 -0700 Subject: [PATCH 1/4] Fix Rope interleaved kernel --- .../lib/rotary_embedding_kernel_neon_fp16.cpp | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp index 3a93723fc3b52..e611009733fbf 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp @@ -150,8 +150,8 @@ RopeKernel_Fp16_Impl( if (i + 15 < dim) { float16x8_t x0 = MlasLoadFloat16x8(input + i); float16x8_t x1 = MlasLoadFloat16x8(input + i + 8); - float16x8_t sin_val = MlasLoadFloat16x8(sin + i); - float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i / 2); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i / 2); for (; i + 31 < dim; i += 16) { float16x8_t real = vuzp1q_f16(x0, x1); float16x8_t imag = vuzp2q_f16(x0, x1); @@ -163,8 +163,8 @@ RopeKernel_Fp16_Impl( MlasStoreFloat16x8(output + i + 8, y1); x0 = MlasLoadFloat16x8(input + i + 16); x1 = MlasLoadFloat16x8(input + i + 24); - sin_val = MlasLoadFloat16x8(sin + i + 16); - cos_val = MlasLoadFloat16x8(cos + i + 16); + sin_val = MlasLoadFloat16x8(sin + (i + 16) / 2); + cos_val = MlasLoadFloat16x8(cos + (i + 16) / 2); } float16x8_t real = vuzp1q_f16(x0, x1); float16x8_t imag = vuzp2q_f16(x0, x1); @@ -181,8 +181,8 @@ RopeKernel_Fp16_Impl( float16x4_t x1 = MlasLoadFloat16x4(input + i + 4); float16x4_t real = vuzp1_f16(x0, x1); float16x4_t imag = vuzp2_f16(x0, x1); - float16x4_t sin_val = MlasLoadFloat16x4(sin + i); - float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i / 2); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i / 2); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); float16x4_t y0 = vzip1_f16(real_out, imag_out); @@ -201,12 +201,12 @@ RopeKernel_Fp16_Impl( imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); real = MlasLoadLaneFloat16x4<2>(input + i + 4, real); imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); - sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); - cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); - cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i / 2 + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i / 2 + 2, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); @@ -224,10 +224,10 @@ RopeKernel_Fp16_Impl( imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); - cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); @@ -241,8 +241,8 @@ RopeKernel_Fp16_Impl( float16x4_t cos_val = MlasZeroFloat16x4(); real = MlasLoadLaneFloat16x4<0>(input + i, real); imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); From 32a586a254c15c2eb917da99ddb51b57cb414861 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 22 Oct 2025 16:16:04 -0700 Subject: [PATCH 2/4] Add test --- .../mlas/unittest/test_rope_neon_fp16.cpp | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp diff --git a/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp new file mode 100644 index 0000000000000..24ddbab995aed --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp @@ -0,0 +1,102 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_rope_neon_fp16.h + +Abstract: + + Tests for MLAS fp16 RoPE on NEON. + +--*/ + +#include +#include + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/rotary_embedding.h" +#include "core/mlas/lib/rotary_embedding_kernel_neon.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonFp16RoPETest : public MlasTestBase { + private: + const float Pi = 2 * std::acos(0.0f); + + void Test(size_t rotary_emb_dim, bool interleaved) { + // Per kernel logic (both fallback and optimized), the sin/cos tables + // are always half the rotary embedding dimension. + const size_t table_len = rotary_emb_dim / 2; + + std::vector input(rotary_emb_dim); + std::vector sin_data(table_len); + std::vector cos_data(table_len); + std::vector output_ref(rotary_emb_dim); + std::vector output_impl(rotary_emb_dim); + + // Initialize input data + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = MLAS_FP16(static_cast(i + 1)); + } + + // Initialize sin/cos tables + for (size_t i = 0; i < table_len; ++i) { + float theta = static_cast(i) / 1000.0f * Pi; + sin_data[i] = MLAS_FP16(std::sin(theta)); + cos_data[i] = MLAS_FP16(std::cos(theta)); + } + + // Call fallback implementation + MlasRotaryEmbedOneRow_FallBack(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_ref.data()); + + // Call dispatched implementation (which should pick up the NEON kernel) + MlasRotaryEmbedOneRow(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_impl.data()); + + // Compare results + for (size_t i = 0; i < rotary_emb_dim; i++) { + ASSERT_EQ(output_impl[i].val, output_ref[i].val) + << "Expected bits: " << output_ref[i].val << " (" << output_ref[i].ToFloat() << ")" + << " Actual bits: " << output_impl[i].val << " (" << output_impl[i].ToFloat() << ")" + << " @[" << i << "], " + << "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved; + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16RoPE"; + } + + void ExecuteShort(void) override { + // Test dimensions that cover main loops and various remainders + Test(6, false); + Test(6, true); + Test(16, false); + Test(16, true); + Test(24, false); + Test(24, true); + Test(32, false); + Test(32, true); + Test(42, false); + Test(42, true); + Test(64, false); + Test(64, true); + Test(70, false); + Test(70, true); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) From 64bf0fca225975c908451163c0a5bc915652692a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 22 Oct 2025 16:18:10 -0700 Subject: [PATCH 3/4] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp index 24ddbab995aed..409b5ce47b930 100644 --- a/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp +++ b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - test_rope_neon_fp16.h + test_rope_neon_fp16.cpp Abstract: From b88d691d4b43a0467a3756618c08e06cf0a90744 Mon Sep 17 00:00:00 2001 From: Tianlei WU Date: Thu, 23 Oct 2025 11:02:55 -0700 Subject: [PATCH 4/4] fix build error --- onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp index 409b5ce47b930..dac8e01cafbc1 100644 --- a/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp +++ b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp @@ -13,6 +13,7 @@ Module Name: Tests for MLAS fp16 RoPE on NEON. --*/ +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) #include #include @@ -22,8 +23,6 @@ Module Name: #include "core/mlas/lib/rotary_embedding.h" #include "core/mlas/lib/rotary_embedding_kernel_neon.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) - class MlasNeonFp16RoPETest : public MlasTestBase { private: const float Pi = 2 * std::acos(0.0f);