From 1d7915627e889796bf2a89dc6c51dd41e5300886 Mon Sep 17 00:00:00 2001 From: Yang Hau Date: Wed, 31 Jul 2024 00:03:24 +0800 Subject: [PATCH] feat: Add vqdmlal[h|s]_lane[q]_[s16|s32] --- neon2rvv.h | 28 +++++++++++++--- tests/impl.cpp | 88 +++++++++++++++++++++++++++++++++++++++++++++++--- tests/impl.h | 8 ++--- 3 files changed, 112 insertions(+), 12 deletions(-) diff --git a/neon2rvv.h b/neon2rvv.h index 2dcc0e3c..b97ed5ff 100644 --- a/neon2rvv.h +++ b/neon2rvv.h @@ -9210,9 +9210,19 @@ FORCE_INLINE int64x2_t vqdmlal_lane_s32(int64x2_t a, int32x2_t b, int32x2_t c, c return __riscv_vadd_vv_i64m1(a, bc_mulx2, 2); } -// FORCE_INLINE int32_t vqdmlalh_lane_s16(int32_t a, int16_t b, int16x4_t v, const int lane); +FORCE_INLINE int32_t vqdmlalh_lane_s16(int32_t a, int16_t b, int16x4_t c, const int lane) { + int16_t c_lane = vget_lane_s16(c, lane); + int32_t dmull = (int32_t)b * (int32_t)c_lane; + dmull = dmull > INT32_MAX / 2 ? INT32_MAX : dmull < INT32_MIN / 2 ? INT32_MIN : dmull * 2; + return sat_add_int32(a, dmull); +} -// FORCE_INLINE int64_t vqdmlals_lane_s32(int64_t a, int32_t b, int32x2_t v, const int lane); +FORCE_INLINE int64_t vqdmlals_lane_s32(int64_t a, int32_t b, int32x2_t c, const int lane) { + int32_t c_lane = vget_lane_s32(c, lane); + int64_t dmull = (int64_t)b * (int64_t)c_lane; + dmull = dmull > INT64_MAX / 2 ? INT64_MAX : dmull < INT64_MIN / 2 ? INT64_MIN : dmull * 2; + return sat_add_int64(a, dmull); +} FORCE_INLINE int32x4_t vqdmlal_high_lane_s16(int32x4_t a, int16x8_t b, int16x4_t c, const int lane) { vint16m1_t b_high = __riscv_vslidedown_vx_i16m1(b, 4, 8); @@ -9244,9 +9254,19 @@ FORCE_INLINE int64x2_t vqdmlal_laneq_s32(int64x2_t a, int32x2_t b, int32x4_t c, return __riscv_vadd_vv_i64m1(a, bc_mulx2, 2); } -// FORCE_INLINE int32_t vqdmlalh_laneq_s16(int32_t a, int16_t b, int16x8_t c, const int lane); +FORCE_INLINE int32_t vqdmlalh_laneq_s16(int32_t a, int16_t b, int16x8_t c, const int lane) { + int16_t c_lane = vgetq_lane_s16(c, lane); + int32_t dmull = (int32_t)b * (int32_t)c_lane; + dmull = dmull > INT32_MAX / 2 ? INT32_MAX : dmull < INT32_MIN / 2 ? INT32_MIN : dmull * 2; + return sat_add_int32(a, dmull); +} -// FORCE_INLINE int64_t vqdmlals_laneq_s32(int64_t a, int32_t b, int32x4_t c, const int lane); +FORCE_INLINE int64_t vqdmlals_laneq_s32(int64_t a, int32_t b, int32x4_t c, const int lane) { + int32_t c_lane = vgetq_lane_s32(c, lane); + int64_t dmull = (int64_t)b * (int64_t)c_lane; + dmull = dmull > INT64_MAX / 2 ? INT64_MAX : dmull < INT64_MIN / 2 ? INT64_MIN : dmull * 2; + return sat_add_int64(a, dmull); +} FORCE_INLINE int32x4_t vqdmlal_high_laneq_s16(int32x4_t a, int16x8_t b, int16x8_t c, const int lane) { vint16m1_t b_high = __riscv_vslidedown_vx_i16m1(b, 4, 8); diff --git a/tests/impl.cpp b/tests/impl.cpp index 6d9c57fc..e9d15254 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -33423,9 +33423,49 @@ result_t test_vqdmlal_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #endif // ENABLE_TEST_ALL } -result_t test_vqdmlalh_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vqdmlalh_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const int32_t *_a = (int32_t *)impl.test_cases_int_pointer1; + const int16_t *_b = (int16_t *)impl.test_cases_int_pointer2; + const int16_t *_c = (int16_t *)impl.test_cases_int_pointer3; + int32_t _d, d; + int16x4_t c = vld1_s16(_c); + +#define TEST_IMPL(IDX) \ + _d = sat_add(_a[0], sat_dmull(_b[0], _c[IDX])); \ + d = vqdmlalh_lane_s16(_a[0], _b[0], c, IDX); \ + CHECK_RESULT(d == _d ? TEST_SUCCESS : TEST_FAIL) + + IMM_4_ITER +#undef TEST_IMPL + + return TEST_SUCCESS; +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} -result_t test_vqdmlals_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vqdmlals_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const int64_t *_a = (int64_t *)impl.test_cases_int_pointer1; + const int32_t *_b = (int32_t *)impl.test_cases_int_pointer2; + const int32_t *_c = (int32_t *)impl.test_cases_int_pointer3; + int64_t _d, d; + int32x2_t c = vld1_s32(_c); + +#define TEST_IMPL(IDX) \ + _d = sat_add(_a[0], sat_dmull(_b[0], _c[IDX])); \ + d = vqdmlals_lane_s32(_a[0], _b[0], c, IDX); \ + CHECK_RESULT(d == _d ? TEST_SUCCESS : TEST_FAIL) + + IMM_2_ITER +#undef TEST_IMPL + + return TEST_SUCCESS; +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} result_t test_vqdmlal_high_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #ifdef ENABLE_TEST_ALL @@ -33535,9 +33575,49 @@ result_t test_vqdmlal_laneq_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #endif // ENABLE_TEST_ALL } -result_t test_vqdmlalh_laneq_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vqdmlalh_laneq_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const int32_t *_a = (int32_t *)impl.test_cases_int_pointer1; + const int16_t *_b = (int16_t *)impl.test_cases_int_pointer2; + const int16_t *_c = (int16_t *)impl.test_cases_int_pointer3; + int32_t _d, d; + int16x8_t c = vld1q_s16(_c); + +#define TEST_IMPL(IDX) \ + _d = sat_add(_a[0], sat_dmull(_b[0], _c[IDX])); \ + d = vqdmlalh_laneq_s16(_a[0], _b[0], c, IDX); \ + CHECK_RESULT(d == _d ? TEST_SUCCESS : TEST_FAIL) + + IMM_8_ITER +#undef TEST_IMPL + + return TEST_SUCCESS; +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} -result_t test_vqdmlals_laneq_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vqdmlals_laneq_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const int64_t *_a = (int64_t *)impl.test_cases_int_pointer1; + const int32_t *_b = (int32_t *)impl.test_cases_int_pointer2; + const int32_t *_c = (int32_t *)impl.test_cases_int_pointer3; + int64_t _d, d; + int32x4_t c = vld1q_s32(_c); + +#define TEST_IMPL(IDX) \ + _d = sat_add(_a[0], sat_dmull(_b[0], _c[IDX])); \ + d = vqdmlals_laneq_s32(_a[0], _b[0], c, IDX); \ + CHECK_RESULT(d == _d ? TEST_SUCCESS : TEST_FAIL) + + IMM_4_ITER +#undef TEST_IMPL + + return TEST_SUCCESS; +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} result_t test_vqdmlal_high_laneq_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #ifdef ENABLE_TEST_ALL diff --git a/tests/impl.h b/tests/impl.h index 1919bdf6..fe58a6fc 100644 --- a/tests/impl.h +++ b/tests/impl.h @@ -1948,14 +1948,14 @@ _(vmlal_high_laneq_u32) \ _(vqdmlal_lane_s16) \ _(vqdmlal_lane_s32) \ - /*_(vqdmlalh_lane_s16) */ \ - /*_(vqdmlals_lane_s32) */ \ + _(vqdmlalh_lane_s16) \ + _(vqdmlals_lane_s32) \ _(vqdmlal_high_lane_s16) \ _(vqdmlal_high_lane_s32) \ _(vqdmlal_laneq_s16) \ _(vqdmlal_laneq_s32) \ - /*_(vqdmlalh_laneq_s16) */ \ - /*_(vqdmlals_laneq_s32) */ \ + _(vqdmlalh_laneq_s16) \ + _(vqdmlals_laneq_s32) \ _(vqdmlal_high_laneq_s16) \ _(vqdmlal_high_laneq_s32) \ _(vmls_lane_s16) \