Skip to content

Commit

Permalink
Merge pull request #11 from rivosinc/dev/PingTakPeterTang/invhyper
Browse files Browse the repository at this point in the history
added inverse hyperbolic functions acosh/asinh/atanh
  • Loading branch information
PingTakPeterTang authored Jan 4, 2024
2 parents 239cfd7 + 4792241 commit 1bf7bb3
Show file tree
Hide file tree
Showing 19 changed files with 988 additions and 9 deletions.
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ set(PROJECT_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_atan2DI.c
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_atan2piD.c
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_atan2piDI.c
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_acoshD.c
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_acoshDI.c
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_asinhD.c
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_asinhDI.c
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_atanhD.c
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_atanhDI.c
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_expD_tbl.c
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_expD.c
${CMAKE_CURRENT_SOURCE_DIR}/src/rvvlm_expDI.c
Expand Down
69 changes: 64 additions & 5 deletions include/rvvlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ union sui64_fp64 {

#define FAST2SUM(X, Y, S, s, vlen) \
do { \
S = __riscv_vfadd((X), (Y), (vlen)); \
s = __riscv_vfsub((X), (S), (vlen)); \
s = __riscv_vfadd((s), (Y), (vlen)); \
(S) = __riscv_vfadd((X), (Y), (vlen)); \
(s) = __riscv_vfsub((X), (S), (vlen)); \
(s) = __riscv_vfadd((s), (Y), (vlen)); \
} while (0)

#define POS2SUM(X, Y, S, s, vlen) \
Expand All @@ -92,8 +92,8 @@ union sui64_fp64 {

#define PROD_X1Y1(x, y, prod_hi, prod_lo, vlen) \
do { \
prod_hi = __riscv_vfmul((x), (y), (vlen)); \
prod_lo = __riscv_vfmsub((x), (y), (prod_hi), (vlen)); \
(prod_hi) = __riscv_vfmul((x), (y), (vlen)); \
(prod_lo) = __riscv_vfmsub((x), (y), (prod_hi), (vlen)); \
} while (0)

#define DIV_N1D2(numer, denom, delta_d, Q, q, vlen) \
Expand All @@ -115,6 +115,32 @@ union sui64_fp64 {
(Q) = __riscv_vfadd((Q), _q, (vlen)); \
} while (0)

#define DIV2_N2D2(numer, delta_n, denom, delta_d, Q, delta_Q, vlen) \
do { \
VFLOAT _q; \
(Q) = __riscv_vfdiv((numer), (denom), (vlen)); \
_q = __riscv_vfnmsub((Q), (denom), (numer), (vlen)); \
_q = __riscv_vfnmsac(_q, (Q), (delta_d), (vlen)); \
_q = __riscv_vfadd(_q, (delta_n), (vlen)); \
(delta_Q) = __riscv_vfmul(_q, __riscv_vfrec7((denom), (vlen)), (vlen)); \
} while (0)

#define SQRT2_X2(x, delta_x, r, delta_r, vlen) \
do { \
VFLOAT xx = __riscv_vfadd((x), (delta_x), (vlen)); \
VBOOL x_eq_0 = __riscv_vmfeq(xx, fp_posZero, (vlen)); \
xx = __riscv_vfmerge(xx, fp_posOne, x_eq_0, (vlen)); \
(r) = __riscv_vfsqrt(xx, (vlen)); \
(delta_r) = __riscv_vfnmsub((r), (r), (x), (vlen)); \
(delta_r) = __riscv_vfadd((delta_r), (delta_x), (vlen)); \
(delta_r) = __riscv_vfmul((delta_r), __riscv_vfrec7(xx, (vlen)), (vlen)); \
/* (delta_r) = __riscv_vfdiv((delta_r), xx, (vlen)); */ \
(delta_r) = __riscv_vfmul((delta_r), 0x1.0p-1, (vlen)); \
(delta_r) = __riscv_vfmul((delta_r), (r), (vlen)); \
(r) = __riscv_vfmerge((r), fp_posZero, x_eq_0, (vlen)); \
(delta_r) = __riscv_vfmerge((delta_r), fp_posZero, x_eq_0, (vlen)); \
} while (0)

#define IDENTIFY(vclass, stencil, identity_mask, vlen) \
identity_mask = \
__riscv_vmsgtu(__riscv_vand((vclass), (stencil), (vlen)), 0, (vlen));
Expand Down Expand Up @@ -189,6 +215,27 @@ union sui64_fp64 {
#define RVVLM_ATAN2PIDI_VSET_CONFIG "rvvlm_fp64m2.h"
#define RVVLM_ATAN2PIDI_FIXEDPT rvvlm_atan2piI

// FP64 acosh function configuration
#define RVVLM_ACOSHD_VSET_CONFIG "rvvlm_fp64m2.h"
#define RVVLM_ACOSHD_STD rvvlm_acosh

#define RVVLM_ACOSHDI_VSET_CONFIG "rvvlm_fp64m2.h"
#define RVVLM_ACOSHDI_STD rvvlm_acoshI

// FP64 asinh function configuration
#define RVVLM_ASINHD_VSET_CONFIG "rvvlm_fp64m2.h"
#define RVVLM_ASINHD_STD rvvlm_asinh

#define RVVLM_ASINHDI_VSET_CONFIG "rvvlm_fp64m2.h"
#define RVVLM_ASINHDI_STD rvvlm_asinhI

// FP64 atanh function configuration
#define RVVLM_ATANHD_VSET_CONFIG "rvvlm_fp64m2.h"
#define RVVLM_ATANHD_MIXED rvvlm_atanh

#define RVVLM_ATANHDI_VSET_CONFIG "rvvlm_fp64m2.h"
#define RVVLM_ATANHDI_MIXED rvvlm_atanhI

// FP64 exp function configuration
#define RVVLM_EXPD_VSET_CONFIG "rvvlm_fp64m4.h"
#define RVVLM_EXPD_STD rvvlm_expD_std
Expand Down Expand Up @@ -381,6 +428,18 @@ void RVVLM_ATAN2PIDI_FIXEDPT(size_t xy_len, const double *y, size_t stride_y,
const double *x, size_t stride_x, double *z,
size_t stride_z);

void RVVLM_ACOSHD_STD(size_t x_len, const double *x, double *y);
void RVVLM_ACOSHDI_STD(size_t x_len, const double *x, size_t stride_x,
double *y, size_t stride_y);

void RVVLM_ASINHD_STD(size_t x_len, const double *x, double *y);
void RVVLM_ASINHDI_STD(size_t x_len, const double *x, size_t stride_x,
double *y, size_t stride_y);

void RVVLM_ATANHD_MIXED(size_t x_len, const double *x, double *y);
void RVVLM_ATANHDI_MIXED(size_t x_len, const double *x, size_t stride_x,
double *y, size_t stride_y);

void RVVLM_EXPD_STD(size_t x_len, const double *x, double *y);
void RVVLM_EXPDI_STD(size_t x_len, const double *x, size_t stride_x, double *y,
size_t stride_y);
Expand Down
120 changes: 120 additions & 0 deletions include/rvvlm_acoshD.inc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// SPDX-FileCopyrightText: 2023 Rivos Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "rvvlm_invhyperD.h"

#if (STRIDE == UNIT_STRIDE)
#define F_VER1 RVVLM_ACOSHD_STD
#else
#define F_VER1 RVVLM_ACOSHDI_STD
#endif

#include <fenv.h>

// Acosh(x) is defined for x >= 1 by the formula log(x + sqrt(x*x - 1))
// and for the log function log(2^n z), we uses the expansion in terms of atanh
// n log(2) + 2 atanh((z-1)/(z+1))
// This algorithm obtains this scale factor 2^n from the input x, and computes
// the expression x' + sqrt(x'*x' - 2^(-2n)) thus avoiding possible overflow or
// excess computation such as computing sqrt(x*x - 1) by x * sqrt(1 - 1/(x*x))
// which needs a division on top of a sqrt.
void F_VER1(API) {
size_t vlen;
VFLOAT vx, vx_orig, vy, vy_special;
VBOOL special_args;

SET_ROUNDTONEAREST;
// stripmining over input arguments
for (; _inarg_n > 0; _inarg_n -= vlen) {
vlen = VSET(_inarg_n);
vx = VFLOAD_INARG1(vlen);
#if defined(COMPILE_FOR_ASINH)
vx_orig = vx;
#endif

#if defined(COMPILE_FOR_ACOSH)
// Handle Inf and NaN and input <= 1.0
EXCEPTION_HANDLING_ACOSH(vx, special_args, vy_special, vlen);
#else
// Handle Inf and NaN and |input} < 2^(-30)
EXCEPTION_HANDLING_ASINH(vx, special_args, vy_special, vlen);
vx = __riscv_vfsgnj(vx, fp_posOne, vlen);
#endif

// Need to scale x so that x + sqrt(x*x +/- 1) doesn't overflow
// Since x >= 1, we scale x down by 2^(-550) if x >= 2^500 and set 1 to 0
VINT n;
VFLOAT u;
SCALE_X(vx, n, u, vlen);
// n is 0 or 500; and u is +/-1.0 or 0.0

// sqrt(x*x + u) extra precisely
VFLOAT A, a;
#if defined(COMPILE_FOR_ACOSH)
XSQ_PLUS_U_ACOSH(vx, u, A, a, vlen);
#else
XSQ_PLUS_U_ASINH(vx, u, A, a, vlen);
#endif
// A + a is x*x + u

VFLOAT B, b;
SQRT2_X2(A, a, B, b, vlen);
// B + b is sqrt(x*x + u) to about 7 extra bits

// x dominants B for acosh
VFLOAT S, s;
#if defined(COMPILE_FOR_ACOSH)
FAST2SUM(vx, B, S, s, vlen);
s = __riscv_vfadd(s, b, vlen);
#else
FAST2SUM(B, vx, S, s, vlen);
s = __riscv_vfadd(s, b, vlen);
#endif

// x + sqrt(x*x + u) is accurately represented as S + s
// We first scale S, s so that it falls roughly in [1/rt2, rt2]
SCALE_4_LOG(S, s, n, vlen);

// log(x + sqrt(x*x + u)) = n * log(2) + log(y); y = S + s
// since log(y) = 2 atanh( (y-1)/(y+1) ) to be approximated
// by t + t^3 * poly(t^2), t = 2 (y-1)/(y+1)
// We now compute the numerator and denominator and its quotient
// to extra precision
VFLOAT numer, delta_numer, denom, delta_denom;
TRANSFORM_2_ATANH(S, s, numer, delta_numer, denom, delta_denom, vlen);

VFLOAT r_hi, r_lo, r;
DIV2_N2D2(numer, delta_numer, denom, delta_denom, r_hi, r_lo, vlen);
r = __riscv_vfadd(r_hi, r_lo, vlen);

VFLOAT n_flt = __riscv_vfcvt_f(n, vlen);

VFLOAT poly;
LOG_POLY(r, r_lo, poly, vlen);
// At this point r_hi + poly approximates log(X)

// Reconstruction: logB(in_arg) = n logB(2) + log(X) * logB(e), computed as
// n*(logB_2_hi + logB_2_lo) + r * (logB_e_hi + logB_e_lo) + poly *
// logB_e_hi It is best to compute n * logB_2_hi + r * logB_e_hi in extra
// precision

A = __riscv_vfmul(n_flt, LOG2_HI, vlen);
FAST2SUM(A, r_hi, S, s, vlen);
s = __riscv_vfmacc(s, LOG2_LO, n_flt, vlen);
s = __riscv_vfadd(s, poly, vlen);

vy = __riscv_vfadd(S, s, vlen);
#if defined(COMPILE_FOR_ASINH)
vy = __riscv_vfsgnj(vy, vx_orig, vlen);
#endif
vy = __riscv_vmerge(vy, vy_special, special_args, vlen);

// copy vy into y and increment addr pointers
VFSTORE_OUTARG1(vy, vlen);

INCREMENT_INARG1(vlen);
INCREMENT_OUTARG1(vlen);
}
RESTORE_FRM;
}
147 changes: 147 additions & 0 deletions include/rvvlm_asinhcoshD.inc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// SPDX-FileCopyrightText: 2023 Rivos Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "rvvlm_invhyperD.h"

#if defined(COMPILE_FOR_ACOSH)
#if (STRIDE == UNIT_STRIDE)
#define F_VER1 RVVLM_ACOSHD_STD
#else
#define F_VER1 RVVLM_ACOSHDI_STD
#endif
#else
#if (STRIDE == UNIT_STRIDE)
#define F_VER1 RVVLM_ASINHD_STD
#else
#define F_VER1 RVVLM_ASINHDI_STD
#endif
#endif

#include <fenv.h>

// Acosh(x) is defined for x >= 1 by the formula log(x + sqrt(x*x - 1))
// Asinh(x) is defined for all finite x by the formula log(x + sqrt(x*x + 1))
// Acosh is always positive, and Asinh(-x) = -Asinh(x). Thus we in general
// work with |x| and restore the sign (if necessary) in the end.
// For the log function log(2^n z), we uses the expansion in terms of atanh:
// n log(2) + 2 atanh((z-1)/(z+1))
// The algorithm here first scales down x by 2^(-550) when |x| >= 2^500.
// And for such large x, both acosh and asinh equals log(2x) to very high
// precision. We safely ignore the +/- 1 when this is the case.
//
// A power 2^n is determined by the value of x + sqrt(x*x +/- 1) so that
// scaling the expression by 2^(-n) transforms it to the range [0.71, 1.42].
// Log(t) for t in this region is computed by 2 atanh((t-1)/(t+1))
// More precisely, we use s = 2(t-1)/(t+1) and approximate the function
// 2 atanh(s/2) by s + s^3 * polynomial(s^2).
// The final result is n * log(2) + s + s^3 * polynomial(s^2)
// which is computed with care.
void F_VER1(API) {
size_t vlen;
VFLOAT vx, vx_orig, vy, vy_special;
VBOOL special_args;

SET_ROUNDTONEAREST;
// stripmining over input arguments
for (; _inarg_n > 0; _inarg_n -= vlen) {
vlen = VSET(_inarg_n);
vx = VFLOAD_INARG1(vlen);
#if defined(COMPILE_FOR_ASINH)
vx_orig = vx;
#endif

#if defined(COMPILE_FOR_ACOSH)
// Handle Inf and NaN and input <= 1.0
EXCEPTION_HANDLING_ACOSH(vx, special_args, vy_special, vlen);
#else
// Handle Inf and NaN and |input| < 2^(-30)
EXCEPTION_HANDLING_ASINH(vx, special_args, vy_special, vlen);
vx = __riscv_vfsgnj(vx, fp_posOne, vlen);
#endif

// Need to scale x so that x + sqrt(x*x +/- 1) doesn't overflow
// We scale x down by 2^(-550) if x >= 2^500 and set the "+/- 1" to 0
VINT n;
VFLOAT u;
SCALE_X(vx, n, u, vlen);
// n is 0 or 500; and u is +/-1.0 or 0.0

// sqrt(x*x + u) extra precisely
VFLOAT A, a;
#if defined(COMPILE_FOR_ACOSH)
XSQ_PLUS_U_ACOSH(vx, u, A, a, vlen);
#else
XSQ_PLUS_U_ASINH(vx, u, A, a, vlen);
#endif
// A + a is x*x + u

VFLOAT B, b;
#if defined(COMPILE_FOR_ACOSH)
SQRT2_X2(A, a, B, b, vlen);
// B + b is sqrt(x*x + u) to about 7 extra bits
#else
// For asinh, we need the sqrt to double-double precision
VFLOAT recip = __riscv_vfrdiv(A, fp_posOne, vlen);
B = __riscv_vfsqrt(A, vlen);
b = __riscv_vfnmsub(B, B, A, vlen);
b = __riscv_vfadd(b, a, vlen);
VFLOAT B_recip = __riscv_vfmul(B, recip, vlen);
b = __riscv_vfmul(b, 0x1.0p-1, vlen);
b = __riscv_vfmul(b, B_recip, vlen);
#endif

VFLOAT S, s;
#if defined(COMPILE_FOR_ACOSH)
// x dominantes B for acosh
FAST2SUM(vx, B, S, s, vlen);
s = __riscv_vfadd(s, b, vlen);
#else
// B dominates x for asinh
FAST2SUM(B, vx, S, s, vlen);
s = __riscv_vfadd(s, b, vlen);
#endif

// x + sqrt(x*x + u) is accurately represented as S + s
// We first scale S, s by 2^(-n) so that the scaled value
// falls roughly in [1/rt2, rt2]
SCALE_4_LOG(S, s, n, vlen);

// log(x + sqrt(x*x + u)) = n * log(2) + log(y); y = S + s
// We use log(y) = 2 atanh( (y-1)/(y+1) ) and approximate the latter
// by t + t^3 * poly(t^2), t = 2 (y-1)/(y+1)

// We now compute the numerator 2(y-1) and denominator y+1 and their
// quotient to extra precision
VFLOAT numer, delta_numer, denom, delta_denom;
TRANSFORM_2_ATANH(S, s, numer, delta_numer, denom, delta_denom, vlen);

VFLOAT r_hi, r_lo, r;
DIV2_N2D2(numer, delta_numer, denom, delta_denom, r_hi, r_lo, vlen);
r = __riscv_vfadd(r_hi, r_lo, vlen);

VFLOAT poly;
LOG_POLY(r, r_lo, poly, vlen);
// At this point r_hi + poly approximates log(X)

// Compose the final result: n * log(2) + r_hi + poly
VFLOAT n_flt = __riscv_vfcvt_f(n, vlen);
A = __riscv_vfmul(n_flt, LOG2_HI, vlen);
FAST2SUM(A, r_hi, S, s, vlen);
s = __riscv_vfmacc(s, LOG2_LO, n_flt, vlen);
s = __riscv_vfadd(s, poly, vlen);

vy = __riscv_vfadd(S, s, vlen);
#if defined(COMPILE_FOR_ASINH)
vy = __riscv_vfsgnj(vy, vx_orig, vlen);
#endif
vy = __riscv_vmerge(vy, vy_special, special_args, vlen);

// copy vy into y and increment addr pointers
VFSTORE_OUTARG1(vy, vlen);

INCREMENT_INARG1(vlen);
INCREMENT_OUTARG1(vlen);
}
RESTORE_FRM;
}
Loading

0 comments on commit 1bf7bb3

Please sign in to comment.