Skip to content

Commit

Permalink
Merge pull request #28 from rivosinc/dev/PingTakPeterTang/cdfnorminv
Browse files Browse the repository at this point in the history
added inverse cdfnorm function
  • Loading branch information
PingTakPeterTang authored Mar 13, 2024
2 parents 02eada9 + bf735f1 commit 194eb48
Show file tree
Hide file tree
Showing 11 changed files with 480 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ set(PROJECT_SOURCES
src/rvvlm_cbrtDI.c
src/rvvlm_cdfnormD.c
src/rvvlm_cdfnormDI.c
src/rvvlm_cdfnorminvD.c
src/rvvlm_cdfnorminvDI.c
src/rvvlm_erfD.c
src/rvvlm_erfDI.c
src/rvvlm_erfcD.c
Expand Down
11 changes: 11 additions & 0 deletions include/rvvlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,13 @@ union sui64_fp64 {
#define RVVLM_CDFNORMDI_VSET_CONFIG "rvvlm_fp64m1.h"
#define RVVLM_CDFNORMDI_STD rvvlm_cdfnormI

// FP64 cdfnorminv function configuration
#define RVVLM_CDFNORMINVD_VSET_CONFIG "rvvlm_fp64m1.h"
#define RVVLM_CDFNORMINVD_STD rvvlm_cdfnorminv

#define RVVLM_CDFNORMINVDI_VSET_CONFIG "rvvlm_fp64m1.h"
#define RVVLM_CDFNORMINVDI_STD rvvlm_cdfnorminvI

// FP64 erf function configuration
#define RVVLM_ERFD_VSET_CONFIG "rvvlm_fp64m2.h"
#define RVVLM_ERFD_STD rvvlm_erf
Expand Down Expand Up @@ -550,6 +557,10 @@ void RVVLM_CDFNORMD_STD(size_t x_len, const double *x, double *y);
void RVVLM_CDFNORMDI_STD(size_t x_len, const double *x, size_t stride_x,
double *y, size_t stride_y);

void RVVLM_CDFNORMINVD_STD(size_t x_len, const double *x, double *y);
void RVVLM_CDFNORMINVDI_STD(size_t x_len, const double *x, size_t stride_x,
double *y, size_t stride_y);

void RVVLM_ERFD_STD(size_t x_len, const double *x, double *y);
void RVVLM_ERFDI_STD(size_t x_len, const double *x, size_t stride_x, double *y,
size_t stride_y);
Expand Down
234 changes: 234 additions & 0 deletions include/rvvlm_cdfnorminvD.inc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
// SPDX-FileCopyrightText: 2023 Rivos Inc.
//
// SPDX-License-Identifier: Apache-2.0

#if (STRIDE == UNIT_STRIDE)
#define F_VER1 RVVLM_CDFNORMINVD_STD
#else
#define F_VER1 RVVLM_CDFNORMINVDI_STD
#endif

// cdfnorminv is defined on (0, 1). Suffices to consider (0, 1/2]
// Two regions of approximation: left is [0, 0x1.2p-3) and right is [0x1.2p-3,
// 1/2) Both are done with rational functions. For right, t*P(t)/Q(t) t = 1/2-x;
// x in [0x1.2p-3, 1/2) For left, y*P(t)/Q(t), y = sqrt(-log(2x)); and t = 1/y

// P_coefficients in asending order, all in Q79.
// p0_delta is in floating point, scale 79
#define P_right_0 -0x6709ca23d4199a8L
#define P_right_1 -0xfd998fbae8eb3c8L
#define P_right_2 0x48ca86036ae6e955L
#define P_right_3 -0x278f4a98238f8c27L
#define P_right_4 -0x40132208941e6a5aL
#define P_right_5 0x402e2635719a3914L
#define P_right_6 0x31c67fdc7e5073fL
#define P_right_7 -0x12d1e1d375fb5d31L
#define P_right_8 0x4232daca563749dL
#define P_right_9 0xb02a8971665c0dL
#define P_right_10 -0x2a7ae4292a6a4fL
#define DELTA_P0_right 0x1.6c4b0b32778d0p-3

// Q_coefficients in asending order, all in Q79.
// q0_delta is in floating point, scale 79
#define Q_right_0 -0x52366e5b14c0970L
#define Q_right_1 -0xca57e95abcc599bL
#define Q_right_2 0x3b6c91ec67f5759cL
#define Q_right_3 -0x1c40d5daa3be22bcL
#define Q_right_4 -0x41f11eb5d837386cL
#define Q_right_5 0x3c6ce478fcd75c9aL
#define Q_right_6 0xbb1cd7270cfba1dL
#define Q_right_7 -0x1988a4116498f1afL
#define Q_right_8 0x44dc3042f103d20L
#define Q_right_9 0x2390e683d02edf3L
#define Q_right_10 -0x8ec66f2a7e410cL
#define DELTA_Q0_right -0x1.29a0161e99446p-3

// P_coefficients in asending order, all in Q67. p0_delta is in floating point
#define P_left_0 0x216a32ed581bfL
#define P_left_1 0x5ac486106d127fL
#define P_left_2 0x3a9f84d231c6131L
#define P_left_3 0xb54f6ab23cca5a3L
#define P_left_4 0xecc53db7ed5eccbL
#define P_left_5 0x194382b2de726d58L
#define P_left_6 0x166fc6bd87b1b0b6L
#define P_left_7 0xfd7bc0d477f41a9L
#define P_left_8 0x7fc186088d7ad8cL
#define P_left_9 0x18d6aeeb448b50aL
#define P_left_10 -0x8fb330020a5bL
#define DELTA_P0_left 0x1.b81f6f45914f0p-2

// Q_coefficients in asending order, all in Q67. q0_delta is in floating point
#define Q_left_0 0x17a09aabf9ceeL
#define Q_left_1 0x4030b9059ffcadL
#define Q_left_2 0x29b26b0d87f7855L
#define Q_left_3 0x87572a13d3fa2ddL
#define Q_left_4 0xd7a728b5620ac3cL
#define Q_left_5 0x1754392b473fd439L
#define Q_left_6 0x1791b9a091a816c2L
#define Q_left_7 0x167f71db9e13b075L
#define Q_left_8 0xcb9f5f3e5e618a4L
#define Q_left_9 0x68271fae767c68eL
#define Q_left_10 0x13745c4fa224b25L
#define DELTA_Q0_left 0x1.f7e7557a34ae6p-2

// cdfnorminv(x) = -sqrt(2)*erfcinv(2x)
// The approximation rational functions are based on those for erfcinv
// hence you will see a doubling of arguments here and there so that
// "2x" is created
void F_VER1(API) {
size_t vlen;
VFLOAT vx, vx_sign, vy, vy_special;
VBOOL special_args;

SET_ROUNDTONEAREST;
// stripmining over input arguments
for (; _inarg_n > 0; _inarg_n -= vlen) {
vlen = VSET(_inarg_n);
vx = VFLOAD_INARG1(vlen);

// Handle Inf and NaN
EXCEPTION_HANDLING_CDFNORMINV(vx, special_args, vy_special, vlen);

vx_sign = __riscv_vfsub(vx, 0x1.0p-1, vlen);
VFLOAT one_minus_x = __riscv_vfrsub(vx, fp_posOne, vlen);
VBOOL x_gt_half = __riscv_vmfgt(vx_sign, fp_posZero, vlen);
vx = __riscv_vmerge(vx, one_minus_x, x_gt_half, vlen);
// vx is now in (0, 1/2]
VBOOL x_in_left = __riscv_vmfle(vx, 0x1.2p-3, vlen);

VFLOAT w_hi, w_lo, w_hi_left, w_lo_left, y_hi, y_lo;
VINT T, T_left, T_tiny;
VBOOL x_is_tiny;
x_is_tiny = __riscv_vmxor(x_is_tiny, x_is_tiny, vlen);

if (__riscv_vcpop(x_in_left, vlen) > 0) {
VFLOAT x_left = VFMV_VF(0x1.0p-4, vlen);
x_left = __riscv_vmerge(x_left, vx, x_in_left, vlen);
x_is_tiny = __riscv_vmflt(x_left, 0x1.0p-53, vlen);
INT n_adjust = 59;
x_left = __riscv_vfmul(x_left, 0x1.0p60, vlen);
// adjusting only 59 instead of 60 essentially doubles x
NEG_LOGX_4_TRANSFORM(x_left, n_adjust, y_hi, y_lo, vlen);

SQRTX_4_TRANSFORM(y_hi, y_lo, w_hi_left, w_lo_left, T_left, 0x1.0p63,
0x1.0p-63, vlen);
if (__riscv_vcpop(x_is_tiny, vlen) > 0) {
VFLOAT w_hi_dummy, w_lo_dummy;
SQRTX_4_TRANSFORM(y_hi, y_lo, w_hi_dummy, w_lo_dummy, T_tiny, 0x1.0p64,
0x1.0p-64, vlen);
}
}
vx = __riscv_vfadd(vx, vx, vlen);
w_hi = VFMV_VF(fp_posOne, vlen);
w_hi = __riscv_vfsub(w_hi, vx, vlen);
w_lo = __riscv_vfrsub(w_hi, fp_posOne, vlen);
w_lo = __riscv_vfsub(w_lo, vx, vlen);
T = __riscv_vfcvt_x(__riscv_vfmul(w_hi, 0x1.0p63, vlen), vlen);
VFLOAT delta_t = __riscv_vfmul(w_lo, 0x1.0p63, vlen);
T = __riscv_vadd(T, __riscv_vfcvt_x(delta_t, vlen), vlen);
T = __riscv_vmerge(T, T_left, x_in_left, vlen);

w_hi = __riscv_vmerge(w_hi, w_hi_left, x_in_left, vlen);
w_lo = __riscv_vmerge(w_lo, w_lo_left, x_in_left, vlen);

// For transformed branch, compute (w_hi + w_lo) * P(T)/Q(T)
VINT P, Q;

P = __riscv_vmerge(VMVI_VX(P_right_10, vlen), P_left_10, x_in_left, vlen);
P = PSTEP_I_ab(x_in_left, P_left_6, P_right_6, T,
PSTEP_I_ab(x_in_left, P_left_7, P_right_7, T,
PSTEP_I_ab(x_in_left, P_left_8, P_right_8, T,
PSTEP_I_ab(x_in_left, P_left_9,
P_right_9, T, P, vlen),
vlen),
vlen),
vlen);

Q = __riscv_vmerge(VMVI_VX(Q_right_10, vlen), Q_left_10, x_in_left, vlen);
Q = PSTEP_I_ab(x_in_left, Q_left_6, Q_right_6, T,
PSTEP_I_ab(x_in_left, Q_left_7, Q_right_7, T,
PSTEP_I_ab(x_in_left, Q_left_8, Q_right_8, T,
PSTEP_I_ab(x_in_left, Q_left_9,
Q_right_9, T, Q, vlen),
vlen),
vlen),
vlen);

P = PSTEP_I_ab(
x_in_left, P_left_0, P_right_0, T,
PSTEP_I_ab(
x_in_left, P_left_1, P_right_1, T,
PSTEP_I_ab(x_in_left, P_left_2, P_right_2, T,
PSTEP_I_ab(x_in_left, P_left_3, P_right_3, T,
PSTEP_I_ab(x_in_left, P_left_4, P_right_4, T,
PSTEP_I_ab(x_in_left, P_left_5,
P_right_5, T, P, vlen),
vlen),
vlen),
vlen),
vlen),
vlen);

Q = PSTEP_I_ab(
x_in_left, Q_left_0, Q_right_0, T,
PSTEP_I_ab(
x_in_left, Q_left_1, Q_right_1, T,
PSTEP_I_ab(x_in_left, Q_left_2, Q_right_2, T,
PSTEP_I_ab(x_in_left, Q_left_3, Q_right_3, T,
PSTEP_I_ab(x_in_left, Q_left_4, Q_right_4, T,
PSTEP_I_ab(x_in_left, Q_left_5,
Q_right_5, T, Q, vlen),
vlen),
vlen),
vlen),
vlen),
vlen);

VFLOAT p_hi, p_lo;
p_hi = __riscv_vfcvt_f(P, vlen);

p_lo = __riscv_vfcvt_f(__riscv_vsub(P, __riscv_vfcvt_x(p_hi, vlen), vlen),
vlen);
VFLOAT delta_p0 = VFMV_VF(DELTA_P0_right, vlen);
delta_p0 = __riscv_vfmerge(delta_p0, DELTA_P0_left, x_in_left, vlen);
p_lo = __riscv_vfadd(p_lo, delta_p0, vlen);

VFLOAT q_hi, q_lo;
q_hi = __riscv_vfcvt_f(Q, vlen);
q_lo = __riscv_vfcvt_f(__riscv_vsub(Q, __riscv_vfcvt_x(q_hi, vlen), vlen),
vlen);
VFLOAT delta_q0 = VFMV_VF(DELTA_Q0_right, vlen);
delta_q0 = __riscv_vfmerge(delta_q0, DELTA_Q0_left, x_in_left, vlen);
q_lo = __riscv_vfadd(q_lo, delta_q0, vlen);

if (__riscv_vcpop(x_is_tiny, vlen) > 0) {
VFLOAT p_hi_tiny, p_lo_tiny, q_hi_tiny, q_lo_tiny;
ERFCINV_PQ_HILO_TINY(T_tiny, p_hi_tiny, p_lo_tiny, q_hi_tiny, q_lo_tiny,
vlen);
p_hi = __riscv_vmerge(p_hi, p_hi_tiny, x_is_tiny, vlen);
p_lo = __riscv_vmerge(p_lo, p_lo_tiny, x_is_tiny, vlen);
q_hi = __riscv_vmerge(q_hi, q_hi_tiny, x_is_tiny, vlen);
q_lo = __riscv_vmerge(q_lo, q_lo_tiny, x_is_tiny, vlen);
}

// (y_hi, y_lo) <-- (w_hi + w_lo) * (p_hi + p_lo)
y_hi = __riscv_vfmul(w_hi, p_hi, vlen);
y_lo = __riscv_vfmsub(w_hi, p_hi, y_hi, vlen);
y_lo = __riscv_vfmacc(y_lo, w_hi, p_lo, vlen);
y_lo = __riscv_vfmacc(y_lo, w_lo, p_hi, vlen);

DIV_N2D2(y_hi, y_lo, q_hi, q_lo, w_hi, vlen);

vy = w_hi;

vy = __riscv_vfsgnj(vy, vx_sign, vlen);
vy = __riscv_vmerge(vy, vy_special, special_args, vlen);

// copy vy into y and increment addr pointers
VFSTORE_OUTARG1(vy, vlen);

INCREMENT_INARG1(vlen);
INCREMENT_OUTARG1(vlen);
}
RESTORE_FRM;
}
77 changes: 75 additions & 2 deletions include/rvvlm_inverrorfuncsD.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@
#define Q_tiny_9 0x5e4a26a7c1415755 // sacle 57
#define DELTA_Q0_tiny 0x1.8a7adad44d65ap-4 // scale 66

#define Q50_84
#if defined(COMPILE_FOR_ERFCINV)
// Using [P,Q]_tiny_[HI,LO]_k, HI in Q50, LO in Q84
#if defined(Q50_84)
#define P_tiny_HI_0 -0x8593442eL
#define P_tiny_LO_0 -0x4e7245b3L
#define P_tiny_HI_1 -0x7f3dc156b1L
Expand Down Expand Up @@ -81,6 +80,51 @@
#define Q_tiny_LO_9 -0x155b50b48L
#endif

#if defined(COMPILE_FOR_CDFNORMINV)
// Using [P,Q]_tiny_[HI,LO]_k, HI in Q50, LO in Q84
#define P_tiny_HI_0 -0xbce768cfL
#define P_tiny_LO_0 -0x6824d442L
#define P_tiny_HI_1 -0xb3f23f158aL
#define P_tiny_LO_1 0x120e225b6L
#define P_tiny_HI_2 -0x2e77fdb703eaL
#define P_tiny_LO_2 -0x1e1d72461L
#define P_tiny_HI_3 -0x44fbca4f8507eL
#define P_tiny_LO_3 -0xd2fb9bf1L
#define P_tiny_HI_4 -0x25be85812224dcL
#define P_tiny_LO_4 -0x14663c6d2L
#define P_tiny_HI_5 -0x56d9a544fd76f0L
#define P_tiny_LO_5 -0x1e3fd12d9L
#define P_tiny_HI_6 0xb44c46b00008ccL
#define P_tiny_LO_6 0x123f14b79L
#define P_tiny_HI_7 0x22eb3f29425cc2dL
#define P_tiny_LO_7 -0x1f47840b1L
#define P_tiny_HI_8 0x6b5068e2aa0bc1L
#define P_tiny_LO_8 -0xd830044aL
#define P_tiny_HI_9 0x1e496a7253435eL
#define P_tiny_LO_9 -0xf06a1c9L

#define Q_tiny_HI_0 -0x85933cdaL
#define Q_tiny_LO_0 -0xb5b39d61L
#define Q_tiny_HI_1 -0x7f3de4b69fL
#define Q_tiny_LO_1 -0x151d1cd35L
#define Q_tiny_HI_2 -0x20dd8dc1da27L
#define Q_tiny_LO_2 -0x1706945d7L
#define Q_tiny_HI_3 -0x30dc92d1cd231L
#define Q_tiny_LO_3 0xabde03f9L
#define Q_tiny_HI_4 -0x1af5fcee397d58L
#define Q_tiny_LO_4 -0xc3530d28L
#define Q_tiny_HI_5 -0x42639eeec1d051L
#define Q_tiny_LO_5 0x662b41ecL
#define Q_tiny_HI_6 0x6182b99f6ca998L
#define Q_tiny_LO_6 0x938a5e35L
#define Q_tiny_HI_7 0x17a6848dc07624aL
#define Q_tiny_LO_7 0x8a0484b7L
#define Q_tiny_HI_8 0x105ecd6aac52b12L
#define Q_tiny_LO_8 0x1d1e38258L
#define Q_tiny_HI_9 0xbc944d4f8282afL
#define Q_tiny_LO_9 -0x155b50b48L
#endif

// erfinv(+-1) = +-Inf with divide by zero
// erfinv(x) |x| > 1, real is NaN with invalid
// erfinv(NaN) is NaN, invalid if input is signalling NaN
Expand Down Expand Up @@ -140,6 +184,35 @@
} \
} while (0)

// cdfnorminv(0) = -Inf, erfcinv(1) = Inf with divide by zero
// cdfnorminv(x) x outside [0, 1], real is NaN with invalid
// cdfnorminv(NaN) is NaN, invalid if input is signalling NaN
#define EXCEPTION_HANDLING_CDFNORMINV(vx, special_args, vy_special, vlen) \
do { \
VUINT vclass = __riscv_vfclass((vx), (vlen)); \
IDENTIFY(vclass, 0x39F, (special_args), (vlen)); \
VBOOL x_ge_1 = __riscv_vmfge((vx), fp_posOne, (vlen)); \
(special_args) = __riscv_vmor((special_args), x_ge_1, (vlen)); \
if (__riscv_vcpop((special_args), (vlen)) > 0) { \
VBOOL x_gt_1 = __riscv_vmfgt((vx), fp_posOne, (vlen)); \
VBOOL x_lt_0 = __riscv_vmflt((vx), fp_posZero, (vlen)); \
/* substitute x > 1 or x < 0 with sNaN */ \
(vx) = __riscv_vfmerge((vx), fp_sNaN, x_gt_1, (vlen)); \
(vx) = __riscv_vfmerge((vx), fp_sNaN, x_lt_0, (vlen)); \
/* substitute x = 0 or 1 with +/-Inf and generate div-by-zero signal */ \
VFLOAT tmp = VFMV_VF(fp_posZero, (vlen)); \
VFLOAT x_tmp = __riscv_vfsub((vx), 0x1.0p-1, (vlen)); \
tmp = __riscv_vfsgnj(tmp, x_tmp, (vlen)); \
VBOOL x_eq_1 = __riscv_vmfeq((vx), fp_posOne, (vlen)); \
VBOOL x_eq_0 = __riscv_vmfeq((vx), fp_posZero, (vlen)); \
VBOOL pm_Inf = __riscv_vmor(x_eq_1, x_eq_0, (vlen)); \
tmp = __riscv_vfrec7(pm_Inf, tmp, (vlen)); \
(vy_special) = __riscv_vfsub((special_args), (vx), (vx), (vlen)); \
(vy_special) = __riscv_vmerge((vy_special), tmp, pm_Inf, (vlen)); \
(vx) = __riscv_vfmerge((vx), 0x1.0p-1, (special_args), (vlen)); \
} \
} while (0)

// Compute -log(2^(-n_adjust) * x), where x < 1
#define NEG_LOGX_4_TRANSFORM(vx, n_adjust, y_hi, y_lo, vlen) \
do { \
Expand Down
17 changes: 17 additions & 0 deletions src/rvvlm_cdfnorminvD.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-FileCopyrightText: 2023 Rivos Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <riscv_vector.h>
#include <stdio.h>

#include "rvvlm.h"
#define API_SIGNATURE API_SIGNATURE_11
#define STRIDE UNIT_STRIDE

#include RVVLM_CDFNORMINVD_VSET_CONFIG

#define COMPILE_FOR_CDFNORMINV
#include "rvvlm_inverrorfuncsD.h"

#include "rvvlm_cdfnorminvD.inc.h"
Loading

0 comments on commit 194eb48

Please sign in to comment.