19
19
20
20
// / \author Yuuichi Asahi (yuuichi.asahi@cea.fr)
21
21
22
+ #include " KokkosBlas_util.hpp"
22
23
#include " KokkosBatched_Util.hpp"
23
24
#include " KokkosBatched_Tbsv_Serial_Internal.hpp"
24
25
25
26
namespace KokkosBatched {
26
-
27
+ namespace Impl {
27
28
template <typename AViewType, typename XViewType>
28
29
KOKKOS_INLINE_FUNCTION static int checkTbsvInput ([[maybe_unused]] const AViewType &A,
29
30
[[maybe_unused]] const XViewType &x, [[maybe_unused]] const int k) {
30
- static_assert (Kokkos::is_view <AViewType>::value , " KokkosBatched::tbsv: AViewType is not a Kokkos::View." );
31
- static_assert (Kokkos::is_view <XViewType>::value , " KokkosBatched::tbsv: XViewType is not a Kokkos::View." );
31
+ static_assert (Kokkos::is_view_v <AViewType>, " KokkosBatched::tbsv: AViewType is not a Kokkos::View." );
32
+ static_assert (Kokkos::is_view_v <XViewType>, " KokkosBatched::tbsv: XViewType is not a Kokkos::View." );
32
33
static_assert (AViewType::rank == 2 , " KokkosBatched::tbsv: AViewType must have rank 2." );
33
34
static_assert (XViewType::rank == 1 , " KokkosBatched::tbsv: XViewType must have rank 1." );
34
35
@@ -63,15 +64,17 @@ KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewTyp
63
64
return 0 ;
64
65
}
65
66
67
+ } // namespace Impl
68
+
66
69
// // Lower non-transpose ////
67
70
template <typename ArgDiag>
68
71
struct SerialTbsv <Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
69
72
template <typename AViewType, typename XViewType>
70
73
KOKKOS_INLINE_FUNCTION static int invoke (const AViewType &A, const XViewType &x, const int k) {
71
- auto info = checkTbsvInput (A, x, k);
74
+ auto info = Impl:: checkTbsvInput (A, x, k);
72
75
if (info) return info;
73
76
74
- return SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke (
77
+ return Impl:: SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke (
75
78
ArgDiag::use_unit_diag, A.extent (1 ), A.data (), A.stride_0 (), A.stride_1 (), x.data (), x.stride_0 (), k);
76
79
}
77
80
};
@@ -81,11 +84,12 @@ template <typename ArgDiag>
81
84
struct SerialTbsv <Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
82
85
template <typename AViewType, typename XViewType>
83
86
KOKKOS_INLINE_FUNCTION static int invoke (const AViewType &A, const XViewType &x, const int k) {
84
- auto info = checkTbsvInput (A, x, k);
87
+ auto info = Impl:: checkTbsvInput (A, x, k);
85
88
if (info) return info;
86
89
87
- return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke (
88
- ArgDiag::use_unit_diag, false , A.extent (1 ), A.data (), A.stride_0 (), A.stride_1 (), x.data (), x.stride_0 (), k);
90
+ return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke (
91
+ KokkosBlas::Impl::OpID (), ArgDiag::use_unit_diag, A.extent (1 ), A.data (), A.stride_0 (), A.stride_1 (), x.data (),
92
+ x.stride_0 (), k);
89
93
}
90
94
};
91
95
@@ -94,11 +98,12 @@ template <typename ArgDiag>
94
98
struct SerialTbsv <Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
95
99
template <typename AViewType, typename XViewType>
96
100
KOKKOS_INLINE_FUNCTION static int invoke (const AViewType &A, const XViewType &x, const int k) {
97
- auto info = checkTbsvInput (A, x, k);
101
+ auto info = Impl:: checkTbsvInput (A, x, k);
98
102
if (info) return info;
99
103
100
- return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke (
101
- ArgDiag::use_unit_diag, true , A.extent (1 ), A.data (), A.stride_0 (), A.stride_1 (), x.data (), x.stride_0 (), k);
104
+ return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke (
105
+ KokkosBlas::Impl::OpConj (), ArgDiag::use_unit_diag, A.extent (1 ), A.data (), A.stride_0 (), A.stride_1 (), x.data (),
106
+ x.stride_0 (), k);
102
107
}
103
108
};
104
109
@@ -107,10 +112,10 @@ template <typename ArgDiag>
107
112
struct SerialTbsv <Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
108
113
template <typename AViewType, typename XViewType>
109
114
KOKKOS_INLINE_FUNCTION static int invoke (const AViewType &A, const XViewType &x, const int k) {
110
- auto info = checkTbsvInput (A, x, k);
115
+ auto info = Impl:: checkTbsvInput (A, x, k);
111
116
if (info) return info;
112
117
113
- return SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke (
118
+ return Impl:: SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke (
114
119
ArgDiag::use_unit_diag, A.extent (1 ), A.data (), A.stride_0 (), A.stride_1 (), x.data (), x.stride_0 (), k);
115
120
}
116
121
};
@@ -120,11 +125,12 @@ template <typename ArgDiag>
120
125
struct SerialTbsv <Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
121
126
template <typename AViewType, typename XViewType>
122
127
KOKKOS_INLINE_FUNCTION static int invoke (const AViewType &A, const XViewType &x, const int k) {
123
- auto info = checkTbsvInput (A, x, k);
128
+ auto info = Impl:: checkTbsvInput (A, x, k);
124
129
if (info) return info;
125
130
126
- return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke (
127
- ArgDiag::use_unit_diag, false , A.extent (1 ), A.data (), A.stride_0 (), A.stride_1 (), x.data (), x.stride_0 (), k);
131
+ return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke (
132
+ KokkosBlas::Impl::OpID (), ArgDiag::use_unit_diag, A.extent (1 ), A.data (), A.stride_0 (), A.stride_1 (), x.data (),
133
+ x.stride_0 (), k);
128
134
}
129
135
};
130
136
@@ -133,11 +139,12 @@ template <typename ArgDiag>
133
139
struct SerialTbsv <Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
134
140
template <typename AViewType, typename XViewType>
135
141
KOKKOS_INLINE_FUNCTION static int invoke (const AViewType &A, const XViewType &x, const int k) {
136
- auto info = checkTbsvInput (A, x, k);
142
+ auto info = Impl:: checkTbsvInput (A, x, k);
137
143
if (info) return info;
138
144
139
- return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke (
140
- ArgDiag::use_unit_diag, true , A.extent (1 ), A.data (), A.stride_0 (), A.stride_1 (), x.data (), x.stride_0 (), k);
145
+ return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke (
146
+ KokkosBlas::Impl::OpConj (), ArgDiag::use_unit_diag, A.extent (1 ), A.data (), A.stride_0 (), A.stride_1 (), x.data (),
147
+ x.stride_0 (), k);
141
148
}
142
149
};
143
150
0 commit comments