diff --git a/sse2rvv.h b/sse2rvv.h index 4ab5f17..1f298fa 100644 --- a/sse2rvv.h +++ b/sse2rvv.h @@ -12,7 +12,7 @@ * Contributors to this work are: * Yang Hau * Cheng-Hao - * + * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights @@ -248,6 +248,49 @@ typedef union ALIGN_STRUCT(16) SIMDVec { #endif #endif +// XRM +// #define __RISCV_VXRM_RNU 0 // round-to-nearest-up (add +0.5 LSB) +// #define __RISCV_VXRM_RNE 1 // round-to-nearest-even +// #define __RISCV_VXRM_RDN 2 // round-down (truncate) +// #define __RISCV_VXRM_ROD 3 // round-to-odd (OR bits into LSB, aka "jam") +// FRM +// #define __RISCV_FRM_RNE 0 // round to nearest, ties to even +// #define __RISCV_FRM_RTZ 1 // round towards zero +// #define __RISCV_FRM_RDN 2 // round down (towards -infinity) +// #define __RISCV_FRM_RUP 3 // round up (towards +infinity) +// #define __RISCV_FRM_RMM 4 // round to nearest, ties to max magnitude + +// The bit field mapping to the FCSR (floating-point control and status +// register) +typedef struct { + uint8_t nx : 1; + uint8_t uf : 1; + uint8_t of : 1; + uint8_t dz : 1; + uint8_t nv : 1; + uint8_t frm : 3; + uint32_t reserved : 24; +} fcsr_bitfield; + +/* Rounding mode macros. */ +#define _MM_FROUND_TO_NEAREST_INT 0x00 +#define _MM_FROUND_TO_NEG_INF 0x01 +#define _MM_FROUND_TO_POS_INF 0x02 +#define _MM_FROUND_TO_ZERO 0x03 +#define _MM_FROUND_CUR_DIRECTION 0x04 +#define _MM_FROUND_NO_EXC 0x08 +#define _MM_FROUND_RAISE_EXC 0x00 +#define _MM_FROUND_NINT (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_RAISE_EXC) +#define _MM_FROUND_FLOOR (_MM_FROUND_TO_NEG_INF | _MM_FROUND_RAISE_EXC) +#define _MM_FROUND_CEIL (_MM_FROUND_TO_POS_INF | _MM_FROUND_RAISE_EXC) +#define _MM_FROUND_TRUNC (_MM_FROUND_TO_ZERO | _MM_FROUND_RAISE_EXC) +#define _MM_FROUND_RINT (_MM_FROUND_CUR_DIRECTION | _MM_FROUND_RAISE_EXC) +#define _MM_FROUND_NEARBYINT (_MM_FROUND_CUR_DIRECTION | _MM_FROUND_NO_EXC) +#define _MM_ROUND_NEAREST 0x0000 +#define _MM_ROUND_DOWN 0x2000 +#define _MM_ROUND_UP 0x4000 +#define _MM_ROUND_TOWARD_ZERO 0x6000 + // forward declaration FORCE_INLINE int _mm_extract_pi16(__m64 a, int imm8); FORCE_INLINE __m64 _mm_sad_pu8(__m64 a, __m64 b); @@ -2537,7 +2580,26 @@ FORCE_INLINE __m128 _mm_rcp_ss(__m128 a) { // FORCE_INLINE __m128d _mm_round_pd (__m128d a, int rounding) {} -// FORCE_INLINE __m128 _mm_round_ps (__m128 a, int rounding) {} +FORCE_INLINE __m128 _mm_round_ps(__m128 a, int rounding) { + vfloat32m1_t _a = vreinterpretq_m128_f32(a); + switch (rounding) { + case (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC): + return vreinterpretq_f32_m128(__riscv_vfcvt_f_x_v_f32m1( + __riscv_vfcvt_x_f_v_i32m1_rm(_a, __RISCV_FRM_RNE, 4), 4)); + case (_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC): + return vreinterpretq_f32_m128(__riscv_vfcvt_f_x_v_f32m1( + __riscv_vfcvt_x_f_v_i32m1_rm(_a, __RISCV_FRM_RDN, 4), 4)); + case (_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC): + return vreinterpretq_f32_m128(__riscv_vfcvt_f_x_v_f32m1( + __riscv_vfcvt_x_f_v_i32m1_rm(_a, __RISCV_FRM_RUP, 4), 4)); + case (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC): + return vreinterpretq_f32_m128(__riscv_vfcvt_f_x_v_f32m1( + __riscv_vfcvt_x_f_v_i32m1_rm(_a, __RISCV_FRM_RTZ, 4), 4)); + default: //_MM_FROUND_CUR_DIRECTION + return vreinterpretq_f32_m128( + __riscv_vfcvt_f_x_v_f32m1(__riscv_vfcvt_x_f_v_i32m1(_a, 4), 4)); + } +} // FORCE_INLINE __m128d _mm_round_sd (__m128d a, __m128d b, int rounding) {} @@ -2637,7 +2699,30 @@ FORCE_INLINE __m128 _mm_set_ps1(float a) { return vreinterpretq_f32_m128(__riscv_vfmv_v_f_f32m1(a, 4)); } -// FORCE_INLINE void _MM_SET_ROUNDING_MODE (unsigned int a) {} +FORCE_INLINE void _MM_SET_ROUNDING_MODE(unsigned int a) { + union { + fcsr_bitfield field; + uint32_t value; + } r; + + __asm__ volatile("csrr %0, fcsr" : "=r"(r)); + + switch (a) { + case _MM_ROUND_TOWARD_ZERO: + r.field.frm = __RISCV_FRM_RTZ; + break; + case _MM_ROUND_DOWN: + r.field.frm = __RISCV_FRM_RDN; + break; + case _MM_ROUND_UP: + r.field.frm = __RISCV_FRM_RUP; + break; + default: //_MM_ROUND_NEAREST + r.field.frm = __RISCV_FRM_RNE; + } + + __asm__ volatile("csrw fcsr, %0" : : "r"(r)); +} FORCE_INLINE __m128d _mm_set_sd(double a) { double arr[2] = {a, 0}; diff --git a/tests/impl.cpp b/tests/impl.cpp index 38e9536..2fa3a7a 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -3098,42 +3098,42 @@ result_t test_mm_set_ps1(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { result_t test_mm_set_rounding_mode(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const float *_a = impl.test_cases_float_pointer1; - // result_t res_toward_zero, res_to_neg_inf, res_to_pos_inf, res_nearest; - // - // __m128 a = load_m128(_a); - // __m128 b, c; - // - // _MM_SET_ROUNDING_MODE(_MM_ROUND_TOWARD_ZERO); - // b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); - // c = _mm_round_ps(a, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); - // res_toward_zero = validate_128bits(c, b); - // - // _MM_SET_ROUNDING_MODE(_MM_ROUND_DOWN); - // b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); - // c = _mm_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); - // res_to_neg_inf = validate_128bits(c, b); - // - // _MM_SET_ROUNDING_MODE(_MM_ROUND_UP); - // b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); - // c = _mm_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); - // res_to_pos_inf = validate_128bits(c, b); - // - // _MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST); - // b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); - // c = _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - // res_nearest = validate_128bits(c, b); - // - // if (res_toward_zero == TEST_SUCCESS && res_to_neg_inf == TEST_SUCCESS && - // res_to_pos_inf == TEST_SUCCESS && res_nearest == TEST_SUCCESS) { - // return TEST_SUCCESS; - // } else { - // return TEST_FAIL; - // } - // #else +#ifdef ENABLE_TEST_ALL + const float *_a = impl.test_cases_float_pointer1; + result_t res_toward_zero, res_to_neg_inf, res_to_pos_inf, res_nearest; + + __m128 a = load_m128(_a); + __m128 b, c; + + _MM_SET_ROUNDING_MODE(_MM_ROUND_TOWARD_ZERO); + b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); + c = _mm_round_ps(a, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + res_toward_zero = validate_128bits(c, b); + + _MM_SET_ROUNDING_MODE(_MM_ROUND_DOWN); + b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); + c = _mm_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); + res_to_neg_inf = validate_128bits(c, b); + + _MM_SET_ROUNDING_MODE(_MM_ROUND_UP); + b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); + c = _mm_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); + res_to_pos_inf = validate_128bits(c, b); + + _MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST); + b = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); + c = _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + res_nearest = validate_128bits(c, b); + + if (res_toward_zero == TEST_SUCCESS && res_to_neg_inf == TEST_SUCCESS && + res_to_pos_inf == TEST_SUCCESS && res_nearest == TEST_SUCCESS) { + return TEST_SUCCESS; + } else { + return TEST_FAIL; + } +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_set_ss(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { @@ -10074,87 +10074,87 @@ result_t test_mm_round_pd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { } result_t test_mm_round_ps(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) { - // #ifdef ENABLE_TEST_ALL - // const float *_a = impl.test_cases_float_pointer1; - // float f[4]; - // __m128 ret; - // - // __m128 a = load_m128(_a); - // switch (iter & 0x7) { - // case 0: - // f[0] = bankers_rounding(_a[0]); - // f[1] = bankers_rounding(_a[1]); - // f[2] = bankers_rounding(_a[2]); - // f[3] = bankers_rounding(_a[3]); - // - // ret = _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - // break; - // case 1: - // f[0] = floorf(_a[0]); - // f[1] = floorf(_a[1]); - // f[2] = floorf(_a[2]); - // f[3] = floorf(_a[3]); - // - // ret = _mm_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); - // break; - // case 2: - // f[0] = ceilf(_a[0]); - // f[1] = ceilf(_a[1]); - // f[2] = ceilf(_a[2]); - // f[3] = ceilf(_a[3]); - // - // ret = _mm_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); - // break; - // case 3: - // f[0] = _a[0] > 0 ? floorf(_a[0]) : ceilf(_a[0]); - // f[1] = _a[1] > 0 ? floorf(_a[1]) : ceilf(_a[1]); - // f[2] = _a[2] > 0 ? floorf(_a[2]) : ceilf(_a[2]); - // f[3] = _a[3] > 0 ? floorf(_a[3]) : ceilf(_a[3]); - // - // ret = _mm_round_ps(a, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); - // break; - // case 4: - // f[0] = bankers_rounding(_a[0]); - // f[1] = bankers_rounding(_a[1]); - // f[2] = bankers_rounding(_a[2]); - // f[3] = bankers_rounding(_a[3]); - // - // _MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST); - // ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); - // break; - // case 5: - // f[0] = floorf(_a[0]); - // f[1] = floorf(_a[1]); - // f[2] = floorf(_a[2]); - // f[3] = floorf(_a[3]); - // - // _MM_SET_ROUNDING_MODE(_MM_ROUND_DOWN); - // ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); - // break; - // case 6: - // f[0] = ceilf(_a[0]); - // f[1] = ceilf(_a[1]); - // f[2] = ceilf(_a[2]); - // f[3] = ceilf(_a[3]); - // - // _MM_SET_ROUNDING_MODE(_MM_ROUND_UP); - // ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); - // break; - // case 7: - // f[0] = _a[0] > 0 ? floorf(_a[0]) : ceilf(_a[0]); - // f[1] = _a[1] > 0 ? floorf(_a[1]) : ceilf(_a[1]); - // f[2] = _a[2] > 0 ? floorf(_a[2]) : ceilf(_a[2]); - // f[3] = _a[3] > 0 ? floorf(_a[3]) : ceilf(_a[3]); - // - // _MM_SET_ROUNDING_MODE(_MM_ROUND_TOWARD_ZERO); - // ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); - // break; - // } - // - // return validate_float(ret, f[0], f[1], f[2], f[3]); - // #else +#ifdef ENABLE_TEST_ALL + const float *_a = impl.test_cases_float_pointer1; + float f[4]; + __m128 ret; + + __m128 a = load_m128(_a); + switch (iter & 0x7) { + case 0: + f[0] = bankers_rounding(_a[0]); + f[1] = bankers_rounding(_a[1]); + f[2] = bankers_rounding(_a[2]); + f[3] = bankers_rounding(_a[3]); + + ret = _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + break; + case 1: + f[0] = floorf(_a[0]); + f[1] = floorf(_a[1]); + f[2] = floorf(_a[2]); + f[3] = floorf(_a[3]); + + ret = _mm_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); + break; + case 2: + f[0] = ceilf(_a[0]); + f[1] = ceilf(_a[1]); + f[2] = ceilf(_a[2]); + f[3] = ceilf(_a[3]); + + ret = _mm_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); + break; + case 3: + f[0] = _a[0] > 0 ? floorf(_a[0]) : ceilf(_a[0]); + f[1] = _a[1] > 0 ? floorf(_a[1]) : ceilf(_a[1]); + f[2] = _a[2] > 0 ? floorf(_a[2]) : ceilf(_a[2]); + f[3] = _a[3] > 0 ? floorf(_a[3]) : ceilf(_a[3]); + + ret = _mm_round_ps(a, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC); + break; + case 4: + f[0] = bankers_rounding(_a[0]); + f[1] = bankers_rounding(_a[1]); + f[2] = bankers_rounding(_a[2]); + f[3] = bankers_rounding(_a[3]); + + _MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST); + ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); + break; + case 5: + f[0] = floorf(_a[0]); + f[1] = floorf(_a[1]); + f[2] = floorf(_a[2]); + f[3] = floorf(_a[3]); + + _MM_SET_ROUNDING_MODE(_MM_ROUND_DOWN); + ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); + break; + case 6: + f[0] = ceilf(_a[0]); + f[1] = ceilf(_a[1]); + f[2] = ceilf(_a[2]); + f[3] = ceilf(_a[3]); + + _MM_SET_ROUNDING_MODE(_MM_ROUND_UP); + ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); + break; + case 7: + f[0] = _a[0] > 0 ? floorf(_a[0]) : ceilf(_a[0]); + f[1] = _a[1] > 0 ? floorf(_a[1]) : ceilf(_a[1]); + f[2] = _a[2] > 0 ? floorf(_a[2]) : ceilf(_a[2]); + f[3] = _a[3] > 0 ? floorf(_a[3]) : ceilf(_a[3]); + + _MM_SET_ROUNDING_MODE(_MM_ROUND_TOWARD_ZERO); + ret = _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); + break; + } + + return validate_float(ret, f[0], f[1], f[2], f[3]); +#else return TEST_UNIMPL; - // #endif // ENABLE_TEST_ALL +#endif // ENABLE_TEST_ALL } result_t test_mm_round_sd(const SSE2RVV_TEST_IMPL &impl, uint32_t iter) {