Skip to content

Commit

Permalink
doc + minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro committed Dec 12, 2023
1 parent 4f31507 commit a824d40
Showing 1 changed file with 79 additions and 9 deletions.
88 changes: 79 additions & 9 deletions include/dlaf/eigensolver/tridiag_solver/merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ auto stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* typ
//
// @return k number of non-deflated eigenvectors
// @return k_local number of local non-deflated eigenvectors
// @return n_udl tuple with global indices for [first_dense, last_dense, last_lower]
template <class T>
auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub, const ColType* types,
const T* evals, SizeType* perm_sorted,
Expand Down Expand Up @@ -500,7 +501,6 @@ auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub
index_sorted_coltype[to_sizet(jjj_el)] = jj_el;
}

// TODO manage edge cases
std::array<SizeType, 3> n_udl = [&]() {
SizeType first_dense;
for (first_dense = 0; first_dense < n; ++first_dense) {
Expand All @@ -510,14 +510,11 @@ auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub
break;
}

SizeType last_dense;
for (last_dense = n - 1; last_dense >= 0; --last_dense) {
const SizeType initial_el = index_sorted_coltype[to_sizet(last_dense)];
const ColType coltype = types[to_sizet(initial_el)];
if (ColType::LowerHalf != coltype && ColType::Deflated != coltype)
break;
}

// Note:
// Eigenvectors will be sorted according index_sorted_coltype, i.e. local sort by coltype.
// Since it is a local order, it is legit if deflated are globally interlaced with other column
// types. However, GEMM will be able to skip just the last global contiguous group of deflated
// eigenvectors, but not the ones interlaced with others.
SizeType last_lower;
for (last_lower = n - 1; last_lower >= 0; --last_lower) {
const SizeType initial_el = index_sorted_coltype[to_sizet(last_lower)];
Expand All @@ -526,6 +523,14 @@ auto stablePartitionIndexForDeflationArrays(const matrix::Distribution& dist_sub
break;
}

SizeType last_dense;
for (last_dense = last_lower; last_dense >= 0; --last_dense) {
const SizeType initial_el = index_sorted_coltype[to_sizet(last_dense)];
const ColType coltype = types[to_sizet(initial_el)];
if (ColType::LowerHalf != coltype && ColType::Deflated != coltype)
break;
}

return std::array<SizeType, 3>{first_dense, last_dense + 1, last_lower + 1};
}();

Expand Down Expand Up @@ -1657,6 +1662,71 @@ void multiplyEigenvectors(const matrix::Distribution& dist_sub,
const GlobalElementIndex sub_offset, const SizeType n, const SizeType n_upper,
const SizeType n_lower, Matrix<T, D>& e0, Matrix<T, D>& e1, Matrix<T, D>& e2,
KLcSender&& k_lc, UDLSenders&& n_udl) {
// Note:
// This function computes E0 = E1 . E2
//
// where E1 is the matrix with eigenvectors and it looks like this
//
// ┌──────────┐ k
// │ b │ │
//
// ┌── ┌───┬──────┬─┬────┐
// │ │UUU│DDDDDD│ │XXXX│
// │ │UUU│DDDDDD│ │XXXX│
// n_upper │ │UUU│DDDDDD│ │XXXX│
// │ │UUU│DDDDDD│ │XXXX│
// │ │UUU│DDDDDD│ │XXXX│
// ├── ├───┼──────┼─┤XXXX│
// │ │ │DDDDDD│L│XXXX│
// n_lower │ │ │DDDDDD│L│XXXX│
// │ │ │DDDDDD│L│XXXX│
// └── └───┴──────┴─┴────┘
// │ a │
// └───┘
// │ c │
// └────────────┘
//
// Where (a, b, c) are the values from n_udl
//
// Note:
// E1 matrix does not have all deflated values at the end, indeed part of them are "interlaced" with
// others. The GEMM will perform anyway a computation for deflated eigenvectors (which are zeroed out)
// while the copy step will be performed at "local" level, so even interlaced ones will get copied
// in the right spot.
//
// The multiplication in two different steps in order to skip zero blocks of the matrix, created by
// the grouping of eigenvectors of different lengths (UPPER, DENSE and LOWER).
//
// 1. GEMM1 = TL . TOP
// 2. GEMM2 = BR . BOTTOM
// 3. copy DEFLATED
//
// ┌────────────┬────┐
// │ │ │
// │ │ │
// │ T O P │ │
// │ │ │
// │ │ │
// ├────────────┤ │
// │ │ │
// │ │ │
// │B O T T O M │ │
// │ │ │
// └────────────┴────┘
//
// ┌──────────┬─┬────┐ ┌────────────┬────┐
// │ │0│ │ │ │ │
// │ │0│ D │ │ │ │
// │ TL │0│ E │ │ GEMM 1 │ C │
// │ │0│ F │ │ │ │
// │ │0│ L │ │ │ O │
// ├───┬──────┴─┤ A │ ├────────────┤ │
// │000│ │ T │ │ │ P │
// │000│ │ E │ │ │ │
// │000│ BR │ D │ │ GEMM 2 │ Y │
// │000│ │ │ │ │ │
// └───┴────────┴────┘ └────────────┴────┘

namespace ex = pika::execution::experimental;

ex::start_detached(
Expand Down

0 comments on commit a824d40

Please sign in to comment.