From 7c7d32baff58ac478772484ba4656890fabacaf3 Mon Sep 17 00:00:00 2001 From: sadko4u Date: Fri, 29 Nov 2024 12:41:21 +0300 Subject: [PATCH] ARM NEON optimizations for depan_lin and depan_eqpow --- include/private/dsp/arch/arm/neon-d32/pan.h | 254 ++++++++++++++++++++ src/main/arm/neon-d32.cpp | 4 + src/test/ptest/pan/depan_eqpow.cpp | 4 +- src/test/ptest/pan/depan_lin.cpp | 4 +- src/test/utest/pan/depan_eqpow.cpp | 4 +- src/test/utest/pan/depan_lin.cpp | 4 +- 6 files changed, 266 insertions(+), 8 deletions(-) create mode 100644 include/private/dsp/arch/arm/neon-d32/pan.h diff --git a/include/private/dsp/arch/arm/neon-d32/pan.h b/include/private/dsp/arch/arm/neon-d32/pan.h new file mode 100644 index 00000000..3724527a --- /dev/null +++ b/include/private/dsp/arch/arm/neon-d32/pan.h @@ -0,0 +1,254 @@ +/* + * Copyright (C) 2024 Linux Studio Plugins Project + * (C) 2024 Vladimir Sadovnikov + * + * This file is part of lsp-dsp-lib + * Created on: 29 нояб. 2024 г. + * + * lsp-dsp-lib is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * any later version. + * + * lsp-dsp-lib is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with lsp-dsp-lib. If not, see . + */ + +#ifndef PRIVATE_DSP_ARCH_ARM_NEON_D32_PAN_H_ +#define PRIVATE_DSP_ARCH_ARM_NEON_D32_PAN_H_ + +#ifndef PRIVATE_DSP_ARCH_ARM_NEON_D32_IMPL + #error "This header should not be included directly" +#endif /* PRIVATE_DSP_ARCH_ARM_NEON_D32_IMPL */ + +namespace lsp +{ + namespace neon_d32 + { + IF_ARCH_ARM( + static const float depan_lin_const_f[] __lsp_aligned32 = + { + LSP_DSP_VEC8(1e-18f) + }; + ); + + void depan_lin(float *dst, const float *l, const float *r, float dfl, size_t count) + { + /* + const float sl = fabsf(l[i]); + const float sr = fabsf(r[i]); + const float den = sl + sr; + dst[i] = (den >= 1e-18f) ? sr / den : dfl; + */ + ARCH_ARM_ASM + ( + __ASM_EMIT("vdup.32 q15, %y[dfl]") /* q15 = dfl */ + __ASM_EMIT("vldm %[CC], {q12-q13}") /* q12-q13 = thresh */ + __ASM_EMIT("subs %[count], #8") + __ASM_EMIT("blo 2f") + /* 8x blocks */ + __ASM_EMIT("1:") + __ASM_EMIT("vldm %[l]!, {q0-q1}") /* q0-q1 = l */ + __ASM_EMIT("vldm %[r]!, {q2-q3}") /* q2-q3 = r */ + __ASM_EMIT("vabs.f32 q0, q0") /* q0 = fabsf(l) */ + __ASM_EMIT("vabs.f32 q1, q1") + __ASM_EMIT("vabs.f32 q2, q2") /* q2 = fabsf(r) */ + __ASM_EMIT("vabs.f32 q3, q3") + __ASM_EMIT("vadd.f32 q0, q0, q2") /* q0 = den = fabsf(l) + fabsf(r) */ + __ASM_EMIT("vadd.f32 q1, q1, q3") + __ASM_EMIT("vcge.f32 q4, q0, q12") /* q4 = [den >= thresh] */ + __ASM_EMIT("vcge.f32 q5, q1, q13") + __ASM_EMIT("vrecpe.f32 q6, q0") /* q6 = s2 */ + __ASM_EMIT("vrecpe.f32 q7, q1") + __ASM_EMIT("vrecps.f32 q8, q6, q0") /* q8 = (2 - R*s2) */ + __ASM_EMIT("vrecps.f32 q9, q7, q1") + __ASM_EMIT("vmul.f32 q6, q8, q6") /* q6 = s2' = s2 * (2 - R*s2) */ + __ASM_EMIT("vmul.f32 q7, q9, q7") + __ASM_EMIT("vrecps.f32 q8, q6, q0") /* q8 = (2 - R*s2') */ + __ASM_EMIT("vrecps.f32 q9, q7, q1") + __ASM_EMIT("vmul.f32 q0, q8, q6") /* q0 = s2" = s2' * (2 - R*s2) = 1/s2 */ + __ASM_EMIT("vmul.f32 q1, q9, q7") + __ASM_EMIT("vmul.f32 q0, q0, q2") /* q0 = pan = fabsf(r) / den */ + __ASM_EMIT("vmul.f32 q1, q1, q3") + __ASM_EMIT("vbif q0, q15, q4") /* q0 = (den >= thresh) ? fabsf(r)/den : dfl */ + __ASM_EMIT("vbif q1, q15, q5") + __ASM_EMIT("subs %[count], #8") + __ASM_EMIT("vstm %[dst]!, {q0-q1}") + __ASM_EMIT("bhs 1b") + /* 4x block */ + __ASM_EMIT("2:") + __ASM_EMIT("adds %[count], #4") + __ASM_EMIT("blt 4f") + __ASM_EMIT("vldm %[l]!, {q0}") /* q0 = l */ + __ASM_EMIT("vldm %[r]!, {q2}") /* q2 = r */ + __ASM_EMIT("vabs.f32 q0, q0") /* q0 = fabsf(l) */ + __ASM_EMIT("vabs.f32 q2, q2") /* q2 = fabsf(r) */ + __ASM_EMIT("vadd.f32 q0, q0, q2") /* q0 = den = fabsf(l) + fabsf(r) */ + __ASM_EMIT("vcge.f32 q4, q0, q12") /* q4 = [den >= thresh] */ + __ASM_EMIT("vrecpe.f32 q6, q0") /* q6 = s2 */ + __ASM_EMIT("vrecps.f32 q8, q6, q0") /* q8 = (2 - R*s2) */ + __ASM_EMIT("vmul.f32 q6, q8, q6") /* q6 = s2' = s2 * (2 - R*s2) */ + __ASM_EMIT("vrecps.f32 q8, q6, q0") /* q8 = (2 - R*s2') */ + __ASM_EMIT("vmul.f32 q0, q8, q6") /* q0 = s2" = s2' * (2 - R*s2) = 1/s2 */ + __ASM_EMIT("vmul.f32 q0, q0, q2") /* q0 = pan = fabsf(r) / den */ + __ASM_EMIT("vbif q0, q15, q4") /* q0 = (den >= thresh) ? fabsf(r)/den : dfl */ + __ASM_EMIT("sub %[count], #4") + __ASM_EMIT("vstm %[dst]!, {q0-q1}") + /* 1x blocks */ + __ASM_EMIT("4:") + __ASM_EMIT("adds %[count], #3") + __ASM_EMIT("blt 6f") + __ASM_EMIT("5:") + __ASM_EMIT("vld1.32 {d0[], d1[]}, [%[l]]!") + __ASM_EMIT("vld1.32 {d4[], d5[]}, [%[r]]!") + __ASM_EMIT("vabs.f32 q0, q0") /* q0 = fabsf(l) */ + __ASM_EMIT("vabs.f32 q2, q2") /* q2 = fabsf(r) */ + __ASM_EMIT("vadd.f32 q0, q0, q2") /* q0 = den = fabsf(l) + fabsf(r) */ + __ASM_EMIT("vcge.f32 q4, q0, q12") /* q4 = [den >= thresh] */ + __ASM_EMIT("vrecpe.f32 q6, q0") /* q6 = s2 */ + __ASM_EMIT("vrecps.f32 q8, q6, q0") /* q8 = (2 - R*s2) */ + __ASM_EMIT("vmul.f32 q6, q8, q6") /* q6 = s2' = s2 * (2 - R*s2) */ + __ASM_EMIT("vrecps.f32 q8, q6, q0") /* q8 = (2 - R*s2') */ + __ASM_EMIT("vmul.f32 q0, q8, q6") /* q0 = s2" = s2' * (2 - R*s2) = 1/s2 */ + __ASM_EMIT("vmul.f32 q0, q0, q2") /* q0 = pan = fabsf(r) / den */ + __ASM_EMIT("vbif q0, q15, q4") /* q0 = (den >= thresh) ? fabsf(r)/den : dfl */ + __ASM_EMIT("subs %[count], #1") + __ASM_EMIT("vst1.32 {d0[0]}, [%[dst]]!") + __ASM_EMIT("bge 5b") + /* end */ + __ASM_EMIT("6:") + + : [dst] "+r" (dst), [l] "+r" (l), [r] "+r" (r), + [count] "+r" (count), + [dfl] "+t" (dfl) + : [CC] "r" (&depan_lin_const_f[0]) + : "cc", "memory", + /* "q0" */, "q1", "q2", "q3", + "q4", "q5", "q6", "q7", + "q8", "q9", + "q12", "q13", "q14", "q15" + ); + } + + IF_ARCH_ARM( + static const float depan_eqpow_const_f[] __lsp_aligned32 = + { + LSP_DSP_VEC8(1e-36f) + }; + ); + + void depan_eqpow(float *dst, const float *l, const float *r, float dfl, size_t count) + { + /* + const float sl = l[i] * l[i]; + const float sr = r[i] * r[i]; + const float den = sl + sr; + dst[i] = (den >= 1e-36f) ? sr / den : dfl; + */ + ARCH_ARM_ASM + ( + __ASM_EMIT("vdup.32 q15, %y[dfl]") /* q15 = dfl */ + __ASM_EMIT("vldm %[CC], {q12-q13}") /* q12-q13 = thresh */ + __ASM_EMIT("subs %[count], #8") + __ASM_EMIT("blo 2f") + /* 8x blocks */ + __ASM_EMIT("1:") + __ASM_EMIT("vldm %[l]!, {q0-q1}") /* q0-q1 = l */ + __ASM_EMIT("vldm %[r]!, {q2-q3}") /* q2-q3 = r */ + __ASM_EMIT("vmul.f32 q0, q0, q0") /* q0 = l*l */ + __ASM_EMIT("vmul.f32 q1, q1, q1") + __ASM_EMIT("vmul.f32 q2, q2, q2") /* q2 = r*r */ + __ASM_EMIT("vmul.f32 q3, q3, q3") + __ASM_EMIT("vadd.f32 q0, q0, q2") /* q0 = den = l*l + r*r */ + __ASM_EMIT("vadd.f32 q1, q1, q3") + __ASM_EMIT("vcge.f32 q4, q0, q12") /* q4 = [den >= thresh] */ + __ASM_EMIT("vcge.f32 q5, q1, q13") + __ASM_EMIT("vrsqrte.f32 q6, q0") /* q6 = x0 */ + __ASM_EMIT("vrsqrte.f32 q7, q1") + __ASM_EMIT("vmul.f32 q8, q6, q0") /* q8 = R * x0 */ + __ASM_EMIT("vmul.f32 q9, q7, q1") + __ASM_EMIT("vrsqrts.f32 q10, q8, q6") /* q10 = (3 - R * x0 * x0) / 2 */ + __ASM_EMIT("vrsqrts.f32 q11, q9, q7") + __ASM_EMIT("vmul.f32 q6, q6, q10") /* q6 = x1 = x0 * (3 - R * x0 * x0) / 2 */ + __ASM_EMIT("vmul.f32 q7, q7, q11") + __ASM_EMIT("vmul.f32 q8, q6, q0") /* q8 = R * x1 */ + __ASM_EMIT("vmul.f32 q9, q7, q1") + __ASM_EMIT("vrsqrts.f32 q10, q8, q6") /* q10 = (3 - R * x1 * x1) / 2 */ + __ASM_EMIT("vrsqrts.f32 q11, q9, q7") + __ASM_EMIT("vmul.f32 q0, q6, q10") /* q0 = 1/sqrt(den) = x2 = x1 * (3 - R * x1 * x1) / 2 */ + __ASM_EMIT("vmul.f32 q1, q7, q11") + __ASM_EMIT("vmul.f32 q0, q2, q6") /* q0 = pan = r*r/sqrt(den) */ + __ASM_EMIT("vmul.f32 q1, q2, q7") + __ASM_EMIT("vbif q0, q15, q4") /* q0 = (den >= thresh) ? fabsf(r)/den : dfl */ + __ASM_EMIT("vbif q1, q15, q5") + __ASM_EMIT("subs %[count], #8") + __ASM_EMIT("vstm %[dst]!, {q0-q1}") + __ASM_EMIT("bhs 1b") + /* 4x block */ + __ASM_EMIT("2:") + __ASM_EMIT("adds %[count], #4") + __ASM_EMIT("blt 4f") + __ASM_EMIT("vldm %[l]!, {q0}") /* q0 = l */ + __ASM_EMIT("vldm %[r]!, {q2}") /* q2 = r */ + __ASM_EMIT("vmul.f32 q0, q0, q0") /* q0 = l*l */ + __ASM_EMIT("vmul.f32 q2, q2, q2") /* q2 = r*r */ + __ASM_EMIT("vadd.f32 q0, q0, q2") /* q0 = den = l*l + r*r */ + __ASM_EMIT("vcge.f32 q4, q0, q12") /* q4 = [den >= thresh] */ + __ASM_EMIT("vrsqrte.f32 q6, q0") /* q6 = x0 */ + __ASM_EMIT("vmul.f32 q8, q6, q0") /* q8 = R * x0 */ + __ASM_EMIT("vrsqrts.f32 q10, q8, q6") /* q10 = (3 - R * x0 * x0) / 2 */ + __ASM_EMIT("vmul.f32 q6, q6, q10") /* q6 = x1 = x0 * (3 - R * x0 * x0) / 2 */ + __ASM_EMIT("vmul.f32 q8, q6, q0") /* q8 = R * x1 */ + __ASM_EMIT("vrsqrts.f32 q10, q8, q6") /* q10 = (3 - R * x1 * x1) / 2 */ + __ASM_EMIT("vmul.f32 q0, q6, q10") /* q0 = 1/sqrt(den) = x2 = x1 * (3 - R * x1 * x1) / 2 */ + __ASM_EMIT("vmul.f32 q0, q2, q6") /* q0 = pan = r*r/sqrt(den) */ + __ASM_EMIT("vbif q0, q15, q4") /* q0 = (den >= thresh) ? fabsf(r)/den : dfl */ + __ASM_EMIT("sub %[count], #4") + __ASM_EMIT("vstm %[dst]!, {q0-q1}") + /* 1x blocks */ + __ASM_EMIT("4:") + __ASM_EMIT("adds %[count], #3") + __ASM_EMIT("blt 6f") + __ASM_EMIT("5:") + __ASM_EMIT("vld1.32 {d0[], d1[]}, [%[l]]!") + __ASM_EMIT("vld1.32 {d4[], d5[]}, [%[r]]!") + __ASM_EMIT("vmul.f32 q0, q0, q0") /* q0 = l*l */ + __ASM_EMIT("vmul.f32 q2, q2, q2") /* q2 = r*r */ + __ASM_EMIT("vadd.f32 q0, q0, q2") /* q0 = den = l*l + r*r */ + __ASM_EMIT("vcge.f32 q4, q0, q12") /* q4 = [den >= thresh] */ + __ASM_EMIT("vrsqrte.f32 q6, q0") /* q6 = x0 */ + __ASM_EMIT("vmul.f32 q8, q6, q0") /* q8 = R * x0 */ + __ASM_EMIT("vrsqrts.f32 q10, q8, q6") /* q10 = (3 - R * x0 * x0) / 2 */ + __ASM_EMIT("vmul.f32 q6, q6, q10") /* q6 = x1 = x0 * (3 - R * x0 * x0) / 2 */ + __ASM_EMIT("vmul.f32 q8, q6, q0") /* q8 = R * x1 */ + __ASM_EMIT("vrsqrts.f32 q10, q8, q6") /* q10 = (3 - R * x1 * x1) / 2 */ + __ASM_EMIT("vmul.f32 q0, q6, q10") /* q0 = 1/sqrt(den) = x2 = x1 * (3 - R * x1 * x1) / 2 */ + __ASM_EMIT("vmul.f32 q0, q2, q6") /* q0 = pan = r*r/sqrt(den) */ + __ASM_EMIT("vbif q0, q15, q4") /* q0 = (den >= thresh) ? fabsf(r)/den : dfl */ + __ASM_EMIT("subs %[count], #1") + __ASM_EMIT("vst1.32 {d0[0]}, [%[dst]]!") + __ASM_EMIT("bge 5b") + /* end */ + __ASM_EMIT("6:") + + : [dst] "+r" (dst), [l] "+r" (l), [r] "+r" (r), + [count] "+r" (count), + [dfl] "+t" (dfl) + : [CC] "r" (&depan_eqpow_const_f[0]) + : "cc", "memory", + /* "q0" */, "q1", "q2", "q3", + "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15" + ); + } + } /* namespace neon_d32 */ +} /* namespace lsp */ + + +#endif /* PRIVATE_DSP_ARCH_ARM_NEON_D32_PAN_H_ */ diff --git a/src/main/arm/neon-d32.cpp b/src/main/arm/neon-d32.cpp index 9dcb918b..5ad6ae9e 100644 --- a/src/main/arm/neon-d32.cpp +++ b/src/main/arm/neon-d32.cpp @@ -68,6 +68,7 @@ #include #include #include + #include #include #include #include @@ -418,6 +419,9 @@ EXPORT1(mix_add3); EXPORT1(mix_add4); + EXPORT1(depan_lin); + EXPORT1(depan_eqpow); + EXPORT1(lin_inter_set); EXPORT1(lin_inter_mul2); EXPORT1(lin_inter_mul3); diff --git a/src/test/ptest/pan/depan_eqpow.cpp b/src/test/ptest/pan/depan_eqpow.cpp index 5958bb60..d18309d0 100644 --- a/src/test/ptest/pan/depan_eqpow.cpp +++ b/src/test/ptest/pan/depan_eqpow.cpp @@ -56,7 +56,7 @@ namespace lsp IF_ARCH_ARM( namespace neon_d32 { -// void depan_eqpow(float *dst, const float *l, const float *r, float dfl, size_t count); + void depan_eqpow(float *dst, const float *l, const float *r, float dfl, size_t count); } ) @@ -111,7 +111,7 @@ PTEST_BEGIN("dsp.pan", depan_eqpow, 5, 1000) IF_ARCH_X86(CALL(avx::depan_eqpow)); IF_ARCH_X86(CALL(avx::depan_eqpow_fma3)); IF_ARCH_X86(CALL(avx512::depan_eqpow)); -// IF_ARCH_ARM(CALL(neon_d32::depan_eqpow)); + IF_ARCH_ARM(CALL(neon_d32::depan_eqpow)); // IF_ARCH_AARCH64(CALL(asimd::depan_eqpow)); PTEST_SEPARATOR; } diff --git a/src/test/ptest/pan/depan_lin.cpp b/src/test/ptest/pan/depan_lin.cpp index 500702a9..bc582042 100644 --- a/src/test/ptest/pan/depan_lin.cpp +++ b/src/test/ptest/pan/depan_lin.cpp @@ -55,7 +55,7 @@ namespace lsp IF_ARCH_ARM( namespace neon_d32 { -// void depan_lin(float *dst, const float *l, const float *r, float dfl, size_t count); + void depan_lin(float *dst, const float *l, const float *r, float dfl, size_t count); } ) @@ -109,7 +109,7 @@ PTEST_BEGIN("dsp.pan", depan_lin, 5, 1000) IF_ARCH_X86(CALL(sse::depan_lin)); IF_ARCH_X86(CALL(avx::depan_lin)); IF_ARCH_X86(CALL(avx512::depan_lin)); -// IF_ARCH_ARM(CALL(neon_d32::depan_lin)); + IF_ARCH_ARM(CALL(neon_d32::depan_lin)); // IF_ARCH_AARCH64(CALL(asimd::depan_lin)); PTEST_SEPARATOR; } diff --git a/src/test/utest/pan/depan_eqpow.cpp b/src/test/utest/pan/depan_eqpow.cpp index c9af0f6a..5ab8e6fa 100644 --- a/src/test/utest/pan/depan_eqpow.cpp +++ b/src/test/utest/pan/depan_eqpow.cpp @@ -52,7 +52,7 @@ namespace lsp IF_ARCH_ARM( namespace neon_d32 { -// void depan_eqpow(float *dst, const float *l, const float *r, float dfl, size_t count); + void depan_eqpow(float *dst, const float *l, const float *r, float dfl, size_t count); } ) @@ -118,7 +118,7 @@ UTEST_BEGIN("dsp.pan", depan_eqpow) IF_ARCH_X86(CALL(avx::depan_eqpow, 32)); IF_ARCH_X86(CALL(avx::depan_eqpow_fma3, 32)); IF_ARCH_X86(CALL(avx512::depan_eqpow, 64)); -// IF_ARCH_ARM(CALL(neon_d32::depan_eqpow, 16)); + IF_ARCH_ARM(CALL(neon_d32::depan_eqpow, 16)); // IF_ARCH_AARCH64(CALL(asimd::depan_eqpow, 16)); } diff --git a/src/test/utest/pan/depan_lin.cpp b/src/test/utest/pan/depan_lin.cpp index d83d0627..cd3077eb 100644 --- a/src/test/utest/pan/depan_lin.cpp +++ b/src/test/utest/pan/depan_lin.cpp @@ -51,7 +51,7 @@ namespace lsp IF_ARCH_ARM( namespace neon_d32 { -// void depan_lin(float *dst, const float *l, const float *r, float dfl, size_t count); + void depan_lin(float *dst, const float *l, const float *r, float dfl, size_t count); } ) @@ -116,7 +116,7 @@ UTEST_BEGIN("dsp.pan", depan_lin) IF_ARCH_X86(CALL(sse::depan_lin, 16)); IF_ARCH_X86(CALL(avx::depan_lin, 32)); IF_ARCH_X86(CALL(avx512::depan_lin, 64)); -// IF_ARCH_ARM(CALL(neon_d32::depan_lin, 16)); + IF_ARCH_ARM(CALL(neon_d32::depan_lin, 16)); // IF_ARCH_AARCH64(CALL(asimd::depan_lin, 16)); }