From 7de1c40d79250b6c73d796ea0a33274722101a29 Mon Sep 17 00:00:00 2001 From: Yang Hau Date: Fri, 26 Jul 2024 03:12:24 +0800 Subject: [PATCH] feat: Add vqdmlal_high_lane_[s16|s32] --- neon2rvv.h | 24 ++++++++++++++++------ tests/impl.cpp | 54 ++++++++++++++++++++++++++++++++++++++++++++++++-- tests/impl.h | 4 ++-- 3 files changed, 72 insertions(+), 10 deletions(-) diff --git a/neon2rvv.h b/neon2rvv.h index 63892d0d..6f3d26cd 100644 --- a/neon2rvv.h +++ b/neon2rvv.h @@ -8993,15 +8993,15 @@ FORCE_INLINE uint64x2_t vmlal_high_laneq_u32(uint64x2_t a, uint32x4_t b, uint32x __riscv_vwmaccu_vv_u64m2(__riscv_vlmul_ext_v_u64m1_u64m2(a), b_high, c_dup, 2)); } -FORCE_INLINE int32x4_t vqdmlal_lane_s16(int32x4_t a, int16x4_t b, int16x4_t c, const int __d) { - vint16m1_t c_dup = __riscv_vrgather_vx_i16m1(c, __d, 4); +FORCE_INLINE int32x4_t vqdmlal_lane_s16(int32x4_t a, int16x4_t b, int16x4_t c, const int lane) { + vint16m1_t c_dup = __riscv_vrgather_vx_i16m1(c, lane, 4); vint32m1_t bc_mul = __riscv_vlmul_trunc_v_i32m2_i32m1(__riscv_vwmul_vv_i32m2(b, c_dup, 4)); vint32m1_t bc_mulx2 = __riscv_vmul_vx_i32m1(bc_mul, 2, 4); return __riscv_vadd_vv_i32m1(a, bc_mulx2, 4); } -FORCE_INLINE int64x2_t vqdmlal_lane_s32(int64x2_t a, int32x2_t b, int32x2_t c, const int __d) { - vint32m1_t c_dup = __riscv_vrgather_vx_i32m1(c, __d, 2); +FORCE_INLINE int64x2_t vqdmlal_lane_s32(int64x2_t a, int32x2_t b, int32x2_t c, const int lane) { + vint32m1_t c_dup = __riscv_vrgather_vx_i32m1(c, lane, 2); vint64m1_t bc_mul = __riscv_vlmul_trunc_v_i64m2_i64m1(__riscv_vwmul_vv_i64m2(b, c_dup, 2)); vint64m1_t bc_mulx2 = __riscv_vmul_vx_i64m1(bc_mul, 2, 2); return __riscv_vadd_vv_i64m1(a, bc_mulx2, 2); @@ -9011,9 +9011,21 @@ FORCE_INLINE int64x2_t vqdmlal_lane_s32(int64x2_t a, int32x2_t b, int32x2_t c, c // FORCE_INLINE int64_t vqdmlals_lane_s32(int64_t a, int32_t b, int32x2_t v, const int lane); -// FORCE_INLINE int32x4_t vqdmlal_high_lane_s16(int32x4_t a, int16x8_t b, int16x4_t v, const int lane); +FORCE_INLINE int32x4_t vqdmlal_high_lane_s16(int32x4_t a, int16x8_t b, int16x4_t c, const int lane) { + vint16m1_t b_high = __riscv_vslidedown_vx_i16m1(b, 4, 8); + vint16m1_t c_dup = __riscv_vrgather_vx_i16m1(c, lane, 4); + vint32m1_t bc_mul = __riscv_vlmul_trunc_v_i32m2_i32m1(__riscv_vwmul_vv_i32m2(b_high, c_dup, 4)); + vint32m1_t bc_mulx2 = __riscv_vmul_vx_i32m1(bc_mul, 2, 4); + return __riscv_vadd_vv_i32m1(a, bc_mulx2, 4); +} -// FORCE_INLINE int64x2_t vqdmlal_high_lane_s32(int64x2_t a, int32x4_t b, int32x2_t v, const int lane); +FORCE_INLINE int64x2_t vqdmlal_high_lane_s32(int64x2_t a, int32x4_t b, int32x2_t c, const int lane) { + vint32m1_t b_high = __riscv_vslidedown_vx_i32m1(b, 2, 4); + vint32m1_t c_dup = __riscv_vrgather_vx_i32m1(c, lane, 2); + vint64m1_t bc_mul = __riscv_vlmul_trunc_v_i64m2_i64m1(__riscv_vwmul_vv_i64m2(b_high, c_dup, 2)); + vint64m1_t bc_mulx2 = __riscv_vmul_vx_i64m1(bc_mul, 2, 2); + return __riscv_vadd_vv_i64m1(a, bc_mulx2, 2); +} // FORCE_INLINE int32x4_t vqdmlal_laneq_s16(int32x4_t a, int16x4_t b, int16x8_t v, const int lane); diff --git a/tests/impl.cpp b/tests/impl.cpp index cd9e0a07..be196be0 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -31892,9 +31892,59 @@ result_t test_vqdmlalh_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { result_t test_vqdmlals_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } -result_t test_vqdmlal_high_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vqdmlal_high_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const int32_t *_a = (int32_t *)impl.test_cases_int_pointer1; + const int16_t *_b = (int16_t *)impl.test_cases_int_pointer2; + const int16_t *_c = (int16_t *)impl.test_cases_int_pointer3; + int32x4_t a = vld1q_s32(_a); + int16x8_t b = vld1q_s16(_b); + int16x4_t c = vld1_s16(_c); + ; + int32x4_t d; + int32_t _d[4]; +#define TEST_IMPL(IDX) \ + for (int i = 0; i < 4; i++) { \ + _d[i] = sat_add(_a[i], sat_dmull(_b[i + 4], _c[IDX])); \ + } \ + d = vqdmlal_high_lane_s16(a, b, c, IDX); \ + CHECK_RESULT(validate_int32(d, _d[0], _d[1], _d[2], _d[3])) + + IMM_4_ITER +#undef TEST_IMPL -result_t test_vqdmlal_high_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } + return TEST_SUCCESS; +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} + +result_t test_vqdmlal_high_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const int64_t *_a = (int64_t *)impl.test_cases_int_pointer1; + const int32_t *_b = (int32_t *)impl.test_cases_int_pointer2; + const int32_t *_c = (int32_t *)impl.test_cases_int_pointer3; + int64x2_t a = vld1q_s64(_a); + int32x4_t b = vld1q_s32(_b); + int32x2_t c = vld1_s32(_c); + int64x2_t d; + int64_t _d[2]; + +#define TEST_IMPL(IDX) \ + for (int i = 0; i < 2; i++) { \ + _d[i] = sat_add(_a[i], sat_dmull(_b[i + 2], _c[IDX])); \ + } \ + d = vqdmlal_high_lane_s32(a, b, c, IDX); \ + CHECK_RESULT(validate_int64(d, _d[0], _d[1])) + + IMM_2_ITER +#undef TEST_IMPL + + return TEST_SUCCESS; +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} result_t test_vqdmlal_laneq_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } diff --git a/tests/impl.h b/tests/impl.h index 0d94a702..60ddb442 100644 --- a/tests/impl.h +++ b/tests/impl.h @@ -1950,8 +1950,8 @@ _(vqdmlal_lane_s32) \ /*_(vqdmlalh_lane_s16) */ \ /*_(vqdmlals_lane_s32) */ \ - /*_(vqdmlal_high_lane_s16) */ \ - /*_(vqdmlal_high_lane_s32) */ \ + _(vqdmlal_high_lane_s16) \ + _(vqdmlal_high_lane_s32) \ /*_(vqdmlal_laneq_s16) */ \ /*_(vqdmlal_laneq_s32) */ \ /*_(vqdmlalh_laneq_s16) */ \