Skip to content

Commit

Permalink
Merge pull request #86 from pattonkan/set-mode
Browse files Browse the repository at this point in the history
Add rounding intrinsics
  • Loading branch information
howjmay authored Sep 5, 2024
2 parents 12be94f + 2c46729 commit c90c592
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 118 deletions.
91 changes: 88 additions & 3 deletions sse2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
* Contributors to this work are:
* Yang Hau <yuanyanghau@gmail.com>
* Cheng-Hao <chahsiao@gmail.com>
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
Expand Down Expand Up @@ -248,6 +248,49 @@ typedef union ALIGN_STRUCT(16) SIMDVec {
#endif
#endif

// XRM
// #define __RISCV_VXRM_RNU 0 // round-to-nearest-up (add +0.5 LSB)
// #define __RISCV_VXRM_RNE 1 // round-to-nearest-even
// #define __RISCV_VXRM_RDN 2 // round-down (truncate)
// #define __RISCV_VXRM_ROD 3 // round-to-odd (OR bits into LSB, aka "jam")
// FRM
// #define __RISCV_FRM_RNE 0 // round to nearest, ties to even
// #define __RISCV_FRM_RTZ 1 // round towards zero
// #define __RISCV_FRM_RDN 2 // round down (towards -infinity)
// #define __RISCV_FRM_RUP 3 // round up (towards +infinity)
// #define __RISCV_FRM_RMM 4 // round to nearest, ties to max magnitude

// The bit field mapping to the FCSR (floating-point control and status
// register)
typedef struct {
uint8_t nx : 1;
uint8_t uf : 1;
uint8_t of : 1;
uint8_t dz : 1;
uint8_t nv : 1;
uint8_t frm : 3;
uint32_t reserved : 24;
} fcsr_bitfield;

/* Rounding mode macros. */
#define _MM_FROUND_TO_NEAREST_INT 0x00
#define _MM_FROUND_TO_NEG_INF 0x01
#define _MM_FROUND_TO_POS_INF 0x02
#define _MM_FROUND_TO_ZERO 0x03
#define _MM_FROUND_CUR_DIRECTION 0x04
#define _MM_FROUND_NO_EXC 0x08
#define _MM_FROUND_RAISE_EXC 0x00
#define _MM_FROUND_NINT (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_RAISE_EXC)
#define _MM_FROUND_FLOOR (_MM_FROUND_TO_NEG_INF | _MM_FROUND_RAISE_EXC)
#define _MM_FROUND_CEIL (_MM_FROUND_TO_POS_INF | _MM_FROUND_RAISE_EXC)
#define _MM_FROUND_TRUNC (_MM_FROUND_TO_ZERO | _MM_FROUND_RAISE_EXC)
#define _MM_FROUND_RINT (_MM_FROUND_CUR_DIRECTION | _MM_FROUND_RAISE_EXC)
#define _MM_FROUND_NEARBYINT (_MM_FROUND_CUR_DIRECTION | _MM_FROUND_NO_EXC)
#define _MM_ROUND_NEAREST 0x0000
#define _MM_ROUND_DOWN 0x2000
#define _MM_ROUND_UP 0x4000
#define _MM_ROUND_TOWARD_ZERO 0x6000

// forward declaration
FORCE_INLINE int _mm_extract_pi16(__m64 a, int imm8);
FORCE_INLINE __m64 _mm_sad_pu8(__m64 a, __m64 b);
Expand Down Expand Up @@ -2537,7 +2580,26 @@ FORCE_INLINE __m128 _mm_rcp_ss(__m128 a) {

// FORCE_INLINE __m128d _mm_round_pd (__m128d a, int rounding) {}

// FORCE_INLINE __m128 _mm_round_ps (__m128 a, int rounding) {}
FORCE_INLINE __m128 _mm_round_ps(__m128 a, int rounding) {
vfloat32m1_t _a = vreinterpretq_m128_f32(a);
switch (rounding) {
case (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC):
return vreinterpretq_f32_m128(__riscv_vfcvt_f_x_v_f32m1(
__riscv_vfcvt_x_f_v_i32m1_rm(_a, __RISCV_FRM_RNE, 4), 4));
case (_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC):
return vreinterpretq_f32_m128(__riscv_vfcvt_f_x_v_f32m1(
__riscv_vfcvt_x_f_v_i32m1_rm(_a, __RISCV_FRM_RDN, 4), 4));
case (_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC):
return vreinterpretq_f32_m128(__riscv_vfcvt_f_x_v_f32m1(
__riscv_vfcvt_x_f_v_i32m1_rm(_a, __RISCV_FRM_RUP, 4), 4));
case (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC):
return vreinterpretq_f32_m128(__riscv_vfcvt_f_x_v_f32m1(
__riscv_vfcvt_x_f_v_i32m1_rm(_a, __RISCV_FRM_RTZ, 4), 4));
default: //_MM_FROUND_CUR_DIRECTION
return vreinterpretq_f32_m128(
__riscv_vfcvt_f_x_v_f32m1(__riscv_vfcvt_x_f_v_i32m1(_a, 4), 4));
}
}

// FORCE_INLINE __m128d _mm_round_sd (__m128d a, __m128d b, int rounding) {}

Expand Down Expand Up @@ -2637,7 +2699,30 @@ FORCE_INLINE __m128 _mm_set_ps1(float a) {
return vreinterpretq_f32_m128(__riscv_vfmv_v_f_f32m1(a, 4));
}

// FORCE_INLINE void _MM_SET_ROUNDING_MODE (unsigned int a) {}
FORCE_INLINE void _MM_SET_ROUNDING_MODE(unsigned int a) {
union {
fcsr_bitfield field;
uint32_t value;
} r;

__asm__ volatile("csrr %0, fcsr" : "=r"(r));

switch (a) {
case _MM_ROUND_TOWARD_ZERO:
r.field.frm = __RISCV_FRM_RTZ;
break;
case _MM_ROUND_DOWN:
r.field.frm = __RISCV_FRM_RDN;
break;
case _MM_ROUND_UP:
r.field.frm = __RISCV_FRM_RUP;
break;
default: //_MM_ROUND_NEAREST
r.field.frm = __RISCV_FRM_RNE;
}

__asm__ volatile("csrw fcsr, %0" : : "r"(r));
}

FORCE_INLINE __m128d _mm_set_sd(double a) {
double arr[2] = {a, 0};
Expand Down
230 changes: 115 additions & 115 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3098,42 +3098,42 @@ result_t test_mm_set_ps1(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {

result_t test_mm_set_rounding_mode(const SSE2RVV_TEST_IMPL &impl,
uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *_a = impl.test_cases_float_pointer1;
// result_t res_toward_zero, res_to_neg_inf, res_to_pos_inf, res_nearest;
//
// __m128 a = load_m128(_a);
// __m128 b, c;
//
// _MM_SET_ROUNDING_MODE(_MM_ROUND_TOWARD_ZERO);
// b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
// c = _mm_round_ps(a, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
// res_toward_zero = validate_128bits(c, b);
//
// _MM_SET_ROUNDING_MODE(_MM_ROUND_DOWN);
// b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
// c = _mm_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC);
// res_to_neg_inf = validate_128bits(c, b);
//
// _MM_SET_ROUNDING_MODE(_MM_ROUND_UP);
// b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
// c = _mm_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC);
// res_to_pos_inf = validate_128bits(c, b);
//
// _MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST);
// b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
// c = _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
// res_nearest = validate_128bits(c, b);
//
// if (res_toward_zero == TEST_SUCCESS && res_to_neg_inf == TEST_SUCCESS &&
// res_to_pos_inf == TEST_SUCCESS && res_nearest == TEST_SUCCESS) {
// return TEST_SUCCESS;
// } else {
// return TEST_FAIL;
// }
// #else
#ifdef ENABLE_TEST_ALL
const float *_a = impl.test_cases_float_pointer1;
result_t res_toward_zero, res_to_neg_inf, res_to_pos_inf, res_nearest;

__m128 a = load_m128(_a);
__m128 b, c;

_MM_SET_ROUNDING_MODE(_MM_ROUND_TOWARD_ZERO);
b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
c = _mm_round_ps(a, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
res_toward_zero = validate_128bits(c, b);

_MM_SET_ROUNDING_MODE(_MM_ROUND_DOWN);
b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
c = _mm_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC);
res_to_neg_inf = validate_128bits(c, b);

_MM_SET_ROUNDING_MODE(_MM_ROUND_UP);
b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
c = _mm_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC);
res_to_pos_inf = validate_128bits(c, b);

_MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST);
b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
c = _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
res_nearest = validate_128bits(c, b);

if (res_toward_zero == TEST_SUCCESS && res_to_neg_inf == TEST_SUCCESS &&
res_to_pos_inf == TEST_SUCCESS && res_nearest == TEST_SUCCESS) {
return TEST_SUCCESS;
} else {
return TEST_FAIL;
}
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_set_ss(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down Expand Up @@ -10074,87 +10074,87 @@ result_t test_mm_round_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_mm_round_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
// #ifdef ENABLE_TEST_ALL
// const float *_a = impl.test_cases_float_pointer1;
// float f[4];
// __m128 ret;
//
// __m128 a = load_m128(_a);
// switch (iter & 0x7) {
// case 0:
// f[0] = bankers_rounding(_a[0]);
// f[1] = bankers_rounding(_a[1]);
// f[2] = bankers_rounding(_a[2]);
// f[3] = bankers_rounding(_a[3]);
//
// ret = _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
// break;
// case 1:
// f[0] = floorf(_a[0]);
// f[1] = floorf(_a[1]);
// f[2] = floorf(_a[2]);
// f[3] = floorf(_a[3]);
//
// ret = _mm_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC);
// break;
// case 2:
// f[0] = ceilf(_a[0]);
// f[1] = ceilf(_a[1]);
// f[2] = ceilf(_a[2]);
// f[3] = ceilf(_a[3]);
//
// ret = _mm_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC);
// break;
// case 3:
// f[0] = _a[0] > 0 ? floorf(_a[0]) : ceilf(_a[0]);
// f[1] = _a[1] > 0 ? floorf(_a[1]) : ceilf(_a[1]);
// f[2] = _a[2] > 0 ? floorf(_a[2]) : ceilf(_a[2]);
// f[3] = _a[3] > 0 ? floorf(_a[3]) : ceilf(_a[3]);
//
// ret = _mm_round_ps(a, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
// break;
// case 4:
// f[0] = bankers_rounding(_a[0]);
// f[1] = bankers_rounding(_a[1]);
// f[2] = bankers_rounding(_a[2]);
// f[3] = bankers_rounding(_a[3]);
//
// _MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST);
// ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
// break;
// case 5:
// f[0] = floorf(_a[0]);
// f[1] = floorf(_a[1]);
// f[2] = floorf(_a[2]);
// f[3] = floorf(_a[3]);
//
// _MM_SET_ROUNDING_MODE(_MM_ROUND_DOWN);
// ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
// break;
// case 6:
// f[0] = ceilf(_a[0]);
// f[1] = ceilf(_a[1]);
// f[2] = ceilf(_a[2]);
// f[3] = ceilf(_a[3]);
//
// _MM_SET_ROUNDING_MODE(_MM_ROUND_UP);
// ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
// break;
// case 7:
// f[0] = _a[0] > 0 ? floorf(_a[0]) : ceilf(_a[0]);
// f[1] = _a[1] > 0 ? floorf(_a[1]) : ceilf(_a[1]);
// f[2] = _a[2] > 0 ? floorf(_a[2]) : ceilf(_a[2]);
// f[3] = _a[3] > 0 ? floorf(_a[3]) : ceilf(_a[3]);
//
// _MM_SET_ROUNDING_MODE(_MM_ROUND_TOWARD_ZERO);
// ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
// break;
// }
//
// return validate_float(ret, f[0], f[1], f[2], f[3]);
// #else
#ifdef ENABLE_TEST_ALL
const float *_a = impl.test_cases_float_pointer1;
float f[4];
__m128 ret;

__m128 a = load_m128(_a);
switch (iter & 0x7) {
case 0:
f[0] = bankers_rounding(_a[0]);
f[1] = bankers_rounding(_a[1]);
f[2] = bankers_rounding(_a[2]);
f[3] = bankers_rounding(_a[3]);

ret = _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
break;
case 1:
f[0] = floorf(_a[0]);
f[1] = floorf(_a[1]);
f[2] = floorf(_a[2]);
f[3] = floorf(_a[3]);

ret = _mm_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC);
break;
case 2:
f[0] = ceilf(_a[0]);
f[1] = ceilf(_a[1]);
f[2] = ceilf(_a[2]);
f[3] = ceilf(_a[3]);

ret = _mm_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC);
break;
case 3:
f[0] = _a[0] > 0 ? floorf(_a[0]) : ceilf(_a[0]);
f[1] = _a[1] > 0 ? floorf(_a[1]) : ceilf(_a[1]);
f[2] = _a[2] > 0 ? floorf(_a[2]) : ceilf(_a[2]);
f[3] = _a[3] > 0 ? floorf(_a[3]) : ceilf(_a[3]);

ret = _mm_round_ps(a, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
break;
case 4:
f[0] = bankers_rounding(_a[0]);
f[1] = bankers_rounding(_a[1]);
f[2] = bankers_rounding(_a[2]);
f[3] = bankers_rounding(_a[3]);

_MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST);
ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
break;
case 5:
f[0] = floorf(_a[0]);
f[1] = floorf(_a[1]);
f[2] = floorf(_a[2]);
f[3] = floorf(_a[3]);

_MM_SET_ROUNDING_MODE(_MM_ROUND_DOWN);
ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
break;
case 6:
f[0] = ceilf(_a[0]);
f[1] = ceilf(_a[1]);
f[2] = ceilf(_a[2]);
f[3] = ceilf(_a[3]);

_MM_SET_ROUNDING_MODE(_MM_ROUND_UP);
ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
break;
case 7:
f[0] = _a[0] > 0 ? floorf(_a[0]) : ceilf(_a[0]);
f[1] = _a[1] > 0 ? floorf(_a[1]) : ceilf(_a[1]);
f[2] = _a[2] > 0 ? floorf(_a[2]) : ceilf(_a[2]);
f[3] = _a[3] > 0 ? floorf(_a[3]) : ceilf(_a[3]);

_MM_SET_ROUNDING_MODE(_MM_ROUND_TOWARD_ZERO);
ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION);
break;
}

return validate_float(ret, f[0], f[1], f[2], f[3]);
#else
return TEST_UNIMPL;
// #endif // ENABLE_TEST_ALL
#endif // ENABLE_TEST_ALL
}

result_t test_mm_round_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {
Expand Down

0 comments on commit c90c592

Please sign in to comment.