From 6667b94f7312e00178a4b9770b6b0c9e32288d7b Mon Sep 17 00:00:00 2001 From: Luc Berger-Vergiat Date: Wed, 18 Dec 2024 12:14:10 -0700 Subject: [PATCH] Blas - rot: fixing interface of rot The cosine coefficient is strictly real while the sine coefficient can be real or complex leading to a bug in the current API. This commit should fix that for the native and TPL implementation and the associated unit-test is also fixed accordingly. Signed-off-by: Luc Berger-Vergiat --- blas/impl/KokkosBlas1_rot_impl.hpp | 11 +++--- blas/impl/KokkosBlas1_rot_spec.hpp | 23 +++++++------ blas/src/KokkosBlas1_rot.hpp | 26 ++++++++++---- blas/tpls/KokkosBlas1_rot_tpl_spec_avail.hpp | 7 +++- blas/tpls/KokkosBlas1_rot_tpl_spec_decl.hpp | 36 ++++++++++++++------ blas/unit_test/Test_Blas1_rot.hpp | 6 ++-- 6 files changed, 74 insertions(+), 35 deletions(-) diff --git a/blas/impl/KokkosBlas1_rot_impl.hpp b/blas/impl/KokkosBlas1_rot_impl.hpp index e139e916be..612ba60bf5 100644 --- a/blas/impl/KokkosBlas1_rot_impl.hpp +++ b/blas/impl/KokkosBlas1_rot_impl.hpp @@ -23,14 +23,15 @@ namespace KokkosBlas { namespace Impl { -template +template struct rot_functor { using scalar_type = typename VectorView::non_const_value_type; VectorView X, Y; - ScalarView c, s; + MagnitudeView c; + ScalarView s; - rot_functor(VectorView const& X_, VectorView const& Y_, ScalarView const& c_, ScalarView const& s_) + rot_functor(VectorView const& X_, VectorView const& Y_, MagnitudeView const& c_, ScalarView const& s_) : X(X_), Y(Y_), c(c_), s(s_) {} KOKKOS_INLINE_FUNCTION @@ -41,8 +42,8 @@ struct rot_functor { } }; -template -void Rot_Invoke(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c, + template +void Rot_Invoke(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c, ScalarView const& s) { Kokkos::RangePolicy rot_policy(space, 0, X.extent(0)); rot_functor rot_func(X, Y, c, s); diff --git a/blas/impl/KokkosBlas1_rot_spec.hpp b/blas/impl/KokkosBlas1_rot_spec.hpp index 493cd648cf..61f14fb302 100644 --- a/blas/impl/KokkosBlas1_rot_spec.hpp +++ b/blas/impl/KokkosBlas1_rot_spec.hpp @@ -29,7 +29,7 @@ namespace KokkosBlas { namespace Impl { // Specialization struct which defines whether a specialization exists -template +template struct rot_eti_spec_avail { enum : bool { value = false }; }; @@ -49,7 +49,8 @@ struct rot_eti_spec_avail { EXECSPACE, \ Kokkos::View, Kokkos::MemoryTraits>, \ Kokkos::View::mag_type, LAYOUT, Kokkos::Device, \ - Kokkos::MemoryTraits>> { \ + Kokkos::MemoryTraits>, \ + Kokkos::View, Kokkos::MemoryTraits>> {\ enum : bool { value = true }; \ }; @@ -61,19 +62,19 @@ namespace KokkosBlas { namespace Impl { // Unification layer -template ::value, - bool eti_spec_avail = rot_eti_spec_avail::value> +template ::value, + bool eti_spec_avail = rot_eti_spec_avail::value> struct Rot { - static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c, + static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c, ScalarView const& s); }; #if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY //! Full specialization of Rot. -template -struct Rot { - static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c, +template +struct Rot { + static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c, ScalarView const& s) { Kokkos::Profiling::pushRegion(KOKKOSKERNELS_IMPL_COMPILE_LIBRARY ? "KokkosBlas::rot[ETI]" : "KokkosBlas::rot[noETI]"); @@ -86,7 +87,7 @@ struct Rot(space, X, Y, c, s); + Rot_Invoke(space, X, Y, c, s); Kokkos::Profiling::popRegion(); } }; @@ -108,6 +109,7 @@ struct Rot, Kokkos::MemoryTraits>, \ Kokkos::View::mag_type, LAYOUT, Kokkos::Device, \ Kokkos::MemoryTraits>, \ + Kokkos::View, Kokkos::MemoryTraits>, \ false, true>; // @@ -121,6 +123,7 @@ struct Rot, Kokkos::MemoryTraits>, \ Kokkos::View::mag_type, LAYOUT, Kokkos::Device, \ Kokkos::MemoryTraits>, \ + Kokkos::View, Kokkos::MemoryTraits>, \ false, true>; #include diff --git a/blas/src/KokkosBlas1_rot.hpp b/blas/src/KokkosBlas1_rot.hpp index 0c36eab426..01b8a62f06 100644 --- a/blas/src/KokkosBlas1_rot.hpp +++ b/blas/src/KokkosBlas1_rot.hpp @@ -21,22 +21,28 @@ namespace KokkosBlas { -template -void rot(execution_space const& space, VectorView const& X, VectorView const& Y, ScalarView const& c, +template +void rot(execution_space const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c, ScalarView const& s) { static_assert(Kokkos::is_execution_space::value, "rot: execution_space template parameter is not a Kokkos " "execution space."); static_assert(Kokkos::is_view_v, "KokkosBlas::rot: VectorView is not a Kokkos::View."); + static_assert(Kokkos::is_view_v, "KokkosBlas::rot: MagnitudeView is not a Kokkos::View."); static_assert(Kokkos::is_view_v, "KokkosBlas::rot: ScalarView is not a Kokkos::View."); static_assert(VectorView::rank == 1, "rot: VectorView template parameter needs to be a rank 1 view"); + static_assert(MagnitudeView::rank == 0, "rot: MagnitudeView template parameter needs to be a rank 0 view"); static_assert(ScalarView::rank == 0, "rot: ScalarView template parameter needs to be a rank 0 view"); static_assert(Kokkos::SpaceAccessibility::accessible, "rot: VectorView template parameter memory space needs to be accessible " "from " "execution_space template parameter"); + static_assert(Kokkos::SpaceAccessibility::accessible, + "rot: MagnitudeView template parameter memory space needs to be accessible " + "from " + "execution_space template parameter"); static_assert(Kokkos::SpaceAccessibility::accessible, - "rot: VectorView template parameter memory space needs to be accessible " + "rot: ScalarView template parameter memory space needs to be accessible " "from " "execution_space template parameter"); static_assert(std::is_same::value, @@ -55,21 +61,27 @@ void rot(execution_space const& space, VectorView const& X, VectorView const& Y, Kokkos::Device, Kokkos::MemoryTraits>; + using MagnitudeView_Internal = Kokkos::View::array_layout, + Kokkos::Device, + Kokkos::MemoryTraits>; + using ScalarView_Internal = Kokkos::View::array_layout, Kokkos::Device, Kokkos::MemoryTraits>; VectorView_Internal X_(X), Y_(Y); - ScalarView_Internal c_(c), s_(s); + MagnitudeView_Internal c_(c); + ScalarView_Internal s_(s); Kokkos::Profiling::pushRegion("KokkosBlas::rot"); - Impl::Rot::rot(space, X_, Y_, c_, s_); + Impl::Rot::rot(space, X_, Y_, c_, s_); Kokkos::Profiling::popRegion(); } -template -void rot(VectorView const& X, VectorView const& Y, ScalarView const& c, ScalarView const& s) { +template +void rot(VectorView const& X, VectorView const& Y, MagnitudeView const& c, ScalarView const& s) { const typename VectorView::execution_space space = typename VectorView::execution_space(); rot(space, X, Y, c, s); } diff --git a/blas/tpls/KokkosBlas1_rot_tpl_spec_avail.hpp b/blas/tpls/KokkosBlas1_rot_tpl_spec_avail.hpp index fee65fce14..6f4784b389 100644 --- a/blas/tpls/KokkosBlas1_rot_tpl_spec_avail.hpp +++ b/blas/tpls/KokkosBlas1_rot_tpl_spec_avail.hpp @@ -20,7 +20,7 @@ namespace KokkosBlas { namespace Impl { // Specialization struct which defines whether a specialization exists -template +template struct rot_tpl_spec_avail { enum : bool { value = false }; }; @@ -37,6 +37,9 @@ namespace Impl { struct rot_tpl_spec_avail, \ Kokkos::MemoryTraits>, \ + Kokkos::View::mag_type, LAYOUT, \ + Kokkos::Device, \ + Kokkos::MemoryTraits>, \ Kokkos::View, \ Kokkos::MemoryTraits>> { \ enum : bool { value = true }; \ @@ -64,6 +67,8 @@ KOKKOSBLAS1_ROT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex, Kokkos::LayoutLeft, struct rot_tpl_spec_avail< \ EXECSPACE, \ Kokkos::View, Kokkos::MemoryTraits>, \ + Kokkos::View::mag_type, LAYOUT, Kokkos::Device, \ + Kokkos::MemoryTraits>, \ Kokkos::View, Kokkos::MemoryTraits>> { \ enum : bool { value = true }; \ }; diff --git a/blas/tpls/KokkosBlas1_rot_tpl_spec_decl.hpp b/blas/tpls/KokkosBlas1_rot_tpl_spec_decl.hpp index dfe747bf88..d34cbd885f 100644 --- a/blas/tpls/KokkosBlas1_rot_tpl_spec_decl.hpp +++ b/blas/tpls/KokkosBlas1_rot_tpl_spec_decl.hpp @@ -157,12 +157,15 @@ namespace Impl { EXECSPACE, \ Kokkos::View, Kokkos::MemoryTraits>, \ Kokkos::View, Kokkos::MemoryTraits>, \ + Kokkos::View, Kokkos::MemoryTraits>, \ true, ETI_SPEC_AVAIL> { \ using VectorView = \ Kokkos::View, Kokkos::MemoryTraits>; \ + using MagnitudeView = \ + Kokkos::View, Kokkos::MemoryTraits>; \ using ScalarView = \ Kokkos::View, Kokkos::MemoryTraits>; \ - static void rot(EXECSPACE const& space, VectorView const& X, VectorView const& Y, ScalarView const& c, \ + static void rot(EXECSPACE const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c, \ ScalarView const& s) { \ Kokkos::Profiling::pushRegion("KokkosBlas::rot[TPL_CUBLAS,double]"); \ rot_print_specialization(); \ @@ -182,13 +185,16 @@ namespace Impl { struct Rot< \ EXECSPACE, \ Kokkos::View, Kokkos::MemoryTraits>, \ + Kokkos::View, Kokkos::MemoryTraits>, \ Kokkos::View, Kokkos::MemoryTraits>, true, \ ETI_SPEC_AVAIL> { \ using VectorView = \ Kokkos::View, Kokkos::MemoryTraits>; \ + using MagnitudeView = \ + Kokkos::View, Kokkos::MemoryTraits>; \ using ScalarView = \ Kokkos::View, Kokkos::MemoryTraits>; \ - static void rot(EXECSPACE const& space, VectorView const& X, VectorView const& Y, ScalarView const& c, \ + static void rot(EXECSPACE const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c, \ ScalarView const& s) { \ Kokkos::Profiling::pushRegion("KokkosBlas::rot[TPL_CUBLAS,float]"); \ rot_print_specialization(); \ @@ -210,12 +216,17 @@ namespace Impl { Kokkos::View*, LAYOUT, Kokkos::Device, \ Kokkos::MemoryTraits>, \ Kokkos::View, Kokkos::MemoryTraits>, \ + Kokkos::View, LAYOUT, Kokkos::Device, \ + Kokkos::MemoryTraits>, \ true, ETI_SPEC_AVAIL> { \ using VectorView = Kokkos::View*, LAYOUT, Kokkos::Device, \ Kokkos::MemoryTraits>; \ - using ScalarView = \ + using MagnitudeView = \ Kokkos::View, Kokkos::MemoryTraits>; \ - static void rot(EXECSPACE const& space, VectorView const& X, VectorView const& Y, ScalarView const& c, \ + using ScalarView = \ + Kokkos::View, LAYOUT, Kokkos::Device, \ + Kokkos::MemoryTraits>; \ + static void rot(EXECSPACE const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c, \ ScalarView const& s) { \ Kokkos::Profiling::pushRegion("KokkosBlas::rot[TPL_CUBLAS,complex]"); \ rot_print_specialization(); \ @@ -225,7 +236,8 @@ namespace Impl { KOKKOSBLAS_IMPL_CUBLAS_SAFE_CALL(cublasGetPointerMode(singleton.handle, &pointer_mode)); \ KOKKOSBLAS_IMPL_CUBLAS_SAFE_CALL(cublasSetPointerMode(singleton.handle, CUBLAS_POINTER_MODE_DEVICE)); \ cublasZdrot(singleton.handle, X.extent_int(0), reinterpret_cast(X.data()), 1, \ - reinterpret_cast(Y.data()), 1, c.data(), s.data()); \ + reinterpret_cast(Y.data()), 1, c.data(), \ + reinterpret_cast(s.data())); \ KOKKOSBLAS_IMPL_CUBLAS_SAFE_CALL(cublasSetPointerMode(singleton.handle, pointer_mode)); \ Kokkos::Profiling::popRegion(); \ } \ @@ -237,13 +249,17 @@ namespace Impl { EXECSPACE, \ Kokkos::View*, LAYOUT, Kokkos::Device, \ Kokkos::MemoryTraits>, \ - Kokkos::View, Kokkos::MemoryTraits>, true, \ + Kokkos::View, Kokkos::MemoryTraits>, \ + Kokkos::View, LAYOUT, Kokkos::Device, \ + Kokkos::MemoryTraits>, true, \ ETI_SPEC_AVAIL> { \ - using VectorView = Kokkos::View, LAYOUT, Kokkos::Device, \ + using VectorView = Kokkos::View*, LAYOUT, Kokkos::Device, \ Kokkos::MemoryTraits>; \ - using ScalarView = \ + using MagnitudeView = \ Kokkos::View, Kokkos::MemoryTraits>; \ - static void rot(EXECSPACE const& space, VectorView const& X, VectorView const& Y, ScalarView const& c, \ + using ScalarView = Kokkos::View, LAYOUT, Kokkos::Device, \ + Kokkos::MemoryTraits>; \ + static void rot(EXECSPACE const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c, \ ScalarView const& s) { \ Kokkos::Profiling::pushRegion("KokkosBlas::rot[TPL_CUBLAS,complex]"); \ rot_print_specialization(); \ @@ -253,7 +269,7 @@ namespace Impl { KOKKOSBLAS_IMPL_CUBLAS_SAFE_CALL(cublasGetPointerMode(singleton.handle, &pointer_mode)); \ KOKKOSBLAS_IMPL_CUBLAS_SAFE_CALL(cublasSetPointerMode(singleton.handle, CUBLAS_POINTER_MODE_DEVICE)); \ cublasCsrot(singleton.handle, X.extent_int(0), reinterpret_cast(X.data()), 1, \ - reinterpret_cast(Y.data()), 1, c.data(), s.data()); \ + reinterpret_cast(Y.data()), 1, c.data(), reinterpret_cast(s.data())); \ KOKKOSBLAS_IMPL_CUBLAS_SAFE_CALL(cublasSetPointerMode(singleton.handle, pointer_mode)); \ Kokkos::Profiling::popRegion(); \ } \ diff --git a/blas/unit_test/Test_Blas1_rot.hpp b/blas/unit_test/Test_Blas1_rot.hpp index db9367cb42..629e6db003 100644 --- a/blas/unit_test/Test_Blas1_rot.hpp +++ b/blas/unit_test/Test_Blas1_rot.hpp @@ -19,12 +19,14 @@ template int test_rot() { using mag_type = typename Kokkos::ArithTraits::mag_type; using vector_type = Kokkos::View; - using scalar_type = Kokkos::View; + using magnitude_type = Kokkos::View; + using scalar_type = Kokkos::View; using vector_ref_type = Kokkos::View; vector_type X("X", 4), Y("Y", 4); vector_ref_type Xref("Xref", 4), Yref("Yref", 4); - scalar_type c("c"), s("s"); + magnitude_type c("c"); + scalar_type s("s"); // Initialize inputs typename vector_type::HostMirror X_h = Kokkos::create_mirror_view(X);