diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index f5204e87b3..bccf76c420 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -114,21 +114,22 @@ template struct Expm1Functor } // x, y finite numbers - realT cosY_val; - auto cosY_val_multi_ptr = sycl::address_space_cast< - sycl::access::address_space::private_space, - sycl::access::decorated::yes>(&cosY_val); - const realT sinY_val = sycl::sincos(y, cosY_val_multi_ptr); - const realT sinhalfY_val = std::sin(y / 2); + const realT cosY_val = std::cos(y); + const realT sinY_val = (y == 0) ? y : std::sin(y); + const realT sinhalfY_val = (y == 0) ? y : std::sin(y / 2); const realT res_re = std::expm1(x) * cosY_val - 2 * sinhalfY_val * sinhalfY_val; - const realT res_im = std::exp(x) * sinY_val; + realT res_im = std::exp(x) * sinY_val; return resT{res_re, res_im}; } else { static_assert(std::is_floating_point_v || std::is_same_v); + static_assert(std::is_same_v); + if (in == 0) { + return in; + } return std::expm1(in); } } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index 97768fc8e9..d0b86a79f3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -81,11 +81,15 @@ template struct SinFunctor */ if (in_re_finite && in_im_finite) { #ifdef USE_SYCL_FOR_COMPLEX_TYPES - return exprm_ns::sin( + resT res = exprm_ns::sin( exprm_ns::complex(in)); // std::sin(in); #else - return std::sin(in); + resT res = std::sin(in); #endif + if (in_re == realT(0)) { + res.real(std::copysign(realT(0), in_re)); + } + return res; } /* @@ -176,8 +180,10 @@ template struct SinFunctor return resT{sinh_im, -sinh_re}; } else { - static_assert(std::is_floating_point_v || - std::is_same_v); + static_assert(std::is_same_v); + if (in == 0) { + return in; + } return std::sin(in); } }