From 4185564e81c906c8d113e20e56b3052fb7a3a862 Mon Sep 17 00:00:00 2001 From: Yang Hau Date: Fri, 26 Jan 2024 23:35:11 +0800 Subject: [PATCH] feat: Add _mm_madd[|ubs]_[e]pi16 --- sse2rvv.h | 36 +++++++++++-- tests/impl.cpp | 143 ++++++++++++++++++++----------------------------- 2 files changed, 90 insertions(+), 89 deletions(-) diff --git a/sse2rvv.h b/sse2rvv.h index 766fa55..6f4fad0 100644 --- a/sse2rvv.h +++ b/sse2rvv.h @@ -1911,11 +1911,39 @@ FORCE_INLINE __m128i _mm_loadu_si64(void const *mem_addr) { return vreinterpretq_i64_m128i(__riscv_vslideup_vx_i64m1(zeros, ld, 0, 1)); } -// FORCE_INLINE __m128i _mm_madd_epi16 (__m128i a, __m128i b) {} - -// FORCE_INLINE __m128i _mm_maddubs_epi16 (__m128i a, __m128i b) {} +FORCE_INLINE __m128i _mm_madd_epi16(__m128i a, __m128i b) { + vint16m1_t _a = vreinterpretq_m128i_i16(a); + vint16m1_t _b = vreinterpretq_m128i_i16(b); + vint32m2_t wmul = __riscv_vwmul_vv_i32m2(_a, _b, 8); + vint32m2_t wmul_s = __riscv_vslidedown_vx_i32m2(wmul, 1, 8); + vint32m2_t wmul_add = __riscv_vadd_vv_i32m2(wmul, wmul_s, 8); + return vreinterpretq_i32_m128i(__riscv_vnsra_wx_i32m1( + __riscv_vreinterpret_v_i32m2_i64m2(wmul_add), 0, 4)); +} + +FORCE_INLINE __m128i _mm_maddubs_epi16(__m128i a, __m128i b) { + vint16m2_t _a = __riscv_vreinterpret_v_u16m2_i16m2( + __riscv_vzext_vf2_u16m2(vreinterpretq_m128i_u8(a), 16)); + vint16m2_t _b = __riscv_vsext_vf2_i16m2(vreinterpretq_m128i_i8(b), 16); + vint16m2_t mul = __riscv_vmul_vv_i16m2(_a, _b, 16); + vint16m2_t mul_s = __riscv_vslidedown_vx_i16m2(mul, 1, 16); + vint32m4_t mul_add = __riscv_vwadd_vv_i32m4(mul, mul_s, 16); + vint16m2_t sat = __riscv_vnclip_wx_i16m2(mul_add, 0, __RISCV_VXRM_RDN, 16); + return vreinterpretq_i16_m128i( + __riscv_vnsra_wx_i16m1(__riscv_vreinterpret_v_i16m2_i32m2(sat), 0, 16)); +} -// FORCE_INLINE __m64 _mm_maddubs_pi16 (__m64 a, __m64 b) {} +FORCE_INLINE __m64 _mm_maddubs_pi16(__m64 a, __m64 b) { + vint16m2_t _a = __riscv_vreinterpret_v_u16m2_i16m2( + __riscv_vzext_vf2_u16m2(vreinterpretq_m128i_u8(a), 8)); + vint16m2_t _b = __riscv_vsext_vf2_i16m2(vreinterpretq_m128i_i8(b), 8); + vint16m2_t mul = __riscv_vmul_vv_i16m2(_a, _b, 8); + vint16m2_t mul_s = __riscv_vslidedown_vx_i16m2(mul, 1, 8); + vint32m4_t mul_add = __riscv_vwadd_vv_i32m4(mul, mul_s, 8); + vint16m2_t sat = __riscv_vnclip_wx_i16m2(mul_add, 0, __RISCV_VXRM_RDN, 8); + return vreinterpretq_i16_m128i( + __riscv_vnsra_wx_i16m1(__riscv_vreinterpret_v_i16m2_i32m2(sat), 0, 8)); +} FORCE_INLINE void *_mm_malloc(size_t size, size_t align) { void *ptr; diff --git a/tests/impl.cpp b/tests/impl.cpp index 30a80c1..072bd4b 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -5896,31 +5896,26 @@ result_t test_mm_loadu_si32(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { } result_t test_mm_madd_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1; - // const int16_t *_b = (const int16_t *)impl.test_cases_int_pointer2; - // int32_t d0 = (int32_t)_a[0] * _b[0]; - // int32_t d1 = (int32_t)_a[1] * _b[1]; - // int32_t d2 = (int32_t)_a[2] * _b[2]; - // int32_t d3 = (int32_t)_a[3] * _b[3]; - // int32_t d4 = (int32_t)_a[4] * _b[4]; - // int32_t d5 = (int32_t)_a[5] * _b[5]; - // int32_t d6 = (int32_t)_a[6] * _b[6]; - // int32_t d7 = (int32_t)_a[7] * _b[7]; - // - // int32_t e[4]; - // e[0] = d0 + d1; - // e[1] = d2 + d3; - // e[2] = d4 + d5; - // e[3] = d6 + d7; - // - // __m128i a = load_m128i(_a); - // __m128i b = load_m128i(_b); - // __m128i c = _mm_madd_epi16(a, b); - // return VALIDATE_INT32_M128(c, e); - // #else +#ifdef ENABLE_TEST_ALL + const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1; + const int16_t *_b = (const int16_t *)impl.test_cases_int_pointer2; + int32_t __c[8]; + for (int i = 0; i < 8; i++) { + __c[i] = (int32_t)_a[i] * _b[i]; + } + + int32_t _c[4]; + for (int i = 0; i < 4; i++) { + _c[i] = __c[2 * i] + __c[2 * i + 1]; + } + + __m128i a = load_m128i(_a); + __m128i b = load_m128i(_b); + __m128i c = _mm_madd_epi16(a, b); + return VALIDATE_INT32_M128(c, _c); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_maskmoveu_si128(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { @@ -8997,72 +8992,50 @@ result_t test_mm_hsubs_pi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { } result_t test_mm_maddubs_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const uint8_t *_a = (const uint8_t *)impl.test_cases_int_pointer1; - // const int8_t *_b = (const int8_t *)impl.test_cases_int_pointer2; - // int32_t d0 = (int32_t)(_a[0] * _b[0]); - // int32_t d1 = (int32_t)(_a[1] * _b[1]); - // int32_t d2 = (int32_t)(_a[2] * _b[2]); - // int32_t d3 = (int32_t)(_a[3] * _b[3]); - // int32_t d4 = (int32_t)(_a[4] * _b[4]); - // int32_t d5 = (int32_t)(_a[5] * _b[5]); - // int32_t d6 = (int32_t)(_a[6] * _b[6]); - // int32_t d7 = (int32_t)(_a[7] * _b[7]); - // int32_t d8 = (int32_t)(_a[8] * _b[8]); - // int32_t d9 = (int32_t)(_a[9] * _b[9]); - // int32_t d10 = (int32_t)(_a[10] * _b[10]); - // int32_t d11 = (int32_t)(_a[11] * _b[11]); - // int32_t d12 = (int32_t)(_a[12] * _b[12]); - // int32_t d13 = (int32_t)(_a[13] * _b[13]); - // int32_t d14 = (int32_t)(_a[14] * _b[14]); - // int32_t d15 = (int32_t)(_a[15] * _b[15]); - // - // int16_t e[8]; - // e[0] = saturate_16(d0 + d1); - // e[1] = saturate_16(d2 + d3); - // e[2] = saturate_16(d4 + d5); - // e[3] = saturate_16(d6 + d7); - // e[4] = saturate_16(d8 + d9); - // e[5] = saturate_16(d10 + d11); - // e[6] = saturate_16(d12 + d13); - // e[7] = saturate_16(d14 + d15); - // - // __m128i a = load_m128i(_a); - // __m128i b = load_m128i(_b); - // __m128i c = _mm_maddubs_epi16(a, b); - // return VALIDATE_INT16_M128(c, e); - // #else +#ifdef ENABLE_TEST_ALL + const uint8_t *_a = (const uint8_t *)impl.test_cases_int_pointer1; + const int8_t *_b = (const int8_t *)impl.test_cases_int_pointer2; + int32_t __c[16]; + for (int i = 0; i < 16; i++) { + __c[i] = (int32_t)(_a[i] * _b[i]); + } + + int16_t _c[8]; + for (int i = 0; i < 8; i++) { + _c[i] = saturate_16(__c[2 * i] + __c[2 * i + 1]); + } + + __m128i a = load_m128i(_a); + __m128i b = load_m128i(_b); + __m128i c = _mm_maddubs_epi16(a, b); + return VALIDATE_INT16_M128(c, _c); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_maddubs_pi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const uint8_t *_a = (const uint8_t *)impl.test_cases_int_pointer1; - // const int8_t *_b = (const int8_t *)impl.test_cases_int_pointer2; - // int16_t d0 = (int16_t)(_a[0] * _b[0]); - // int16_t d1 = (int16_t)(_a[1] * _b[1]); - // int16_t d2 = (int16_t)(_a[2] * _b[2]); - // int16_t d3 = (int16_t)(_a[3] * _b[3]); - // int16_t d4 = (int16_t)(_a[4] * _b[4]); - // int16_t d5 = (int16_t)(_a[5] * _b[5]); - // int16_t d6 = (int16_t)(_a[6] * _b[6]); - // int16_t d7 = (int16_t)(_a[7] * _b[7]); - // - // int16_t e[4]; - // e[0] = saturate_16(d0 + d1); - // e[1] = saturate_16(d2 + d3); - // e[2] = saturate_16(d4 + d5); - // e[3] = saturate_16(d6 + d7); - // - // __m64 a = load_m64(_a); - // __m64 b = load_m64(_b); - // __m64 c = _mm_maddubs_pi16(a, b); - // - // return VALIDATE_INT16_M64(c, e); - // #else +#ifdef ENABLE_TEST_ALL + const uint8_t *_a = (const uint8_t *)impl.test_cases_int_pointer1; + const int8_t *_b = (const int8_t *)impl.test_cases_int_pointer2; + int32_t __c[8]; + for (int i = 0; i < 8; i++) { + __c[i] = (int32_t)(_a[i] * _b[i]); + } + + int16_t _c[4]; + for (int i = 0; i < 4; i++) { + _c[i] = saturate_16(__c[2 * i] + __c[2 * i + 1]); + } + + __m64 a = load_m64(_a); + __m64 b = load_m64(_b); + __m64 c = _mm_maddubs_pi16(a, b); + + return VALIDATE_INT16_M64(c, _c); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_mulhrs_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {