Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ RopeKernel_Fp16_Impl<true>(
if (i + 15 < dim) {
float16x8_t x0 = MlasLoadFloat16x8(input + i);
float16x8_t x1 = MlasLoadFloat16x8(input + i + 8);
float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
float16x8_t sin_val = MlasLoadFloat16x8(sin + i / 2);
float16x8_t cos_val = MlasLoadFloat16x8(cos + i / 2);
for (; i + 31 < dim; i += 16) {
float16x8_t real = vuzp1q_f16(x0, x1);
float16x8_t imag = vuzp2q_f16(x0, x1);
Expand All @@ -163,8 +163,8 @@ RopeKernel_Fp16_Impl<true>(
MlasStoreFloat16x8(output + i + 8, y1);
x0 = MlasLoadFloat16x8(input + i + 16);
x1 = MlasLoadFloat16x8(input + i + 24);
sin_val = MlasLoadFloat16x8(sin + i + 16);
cos_val = MlasLoadFloat16x8(cos + i + 16);
sin_val = MlasLoadFloat16x8(sin + (i + 16) / 2);
cos_val = MlasLoadFloat16x8(cos + (i + 16) / 2);
}
float16x8_t real = vuzp1q_f16(x0, x1);
float16x8_t imag = vuzp2q_f16(x0, x1);
Expand All @@ -181,8 +181,8 @@ RopeKernel_Fp16_Impl<true>(
float16x4_t x1 = MlasLoadFloat16x4(input + i + 4);
float16x4_t real = vuzp1_f16(x0, x1);
float16x4_t imag = vuzp2_f16(x0, x1);
float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
float16x4_t sin_val = MlasLoadFloat16x4(sin + i / 2);
float16x4_t cos_val = MlasLoadFloat16x4(cos + i / 2);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
float16x4_t y0 = vzip1_f16(real_out, imag_out);
Expand All @@ -201,12 +201,12 @@ RopeKernel_Fp16_Impl<true>(
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
real = MlasLoadLaneFloat16x4<2>(input + i + 4, real);
imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val);
sin_val = MlasLoadLaneFloat16x4<2>(sin + i / 2 + 2, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val);
cos_val = MlasLoadLaneFloat16x4<2>(cos + i / 2 + 2, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
Expand All @@ -224,10 +224,10 @@ RopeKernel_Fp16_Impl<true>(
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val);
sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val);
cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
Expand All @@ -241,8 +241,8 @@ RopeKernel_Fp16_Impl<true>(
float16x4_t cos_val = MlasZeroFloat16x4();
real = MlasLoadLaneFloat16x4<0>(input + i, real);
imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val);
cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val);
float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
MlasStoreLaneFloat16x4<0>(output + i, real_out);
Expand Down
102 changes: 102 additions & 0 deletions onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*++

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

Module Name:

test_rope_neon_fp16.h

Abstract:

Tests for MLAS fp16 RoPE on NEON.

--*/

#include <vector>
#include <cmath>

#include "test_util.h"
#include "core/mlas/lib/mlasi.h"
#include "core/mlas/lib/rotary_embedding.h"
#include "core/mlas/lib/rotary_embedding_kernel_neon.h"

#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)

class MlasNeonFp16RoPETest : public MlasTestBase {
private:
const float Pi = 2 * std::acos(0.0f);

void Test(size_t rotary_emb_dim, bool interleaved) {
// Per kernel logic (both fallback and optimized), the sin/cos tables
// are always half the rotary embedding dimension.
const size_t table_len = rotary_emb_dim / 2;

std::vector<MLAS_FP16> input(rotary_emb_dim);
std::vector<MLAS_FP16> sin_data(table_len);
std::vector<MLAS_FP16> cos_data(table_len);
std::vector<MLAS_FP16> output_ref(rotary_emb_dim);
std::vector<MLAS_FP16> output_impl(rotary_emb_dim);

// Initialize input data
for (size_t i = 0; i < rotary_emb_dim; ++i) {
input[i] = MLAS_FP16(static_cast<float>(i + 1));
}

// Initialize sin/cos tables
for (size_t i = 0; i < table_len; ++i) {
float theta = static_cast<float>(i) / 1000.0f * Pi;
sin_data[i] = MLAS_FP16(std::sin(theta));
cos_data[i] = MLAS_FP16(std::cos(theta));
}

// Call fallback implementation
MlasRotaryEmbedOneRow_FallBack<MLAS_FP16>(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_ref.data());

// Call dispatched implementation (which should pick up the NEON kernel)
MlasRotaryEmbedOneRow<MLAS_FP16>(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_impl.data());

// Compare results
for (size_t i = 0; i < rotary_emb_dim; i++) {
ASSERT_EQ(output_impl[i].val, output_ref[i].val)
<< "Expected bits: " << output_ref[i].val << " (" << output_ref[i].ToFloat() << ")"
<< " Actual bits: " << output_impl[i].val << " (" << output_impl[i].ToFloat() << ")"
<< " @[" << i << "], "
<< "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved;
}
}

public:
static const char* GetTestSuiteName() {
return "NeonFp16RoPE";
}

void ExecuteShort(void) override {
// Test dimensions that cover main loops and various remainders
Test(6, false);
Test(6, true);
Test(16, false);
Test(16, true);
Test(24, false);
Test(24, true);
Test(32, false);
Test(32, true);
Test(42, false);
Test(42, true);
Test(64, false);
Test(64, true);
Test(70, false);
Test(70, true);
}
};

static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) {
size_t count = 0;
if (is_short_execute) {
count += MlasDirectShortExecuteTests<MlasNeonFp16RoPETest>::RegisterShortExecute();
}
return count;
});

#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
Loading