Skip to content

Commit

Permalink
TridiagSolver (dist): STEP1 rank-independent sort of eigenvalues by c…
Browse files Browse the repository at this point in the history
…olumn type for rank1 solver (#967)
  • Loading branch information
albestro authored Dec 11, 2023
1 parent a10493c commit b99bb16
Show file tree
Hide file tree
Showing 4 changed files with 348 additions and 110 deletions.
3 changes: 2 additions & 1 deletion include/dlaf/eigensolver/tridiag_solver/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ void TridiagSolver<B, D, T>::call(comm::CommunicatorGrid grid, Matrix<T, Device:
// Mirror workspace on host memory for CPU-only kernels
DistWorkSpaceHostMirror<T, D> ws_hm{initMirrorMatrix(ws.e0), initMirrorMatrix(ws.e2),
initMirrorMatrix(ws.d1), initMirrorMatrix(ws.z0),
initMirrorMatrix(ws.z1), initMirrorMatrix(ws.i2)};
initMirrorMatrix(ws.z1), initMirrorMatrix(ws.i2),
initMirrorMatrix(ws.i5)};

// Set `ws.e0` to `zero` (needed for Given's rotation to make sure no random values are picked up)
matrix::util::set0<B, T, D>(thread_priority::normal, ws.e0);
Expand Down
281 changes: 190 additions & 91 deletions include/dlaf/eigensolver/tridiag_solver/merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ struct DistWorkSpaceHostMirror {
HostMirrorMatrix<T, D> z1;

HostMirrorMatrix<SizeType, D> i2;
HostMirrorMatrix<SizeType, D> i5;
};

template <class T>
Expand Down Expand Up @@ -273,66 +274,6 @@ inline std::size_t ev_sort_order(const ColType coltype) {
return DLAF_UNREACHABLE(std::size_t);
}

// This function returns number of non-deflated eigenvectors, together with a permutation @p out_ptr
// that represent mapping (sorted non-deflated | sorted deflated) -> initial.
//
// The permutation will allow to keep the mapping between sorted eigenvalues and unsorted eigenvectors,
// which is useful since eigenvectors are more expensive to permuted, so we can keep them in their initial order.
//
// @param n number of eigenvalues
// @param c_ptr array[n] containing the column type of each eigenvector after deflation (initial order)
// @param evals_ptr array[n] of eigenvalues sorted as in_ptr
// @param in_ptr array[n] representing permutation current -> initial (i.e. evals[i] -> c_ptr[in_ptr[i]])
// @param out_ptr array[n] permutation (sorted non-deflated | sorted deflated) -> initial
//
// @return k number of non-deflated eigenvectors
template <class T>
SizeType stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* c_ptr,
const T* evals_ptr, const SizeType* in_ptr,
SizeType* out_ptr) {
// Get the number of non-deflated entries
SizeType k = 0;
for (SizeType i = 0; i < n; ++i) {
if (c_ptr[i] != ColType::Deflated)
++k;
}

// Create the permutation (sorted non-deflated | sorted deflated) -> initial
// Note:
// Since during deflation, eigenvalues related to deflated eigenvectors, might not be sorted anymore,
// this step also take care of sorting eigenvalues (actually just their related index) by their ascending value.
SizeType i1 = 0; // index of non-deflated values in out
SizeType i2 = k; // index of deflated values
for (SizeType i = 0; i < n; ++i) {
const SizeType ii = in_ptr[i];

// non-deflated are untouched, just squeeze them at the beginning as they appear
if (c_ptr[ii] != ColType::Deflated) {
out_ptr[i1] = ii;
++i1;
}
// deflated are the ones that can have been moved "out-of-order" by deflation...
// ... so each time insert it in the right place based on eigenvalue value
else {
const T a = evals_ptr[ii];

SizeType j = i2;
// shift to right all greater values (shift just indices)
for (; j > k; --j) {
const T b = evals_ptr[out_ptr[j - 1]];
if (a > b) {
break;
}
out_ptr[j] = out_ptr[j - 1];
}
// and insert the current index in the empty place, such that eigenvalues are sorted.
out_ptr[j] = ii;
++i2;
}
}
return k;
}

// This function returns number of non-deflated eigenvectors and a tuple with number of upper, dense
// and lower non-deflated eigenvectors, together with two permutations:
// - @p index_sorted (sort(non-deflated)|sorted(deflated) -> initial.
Expand Down Expand Up @@ -432,54 +373,187 @@ auto stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* typ
return std::tuple(k, std::move(n_udl));
}

// This function returns number of global non-deflated eigenvectors, together with two permutations:
// - @p index_sorted (sort(non-deflated)|sort(deflated)) -> initial.
// - @p index_sorted_coltype (sort(upper)|sort(dense)|sort(lower)|sort(deflated)) -> initial
//
// Both permutations are represented using global indices, but:
// - @p index_sorted sorts "globally", i.e. considering all evecs across ranks
// - @p index_sorted_coltype sorts "locally", i.e. consdering just evecs from the same rank
//
// In addition, even if all ranks have the full permutation, it is important to highlight that
// thanks to how it is built, i.e. rank-independent permutations, @p index_sorted_coltype can be used as
// if it was distributed, since each "local" tile would contain global indices that are valid
// just on the related rank.
//
// rank | 0 | 1 | 0 |
// initial | 0U| 1L| 2X| 3U| 4U| 5X| 6L| 7L| 8X|
// index_sorted | 0U| 1L| 3U| 4U| 6L| 7L| 2X| 5X| 8X| -> sort(non-deflated) | sort(deflated)
// index_sorted_col_type | 0U| 1L| 6L| 3U| 4U| 5X| 7L| 2X| 8X| -> rank0(ULLLXX) - rank1(UUX)
//
// index_sorted_col_type can be used "locally":
// on rank0 | 0U| 1L| 6L| --| --| --| 7L| 2X| 8X| -> ULLLXX
// on rank1 | --| --| --| 3U| 4U| 5X| --| --| --| -> UUX
//
// where U: Upper, D: Dense, L: Lower, X: Deflated
//
// The permutations will allow to keep the mapping between sorted eigenvalues and unsorted eigenvectors,
// which is useful since eigenvectors are more expensive to permute, so we can keep them in their
// initial order.
//
// @param n number of eigenvalues
// @param types array[n] column type of each eigenvector after deflation (initial order)
// @param evals array[n] of eigenvalues sorted as perm_sorted
// @param perm_sorted array[n] current -> initial (i.e. evals[i] -> types[perm_sorted[i]])
// @param index_sorted array[n] global(sort(non-deflated)|sort(deflated))) -> initial
// @param index_sorted_coltype array[n] local(sort(upper)|sort(dense)|sort(lower)|sort(deflated))) -> initial
//
// @return k number of non-deflated eigenvectors
template <class T>
auto stablePartitionIndexForDeflation(const SizeType i_begin, const SizeType i_end,
Matrix<const ColType, Device::CPU>& c,
Matrix<const T, Device::CPU>& evals,
Matrix<const SizeType, Device::CPU>& in,
Matrix<SizeType, Device::CPU>& out) {
SizeType stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub, const SizeType n,
const ColType* types, const T* evals,
const SizeType* perm_sorted, SizeType* index_sorted,
SizeType* index_sorted_coltype) {
const SizeType k = std::count_if(types, types + n,
[](const ColType coltype) { return ColType::Deflated != coltype; });

// Create the permutation (sorted non-deflated | sorted deflated) -> initial
// Note:
// Since during deflation, eigenvalues related to deflated eigenvectors, might not be sorted anymore,
// this step also take care of sorting eigenvalues (actually just their related index) by their ascending value.
SizeType i1 = 0; // index of non-deflated values
SizeType i2 = k; // index of deflated values
for (SizeType i = 0; i < n; ++i) {
const SizeType ii = perm_sorted[i];

// non-deflated are untouched, just squeeze them at the beginning as they appear
if (types[ii] != ColType::Deflated) {
index_sorted[i1] = ii;
++i1;
}
// deflated are the ones that can have been moved "out-of-order" by deflation...
// ... so each time insert it in the right place based on eigenvalue value
else {
const T a = evals[ii];

SizeType j = i2;
// shift to right all greater values (just the indices)
for (; j > k; --j) {
const T b = evals[index_sorted[j - 1]];
if (a > b) {
break;
}
index_sorted[j] = index_sorted[j - 1];
}
// and insert the current index in the empty place, such that eigenvalues are sorted.
index_sorted[j] = ii;
++i2;
}
}

// Create the permutation (sort(upper)|sort(dense)|sort(lower)|sort(deflated)) -> initial
// Note:
// index_sorted is used as "reference" in order to deal with deflated vectors in the right sorted order.
// In this way, also non-deflated are considered in a sorted way, which is not a requirement,
// but it does not hurt either.
const SizeType nperms = dist_sub.size().cols();

// Detect how many non-deflated per type (on each rank)
using offsets_t = std::array<std::size_t, 4>;
std::vector<offsets_t> offsets(to_sizet(dist_sub.grid_size().cols()), {0, 0, 0, 0});

for (SizeType j_el = 0; j_el < nperms; ++j_el) {
const SizeType jj_el = index_sorted[to_sizet(j_el)];
const ColType coltype = types[to_sizet(jj_el)];

const comm::IndexT_MPI rank = dist_sub.rank_global_element<Coord::Col>(jj_el);
offsets_t& rank_offsets = offsets[to_sizet(rank)];

if (coltype != ColType::Deflated)
++rank_offsets[1 + ev_sort_order(coltype)];
}
std::for_each(offsets.begin(), offsets.end(), [](offsets_t& rank_offsets) {
std::partial_sum(rank_offsets.cbegin(), rank_offsets.cend(), rank_offsets.begin());
});

// Each rank computes all rank permutations.
// Using previously calculated offsets (per rank), the permutation is already split in column types,
// so this loops over indices, checks the column type and eventually put the index in the right bin.
for (SizeType j_el = 0; j_el < nperms; ++j_el) {
const SizeType jj_el = index_sorted[to_sizet(j_el)];
const ColType coltype = types[to_sizet(jj_el)];

const comm::IndexT_MPI rank = dist_sub.rank_global_element<Coord::Col>(jj_el);
offsets_t& rank_offsets = offsets[to_sizet(rank)];

const SizeType jjj_el_lc = to_SizeType(rank_offsets[ev_sort_order(coltype)]++);
using matrix::internal::distribution::global_element_from_local_element_on_rank;
const SizeType jjj_el =
global_element_from_local_element_on_rank<Coord::Col>(dist_sub, rank, jjj_el_lc);

index_sorted_coltype[to_sizet(jjj_el)] = jj_el;
}

return k;
}

template <class T>
auto stablePartitionIndexForDeflation(
const SizeType i_begin, const SizeType i_end, Matrix<const ColType, Device::CPU>& c,
Matrix<const T, Device::CPU>& evals, Matrix<const SizeType, Device::CPU>& in,
Matrix<SizeType, Device::CPU>& out, Matrix<SizeType, Device::CPU>& out_by_coltype) {
namespace ex = pika::execution::experimental;
namespace di = dlaf::internal;
using pika::execution::thread_stacksize;

const SizeType n = problemSize(i_begin, i_end, in.distribution());
auto part_fn = [n](const auto& c_tiles_futs, const auto& evals_tiles_fut, const auto& in_tiles_futs,
const auto& out_tiles) {
auto part_fn = [n](const auto& c_tiles_futs, const auto& evals_tiles_futs, const auto& in_tiles_futs,
const auto& out_tiles, const auto& out_coltype_tiles) {
const TileElementIndex zero_idx(0, 0);
const ColType* c_ptr = c_tiles_futs[0].get().ptr(zero_idx);
const T* evals_ptr = evals_tiles_fut[0].get().ptr(zero_idx);
const T* evals_ptr = evals_tiles_futs[0].get().ptr(zero_idx);
const SizeType* in_ptr = in_tiles_futs[0].get().ptr(zero_idx);
SizeType* out_ptr = out_tiles[0].ptr(zero_idx);
SizeType* out_coltype_ptr = out_coltype_tiles[0].ptr(zero_idx);

return stablePartitionIndexForDeflationArrays(n, c_ptr, evals_ptr, in_ptr, out_ptr);
return stablePartitionIndexForDeflationArrays(n, c_ptr, evals_ptr, in_ptr, out_ptr, out_coltype_ptr);
};

TileCollector tc{i_begin, i_end};
return ex::when_all(ex::when_all_vector(tc.read(c)), ex::when_all_vector(tc.read(evals)),
ex::when_all_vector(tc.read(in)), ex::when_all_vector(tc.readwrite(out))) |
ex::when_all_vector(tc.read(in)), ex::when_all_vector(tc.readwrite(out)),
ex::when_all_vector(tc.readwrite(out_by_coltype))) |
di::transform(di::Policy<Backend::MC>(thread_stacksize::nostack), std::move(part_fn));
}

template <class T>
auto stablePartitionIndexForDeflation(
const SizeType i_begin, const SizeType i_end, Matrix<const ColType, Device::CPU>& c,
Matrix<const T, Device::CPU>& evals, Matrix<const SizeType, Device::CPU>& in,
Matrix<SizeType, Device::CPU>& out, Matrix<SizeType, Device::CPU>& out_by_coltype) {
auto stablePartitionIndexForDeflation(const matrix::Distribution& dist_evecs, const SizeType i_begin,
const SizeType i_end, Matrix<const ColType, Device::CPU>& c,
Matrix<const T, Device::CPU>& evals,
Matrix<const SizeType, Device::CPU>& in,
Matrix<SizeType, Device::CPU>& out,
Matrix<SizeType, Device::CPU>& out_by_coltype) {
namespace ex = pika::execution::experimental;
namespace di = dlaf::internal;
using pika::execution::thread_stacksize;

const SizeType n = problemSize(i_begin, i_end, in.distribution());
auto part_fn = [n](const auto& c_tiles_futs, const auto& evals_tiles_futs, const auto& in_tiles_futs,
const auto& out_tiles, const auto& out_coltype_tiles) {

const matrix::Distribution dist_evecs_sub(
dist_evecs, {dist_evecs.global_element_index({i_begin, i_begin}, {0, 0}), {n, n}});

auto part_fn = [n, dist_evecs_sub](const auto& c_tiles_futs, const auto& evals_tiles_futs,
const auto& in_tiles_futs, const auto& out_tiles,
const auto& out_coltype_tiles) {
const TileElementIndex zero_idx(0, 0);
const ColType* c_ptr = c_tiles_futs[0].get().ptr(zero_idx);
const T* evals_ptr = evals_tiles_futs[0].get().ptr(zero_idx);
const SizeType* in_ptr = in_tiles_futs[0].get().ptr(zero_idx);
SizeType* out_ptr = out_tiles[0].ptr(zero_idx);
SizeType* out_coltype_ptr = out_coltype_tiles[0].ptr(zero_idx);

return stablePartitionIndexForDeflationArrays(n, c_ptr, evals_ptr, in_ptr, out_ptr, out_coltype_ptr);
return stablePartitionIndexForDeflationArrays(dist_evecs_sub, n, c_ptr, evals_ptr, in_ptr, out_ptr,
out_coltype_ptr);
};

TileCollector tc{i_begin, i_end};
Expand Down Expand Up @@ -1561,21 +1635,25 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid,
// Initialize the column types vector `c`
initColTypes(i_begin, i_split, i_end, ws_h.c);

// Step #1
//
// i1 (out) : initial <--- initial (identity map)
// i2 (out) : initial <--- pre_sorted
//
// - deflate `d`, `z` and `c`
// - apply Givens rotations to `Q` - `evecs`
//
// Initialize i1 as identity just for single tile sub-problems
if (i_split == i_begin + 1) {
initIndex(i_begin, i_split, ws_h.i1);
}
if (i_split + 1 == i_end) {
initIndex(i_split, i_end, ws_h.i1);
}

// Update indices of second sub-problem
addIndex(i_split, i_end, n1, ws_h.i1);

// Step #1
//
// i1 (out) : initial <--- initial (or identity map)
// i2 (out) : initial <--- pre_sorted
//
// - deflate `d`, `z` and `c`
// - apply Givens rotations to `Q` - `evecs`
//
sortIndex(i_begin, i_end, ex::just(n1), ws_h.d0, ws_h.i1, ws_hm.i2);

auto rots =
Expand All @@ -1587,34 +1665,55 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid,
const comm::IndexT_MPI tag = to_int(i_split);
applyGivensRotationsToMatrixColumns(grid.rowCommunicator(), tag, i_begin, i_end, std::move(rots),
ws.e0);
// Placeholder for rearranging the eigenvectors: (local permutation)
copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws.e1);

// Step #2
//
// i2 (in) : initial <--- pre_sorted
// i3 (out) : initial <--- deflated
// i5 (out) : initial <--- local(UDL|X)
//
// - reorder eigenvectors locally so that they are well-shaped for gemm optimization (i.e. UDLX)
// - reorder `d0 -> d1`, `z0 -> z1`, using `i3` such that deflated entries are at the bottom.
// - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`.
// - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented)
//
auto k =
stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_h.d0, ws_hm.i2, ws_h.i3) | ex::split();
auto k = ex::split(stablePartitionIndexForDeflation(dist_evecs, i_begin, i_end, ws_h.c, ws_h.d0,
ws_hm.i2, ws_h.i3, ws_hm.i5));

// Reorder Eigenvectors
using dlaf::permutations::internal::permuteJustLocal;
if constexpr (Backend::MC == B) {
copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.i5, ws.i5);
permuteJustLocal<T, Coord::Col>(i_begin, i_end, ws.i5, ws.e0, ws.e1);
}
else {
copy(idx_loc_begin, sz_loc_tiles, ws.e0, ws_hm.e0);
permuteJustLocal<T, Coord::Col>(i_begin, i_end, ws_hm.i5, ws_hm.e0, ws_hm.e2);
copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e1);
}

// Reorder Eigenvalues
applyIndex(i_begin, i_end, ws_h.i3, ws_h.d0, ws_hm.d1);
applyIndex(i_begin, i_end, ws_h.i3, ws_hm.z0, ws_hm.z1);
copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.d1, ws_h.d0);

//
// i3 (in) : initial <--- deflated
// i2 (out) : initial ---> deflated
// i2 (out) : deflated <--- initial
//
invertIndex(i_begin, i_end, ws_h.i3, ws_hm.i2);

//
// i5 (in) : initial <--- local(UDL|X)
// i2 (in) : deflated <--- initial
// i4 (out) : deflated <--- local(UDL|X)
//
applyIndex(i_begin, i_end, ws_hm.i5, ws_hm.i2, ws_h.i4);

// Note: here ws_hm.z0 is used as a contiguous buffer for the laed4 call
matrix::util::set0<Backend::MC>(thread_priority::normal, idx_loc_begin, sz_loc_tiles, ws_hm.e2);
solveRank1ProblemDist(row_task_chain(), col_task_chain(), i_begin, i_end, idx_loc_begin, sz_loc_tiles,
k, std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_hm.i2, ws_hm.e2);
k, std::move(scaled_rho), ws_hm.d1, ws_hm.z1, ws_h.d0, ws_h.i4, ws_hm.e2);

// Step #3: Eigenvectors of the tridiagonal system: Q * U
//
Expand Down
Loading

0 comments on commit b99bb16

Please sign in to comment.