diff --git a/sse2rvv.h b/sse2rvv.h index 4728f0b..2b538f5 100644 --- a/sse2rvv.h +++ b/sse2rvv.h @@ -2438,13 +2438,41 @@ FORCE_INLINE __m128 _mm_sqrt_ss(__m128 a) { return vreinterpretq_f32_m128(__riscv_vslideup_vx_f32m1(_a, rnd, 0, 1)); } -// FORCE_INLINE __m128i _mm_sra_epi16 (__m128i a, __m128i count) {} +FORCE_INLINE __m128i _mm_sra_epi16(__m128i a, __m128i count) { + vint16m1_t _a = vreinterpretq_m128i_i16(a); + vint64m1_t _count = vreinterpretq_m128i_i64(count); + int64_t count_non_vec = __riscv_vmv_x_s_i64m1_i64(_count); + int64_t count_non_vec_shift = count_non_vec >> 1; + vint16m1_t a_s = __riscv_vsra_vx_i16m1(_a, count_non_vec_shift, 8); + return vreinterpretq_i16_m128i( + __riscv_vsra_vx_i16m1(a_s, count_non_vec - count_non_vec_shift, 8)); +} -// FORCE_INLINE __m128i _mm_sra_epi32 (__m128i a, __m128i count) {} +FORCE_INLINE __m128i _mm_sra_epi32(__m128i a, __m128i count) { + vint32m1_t _a = vreinterpretq_m128i_i32(a); + vint64m1_t _count = vreinterpretq_m128i_i64(count); + int64_t count_non_vec = __riscv_vmv_x_s_i64m1_i64(_count); + int64_t count_non_vec_shift = count_non_vec >> 1; + vint32m1_t a_s = __riscv_vsra_vx_i32m1(_a, count_non_vec_shift, 4); + return vreinterpretq_i32_m128i( + __riscv_vsra_vx_i32m1(a_s, count_non_vec - count_non_vec_shift, 4)); +} -// FORCE_INLINE __m128i _mm_srai_epi16 (__m128i a, int imm8) {} +FORCE_INLINE __m128i _mm_srai_epi16(__m128i a, int imm8) { + vint16m1_t _a = vreinterpretq_m128i_i16(a); + int64_t imm8_shift = imm8 >> 1; + vint16m1_t a_s = __riscv_vsra_vx_i16m1(_a, imm8_shift, 8); + return vreinterpretq_i16_m128i( + __riscv_vsra_vx_i16m1(a_s, imm8 - imm8_shift, 8)); +} -// FORCE_INLINE __m128i _mm_srai_epi32 (__m128i a, int imm8) {} +FORCE_INLINE __m128i _mm_srai_epi32(__m128i a, int imm8) { + vint32m1_t _a = vreinterpretq_m128i_i32(a); + int64_t imm8_shift = imm8 >> 1; + vint32m1_t a_s = __riscv_vsra_vx_i32m1(_a, imm8_shift, 4); + return vreinterpretq_i32_m128i( + __riscv_vsra_vx_i32m1(a_s, imm8 - imm8_shift, 4)); +} // FORCE_INLINE __m128i _mm_srl_epi16 (__m128i a, __m128i count) {} diff --git a/tests/impl.cpp b/tests/impl.cpp index 06ab316..bf48934 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -7022,89 +7022,89 @@ result_t test_mm_sqrt_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { } result_t test_mm_sra_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1; - // const int64_t count = (int64_t)(iter % 18 - 1); // range: -1 ~ 16 - // - // int16_t d[8]; - // d[0] = (count & ~15) ? (_a[0] < 0 ? ~UINT16_C(0) : 0) : (_a[0] >> count); - // d[1] = (count & ~15) ? (_a[1] < 0 ? ~UINT16_C(0) : 0) : (_a[1] >> count); - // d[2] = (count & ~15) ? (_a[2] < 0 ? ~UINT16_C(0) : 0) : (_a[2] >> count); - // d[3] = (count & ~15) ? (_a[3] < 0 ? ~UINT16_C(0) : 0) : (_a[3] >> count); - // d[4] = (count & ~15) ? (_a[4] < 0 ? ~UINT16_C(0) : 0) : (_a[4] >> count); - // d[5] = (count & ~15) ? (_a[5] < 0 ? ~UINT16_C(0) : 0) : (_a[5] >> count); - // d[6] = (count & ~15) ? (_a[6] < 0 ? ~UINT16_C(0) : 0) : (_a[6] >> count); - // d[7] = (count & ~15) ? (_a[7] < 0 ? ~UINT16_C(0) : 0) : (_a[7] >> count); - // - // __m128i a = _mm_load_si128((const __m128i *)_a); - // __m128i b = _mm_set1_epi64x(count); - // __m128i c = _mm_sra_epi16(a, b); - // - // return VALIDATE_INT16_M128(c, d); - // #else +#ifdef ENABLE_TEST_ALL + const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1; + const int64_t count = (int64_t)(iter % 18 - 1); // range: -1 ~ 16 + + int16_t d[8]; + d[0] = (count & ~15) ? (_a[0] < 0 ? ~UINT16_C(0) : 0) : (_a[0] >> count); + d[1] = (count & ~15) ? (_a[1] < 0 ? ~UINT16_C(0) : 0) : (_a[1] >> count); + d[2] = (count & ~15) ? (_a[2] < 0 ? ~UINT16_C(0) : 0) : (_a[2] >> count); + d[3] = (count & ~15) ? (_a[3] < 0 ? ~UINT16_C(0) : 0) : (_a[3] >> count); + d[4] = (count & ~15) ? (_a[4] < 0 ? ~UINT16_C(0) : 0) : (_a[4] >> count); + d[5] = (count & ~15) ? (_a[5] < 0 ? ~UINT16_C(0) : 0) : (_a[5] >> count); + d[6] = (count & ~15) ? (_a[6] < 0 ? ~UINT16_C(0) : 0) : (_a[6] >> count); + d[7] = (count & ~15) ? (_a[7] < 0 ? ~UINT16_C(0) : 0) : (_a[7] >> count); + + __m128i a = _mm_load_si128((const __m128i *)_a); + __m128i b = _mm_set1_epi64x(count); + __m128i c = _mm_sra_epi16(a, b); + + return VALIDATE_INT16_M128(c, d); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_sra_epi32(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const int32_t *_a = (const int32_t *)impl.test_cases_int_pointer1; - // const int64_t count = (int64_t)(iter % 34 - 1); // range: -1 ~ 32 - // - // int32_t d[4]; - // d[0] = (count & ~31) ? (_a[0] < 0 ? ~UINT32_C(0) : 0) : _a[0] >> count; - // d[1] = (count & ~31) ? (_a[1] < 0 ? ~UINT32_C(0) : 0) : _a[1] >> count; - // d[2] = (count & ~31) ? (_a[2] < 0 ? ~UINT32_C(0) : 0) : _a[2] >> count; - // d[3] = (count & ~31) ? (_a[3] < 0 ? ~UINT32_C(0) : 0) : _a[3] >> count; - // - // __m128i a = _mm_load_si128((const __m128i *)_a); - // __m128i b = _mm_set1_epi64x(count); - // __m128i c = _mm_sra_epi32(a, b); - // - // return VALIDATE_INT32_M128(c, d); - // #else +#ifdef ENABLE_TEST_ALL + const int32_t *_a = (const int32_t *)impl.test_cases_int_pointer1; + const int64_t count = (int64_t)(iter % 34 - 1); // range: -1 ~ 32 + + int32_t d[4]; + d[0] = (count & ~31) ? (_a[0] < 0 ? ~UINT32_C(0) : 0) : _a[0] >> count; + d[1] = (count & ~31) ? (_a[1] < 0 ? ~UINT32_C(0) : 0) : _a[1] >> count; + d[2] = (count & ~31) ? (_a[2] < 0 ? ~UINT32_C(0) : 0) : _a[2] >> count; + d[3] = (count & ~31) ? (_a[3] < 0 ? ~UINT32_C(0) : 0) : _a[3] >> count; + + __m128i a = _mm_load_si128((const __m128i *)_a); + __m128i b = _mm_set1_epi64x(count); + __m128i c = _mm_sra_epi32(a, b); + + return VALIDATE_INT32_M128(c, d); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_srai_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1; - // const int32_t b = (int32_t)(iter % 18 - 1); // range: -1 ~ 16 - // int16_t d[8]; - // int count = (b & ~15) ? 15 : b; - // - // for (int i = 0; i < 8; i++) { - // d[i] = _a[i] >> count; - // } - // - // __m128i a = _mm_load_si128((const __m128i *)_a); - // __m128i c = _mm_srai_epi16(a, b); - // - // return VALIDATE_INT16_M128(c, d); - // #else +#ifdef ENABLE_TEST_ALL + const int16_t *_a = (const int16_t *)impl.test_cases_int_pointer1; + const int32_t b = (int32_t)(iter % 18 - 1); // range: -1 ~ 16 + int16_t d[8]; + int count = (b & ~15) ? 15 : b; + + for (int i = 0; i < 8; i++) { + d[i] = _a[i] >> count; + } + + __m128i a = _mm_load_si128((const __m128i *)_a); + __m128i c = _mm_srai_epi16(a, b); + + return VALIDATE_INT16_M128(c, d); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_srai_epi32(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const int32_t *_a = (const int32_t *)impl.test_cases_int_pointer1; - // const int32_t b = (int32_t)(iter % 34 - 1); // range: -1 ~ 32 - // - // int32_t d[4]; - // int count = (b & ~31) ? 31 : b; - // for (int i = 0; i < 4; i++) { - // d[i] = _a[i] >> count; - // } - // - // __m128i a = _mm_load_si128((const __m128i *)_a); - // __m128i c = _mm_srai_epi32(a, b); - // - // return VALIDATE_INT32_M128(c, d); - // #else +#ifdef ENABLE_TEST_ALL + const int32_t *_a = (const int32_t *)impl.test_cases_int_pointer1; + const int32_t b = (int32_t)(iter % 34 - 1); // range: -1 ~ 32 + + int32_t d[4]; + int count = (b & ~31) ? 31 : b; + for (int i = 0; i < 4; i++) { + d[i] = _a[i] >> count; + } + + __m128i a = _mm_load_si128((const __m128i *)_a); + __m128i c = _mm_srai_epi32(a, b); + + return VALIDATE_INT32_M128(c, d); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_srl_epi16(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {