From 10997fcc4e7e5b9d6da38a00dff2e8b1ba3568eb Mon Sep 17 00:00:00 2001 From: Yang Hau Date: Fri, 26 Jul 2024 11:37:08 +0800 Subject: [PATCH] feat: Add vqdmlal_laneq_[s16|s32] --- neon2rvv.h | 14 +++++++++++-- tests/impl.cpp | 54 ++++++++++++++++++++++++++++++++++++++++++++++++-- tests/impl.h | 4 ++-- 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/neon2rvv.h b/neon2rvv.h index 6f3d26cd..ead2ddc1 100644 --- a/neon2rvv.h +++ b/neon2rvv.h @@ -9027,9 +9027,19 @@ FORCE_INLINE int64x2_t vqdmlal_high_lane_s32(int64x2_t a, int32x4_t b, int32x2_t return __riscv_vadd_vv_i64m1(a, bc_mulx2, 2); } -// FORCE_INLINE int32x4_t vqdmlal_laneq_s16(int32x4_t a, int16x4_t b, int16x8_t v, const int lane); +FORCE_INLINE int32x4_t vqdmlal_laneq_s16(int32x4_t a, int16x4_t b, int16x8_t c, const int lane) { + vint16m1_t c_dup = __riscv_vrgather_vx_i16m1(c, lane, 8); + vint32m1_t bc_mul = __riscv_vlmul_trunc_v_i32m2_i32m1(__riscv_vwmul_vv_i32m2(b, c_dup, 4)); + vint32m1_t bc_mulx2 = __riscv_vmul_vx_i32m1(bc_mul, 2, 4); + return __riscv_vadd_vv_i32m1(a, bc_mulx2, 4); +} -// FORCE_INLINE int64x2_t vqdmlal_laneq_s32(int64x2_t a, int32x2_t b, int32x4_t v, const int lane); +FORCE_INLINE int64x2_t vqdmlal_laneq_s32(int64x2_t a, int32x2_t b, int32x4_t c, const int lane) { + vint32m1_t c_dup = __riscv_vrgather_vx_i32m1(c, lane, 4); + vint64m1_t bc_mul = __riscv_vlmul_trunc_v_i64m2_i64m1(__riscv_vwmul_vv_i64m2(b, c_dup, 2)); + vint64m1_t bc_mulx2 = __riscv_vmul_vx_i64m1(bc_mul, 2, 2); + return __riscv_vadd_vv_i64m1(a, bc_mulx2, 2); +} // FORCE_INLINE int32_t vqdmlalh_laneq_s16(int32_t a, int16_t b, int16x8_t v, const int lane); diff --git a/tests/impl.cpp b/tests/impl.cpp index be196be0..a7251c1e 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -31946,9 +31946,59 @@ result_t test_vqdmlal_high_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t ite #endif // ENABLE_TEST_ALL } -result_t test_vqdmlal_laneq_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vqdmlal_laneq_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const int32_t *_a = (int32_t *)impl.test_cases_int_pointer1; + const int16_t *_b = (int16_t *)impl.test_cases_int_pointer2; + const int16_t *_c = (int16_t *)impl.test_cases_int_pointer3; + int32x4_t d; + int32_t _d[4]; + int32x4_t a = vld1q_s32(_a); + int16x4_t b = vld1_s16(_b); + int16x8_t c = vld1q_s16(_c); + +#define TEST_IMPL(IDX) \ + for (int i = 0; i < 4; i++) { \ + _d[i] = sat_add(_a[i], sat_dmull(_b[i], _c[IDX])); \ + } \ + d = vqdmlal_laneq_s16(a, b, c, IDX); \ + CHECK_RESULT(validate_int32(d, _d[0], _d[1], _d[2], _d[3])) + + IMM_8_ITER +#undef TEST_IMPL + + return TEST_SUCCESS; +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} -result_t test_vqdmlal_laneq_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vqdmlal_laneq_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const int64_t *_a = (int64_t *)impl.test_cases_int_pointer1; + const int32_t *_b = (int32_t *)impl.test_cases_int_pointer2; + const int32_t *_c = (int32_t *)impl.test_cases_int_pointer3; + int64x2_t d; + int64x2_t a = vld1q_s64(_a); + int32x2_t b = vld1_s32(_b); + int32x4_t c = vld1q_s32(_c); + int64_t _d[2]; + +#define TEST_IMPL(IDX) \ + for (int i = 0; i < 2; i++) { \ + _d[i] = sat_add(_a[i], sat_dmull(_b[i], _c[IDX])); \ + } \ + d = vqdmlal_laneq_s32(a, b, c, IDX); \ + CHECK_RESULT(validate_int64(d, _d[0], _d[1])) + + IMM_4_ITER +#undef TEST_IMPL + + return TEST_SUCCESS; +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} result_t test_vqdmlalh_laneq_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } diff --git a/tests/impl.h b/tests/impl.h index 60ddb442..a1c324a9 100644 --- a/tests/impl.h +++ b/tests/impl.h @@ -1952,8 +1952,8 @@ /*_(vqdmlals_lane_s32) */ \ _(vqdmlal_high_lane_s16) \ _(vqdmlal_high_lane_s32) \ - /*_(vqdmlal_laneq_s16) */ \ - /*_(vqdmlal_laneq_s32) */ \ + _(vqdmlal_laneq_s16) \ + _(vqdmlal_laneq_s32) \ /*_(vqdmlalh_laneq_s16) */ \ /*_(vqdmlals_laneq_s32) */ \ /*_(vqdmlal_high_laneq_s16) */ \