Skip to content

Commit d28324c

Browse files
author
Yuuichi Asahi
committed
refactor serial tbsv implementation details and tests
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
1 parent 4c33556 commit d28324c

6 files changed

+260
-175
lines changed

batched/dense/impl/KokkosBatched_Pbtrs_Serial_Internal.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_
1818
#define KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_
1919

20+
#include "KokkosBlas_util.hpp"
2021
#include "KokkosBatched_Util.hpp"
2122
#include "KokkosBatched_Tbsv_Serial_Internal.hpp"
2223

@@ -50,8 +51,9 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalLower<Algo::Pbtrs::Unblocked>::inv
5051
SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);
5152

5253
// Solve L**T *X = B, overwriting B with X.
53-
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
54-
SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);
54+
using op =
55+
std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;
56+
SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(op(), false, an, A, as0, as1, x, xs0, kd);
5557

5658
return 0;
5759
}
@@ -76,8 +78,9 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalUpper<Algo::Pbtrs::Unblocked>::inv
7678
/**/ ValueType *KOKKOS_RESTRICT x,
7779
const int xs0, const int kd) {
7880
// Solve U**T *X = B, overwriting B with X.
79-
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
80-
SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);
81+
using op =
82+
std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;
83+
SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(op(), false, an, A, as0, as1, x, xs0, kd);
8184

8285
// Solve U*X = B, overwriting B with X.
8386
SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);

batched/dense/impl/KokkosBatched_Tbsv_Serial_Impl.hpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@
1919

2020
/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr)
2121

22+
#include "KokkosBlas_util.hpp"
2223
#include "KokkosBatched_Util.hpp"
2324
#include "KokkosBatched_Tbsv_Serial_Internal.hpp"
2425

2526
namespace KokkosBatched {
26-
27+
namespace Impl {
2728
template <typename AViewType, typename XViewType>
2829
KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewType &A,
2930
[[maybe_unused]] const XViewType &x, [[maybe_unused]] const int k) {
30-
static_assert(Kokkos::is_view<AViewType>::value, "KokkosBatched::tbsv: AViewType is not a Kokkos::View.");
31-
static_assert(Kokkos::is_view<XViewType>::value, "KokkosBatched::tbsv: XViewType is not a Kokkos::View.");
31+
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::tbsv: AViewType is not a Kokkos::View.");
32+
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::tbsv: XViewType is not a Kokkos::View.");
3233
static_assert(AViewType::rank == 2, "KokkosBatched::tbsv: AViewType must have rank 2.");
3334
static_assert(XViewType::rank == 1, "KokkosBatched::tbsv: XViewType must have rank 1.");
3435

@@ -63,15 +64,17 @@ KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewTyp
6364
return 0;
6465
}
6566

67+
} // namespace Impl
68+
6669
//// Lower non-transpose ////
6770
template <typename ArgDiag>
6871
struct SerialTbsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
6972
template <typename AViewType, typename XViewType>
7073
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
71-
auto info = checkTbsvInput(A, x, k);
74+
auto info = Impl::checkTbsvInput(A, x, k);
7275
if (info) return info;
7376

74-
return SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(
77+
return Impl::SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(
7578
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
7679
}
7780
};
@@ -81,11 +84,12 @@ template <typename ArgDiag>
8184
struct SerialTbsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
8285
template <typename AViewType, typename XViewType>
8386
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
84-
auto info = checkTbsvInput(A, x, k);
87+
auto info = Impl::checkTbsvInput(A, x, k);
8588
if (info) return info;
8689

87-
return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
88-
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
90+
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
91+
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
92+
x.stride_0(), k);
8993
}
9094
};
9195

@@ -94,11 +98,12 @@ template <typename ArgDiag>
9498
struct SerialTbsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
9599
template <typename AViewType, typename XViewType>
96100
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
97-
auto info = checkTbsvInput(A, x, k);
101+
auto info = Impl::checkTbsvInput(A, x, k);
98102
if (info) return info;
99103

100-
return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
101-
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
104+
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
105+
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
106+
x.stride_0(), k);
102107
}
103108
};
104109

@@ -107,10 +112,10 @@ template <typename ArgDiag>
107112
struct SerialTbsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
108113
template <typename AViewType, typename XViewType>
109114
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
110-
auto info = checkTbsvInput(A, x, k);
115+
auto info = Impl::checkTbsvInput(A, x, k);
111116
if (info) return info;
112117

113-
return SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(
118+
return Impl::SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(
114119
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
115120
}
116121
};
@@ -120,11 +125,12 @@ template <typename ArgDiag>
120125
struct SerialTbsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
121126
template <typename AViewType, typename XViewType>
122127
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
123-
auto info = checkTbsvInput(A, x, k);
128+
auto info = Impl::checkTbsvInput(A, x, k);
124129
if (info) return info;
125130

126-
return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
127-
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
131+
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
132+
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
133+
x.stride_0(), k);
128134
}
129135
};
130136

@@ -133,11 +139,12 @@ template <typename ArgDiag>
133139
struct SerialTbsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
134140
template <typename AViewType, typename XViewType>
135141
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
136-
auto info = checkTbsvInput(A, x, k);
142+
auto info = Impl::checkTbsvInput(A, x, k);
137143
if (info) return info;
138144

139-
return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
140-
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
145+
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
146+
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
147+
x.stride_0(), k);
141148
}
142149
};
143150

batched/dense/impl/KokkosBatched_Tbsv_Serial_Internal.hpp

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
#include "KokkosBatched_Util.hpp"
2323

2424
namespace KokkosBatched {
25-
25+
namespace Impl {
2626
///
2727
/// Serial Internal Impl
2828
/// ====================
2929

3030
///
31-
/// Lower, Non-Transpose
31+
/// Lower
3232
///
3333

3434
template <typename AlgoType>
@@ -70,49 +70,37 @@ KOKKOS_INLINE_FUNCTION int SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invok
7070

7171
template <typename AlgoType>
7272
struct SerialTbsvInternalLowerTranspose {
73-
template <typename ValueType>
74-
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int an,
73+
template <typename Op, typename ValueType>
74+
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const bool use_unit_diag, const int an,
7575
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
7676
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k);
7777
};
7878

7979
template <>
80-
template <typename ValueType>
80+
template <typename Op, typename ValueType>
8181
KOKKOS_INLINE_FUNCTION int SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
82-
const bool use_unit_diag, const bool do_conj, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
83-
const int as1,
82+
Op op, const bool use_unit_diag, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
8483
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k) {
8584
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
8685
#pragma unroll
8786
#endif
8887
for (int j = an - 1; j >= 0; --j) {
8988
auto temp = x[j * xs0];
90-
91-
if (do_conj) {
92-
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
93-
#pragma unroll
94-
#endif
95-
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
96-
temp -= Kokkos::ArithTraits<ValueType>::conj(A[(i - j) * as0 + j * as1]) * x[i * xs0];
97-
}
98-
if (!use_unit_diag) temp = temp / Kokkos::ArithTraits<ValueType>::conj(A[0 + j * as1]);
99-
} else {
10089
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
10190
#pragma unroll
10291
#endif
103-
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
104-
temp -= A[(i - j) * as0 + j * as1] * x[i * xs0];
105-
}
106-
if (!use_unit_diag) temp = temp / A[0 + j * as1];
92+
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
93+
temp -= op(A[(i - j) * as0 + j * as1]) * x[i * xs0];
10794
}
95+
if (!use_unit_diag) temp = temp / op(A[0 + j * as1]);
10896
x[j * xs0] = temp;
10997
}
11098

11199
return 0;
112100
}
113101

114102
///
115-
/// Upper, Non-Transpose
103+
/// Upper
116104
///
117105

118106
template <typename AlgoType>
@@ -154,46 +142,36 @@ KOKKOS_INLINE_FUNCTION int SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invok
154142

155143
template <typename AlgoType>
156144
struct SerialTbsvInternalUpperTranspose {
157-
template <typename ValueType>
158-
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int an,
145+
template <typename Op, typename ValueType>
146+
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const bool use_unit_diag, const int an,
159147
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
160148
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k);
161149
};
162150

163151
template <>
164-
template <typename ValueType>
152+
template <typename Op, typename ValueType>
165153
KOKKOS_INLINE_FUNCTION int SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
166-
const bool use_unit_diag, const bool do_conj, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
167-
const int as1,
154+
Op op, const bool use_unit_diag, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
168155
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k) {
169156
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
170157
#pragma unroll
171158
#endif
172159
for (int j = 0; j < an; j++) {
173160
auto temp = x[j * xs0];
174-
if (do_conj) {
175-
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
176-
#pragma unroll
177-
#endif
178-
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
179-
temp -= Kokkos::ArithTraits<ValueType>::conj(A[(i + k - j) * as0 + j * as1]) * x[i * xs0];
180-
}
181-
if (!use_unit_diag) temp = temp / Kokkos::ArithTraits<ValueType>::conj(A[k * as0 + j * as1]);
182-
} else {
183161
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
184162
#pragma unroll
185163
#endif
186-
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
187-
temp -= A[(i + k - j) * as0 + j * as1] * x[i * xs0];
188-
}
189-
if (!use_unit_diag) temp = temp / A[k * as0 + j * as1];
164+
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
165+
temp -= op(A[(i + k - j) * as0 + j * as1]) * x[i * xs0];
190166
}
167+
if (!use_unit_diag) temp = temp / op(A[k * as0 + j * as1]);
191168
x[j * xs0] = temp;
192169
}
193170

194171
return 0;
195172
}
196173

174+
} // namespace Impl
197175
} // namespace KokkosBatched
198176

199177
#endif // KOKKOSBATCHED_TBSV_SERIAL_INTERNAL_HPP_

0 commit comments

Comments
 (0)