Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
suppress shape checks
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Yuuichi Asahi committed Dec 18, 2024
1 parent 8b6a003 commit f02c3df
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -49,15 +49,18 @@ KOKKOS_INLINE_FUNCTION static int checkTrsvInput([[maybe_unused]] const AViewTyp
return 1;
}

const int lda = A.extent(0), n = A.extent(1);
if (lda < Kokkos::max(1, n)) {
Kokkos::printf(
"KokkosBatched::trsv: leading dimension of A must not be smaller than "
"max(1, n): "
"lda = %d, n = %d\n",
lda, n);
return 1;
}
// FIXME : check leading dimension is suppressed for now
// because of the compatibility issue with Trilinos
// const int lda = A.extent(0), n = A.extent(1);
// if (lda < Kokkos::max(1, n)) {
// Kokkos::printf(
// "KokkosBatched::trsv: leading dimension of A must not be smaller than "
// "max(1, n): "
// "lda = %d, n = %d\n",
// lda, n);
// return 1;
// }

#endif
return 0;
}
@@ -71,7 +74,7 @@ struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::CompactM
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -111,7 +114,7 @@ struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::Unblocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -126,7 +129,7 @@ struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::Blocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -144,7 +147,7 @@ struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::CompactMKL
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -184,7 +187,7 @@ struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::Unblocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -199,7 +202,7 @@ struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -217,7 +220,7 @@ struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Compac
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -257,7 +260,7 @@ struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Unbloc
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -271,7 +274,7 @@ struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Blocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -288,7 +291,7 @@ struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::CompactM
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -328,7 +331,7 @@ struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::Unblocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -343,7 +346,7 @@ struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::Blocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -361,7 +364,7 @@ struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::CompactMKL
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -401,7 +404,7 @@ struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::Unblocked>
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -416,7 +419,7 @@ struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -434,7 +437,7 @@ struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Compac
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -474,7 +477,7 @@ struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Unbloc
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
@@ -488,7 +491,7 @@ struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Blocke
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
// if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

0 comments on commit f02c3df

Please sign in to comment.