diff --git a/include/dlaf/eigensolver/tridiag_solver/impl.h b/include/dlaf/eigensolver/tridiag_solver/impl.h index 86a293bb9a..d12cf7cd8c 100644 --- a/include/dlaf/eigensolver/tridiag_solver/impl.h +++ b/include/dlaf/eigensolver/tridiag_solver/impl.h @@ -391,7 +391,8 @@ void TridiagSolver::call(comm::CommunicatorGrid grid, Matrix ws_hm{initMirrorMatrix(ws.e0), initMirrorMatrix(ws.e2), initMirrorMatrix(ws.d1), initMirrorMatrix(ws.z0), - initMirrorMatrix(ws.z1), initMirrorMatrix(ws.i2)}; + initMirrorMatrix(ws.z1), initMirrorMatrix(ws.i2), + initMirrorMatrix(ws.i5)}; // Set `ws.e0` to `zero` (needed for Given's rotation to make sure no random values are picked up) matrix::util::set0(thread_priority::normal, ws.e0); diff --git a/include/dlaf/eigensolver/tridiag_solver/merge.h b/include/dlaf/eigensolver/tridiag_solver/merge.h index 94f95df174..31fe2dd8f2 100644 --- a/include/dlaf/eigensolver/tridiag_solver/merge.h +++ b/include/dlaf/eigensolver/tridiag_solver/merge.h @@ -161,6 +161,7 @@ struct DistWorkSpaceHostMirror { HostMirrorMatrix z1; HostMirrorMatrix i2; + HostMirrorMatrix i5; }; template @@ -273,66 +274,6 @@ inline std::size_t ev_sort_order(const ColType coltype) { return DLAF_UNREACHABLE(std::size_t); } -// This function returns number of non-deflated eigenvectors, together with a permutation @p out_ptr -// that represent mapping (sorted non-deflated | sorted deflated) -> initial. -// -// The permutation will allow to keep the mapping between sorted eigenvalues and unsorted eigenvectors, -// which is useful since eigenvectors are more expensive to permuted, so we can keep them in their initial order. -// -// @param n number of eigenvalues -// @param c_ptr array[n] containing the column type of each eigenvector after deflation (initial order) -// @param evals_ptr array[n] of eigenvalues sorted as in_ptr -// @param in_ptr array[n] representing permutation current -> initial (i.e. evals[i] -> c_ptr[in_ptr[i]]) -// @param out_ptr array[n] permutation (sorted non-deflated | sorted deflated) -> initial -// -// @return k number of non-deflated eigenvectors -template -SizeType stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* c_ptr, - const T* evals_ptr, const SizeType* in_ptr, - SizeType* out_ptr) { - // Get the number of non-deflated entries - SizeType k = 0; - for (SizeType i = 0; i < n; ++i) { - if (c_ptr[i] != ColType::Deflated) - ++k; - } - - // Create the permutation (sorted non-deflated | sorted deflated) -> initial - // Note: - // Since during deflation, eigenvalues related to deflated eigenvectors, might not be sorted anymore, - // this step also take care of sorting eigenvalues (actually just their related index) by their ascending value. - SizeType i1 = 0; // index of non-deflated values in out - SizeType i2 = k; // index of deflated values - for (SizeType i = 0; i < n; ++i) { - const SizeType ii = in_ptr[i]; - - // non-deflated are untouched, just squeeze them at the beginning as they appear - if (c_ptr[ii] != ColType::Deflated) { - out_ptr[i1] = ii; - ++i1; - } - // deflated are the ones that can have been moved "out-of-order" by deflation... - // ... so each time insert it in the right place based on eigenvalue value - else { - const T a = evals_ptr[ii]; - - SizeType j = i2; - // shift to right all greater values (shift just indices) - for (; j > k; --j) { - const T b = evals_ptr[out_ptr[j - 1]]; - if (a > b) { - break; - } - out_ptr[j] = out_ptr[j - 1]; - } - // and insert the current index in the empty place, such that eigenvalues are sorted. - out_ptr[j] = ii; - ++i2; - } - } - return k; -} - // This function returns number of non-deflated eigenvectors and a tuple with number of upper, dense // and lower non-deflated eigenvectors, together with two permutations: // - @p index_sorted (sort(non-deflated)|sorted(deflated) -> initial. @@ -432,46 +373,184 @@ auto stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* typ return std::tuple(k, std::move(n_udl)); } +// This function returns number of global non-deflated eigenvectors, together with two permutations: +// - @p index_sorted_coltype (sort(upper)|sort(dense)|sort(lower)|sort(deflated)) -> initial +// - @p index_sorted (sort(non-deflated)|sort(deflated)) -> initial. +// +// Both permutations are represented using global indices, but: +// - @p index_sorted sorts "globally", i.e. considering all evecs across ranks +// - @p index_sorted_coltype sorts "locally", i.e. consdering just evecs from the same rank +// +// In addition, even if all ranks have the full permutation, it is important to highlight that +// thanks to how it is built, i.e. rank-independent permutations, @p index_sorted_coltype can be used as +// if it was distributed, since each "local" tile would contain global indices that are valid +// just on the related rank. +// +// rank | 0 | 1 | 0 | +// initial | 0U| 1L| 2X| 3U| 4U| 5X| 6L| 7L| 8X| +// index_sorted | 3U| 0U| 4U| 6L| 7L| 1L| 2X| 5X| 8X| -> UUULLLXXX +// index_sorted_col_type | 3U| 6L| 7L| 3U| 4U| 5X| 1L| 2X| 8X| +// +// index_sorted_col_type on can be used "locally": +// on rank0 | 3U| 6L| 7L| --| --| --| 1L| 2X| 8X| -> ULLLXX +// on rank1 | --| --| --| 3U| 4U| 5X| --| --| --| -> UUX +// +// where U: Upper, D: Dense, L: Lower, X: Deflated +// +// The permutations will allow to keep the mapping between sorted eigenvalues and unsorted eigenvectors, +// which is useful since eigenvectors are more expensive to permuted, so we can keep them in their +// initial order. +// +// @param n number of eigenvalues +// @param types array[n] column type of each eigenvector after deflation (initial order) +// @param evals array[n] of eigenvalues sorted as perm_sorted +// @param perm_sorted array[n] current -> initial (i.e. evals[i] -> types[perm_sorted[i]]) +// @param index_sorted array[n] global(sort(non-deflated)|sort(deflated))) -> initial +// @param index_sorted_coltype array[n] local(sort(upper)|sort(dense)|sort(lower)|sort(deflated))) -> initial +// +// @return k number of non-deflated eigenvectors template -auto stablePartitionIndexForDeflation(const SizeType i_begin, const SizeType i_end, - Matrix& c, - Matrix& evals, - Matrix& in, - Matrix& out) { +SizeType stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub, const SizeType n, + const ColType* types, const T* evals, + const SizeType* perm_sorted, SizeType* index_sorted, + SizeType* index_sorted_coltype) { + // Note: + // (in) types + // column type of the initial indexing + // (in) in_ptr + // initial <-- sorted by ascending eigenvalue + // (out) index_sorted + // initial <-- sorted by ascending eigenvalue in two groups (non-deflated | deflated) + // (out) index_sorted_coltype + // initial <-- sorted by ascending eigenvalue in four groups (upper | dense | lower | deflated) + const SizeType k = std::count_if(types, types + n, + [](const ColType coltype) { return ColType::Deflated != coltype; }); + + // Create the permutation (sorted non-deflated | sorted deflated) -> initial + // Note: + // Since during deflation, eigenvalues related to deflated eigenvectors, might not be sorted anymore, + // this step also take care of sorting eigenvalues (actually just their related index) by their ascending value. + SizeType i1 = 0; // index of non-deflated values in out + SizeType i2 = k; // index of deflated values + for (SizeType i = 0; i < n; ++i) { + const SizeType ii = perm_sorted[i]; + + // non-deflated are untouched, just squeeze them at the beginning as they appear + if (types[ii] != ColType::Deflated) { + index_sorted[i1] = ii; + ++i1; + } + // deflated are the ones that can have been moved "out-of-order" by deflation... + // ... so each time insert it in the right place based on eigenvalue value + else { + const T a = evals[ii]; + + SizeType j = i2; + // shift to right all greater values (shift just indices) + for (; j > k; --j) { + const T b = evals[index_sorted[j - 1]]; + if (a > b) { + break; + } + index_sorted[j] = index_sorted[j - 1]; + } + // and insert the current index in the empty place, such that eigenvalues are sorted. + index_sorted[j] = ii; + ++i2; + } + } + + // Create the permutation (sort(upper)|sort(dense)|sort(lower)|sort(deflated)) -> initial + // Note: + // index_sorted is used as "reference" in order to deal with deflated vectors in the right sorted order. + // In this way, also non-deflated are considered in a sorted way, which is not a requirement, + // but it does not hurt either. + const SizeType nperms = dist_sub.size().cols(); + + using offsets_t = std::array; + std::vector offsets(to_sizet(dist_sub.commGridSize().cols()), {0, 0, 0, 0}); + + for (SizeType j_el = 0; j_el < nperms; ++j_el) { + const SizeType jj_el = index_sorted[to_sizet(j_el)]; + const ColType coltype = types[to_sizet(jj_el)]; + + const comm::IndexT_MPI rank = dist_sub.rankGlobalElement(jj_el); + offsets_t& rank_offsets = offsets[to_sizet(rank)]; + + if (coltype != ColType::Deflated) + ++rank_offsets[1 + ev_sort_order(coltype)]; + } + std::for_each(offsets.begin(), offsets.end(), [](offsets_t& rank_offsets) { + std::partial_sum(rank_offsets.cbegin(), rank_offsets.cend(), rank_offsets.begin()); + }); + + for (SizeType j_el = 0; j_el < nperms; ++j_el) { + const SizeType jj_el = index_sorted[to_sizet(j_el)]; + const ColType coltype = types[to_sizet(jj_el)]; + + const comm::IndexT_MPI rank = dist_sub.rankGlobalElement(jj_el); + offsets_t& rank_offsets = offsets[to_sizet(rank)]; + + const SizeType jjj_el_lc = to_SizeType(rank_offsets[ev_sort_order(coltype)]++); + using matrix::internal::distribution::global_element_from_local_element_on_rank; + const SizeType jjj_el = + global_element_from_local_element_on_rank(dist_sub, rank, jjj_el_lc); + + index_sorted_coltype[to_sizet(jjj_el)] = jj_el; + } + + return k; +} + +template +auto stablePartitionIndexForDeflation( + const SizeType i_begin, const SizeType i_end, Matrix& c, + Matrix& evals, Matrix& in, + Matrix& out, Matrix& out_by_coltype) { namespace ex = pika::execution::experimental; namespace di = dlaf::internal; using pika::execution::thread_stacksize; const SizeType n = problemSize(i_begin, i_end, in.distribution()); - auto part_fn = [n](const auto& c_tiles_futs, const auto& evals_tiles_fut, const auto& in_tiles_futs, - const auto& out_tiles) { + auto part_fn = [n](const auto& c_tiles_futs, const auto& evals_tiles_futs, const auto& in_tiles_futs, + const auto& out_tiles, const auto& out_coltype_tiles) { const TileElementIndex zero_idx(0, 0); const ColType* c_ptr = c_tiles_futs[0].get().ptr(zero_idx); - const T* evals_ptr = evals_tiles_fut[0].get().ptr(zero_idx); + const T* evals_ptr = evals_tiles_futs[0].get().ptr(zero_idx); const SizeType* in_ptr = in_tiles_futs[0].get().ptr(zero_idx); SizeType* out_ptr = out_tiles[0].ptr(zero_idx); + SizeType* out_coltype_ptr = out_coltype_tiles[0].ptr(zero_idx); - return stablePartitionIndexForDeflationArrays(n, c_ptr, evals_ptr, in_ptr, out_ptr); + return stablePartitionIndexForDeflationArrays(n, c_ptr, evals_ptr, in_ptr, out_ptr, out_coltype_ptr); }; TileCollector tc{i_begin, i_end}; return ex::when_all(ex::when_all_vector(tc.read(c)), ex::when_all_vector(tc.read(evals)), - ex::when_all_vector(tc.read(in)), ex::when_all_vector(tc.readwrite(out))) | + ex::when_all_vector(tc.read(in)), ex::when_all_vector(tc.readwrite(out)), + ex::when_all_vector(tc.readwrite(out_by_coltype))) | di::transform(di::Policy(thread_stacksize::nostack), std::move(part_fn)); } template -auto stablePartitionIndexForDeflation( - const SizeType i_begin, const SizeType i_end, Matrix& c, - Matrix& evals, Matrix& in, - Matrix& out, Matrix& out_by_coltype) { +auto stablePartitionIndexForDeflation(const matrix::Distribution& dist_evecs, const SizeType i_begin, + const SizeType i_end, Matrix& c, + Matrix& evals, + Matrix& in, + Matrix& out, + Matrix& out_by_coltype) { namespace ex = pika::execution::experimental; namespace di = dlaf::internal; using pika::execution::thread_stacksize; const SizeType n = problemSize(i_begin, i_end, in.distribution()); - auto part_fn = [n](const auto& c_tiles_futs, const auto& evals_tiles_futs, const auto& in_tiles_futs, - const auto& out_tiles, const auto& out_coltype_tiles) { + + const matrix::Distribution dist_evecs_sub( + dist_evecs, + matrix::SubDistributionSpec{dist_evecs.globalElementIndex({i_begin, i_begin}, {0, 0}), {n, n}}); + + auto part_fn = [n, dist_evecs_sub](const auto& c_tiles_futs, const auto& evals_tiles_futs, + const auto& in_tiles_futs, const auto& out_tiles, + const auto& out_coltype_tiles) { const TileElementIndex zero_idx(0, 0); const ColType* c_ptr = c_tiles_futs[0].get().ptr(zero_idx); const T* evals_ptr = evals_tiles_futs[0].get().ptr(zero_idx); @@ -479,7 +558,8 @@ auto stablePartitionIndexForDeflation( SizeType* out_ptr = out_tiles[0].ptr(zero_idx); SizeType* out_coltype_ptr = out_coltype_tiles[0].ptr(zero_idx); - return stablePartitionIndexForDeflationArrays(n, c_ptr, evals_ptr, in_ptr, out_ptr, out_coltype_ptr); + return stablePartitionIndexForDeflationArrays(dist_evecs_sub, n, c_ptr, evals_ptr, in_ptr, out_ptr, + out_coltype_ptr); }; TileCollector tc{i_begin, i_end}; @@ -1561,6 +1641,17 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // Initialize the column types vector `c` initColTypes(i_begin, i_split, i_end, ws_h.c); + // Initialize i1 as identity just for single tile sub-problems + if (i_split == i_begin + 1) { + initIndex(i_begin, i_split, ws_h.i1); + } + if (i_split + 1 == i_end) { + initIndex(i_split, i_end, ws_h.i1); + } + + // Update indices of second sub-problem + addIndex(i_split, i_end, n1, ws_h.i1); + // Step #1 // // i1 (out) : initial <--- initial (identity map) @@ -1569,13 +1660,6 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // - deflate `d`, `z` and `c` // - apply Givens rotations to `Q` - `evecs` // - if (i_split == i_begin + 1) { - initIndex(i_begin, i_split, ws_h.i1); - } - if (i_split + 1 == i_end) { - initIndex(i_split, i_end, ws_h.i1); - } - addIndex(i_split, i_end, n1, ws_h.i1); sortIndex(i_begin, i_end, ex::just(n1), ws_h.d0, ws_h.i1, ws_hm.i2); auto rots = @@ -1587,8 +1671,6 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, const comm::IndexT_MPI tag = to_int(i_split); applyGivensRotationsToMatrixColumns(grid.rowCommunicator(), tag, i_begin, i_end, std::move(rots), ws.e0); - // Placeholder for rearranging the eigenvectors: (local permutation) - copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws.e1); // Step #2 // @@ -1599,8 +1681,24 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`. // - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented) // - auto k = - stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_h.d0, ws_hm.i2, ws_h.i3) | ex::split(); + auto k = ex::split(stablePartitionIndexForDeflation(dist_evecs, i_begin, i_end, ws_h.c, ws_h.d0, + ws_hm.i2, ws_h.i3, ws_hm.i5)); + + // Reorder Eigenvectors + if constexpr (Backend::MC == B) { + copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.i5, ws.i5); + dlaf::permutations::internal::permuteJustLocal(i_begin, i_end, ws.i5, ws.e0, + ws.e1); + } + else { + // TODO remove this branch. It exists just because GPU permuteJustLocal is not implemented yet + copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws_hm.e0); + dlaf::permutations::internal::permuteJustLocal( + i_begin, i_end, ws_hm.i5, ws_hm.e0, ws_hm.e2); + copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e1); + } + + // Reorder Eigenvalues applyIndex(i_begin, i_end, ws_h.i3, ws_h.d0, ws_hm.d1); applyIndex(i_begin, i_end, ws_h.i3, ws_hm.z0, ws_hm.z1); copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.d1, ws_h.d0); @@ -1611,10 +1709,20 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid, // invertIndex(i_begin, i_end, ws_h.i3, ws_hm.i2); + // + // i5 (in) : initial <--- sort by coltype + // i2 (in) : deflated <--- initial + // i4 (out) : deflated <--- sort by col type + // + // TODO This is propedeutic to work in rank1 solver with columns sorted by type, so that they are + // well-shaped for an optimized gemm, but still keeping track of where the actual position sorted by + // eigenvalues is. + applyIndex(i_begin, i_end, ws_hm.i5, ws_hm.i2, ws_h.i4); + // Note: here ws_hm.z0 is used as a contiguous buffer for the laed4 call matrix::util::set0(thread_priority::normal, idx_loc_begin, sz_loc_tiles, ws_hm.e2); solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, idx_loc_begin, sz_loc_tiles, - k, std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_hm.i2, ws_hm.e2); + k, std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_h.i4, ws_hm.e2); // Step #3: Eigenvectors of the tridiagonal system: Q * U // diff --git a/include/dlaf/permutations/general/impl.h b/include/dlaf/permutations/general/impl.h index 84483b9b72..d524003b45 100644 --- a/include/dlaf/permutations/general/impl.h +++ b/include/dlaf/permutations/general/impl.h @@ -27,8 +27,11 @@ #include #include #include +#include +#include #include #include +#include #include #include #include @@ -192,6 +195,81 @@ void Permutations::call(const SizeType i_begin, const SizeType i_end } } +template +void permuteJustLocal(const SizeType i_begin, const SizeType i_end, Matrix& perms, + Matrix& mat_in, Matrix& mat_out) { + static_assert(C == Coord::Col, "Just column permutation"); + + namespace ut = matrix::util; + namespace ex = pika::execution::experimental; + namespace di = dlaf::internal; + + using matrix::internal::MatrixRef; + using matrix::internal::SubMatrixSpec; + + if (i_begin == i_end) + return; + + const matrix::Distribution& distr = mat_in.distribution(); + + using matrix::internal::distribution::global_tile_element_distance; + const SubMatrixSpec sub_spec{distr.globalElementIndex({i_begin, i_begin}, {0, 0}), + { + global_tile_element_distance(distr, i_begin, i_end), + global_tile_element_distance(distr, i_begin, i_end), + }}; + MatrixRef mat_sub_in(mat_in, sub_spec); + MatrixRef mat_sub_out(mat_out, sub_spec); + + const matrix::Distribution& dist_sub = mat_sub_in.distribution(); + + const SizeType ntiles = i_end - i_begin; + const auto perms_range = common::iterate_range2d(LocalTileIndex(i_begin, 0), LocalTileSize(ntiles, 1)); + const auto mat_range = common::iterate_range2d(dist_sub.localNrTiles()); + auto sender = ex::when_all(ex::when_all_vector(matrix::selectRead(perms, std::move(perms_range))), + ex::when_all_vector(matrix::selectRead(mat_sub_in, mat_range)), + ex::when_all_vector(matrix::select(mat_sub_out, mat_range))); + + auto permute_fn = [dist_sub](const auto& perm_tiles_futs, const auto& mat_in_tiles, + const auto& mat_out_tiles, auto&&...) { + const SizeType* perm_ptr = perm_tiles_futs[0].get().ptr(); + + const SizeType nperms_lc = dist_sub.localSize().cols(); + if constexpr (D == Device::CPU) { + for (SizeType j_el_lc = 0; j_el_lc < nperms_lc; ++j_el_lc) { + const SizeType j_el = dist_sub.globalElementFromLocalElement(j_el_lc); + const SizeType jj_el = perm_ptr[to_sizet(j_el)]; + + const SizeType j_lc = dist_sub.localTileFromLocalElement(j_el_lc); + const SizeType j_el_tl = dist_sub.tileElementFromLocalElement(j_el_lc); + + const SizeType jj_lc = dist_sub.localTileFromGlobalElement(jj_el); + const SizeType jj_el_tl = dist_sub.tileElementFromGlobalElement(jj_el); + + for (SizeType i_lc = 0; i_lc < dist_sub.localNrTiles().rows(); ++i_lc) { + const std::size_t j_lc_linear = to_sizet(dist_sub.localTileLinearIndex({i_lc, j_lc})); + const std::size_t jj_lc_linear = to_sizet(dist_sub.localTileLinearIndex({i_lc, jj_lc})); + + const auto& tile_in = mat_in_tiles[jj_lc_linear].get(); + auto& tile_out = mat_out_tiles[j_lc_linear]; + + DLAF_ASSERT_HEAVY(tile_in.size().rows() == tile_out.size().rows(), tile_in.size(), + tile_out.size()); + const TileElementSize region(tile_in.size().rows(), 1); + const TileElementIndex sub_in(0, jj_el_tl); + const TileElementIndex sub_out(0, j_el_tl); + + dlaf::tile::lacpy(region, sub_in, tile_in, sub_out, tile_out); + } + } + } + else { + // TODO GPU + } + }; + ex::start_detached(di::transform(di::Policy(), std::move(permute_fn), std::move(sender))); +} + template auto whenAllReadWriteTilesArray(LocalTileIndex begin, LocalTileIndex end, Matrix& matrix) { const LocalTileSize sz{end.row() - begin.row(), end.col() - begin.col()}; diff --git a/test/unit/eigensolver/test_tridiag_solver_merge.cpp b/test/unit/eigensolver/test_tridiag_solver_merge.cpp index 5701003973..426646dddd 100644 --- a/test/unit/eigensolver/test_tridiag_solver_merge.cpp +++ b/test/unit/eigensolver/test_tridiag_solver_merge.cpp @@ -8,7 +8,12 @@ // SPDX-License-Identifier: BSD-3-Clause // +#include +#include + +#include #include +#include #include #include @@ -83,7 +88,7 @@ TYPED_TEST(TridiagEigensolverMergeTest, SortIndex) { TEST(StablePartitionIndexOnDeflated, FullRange) { constexpr SizeType n = 10; - const SizeType nb = 3; + constexpr SizeType nb = 3; const LocalElementSize sz(n, 1); const TileElementSize bk(nb, 1); @@ -92,6 +97,7 @@ TEST(StablePartitionIndexOnDeflated, FullRange) { Matrix vals(sz, bk); Matrix in(sz, bk); Matrix out(sz, bk); + Matrix out_by_type(sz, bk); // Note: // UpperHalf -> u @@ -101,16 +107,18 @@ TEST(StablePartitionIndexOnDeflated, FullRange) { // | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial // | l| f|d1|d2| u| u| l| f|d3| l| c_arr - std::vector c_arr{ColType::LowerHalf, ColType::Dense, ColType::Deflated, - ColType::Deflated, ColType::UpperHalf, ColType::UpperHalf, - ColType::LowerHalf, ColType::Dense, ColType::Deflated, - ColType::LowerHalf}; + + const std::vector c_arr{ColType::LowerHalf, ColType::Dense, ColType::Deflated, + ColType::Deflated, ColType::UpperHalf, ColType::UpperHalf, + ColType::LowerHalf, ColType::Dense, ColType::Deflated, + ColType::LowerHalf}; DLAF_ASSERT(c_arr.size() == to_sizet(n), n); dlaf::matrix::util::set(c, [&c_arr](GlobalElementIndex i) { return c_arr[to_sizet(i.row())]; }); // | 1| 4| 2| 3| 0| 5| 6| 7| 8| 9| in_arr // | f| u|d1|d2| l| u| l| f|d3| l| c_arr permuted by in_arr - std::array in_arr{1, 4, 2, 3, 0, 5, 6, 7, 8, 9}; + + constexpr std::array in_arr{1, 4, 2, 3, 0, 5, 6, 7, 8, 9}; dlaf::matrix::util::set(in, [&in_arr](GlobalElementIndex i) { return in_arr[to_sizet(i.row())]; }); // | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial @@ -119,36 +127,88 @@ TEST(StablePartitionIndexOnDeflated, FullRange) { // | 1| 4| 2| 3| 0| 5| 6| 7| 8| 9| in_arr // | f| u|d1|d2| l| u| l| f|d3| l| c_arr permuted by in_arr // |10|20| 2| 3|30|40|50|60| 1|70| vals_arr permuted by in_arr - std::array vals_arr{30, 10, 2, 3, 20, 40, 50, 60, 1, 70}; + + constexpr std::array vals_arr{30, 10, 2, 3, 20, 40, 50, 60, 1, 70}; dlaf::matrix::util::set(vals, [&vals_arr](GlobalElementIndex i) { return vals_arr[to_sizet(i.row())]; }); const SizeType i_begin = 0; const SizeType i_end = 4; - auto k = stablePartitionIndexForDeflation(i_begin, i_end, c, vals, in, out); + auto [k_sender, n_udl_sender] = + stablePartitionIndexForDeflation(i_begin, i_end, c, vals, in, out, out_by_type) | + ex::split_tuple(); // | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial // | l| f|d1|d2| u| u| l| f|d3| l| c_arr // |30|10| 2| 3|20|40|50|60| 1|70| vals_arr - // + + const SizeType k = tt::sync_wait(std::move(k_sender)); + ASSERT_EQ(k, 7); + + const auto n_udl = tt::sync_wait(std::move(n_udl_sender)); + EXPECT_EQ(n_udl[0], 2); + EXPECT_EQ(n_udl[1], 2); + EXPECT_EQ(n_udl[2], 3); + // | 1| 4| 0| 5| 6| 7| 9| 8| 2| 3| out_arr - // | f| u| l| u| l| f| l|d3|d1|d2| c_arr permuted by out_arr // |10|20|30|40|50|60|70| 1| 2| 3| vals_arr permuted by out_arr - const SizeType k_value = tt::sync_wait(std::move(k)); - ASSERT_TRUE(k_value == 7); - - std::array expected_out_arr{1, 4, 0, 5, 6, 7, 9, 8, 2, 3}; - auto expected_out = [&expected_out_arr](GlobalElementIndex i) { + constexpr std::array expected_out_arr{1, 4, 0, 5, 6, 7, 9, 8, 2, 3}; + auto expected_out = [&expected_out_arr](const GlobalElementIndex i) { return expected_out_arr[to_sizet(i.row())]; }; + + auto out_sender = tt::sync_wait(out.read(LocalTileIndex(0, 0))); + const SizeType* out_ptr = out_sender.get().ptr(); + CHECK_MATRIX_EQ(expected_out, out); - const SizeType* out_ptr = tt::sync_wait(out.read(LocalTileIndex(0, 0))).get().ptr(); - EXPECT_TRUE(std::is_sorted(out_ptr + k_value, out_ptr + n, + EXPECT_TRUE(std::is_sorted(out_ptr, out_ptr + k, [&vals_arr](const SizeType i, const SizeType j) { + return vals_arr[to_sizet(i)] < vals_arr[to_sizet(j)]; + })) << "non-deflated part of 'out' permutation should be sorted by eigenvalues"; + + EXPECT_TRUE(std::is_sorted(out_ptr + k, out_ptr + n, [&vals_arr](const SizeType i, const SizeType j) { + return vals_arr[to_sizet(i)] < vals_arr[to_sizet(j)]; + })) << "deflated part of 'out' permutation should be sorted by eigenvalues"; + + // | 4| 5| 1| 7| 0| 6| 9| 8| 2| 3| out_by_type_arr + // | u| u| f| f| l| l| l|d3|d1|d2| c_arr permuted by out_by_type_arr + // |20|40|10|60|30|50|70| 1| 2| 3| vals_arr permuted by out_arr_by_type_arr + + constexpr std::array expected_ordered{ColType::UpperHalf, ColType::Dense, + ColType::LowerHalf, ColType::Deflated}; + constexpr std::array expected_out_by_type_arr{4, 5, 1, 7, 0, 6, 9, 8, 2, 3}; + auto expected_out_by_type = [&expected_out_by_type_arr](const GlobalElementIndex i) { + return expected_out_by_type_arr[to_sizet(i.row())]; + }; + + auto out_by_type_sender = tt::sync_wait(out_by_type.read(LocalTileIndex(0, 0))); + const SizeType* out_by_type_ptr = out_by_type_sender.get().ptr(); + + CHECK_MATRIX_EQ(expected_out_by_type, out_by_type); + + const std::array partitions = [&]() { + std::array offsets; + offsets[0] = 0; + offsets[4] = n; + std::partial_sum(n_udl.cbegin(), n_udl.cend(), offsets.begin() + 1); + return offsets; + }(); + + for (std::size_t coltype_index = 0; coltype_index < expected_ordered.size(); ++coltype_index) { + const auto begin = partitions[coltype_index]; + const auto end = partitions[coltype_index + 1]; + for (std::size_t i = begin; i < end; ++i) { + const ColType coltype = c_arr[to_sizet(out_by_type_ptr[to_sizet(i)])]; + EXPECT_EQ(expected_ordered[coltype_index], coltype) << " at index " << i; + } + } + + EXPECT_TRUE(std::is_sorted(out_by_type_ptr + k, out_by_type_ptr + n, [&vals_arr](const SizeType i, const SizeType j) { return vals_arr[to_sizet(i)] < vals_arr[to_sizet(j)]; - })); + })) + << "deflated should be sorted in out_by_type"; } TYPED_TEST(TridiagEigensolverMergeTest, Deflation) {