Skip to content

Commit

Permalink
Add a complicated analytical test based on review
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
  • Loading branch information
Yuuichi Asahi committed Jan 8, 2025
1 parent 6e051b4 commit 65d6fb4
Showing 1 changed file with 92 additions and 1 deletion.
93 changes: 92 additions & 1 deletion batched/dense/unit_test/Test_Batched_SerialGetrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ struct Functor_BatchedSerialGemm {
/// [0. 0. 1.]
/// [0. 0. 0.]]
/// piv2 = [0 1 2]
/// 3x3 more general matrix
/// which satisfies PA = LU
/// P = [[0 0 1]
/// [1 0 0]
/// [0 1 0]]
/// A = [[1 2 3]
/// [2 -4 6]
/// [3 -9 -3]]
/// L = [[1 0 0]
/// [1/3 1 0]
/// [2/3 2/5 1]]
/// U = [[-3 -9 -3]
/// [ 0 5 4]
/// [ 0 0 32/5]]
/// Note P is obtained by piv = [2 2 2]
/// We compare the non-diagnoal elements of L only, which is
/// NL = [[0 0 0]
/// [1/3 0 0]
/// [2/3 2/5 0]]
/// \param Nb [in] Batch size of matrices
template <typename DeviceType, typename ScalarType, typename LayoutType, typename AlgoTagType>
void impl_test_batched_getrf_analytical(const int Nb) {
Expand All @@ -140,12 +159,21 @@ void impl_test_batched_getrf_analytical(const int Nb) {
View3DType A2("A2", Nb, M, N), LU2("LU2", Nb, M, N);
PivView2DType ipiv2("ipiv2", Nb, N), ipiv2_ref("ipiv1_ref", Nb, N);

// Complicated matrix
View3DType A3("A3", Nb, N, N), LU3("LU3", Nb, N, N), L3("L3", Nb, N, N), U3("U3", Nb, N, N),
L3_ref("L3_ref", Nb, N, N), U3_ref("U3_ref", Nb, N, N);
PivView2DType ipiv3("ipiv3", Nb, N), ipiv3_ref("ipiv3_ref", Nb, N);

auto h_A0 = Kokkos::create_mirror_view(A0);
auto h_A1 = Kokkos::create_mirror_view(A1);
auto h_A2 = Kokkos::create_mirror_view(A2);
auto h_A3 = Kokkos::create_mirror_view(A3);
auto h_L3_ref = Kokkos::create_mirror_view(L3_ref);
auto h_U3_ref = Kokkos::create_mirror_view(U3_ref);
auto h_ipiv0_ref = Kokkos::create_mirror_view(ipiv0_ref);
auto h_ipiv1_ref = Kokkos::create_mirror_view(ipiv1_ref);
auto h_ipiv2_ref = Kokkos::create_mirror_view(ipiv2_ref);
auto h_ipiv3_ref = Kokkos::create_mirror_view(ipiv3_ref);
for (int ib = 0; ib < Nb; ib++) {
for (int i = 0; i < M; i++) {
h_ipiv0_ref(ib, i) = i;
Expand All @@ -162,28 +190,67 @@ void impl_test_batched_getrf_analytical(const int Nb) {
h_A2(ib, j, i) = i == j ? 1.0 : 0.0;
}
}

h_A3(ib, 0, 0) = 1.0;
h_A3(ib, 0, 1) = 2.0;
h_A3(ib, 0, 2) = 3.0;
h_A3(ib, 1, 0) = 2.0;
h_A3(ib, 1, 1) = -4.0;
h_A3(ib, 1, 2) = 6.0;
h_A3(ib, 2, 0) = 3.0;
h_A3(ib, 2, 1) = -9.0;
h_A3(ib, 2, 2) = -3.0;

h_L3_ref(ib, 0, 0) = 0.0;
h_L3_ref(ib, 0, 1) = 0.0;
h_L3_ref(ib, 0, 2) = 0.0;
h_L3_ref(ib, 1, 0) = 1.0 / 3.0;
h_L3_ref(ib, 1, 1) = 0.0;
h_L3_ref(ib, 1, 2) = 0.0;
h_L3_ref(ib, 2, 0) = 2.0 / 3.0;
h_L3_ref(ib, 2, 1) = 2.0 / 5.0;
h_L3_ref(ib, 2, 2) = 0.0;

h_U3_ref(ib, 0, 0) = 3.0;
h_U3_ref(ib, 0, 1) = -9.0;
h_U3_ref(ib, 0, 2) = -3.0;
h_U3_ref(ib, 1, 0) = 0.0;
h_U3_ref(ib, 1, 1) = 5.0;
h_U3_ref(ib, 1, 2) = 4.0;
h_U3_ref(ib, 2, 0) = 0.0;
h_U3_ref(ib, 2, 1) = 0.0;
h_U3_ref(ib, 2, 2) = 32.0 / 5.0;

h_ipiv3_ref(ib, 0) = 2;
h_ipiv3_ref(ib, 1) = 2;
h_ipiv3_ref(ib, 2) = 2;
}

Kokkos::deep_copy(A0, h_A0);
Kokkos::deep_copy(A1, h_A1);
Kokkos::deep_copy(A2, h_A2);
Kokkos::deep_copy(A3, h_A3);
Kokkos::deep_copy(LU0, A0);
Kokkos::deep_copy(LU1, A1);
Kokkos::deep_copy(LU2, A2);
Kokkos::deep_copy(LU3, A3);

// getrf to factorize matrix A = P * L * U
auto info0 = Functor_BatchedSerialGetrf<DeviceType, View3DType, PivView2DType, AlgoTagType>(LU0, ipiv0).run();
auto info1 = Functor_BatchedSerialGetrf<DeviceType, View3DType, PivView2DType, AlgoTagType>(LU1, ipiv1).run();
auto info2 = Functor_BatchedSerialGetrf<DeviceType, View3DType, PivView2DType, AlgoTagType>(LU2, ipiv2).run();
auto info3 = Functor_BatchedSerialGetrf<DeviceType, View3DType, PivView2DType, AlgoTagType>(LU3, ipiv3).run();

Kokkos::fence();
EXPECT_EQ(info0, 0);
EXPECT_EQ(info1, 0);
EXPECT_EQ(info2, 0);
EXPECT_EQ(info3, 0);

auto h_ipiv0 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv0);
auto h_ipiv1 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv1);
auto h_ipiv2 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv2);
auto h_ipiv3 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv3);

for (int ib = 0; ib < Nb; ib++) {
// Check if piv0 = [0 1 2 3]
Expand All @@ -195,8 +262,20 @@ void impl_test_batched_getrf_analytical(const int Nb) {
EXPECT_EQ(h_ipiv1(ib, i), h_ipiv1_ref(ib, i));
EXPECT_EQ(h_ipiv2(ib, i), h_ipiv2_ref(ib, i));
}
// Check if piv3 = [2 2 2]
for (int i = 0; i < N; i++) {
EXPECT_EQ(h_ipiv3(ib, i), h_ipiv3_ref(ib, i));
}
}

// Reconstruct L and U from Factorized matrix A
// Copy non-diagonal lower triangular components to NL
create_triangular_matrix<View3DType, View3DType, KokkosBatched::Uplo::Lower, KokkosBatched::Diag::NonUnit>(LU3, L3,
-1);

// Copy upper triangular components to U
create_triangular_matrix<View3DType, View3DType, KokkosBatched::Uplo::Upper, KokkosBatched::Diag::NonUnit>(LU3, U3);

RealType eps = 1.0e1 * ats::epsilon();
auto h_LU0 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), LU0);
auto h_LU1 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), LU1);
Expand All @@ -216,6 +295,18 @@ void impl_test_batched_getrf_analytical(const int Nb) {
}
}
}

// For complicated matrix, we compare L and U with reference L and U
auto h_L3 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), L3);
auto h_U3 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), U3);
for (int ib = 0; ib < Nb; ib++) {
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
EXPECT_NEAR_KK(h_L3(ib, i, j), h_L3_ref(ib, i, j), eps);
EXPECT_NEAR_KK(h_U3(ib, i, j), h_U3_ref(ib, i, j), eps);
}
}
}
}

/// \brief Implementation details of batched getrf test
Expand Down Expand Up @@ -258,7 +349,7 @@ void impl_test_batched_getrf(const int N, const int BlkSize) {
Kokkos::fence();
EXPECT_EQ(info, 0);

// Reconstruct L and D from Factorized matrix A
// Reconstruct L and U from Factorized matrix A
// Copy non-diagonal lower triangular components to NL
create_triangular_matrix<View3DType, View3DType, KokkosBatched::Uplo::Lower, KokkosBatched::Diag::NonUnit>(LU, NL,
-1);
Expand Down

0 comments on commit 65d6fb4

Please sign in to comment.