diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp index 3a93723fc3b52..e611009733fbf 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp @@ -150,8 +150,8 @@ RopeKernel_Fp16_Impl( if (i + 15 < dim) { float16x8_t x0 = MlasLoadFloat16x8(input + i); float16x8_t x1 = MlasLoadFloat16x8(input + i + 8); - float16x8_t sin_val = MlasLoadFloat16x8(sin + i); - float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i / 2); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i / 2); for (; i + 31 < dim; i += 16) { float16x8_t real = vuzp1q_f16(x0, x1); float16x8_t imag = vuzp2q_f16(x0, x1); @@ -163,8 +163,8 @@ RopeKernel_Fp16_Impl( MlasStoreFloat16x8(output + i + 8, y1); x0 = MlasLoadFloat16x8(input + i + 16); x1 = MlasLoadFloat16x8(input + i + 24); - sin_val = MlasLoadFloat16x8(sin + i + 16); - cos_val = MlasLoadFloat16x8(cos + i + 16); + sin_val = MlasLoadFloat16x8(sin + (i + 16) / 2); + cos_val = MlasLoadFloat16x8(cos + (i + 16) / 2); } float16x8_t real = vuzp1q_f16(x0, x1); float16x8_t imag = vuzp2q_f16(x0, x1); @@ -181,8 +181,8 @@ RopeKernel_Fp16_Impl( float16x4_t x1 = MlasLoadFloat16x4(input + i + 4); float16x4_t real = vuzp1_f16(x0, x1); float16x4_t imag = vuzp2_f16(x0, x1); - float16x4_t sin_val = MlasLoadFloat16x4(sin + i); - float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i / 2); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i / 2); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); float16x4_t y0 = vzip1_f16(real_out, imag_out); @@ -201,12 +201,12 @@ RopeKernel_Fp16_Impl( imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); real = MlasLoadLaneFloat16x4<2>(input + i + 4, real); imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); - sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); - cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); - cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i / 2 + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i / 2 + 2, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); @@ -224,10 +224,10 @@ RopeKernel_Fp16_Impl( imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); - cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); @@ -241,8 +241,8 @@ RopeKernel_Fp16_Impl( float16x4_t cos_val = MlasZeroFloat16x4(); real = MlasLoadLaneFloat16x4<0>(input + i, real); imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); diff --git a/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp new file mode 100644 index 0000000000000..dac8e01cafbc1 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp @@ -0,0 +1,101 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_rope_neon_fp16.cpp + +Abstract: + + Tests for MLAS fp16 RoPE on NEON. + +--*/ +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +#include +#include + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/rotary_embedding.h" +#include "core/mlas/lib/rotary_embedding_kernel_neon.h" + +class MlasNeonFp16RoPETest : public MlasTestBase { + private: + const float Pi = 2 * std::acos(0.0f); + + void Test(size_t rotary_emb_dim, bool interleaved) { + // Per kernel logic (both fallback and optimized), the sin/cos tables + // are always half the rotary embedding dimension. + const size_t table_len = rotary_emb_dim / 2; + + std::vector input(rotary_emb_dim); + std::vector sin_data(table_len); + std::vector cos_data(table_len); + std::vector output_ref(rotary_emb_dim); + std::vector output_impl(rotary_emb_dim); + + // Initialize input data + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = MLAS_FP16(static_cast(i + 1)); + } + + // Initialize sin/cos tables + for (size_t i = 0; i < table_len; ++i) { + float theta = static_cast(i) / 1000.0f * Pi; + sin_data[i] = MLAS_FP16(std::sin(theta)); + cos_data[i] = MLAS_FP16(std::cos(theta)); + } + + // Call fallback implementation + MlasRotaryEmbedOneRow_FallBack(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_ref.data()); + + // Call dispatched implementation (which should pick up the NEON kernel) + MlasRotaryEmbedOneRow(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_impl.data()); + + // Compare results + for (size_t i = 0; i < rotary_emb_dim; i++) { + ASSERT_EQ(output_impl[i].val, output_ref[i].val) + << "Expected bits: " << output_ref[i].val << " (" << output_ref[i].ToFloat() << ")" + << " Actual bits: " << output_impl[i].val << " (" << output_impl[i].ToFloat() << ")" + << " @[" << i << "], " + << "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved; + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16RoPE"; + } + + void ExecuteShort(void) override { + // Test dimensions that cover main loops and various remainders + Test(6, false); + Test(6, true); + Test(16, false); + Test(16, true); + Test(24, false); + Test(24, true); + Test(32, false); + Test(32, true); + Test(42, false); + Test(42, true); + Test(64, false); + Test(64, true); + Test(70, false); + Test(70, true); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)