Skip to content

Commit

Permalink
feat: Add _mm_madd[|ubs]_[e]pi16
Browse files Browse the repository at this point in the history
  • Loading branch information
howjmay committed Jan 27, 2024
1 parent 7a0e236 commit 4185564
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 89 deletions.
36 changes: 32 additions & 4 deletions sse2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -1911,11 +1911,39 @@ FORCE_INLINE __m128i _mm_loadu_si64(void const *mem_addr) {
return vreinterpretq_i64_m128i(__riscv_vslideup_vx_i64m1(zeros, ld, 0, 1));
}

// FORCE_INLINE __m128i _mm_madd_epi16 (__m128i a, __m128i b) {}

// FORCE_INLINE __m128i _mm_maddubs_epi16 (__m128i a, __m128i b) {}
FORCE_INLINE __m128i _mm_madd_epi16(__m128i a, __m128i b) {
vint16m1_t _a = vreinterpretq_m128i_i16(a);
vint16m1_t _b = vreinterpretq_m128i_i16(b);
vint32m2_t wmul = __riscv_vwmul_vv_i32m2(_a, _b, 8);
vint32m2_t wmul_s = __riscv_vslidedown_vx_i32m2(wmul, 1, 8);
vint32m2_t wmul_add = __riscv_vadd_vv_i32m2(wmul, wmul_s, 8);
return vreinterpretq_i32_m128i(__riscv_vnsra_wx_i32m1(
__riscv_vreinterpret_v_i32m2_i64m2(wmul_add), 0, 4));
}

FORCE_INLINE __m128i _mm_maddubs_epi16(__m128i a, __m128i b) {
vint16m2_t _a = __riscv_vreinterpret_v_u16m2_i16m2(
__riscv_vzext_vf2_u16m2(vreinterpretq_m128i_u8(a), 16));
vint16m2_t _b = __riscv_vsext_vf2_i16m2(vreinterpretq_m128i_i8(b), 16);
vint16m2_t mul = __riscv_vmul_vv_i16m2(_a, _b, 16);
vint16m2_t mul_s = __riscv_vslidedown_vx_i16m2(mul, 1, 16);
vint32m4_t mul_add = __riscv_vwadd_vv_i32m4(mul, mul_s, 16);
vint16m2_t sat = __riscv_vnclip_wx_i16m2(mul_add, 0, __RISCV_VXRM_RDN, 16);
return vreinterpretq_i16_m128i(
__riscv_vnsra_wx_i16m1(__riscv_vreinterpret_v_i16m2_i32m2(sat), 0, 16));
}

// FORCE_INLINE __m64 _mm_maddubs_pi16 (__m64 a, __m64 b) {}
FORCE_INLINE __m64 _mm_maddubs_pi16(__m64 a, __m64 b) {
vint16m2_t _a = __riscv_vreinterpret_v_u16m2_i16m2(
__riscv_vzext_vf2_u16m2(vreinterpretq_m128i_u8(a), 8));
vint16m2_t _b = __riscv_vsext_vf2_i16m2(vreinterpretq_m128i_i8(b), 8);
vint16m2_t mul = __riscv_vmul_vv_i16m2(_a, _b, 8);
vint16m2_t mul_s = __riscv_vslidedown_vx_i16m2(mul, 1, 8);
vint32m4_t mul_add = __riscv_vwadd_vv_i32m4(mul, mul_s, 8);
vint16m2_t sat = __riscv_vnclip_wx_i16m2(mul_add, 0, __RISCV_VXRM_RDN, 8);
return vreinterpretq_i16_m128i(
__riscv_vnsra_wx_i16m1(__riscv_vreinterpret_v_i16m2_i32m2(sat), 0, 8));
}

FORCE_INLINE void *_mm_malloc(size_t size, size_t align) {
void *ptr;
Expand Down
143 changes: 58 additions & 85 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5896,31 +5896,26 @@ result_t test_mm_loadu_si32(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_madd_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1;
// const int16_t *_b = (const int16_t *)impl.test_cases_int_pointer2;
// int32_t d0 = (int32_t)_a[0] * _b[0];
// int32_t d1 = (int32_t)_a[1] * _b[1];
// int32_t d2 = (int32_t)_a[2] * _b[2];
// int32_t d3 = (int32_t)_a[3] * _b[3];
// int32_t d4 = (int32_t)_a[4] * _b[4];
// int32_t d5 = (int32_t)_a[5] * _b[5];
// int32_t d6 = (int32_t)_a[6] * _b[6];
// int32_t d7 = (int32_t)_a[7] * _b[7];
//
// int32_t e[4];
// e[0] = d0 + d1;
// e[1] = d2 + d3;
// e[2] = d4 + d5;
// e[3] = d6 + d7;
//
// __m128i a = load_m128i(_a);
// __m128i b = load_m128i(_b);
// __m128i c = _mm_madd_epi16(a, b);
// return VALIDATE_INT32_M128(c, e);
// #else
#ifdef ENABLE_TEST_ALL
const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1;
const int16_t *_b = (const int16_t *)impl.test_cases_int_pointer2;
int32_t __c[8];
for (int i = 0; i < 8; i++) {
__c[i] = (int32_t)_a[i] * _b[i];
}

int32_t _c[4];
for (int i = 0; i < 4; i++) {
_c[i] = __c[2 * i] + __c[2 * i + 1];
}

__m128i a = load_m128i(_a);
__m128i b = load_m128i(_b);
__m128i c = _mm_madd_epi16(a, b);
return VALIDATE_INT32_M128(c, _c);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_maskmoveu_si128(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down Expand Up @@ -8997,72 +8992,50 @@ result_t test_mm_hsubs_pi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_maddubs_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const uint8_t *_a = (const uint8_t *)impl.test_cases_int_pointer1;
// const int8_t *_b = (const int8_t *)impl.test_cases_int_pointer2;
// int32_t d0 = (int32_t)(_a[0] * _b[0]);
// int32_t d1 = (int32_t)(_a[1] * _b[1]);
// int32_t d2 = (int32_t)(_a[2] * _b[2]);
// int32_t d3 = (int32_t)(_a[3] * _b[3]);
// int32_t d4 = (int32_t)(_a[4] * _b[4]);
// int32_t d5 = (int32_t)(_a[5] * _b[5]);
// int32_t d6 = (int32_t)(_a[6] * _b[6]);
// int32_t d7 = (int32_t)(_a[7] * _b[7]);
// int32_t d8 = (int32_t)(_a[8] * _b[8]);
// int32_t d9 = (int32_t)(_a[9] * _b[9]);
// int32_t d10 = (int32_t)(_a[10] * _b[10]);
// int32_t d11 = (int32_t)(_a[11] * _b[11]);
// int32_t d12 = (int32_t)(_a[12] * _b[12]);
// int32_t d13 = (int32_t)(_a[13] * _b[13]);
// int32_t d14 = (int32_t)(_a[14] * _b[14]);
// int32_t d15 = (int32_t)(_a[15] * _b[15]);
//
// int16_t e[8];
// e[0] = saturate_16(d0 + d1);
// e[1] = saturate_16(d2 + d3);
// e[2] = saturate_16(d4 + d5);
// e[3] = saturate_16(d6 + d7);
// e[4] = saturate_16(d8 + d9);
// e[5] = saturate_16(d10 + d11);
// e[6] = saturate_16(d12 + d13);
// e[7] = saturate_16(d14 + d15);
//
// __m128i a = load_m128i(_a);
// __m128i b = load_m128i(_b);
// __m128i c = _mm_maddubs_epi16(a, b);
// return VALIDATE_INT16_M128(c, e);
// #else
#ifdef ENABLE_TEST_ALL
const uint8_t *_a = (const uint8_t *)impl.test_cases_int_pointer1;
const int8_t *_b = (const int8_t *)impl.test_cases_int_pointer2;
int32_t __c[16];
for (int i = 0; i < 16; i++) {
__c[i] = (int32_t)(_a[i] * _b[i]);
}

int16_t _c[8];
for (int i = 0; i < 8; i++) {
_c[i] = saturate_16(__c[2 * i] + __c[2 * i + 1]);
}

__m128i a = load_m128i(_a);
__m128i b = load_m128i(_b);
__m128i c = _mm_maddubs_epi16(a, b);
return VALIDATE_INT16_M128(c, _c);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_maddubs_pi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const uint8_t *_a = (const uint8_t *)impl.test_cases_int_pointer1;
// const int8_t *_b = (const int8_t *)impl.test_cases_int_pointer2;
// int16_t d0 = (int16_t)(_a[0] * _b[0]);
// int16_t d1 = (int16_t)(_a[1] * _b[1]);
// int16_t d2 = (int16_t)(_a[2] * _b[2]);
// int16_t d3 = (int16_t)(_a[3] * _b[3]);
// int16_t d4 = (int16_t)(_a[4] * _b[4]);
// int16_t d5 = (int16_t)(_a[5] * _b[5]);
// int16_t d6 = (int16_t)(_a[6] * _b[6]);
// int16_t d7 = (int16_t)(_a[7] * _b[7]);
//
// int16_t e[4];
// e[0] = saturate_16(d0 + d1);
// e[1] = saturate_16(d2 + d3);
// e[2] = saturate_16(d4 + d5);
// e[3] = saturate_16(d6 + d7);
//
// __m64 a = load_m64(_a);
// __m64 b = load_m64(_b);
// __m64 c = _mm_maddubs_pi16(a, b);
//
// return VALIDATE_INT16_M64(c, e);
// #else
#ifdef ENABLE_TEST_ALL
const uint8_t *_a = (const uint8_t *)impl.test_cases_int_pointer1;
const int8_t *_b = (const int8_t *)impl.test_cases_int_pointer2;
int32_t __c[8];
for (int i = 0; i < 8; i++) {
__c[i] = (int32_t)(_a[i] * _b[i]);
}

int16_t _c[4];
for (int i = 0; i < 4; i++) {
_c[i] = saturate_16(__c[2 * i] + __c[2 * i + 1]);
}

__m64 a = load_m64(_a);
__m64 b = load_m64(_b);
__m64 c = _mm_maddubs_pi16(a, b);

return VALIDATE_INT16_M64(c, _c);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_mulhrs_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down

0 comments on commit 4185564

Please sign in to comment.