Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add _mm_hadd* and _mm_hsub* #40

Merged
merged 1 commit into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 155 additions & 19 deletions sse2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -1456,37 +1456,173 @@ FORCE_INLINE int _mm_extract_ps(__m128 a, const int imm8) {

// FORCE_INLINE unsigned int _mm_getcsr (void) {}

// FORCE_INLINE __m128i _mm_hadd_epi16 (__m128i a, __m128i b) {}

// FORCE_INLINE __m128i _mm_hadd_epi32 (__m128i a, __m128i b) {}

// FORCE_INLINE __m128d _mm_hadd_pd (__m128d a, __m128d b) {}

// FORCE_INLINE __m64 _mm_hadd_pi16 (__m64 a, __m64 b) {}
FORCE_INLINE __m128i _mm_hadd_epi16(__m128i a, __m128i b) {
vint16m2_t _a = __riscv_vlmul_ext_v_i16m1_i16m2(vreinterpretq_m128i_i16(a));
vint16m2_t _b = __riscv_vlmul_ext_v_i16m1_i16m2(vreinterpretq_m128i_i16(b));
vint16m2_t ab = __riscv_vslideup_vx_i16m2(_a, _b, 8, 16);
vint16m2_t ab_s = __riscv_vslidedown_vx_i16m2(ab, 1, 16);
vint32m2_t ab_add =
__riscv_vreinterpret_v_i16m2_i32m2(__riscv_vadd_vv_i16m2(ab, ab_s, 16));
return vreinterpretq_i16_m128i(__riscv_vnsra_wx_i16m1(ab_add, 0, 8));
}

FORCE_INLINE __m128i _mm_hadd_epi32(__m128i a, __m128i b) {
vint32m2_t _a = __riscv_vlmul_ext_v_i32m1_i32m2(vreinterpretq_m128i_i32(a));
vint32m2_t _b = __riscv_vlmul_ext_v_i32m1_i32m2(vreinterpretq_m128i_i32(b));
vint32m2_t ab = __riscv_vslideup_vx_i32m2(_a, _b, 4, 8);
vint32m2_t ab_s = __riscv_vslidedown_vx_i32m2(ab, 1, 8);
vint64m2_t ab_add =
__riscv_vreinterpret_v_i32m2_i64m2(__riscv_vadd_vv_i32m2(ab, ab_s, 8));
return vreinterpretq_i32_m128i(__riscv_vnsra_wx_i32m1(ab_add, 0, 4));
}

FORCE_INLINE __m128d _mm_hadd_pd(__m128d a, __m128d b) {
vfloat64m2_t _a = __riscv_vlmul_ext_v_f64m1_f64m2(vreinterpretq_m128d_f64(a));
vfloat64m2_t _b = __riscv_vlmul_ext_v_f64m1_f64m2(vreinterpretq_m128d_f64(b));
vfloat64m2_t ab = __riscv_vslideup_vx_f64m2(_a, _b, 2, 4);
vfloat64m2_t ab_s = __riscv_vslidedown_vx_f64m2(ab, 1, 4);
vfloat64m2_t ab_add = __riscv_vfadd_vv_f64m2(ab, ab_s, 4);
vbool32_t mask = __riscv_vreinterpret_v_u8m1_b32(__riscv_vmv_s_x_u8m1(85, 2));
return vreinterpretq_f64_m128d(__riscv_vlmul_trunc_v_f64m2_f64m1(
__riscv_vcompress_vm_f64m2(ab_add, mask, 4)));
}

FORCE_INLINE __m64 _mm_hadd_pi16(__m64 a, __m64 b) {
vint16m1_t _a = vreinterpretq_m64_i16(a);
vint16m1_t _b = vreinterpretq_m64_i16(b);
vint16m1_t ab = __riscv_vslideup_vx_i16m1(_a, _b, 4, 8);
vint16m1_t ab_s = __riscv_vslidedown_vx_i16m1(ab, 1, 8);
vint32m1_t ab_add =
__riscv_vreinterpret_v_i16m1_i32m1(__riscv_vadd_vv_i16m1(ab, ab_s, 8));
return vreinterpretq_i16_m64(
__riscv_vlmul_ext_v_i16mf2_i16m1(__riscv_vnsra_wx_i16mf2(ab_add, 0, 4)));
}

// FORCE_INLINE __m64 _mm_hadd_pi32 (__m64 a, __m64 b) {}
FORCE_INLINE __m64 _mm_hadd_pi32(__m64 a, __m64 b) {
vint32m1_t _a = vreinterpretq_m64_i32(a);
vint32m1_t _b = vreinterpretq_m64_i32(b);
vint32m1_t ab = __riscv_vslideup_vx_i32m1(_a, _b, 2, 4);
vint32m1_t ab_s = __riscv_vslidedown_vx_i32m1(ab, 1, 4);
vint64m1_t ab_add =
__riscv_vreinterpret_v_i32m1_i64m1(__riscv_vadd_vv_i32m1(ab, ab_s, 4));
return vreinterpretq_i32_m64(
__riscv_vlmul_ext_v_i32mf2_i32m1(__riscv_vnsra_wx_i32mf2(ab_add, 0, 2)));
}

// FORCE_INLINE __m128 _mm_hadd_ps (__m128 a, __m128 b) {}
FORCE_INLINE __m128 _mm_hadd_ps(__m128 a, __m128 b) {
vfloat32m2_t _a = __riscv_vlmul_ext_v_f32m1_f32m2(vreinterpretq_m128_f32(a));
vfloat32m2_t _b = __riscv_vlmul_ext_v_f32m1_f32m2(vreinterpretq_m128_f32(b));
vfloat32m2_t ab = __riscv_vslideup_vx_f32m2(_a, _b, 4, 8);
vfloat32m2_t ab_s = __riscv_vslidedown_vx_f32m2(ab, 1, 8);
vint64m2_t ab_add = __riscv_vreinterpret_v_i32m2_i64m2(
__riscv_vreinterpret_v_f32m2_i32m2(__riscv_vfadd_vv_f32m2(ab, ab_s, 8)));
return vreinterpretq_i32_m128(__riscv_vnsra_wx_i32m1(ab_add, 0, 4));
}

// FORCE_INLINE __m128i _mm_hadds_epi16 (__m128i a, __m128i b) {}
FORCE_INLINE __m128i _mm_hadds_epi16(__m128i a, __m128i b) {
vint16m2_t _a = __riscv_vlmul_ext_v_i16m1_i16m2(vreinterpretq_m128i_i16(a));
vint16m2_t _b = __riscv_vlmul_ext_v_i16m1_i16m2(vreinterpretq_m128i_i16(b));
vint16m2_t ab = __riscv_vslideup_vx_i16m2(_a, _b, 8, 16);
vint16m2_t ab_s = __riscv_vslidedown_vx_i16m2(ab, 1, 16);
vint32m2_t ab_add =
__riscv_vreinterpret_v_i16m2_i32m2(__riscv_vsadd_vv_i16m2(ab, ab_s, 16));
return vreinterpretq_i16_m128i(__riscv_vnsra_wx_i16m1(ab_add, 0, 8));
}

// FORCE_INLINE __m64 _mm_hadds_pi16 (__m64 a, __m64 b) {}
FORCE_INLINE __m64 _mm_hadds_pi16(__m64 a, __m64 b) {
vint16m1_t _a = vreinterpretq_m64_i16(a);
vint16m1_t _b = vreinterpretq_m64_i16(b);
vint16m1_t ab = __riscv_vslideup_vx_i16m1(_a, _b, 4, 8);
vint16m1_t ab_s = __riscv_vslidedown_vx_i16m1(ab, 1, 8);
vint32m1_t ab_add =
__riscv_vreinterpret_v_i16m1_i32m1(__riscv_vsadd_vv_i16m1(ab, ab_s, 8));
return vreinterpretq_i16_m64(
__riscv_vlmul_ext_v_i16mf2_i16m1(__riscv_vnsra_wx_i16mf2(ab_add, 0, 4)));
}

// FORCE_INLINE __m128i _mm_hsub_epi16 (__m128i a, __m128i b) {}
FORCE_INLINE __m128i _mm_hsub_epi16(__m128i a, __m128i b) {
vint16m2_t _a = __riscv_vlmul_ext_v_i16m1_i16m2(vreinterpretq_m128i_i16(a));
vint16m2_t _b = __riscv_vlmul_ext_v_i16m1_i16m2(vreinterpretq_m128i_i16(b));
vint16m2_t ab = __riscv_vslideup_vx_i16m2(_a, _b, 8, 16);
vint16m2_t ab_s = __riscv_vslidedown_vx_i16m2(ab, 1, 16);
vint32m2_t ab_sub =
__riscv_vreinterpret_v_i16m2_i32m2(__riscv_vsub_vv_i16m2(ab, ab_s, 16));
return vreinterpretq_i16_m128i(__riscv_vnsra_wx_i16m1(ab_sub, 0, 8));
}

// FORCE_INLINE __m128i _mm_hsub_epi32 (__m128i a, __m128i b) {}
FORCE_INLINE __m128i _mm_hsub_epi32(__m128i a, __m128i b) {
vint32m2_t _a = __riscv_vlmul_ext_v_i32m1_i32m2(vreinterpretq_m128i_i32(a));
vint32m2_t _b = __riscv_vlmul_ext_v_i32m1_i32m2(vreinterpretq_m128i_i32(b));
vint32m2_t ab = __riscv_vslideup_vx_i32m2(_a, _b, 4, 8);
vint32m2_t ab_s = __riscv_vslidedown_vx_i32m2(ab, 1, 8);
vint64m2_t ab_sub =
__riscv_vreinterpret_v_i32m2_i64m2(__riscv_vsub_vv_i32m2(ab, ab_s, 8));
return vreinterpretq_i32_m128i(__riscv_vnsra_wx_i32m1(ab_sub, 0, 4));
}

// FORCE_INLINE __m128d _mm_hsub_pd (__m128d a, __m128d b) {}
FORCE_INLINE __m128d _mm_hsub_pd(__m128d a, __m128d b) {
vfloat64m2_t _a = __riscv_vlmul_ext_v_f64m1_f64m2(vreinterpretq_m128d_f64(a));
vfloat64m2_t _b = __riscv_vlmul_ext_v_f64m1_f64m2(vreinterpretq_m128d_f64(b));
vfloat64m2_t ab = __riscv_vslideup_vx_f64m2(_a, _b, 2, 4);
vfloat64m2_t ab_s = __riscv_vslidedown_vx_f64m2(ab, 1, 4);
vfloat64m2_t ab_sub = __riscv_vfsub_vv_f64m2(ab, ab_s, 4);
vbool32_t mask = __riscv_vreinterpret_v_u8m1_b32(__riscv_vmv_s_x_u8m1(85, 2));
return vreinterpretq_f64_m128d(__riscv_vlmul_trunc_v_f64m2_f64m1(
__riscv_vcompress_vm_f64m2(ab_sub, mask, 4)));
}

// FORCE_INLINE __m64 _mm_hsub_pi16 (__m64 a, __m64 b) {}
FORCE_INLINE __m64 _mm_hsub_pi16(__m64 a, __m64 b) {
vint16m1_t _a = vreinterpretq_m64_i16(a);
vint16m1_t _b = vreinterpretq_m64_i16(b);
vint16m1_t ab = __riscv_vslideup_vx_i16m1(_a, _b, 4, 8);
vint16m1_t ab_s = __riscv_vslidedown_vx_i16m1(ab, 1, 8);
vint32m1_t ab_sub =
__riscv_vreinterpret_v_i16m1_i32m1(__riscv_vsub_vv_i16m1(ab, ab_s, 8));
return vreinterpretq_i16_m64(
__riscv_vlmul_ext_v_i16mf2_i16m1(__riscv_vnsra_wx_i16mf2(ab_sub, 0, 4)));
}

// FORCE_INLINE __m64 _mm_hsub_pi32 (__m64 a, __m64 b) {}
FORCE_INLINE __m64 _mm_hsub_pi32(__m64 a, __m64 b) {
vint32m1_t _a = vreinterpretq_m64_i32(a);
vint32m1_t _b = vreinterpretq_m64_i32(b);
vint32m1_t ab = __riscv_vslideup_vx_i32m1(_a, _b, 2, 4);
vint32m1_t ab_s = __riscv_vslidedown_vx_i32m1(ab, 1, 4);
vint64m1_t ab_sub =
__riscv_vreinterpret_v_i32m1_i64m1(__riscv_vsub_vv_i32m1(ab, ab_s, 4));
return vreinterpretq_i32_m64(
__riscv_vlmul_ext_v_i32mf2_i32m1(__riscv_vnsra_wx_i32mf2(ab_sub, 0, 2)));
}

// FORCE_INLINE __m128 _mm_hsub_ps (__m128 a, __m128 b) {}
FORCE_INLINE __m128 _mm_hsub_ps(__m128 a, __m128 b) {
vfloat32m2_t _a = __riscv_vlmul_ext_v_f32m1_f32m2(vreinterpretq_m128_f32(a));
vfloat32m2_t _b = __riscv_vlmul_ext_v_f32m1_f32m2(vreinterpretq_m128_f32(b));
vfloat32m2_t ab = __riscv_vslideup_vx_f32m2(_a, _b, 4, 8);
vfloat32m2_t ab_s = __riscv_vslidedown_vx_f32m2(ab, 1, 8);
vint64m2_t ab_sub = __riscv_vreinterpret_v_i32m2_i64m2(
__riscv_vreinterpret_v_f32m2_i32m2(__riscv_vfsub_vv_f32m2(ab, ab_s, 8)));
return vreinterpretq_i32_m128(__riscv_vnsra_wx_i32m1(ab_sub, 0, 4));
}

// FORCE_INLINE __m128i _mm_hsubs_epi16 (__m128i a, __m128i b) {}
FORCE_INLINE __m128i _mm_hsubs_epi16(__m128i a, __m128i b) {
vint16m2_t _a = __riscv_vlmul_ext_v_i16m1_i16m2(vreinterpretq_m128i_i16(a));
vint16m2_t _b = __riscv_vlmul_ext_v_i16m1_i16m2(vreinterpretq_m128i_i16(b));
vint16m2_t ab = __riscv_vslideup_vx_i16m2(_a, _b, 8, 16);
vint16m2_t ab_s = __riscv_vslidedown_vx_i16m2(ab, 1, 16);
vint32m2_t ab_sub =
__riscv_vreinterpret_v_i16m2_i32m2(__riscv_vssub_vv_i16m2(ab, ab_s, 16));
return vreinterpretq_i16_m128i(__riscv_vnsra_wx_i16m1(ab_sub, 0, 8));
}

// FORCE_INLINE __m64 _mm_hsubs_pi16 (__m64 a, __m64 b) {}
FORCE_INLINE __m64 _mm_hsubs_pi16(__m64 a, __m64 b) {
vint16m1_t _a = vreinterpretq_m64_i16(a);
vint16m1_t _b = vreinterpretq_m64_i16(b);
vint16m1_t ab = __riscv_vslideup_vx_i16m1(_a, _b, 4, 8);
vint16m1_t ab_s = __riscv_vslidedown_vx_i16m1(ab, 1, 8);
vint32m1_t ab_sub =
__riscv_vreinterpret_v_i16m1_i32m1(__riscv_vssub_vv_i16m1(ab, ab_s, 8));
return vreinterpretq_i16_m64(
__riscv_vlmul_ext_v_i16mf2_i16m1(__riscv_vnsra_wx_i16mf2(ab_sub, 0, 4)));
}

FORCE_INLINE __m128i _mm_insert_epi16(__m128i a, int i, int imm8) {
vint16m1_t _a = vreinterpretq_m128i_i16(a);
Expand Down
6 changes: 6 additions & 0 deletions tests/debug_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ void print_64_bits_s64_arr(const char *var_name, const int64_t *u) {
void print_64_bits_f32_arr(const char *var_name, const float *f) {
printf("%s0: %.3f, %s1: %.3f\n", var_name, f[0], var_name, f[1]);
}
void print_64_bits_f64_arr(const char *var_name, const double *f) {
printf("%s0: %.6f\n", var_name, f[0]);
}
void print_128_bits_u8_arr(const char *var_name, const uint8_t *u) {
printf("%s0: %3u, %s1: %3u, %s2: %3u, %s3: %3u, %s4: %3u, %s5: %3u, "
"%s6: %3u, %s7: %3u, %s8: %3u, %s9: %3u, %s10: %3u, %s11: %3u, "
Expand Down Expand Up @@ -97,6 +100,9 @@ void print_128_bits_f32_arr(const char *var_name, const float *f) {
printf("%s0: %.3f, %s1: %.3f, %s2: %.3f, %s3: %.3f\n", var_name, f[0],
var_name, f[1], var_name, f[2], var_name, f[3]);
}
void print_128_bits_f64_arr(const char *var_name, const double *f) {
printf("%s0: %.6f, %s1: %.6f\n", var_name, f[0], var_name, f[1]);
}

void print_u8_64(const char *var_name, uint8_t u0, uint8_t u1, uint8_t u2,
uint8_t u3, uint8_t u4, uint8_t u5, uint8_t u6, uint8_t u7) {
Expand Down
10 changes: 10 additions & 0 deletions tests/debug_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void print_64_bits_s32_arr(const char *var_name, const int32_t *u);
void print_64_bits_u64_arr(const char *var_name, const uint64_t *u);
void print_64_bits_s64_arr(const char *var_name, const int64_t *u);
void print_64_bits_f32_arr(const char *var_name, const float *f);
void print_64_bits_f64_arr(const char *var_name, const double *f);
void print_128_bits_u8_arr(const char *var_name, const uint8_t *u);
void print_128_bits_s8_arr(const char *var_name, const int8_t *u);
void print_128_bits_u16_arr(const char *var_name, const uint16_t *u);
Expand All @@ -35,6 +36,7 @@ void print_128_bits_s32_arr(const char *var_name, const int32_t *u);
void print_128_bits_u64_arr(const char *var_name, const uint64_t *u);
void print_128_bits_s64_arr(const char *var_name, const int64_t *u);
void print_128_bits_f32_arr(const char *var_name, const float *f);
void print_128_bits_f64_arr(const char *var_name, const double *f);

void print_u8_64(const char *var_name, uint8_t u0, uint8_t u1, uint8_t u2,
uint8_t u3, uint8_t u4, uint8_t u5, uint8_t u6, uint8_t u7);
Expand Down Expand Up @@ -213,6 +215,14 @@ template <typename T> void print_f32_128(const char *var_name, T *a) {
const float *f = (const float *)a;
print_128_bits_f32_arr(var_name, f);
}
template <typename T> void print_f64_128(const char *var_name, T a) {
const double *f = (const double *)&a;
print_128_bits_f64_arr(var_name, f);
}
template <typename T> void print_f64_128(const char *var_name, T *a) {
const double *f = (const double *)a;
print_128_bits_f64_arr(var_name, f);
}

} // namespace SSE2RVV

Expand Down
Loading