Skip to content

Commit

Permalink
Simplify templates in par_strength
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewreisner committed Aug 19, 2024
1 parent 7bfa9d0 commit 930dfee
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 71 deletions.
90 changes: 43 additions & 47 deletions raptor/par_strength.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <type_traits>

#include "core/par_matrix.hpp"
#include "core/matrix_traits.hpp"

using namespace raptor;

Expand Down Expand Up @@ -64,15 +65,17 @@ template<> struct norm_coupling<strength_norm::abs>
static constexpr double strongest(double a, double b) { return std::max(std::abs(a), b); }
};

template<bool is_bsr>
template<class T>
struct mat_args {
const std::conditional_t<is_bsr, BSRMatrix &, Matrix &> mat;
T & mat;
const int * variables;
int beg;
};
template <class T>
mat_args(T&, const int*, int)->mat_args<T>;

template <class P>
constexpr double value(Matrix & mat, int i) {
constexpr double value(CSRMatrix & mat, int i) {
return mat.vals[i];
}
template <class P>
Expand All @@ -83,8 +86,8 @@ constexpr double value(BSRMatrix & mat, int i) {
if (P::comp(val, curr)) curr = val;
return curr;
}
template <class P, bool filter, bool is_bsr>
constexpr double strongest_element(int i, int row_var, const mat_args<is_bsr> & a) {
template <class P, bool filter, class T>
constexpr double strongest_element(int i, int row_var, const mat_args<T> & a) {
auto curr = P::init;
for (int j = a.beg; j < a.mat.idx1[i+1]; ++j) {
auto col = a.mat.idx2[j];
Expand All @@ -96,9 +99,9 @@ constexpr double strongest_element(int i, int row_var, const mat_args<is_bsr> &
}
return curr;
}
template <class P, bool is_bsr>
template <class P, class T>
constexpr double strongest_connection(int row, int row_var, int num_variables,
mat_args<is_bsr> on_proc, mat_args<is_bsr> off_proc) {
mat_args<T> on_proc, mat_args<T> off_proc) {
if (num_variables == 1) {
return P::strongest(
strongest_element<P, false>(row, row_var, on_proc),
Expand All @@ -109,11 +112,12 @@ constexpr double strongest_connection(int row, int row_var, int num_variables,
strongest_element<P, true>(row, row_var, off_proc));
}
}
template<bool is_bsr>
struct append_args : mat_args<is_bsr> { Matrix & soc; };
template <class P, bool filter, bool is_bsr>
template<class T>
struct append_args : mat_args<T> { Matrix & soc; };

template <class P, bool filter, class T>
constexpr void add_connections(int row, int row_var, double threshold,
const append_args<is_bsr> & args) {
const append_args<T> & args) {
for (int j = args.beg; j < args.mat.idx1[row+1]; ++j) {
auto col = args.mat.idx2[j];
if constexpr (filter)
Expand All @@ -126,9 +130,9 @@ constexpr void add_connections(int row, int row_var, double threshold,
}
}
}
template <class P, bool is_bsr>
template <class P, class T>
constexpr void add_strong_connections(int row, int row_var, int num_variables, double threshold,
append_args<is_bsr> on_proc, append_args<is_bsr> off_proc) {
append_args<T> on_proc, append_args<T> off_proc) {
if (num_variables == 1) {
add_connections<P, false>(row, row_var, threshold, on_proc);
add_connections<P, false>(row, row_var, threshold, off_proc);
Expand All @@ -141,6 +145,8 @@ constexpr void add_strong_connections(int row, int row_var, int num_variables, d

void hybrid_strength(ParCSRMatrix & A, ParCSRMatrix & S,
double theta, int num_variables, int *variables, int *off_variables) {
auto & on_proc = dynamic_cast<CSRMatrix&>(*A.on_proc);
auto & off_proc = dynamic_cast<CSRMatrix&>(*A.off_proc);
for (int i = 0; i < A.local_num_rows; i++)
{
auto row_start_on = A.on_proc->idx1[i];
Expand All @@ -165,9 +171,9 @@ void hybrid_strength(ParCSRMatrix & A, ParCSRMatrix & S,
auto row_scale = [&]() {
auto get_row_scale = [&](auto comp) {
using P = decltype(comp);
return strongest_connection<P, false>(i, row_var, num_variables,
{*A.on_proc, variables, row_start_on},
{*A.off_proc, off_variables, row_start_off});
return strongest_connection<P, CSRMatrix>(i, row_var, num_variables,
{on_proc, variables, row_start_on},
{off_proc, off_variables, row_start_off});
};
if (diag < 0.0)
return get_row_scale(positive_coupling{});
Expand All @@ -188,9 +194,9 @@ void hybrid_strength(ParCSRMatrix & A, ParCSRMatrix & S,
// row_max * theta
auto add_row = [&](auto comp) {
using P = decltype(comp);
add_strong_connections<P, false>(i, row_var, num_variables, threshold,
{{*A.on_proc, variables, row_start_on}, *S.on_proc},
{{*A.off_proc, off_variables, row_start_off}, *S.off_proc});
add_strong_connections<P, CSRMatrix>(i, row_var, num_variables, threshold,
{{on_proc, variables, row_start_on}, *S.on_proc},
{{off_proc, off_variables, row_start_off}, *S.off_proc});
};
if (diag < 0)
add_row(positive_coupling{});
Expand All @@ -202,11 +208,13 @@ void hybrid_strength(ParCSRMatrix & A, ParCSRMatrix & S,
}
} // hybrid_strength

template<bool is_bsr, strength_norm snorm>
void norm_strength(ParCSRMatrix & A, ParCSRMatrix & S,
template<strength_norm snorm, class T, is_bsr_or_csr<T> = true>
void norm_strength(T & A, ParCSRMatrix & S,
double theta, int num_variables, int *variables, int *off_variables) {
auto *bsr_diag = dynamic_cast<BSRMatrix*>(A.on_proc);
auto *bsr_offd = dynamic_cast<BSRMatrix*>(A.off_proc);
using seq_t = sequential_matrix_t<T>;
auto & diag = dynamic_cast<seq_t&>(*A.on_proc);
auto & offd = dynamic_cast<seq_t&>(*A.off_proc);

using P = norm_coupling<snorm>;
for (int i = 0; i < A.local_num_rows; i++)
{
Expand All @@ -224,40 +232,28 @@ void norm_strength(ParCSRMatrix & A, ParCSRMatrix & S,

auto row_var = (num_variables > 1) ? variables[i] : -1;
// Find value with max magnitude in row
auto row_scale = [&]() {
if constexpr (is_bsr)
return strongest_connection<P, true>(i, row_var, num_variables,
{*bsr_diag, variables, row_start_on},
{*bsr_offd, off_variables, row_start_off});
else
return strongest_connection<P, false>(i, row_var, num_variables,
{*A.on_proc, variables, row_start_on},
{*A.off_proc, off_variables, row_start_off});
}();
auto row_scale = strongest_connection<P, seq_t>(i, row_var, num_variables,
{diag, variables, row_start_on},
{offd, off_variables, row_start_off});

// Multiply row max magnitude by theta
auto threshold = row_scale * theta;

// Always add diagonal
S.on_proc->idx2[S.on_proc->nnz] = i;
if constexpr (is_bsr)
if constexpr (is_bsr_v<T>)
S.on_proc->vals[S.on_proc->nnz] = has_zero_diag ? 0 :
value<P>(dynamic_cast<BSRMatrix&>(*A.on_proc), row_start_on - 1);
value<P>(diag, row_start_on - 1);
else
S.on_proc->vals[S.on_proc->nnz] = has_zero_diag ? 0 : A.on_proc->vals[row_start_on - 1];
S.on_proc->nnz++;

// Add all off-diagonal entries to strength
// if magnitude greater than equal to
// row_max * theta
if constexpr (is_bsr)
add_strong_connections<P, true>(i, row_var, num_variables, threshold,
{{*bsr_diag, variables, row_start_on}, *S.on_proc},
{{*bsr_offd, off_variables, row_start_off}, *S.off_proc});
else
add_strong_connections<P, false>(i, row_var, num_variables, threshold,
{*A.on_proc, variables, row_start_on},
{*A.off_proc, off_variables, row_start_off});
add_strong_connections<P, seq_t>(i, row_var, num_variables, threshold,
{{diag, variables, row_start_on}, *S.on_proc},
{{offd, off_variables, row_start_off}, *S.off_proc});
}
S.on_proc->idx1[i+1] = S.on_proc->nnz;
S.off_proc->idx1[i+1] = S.off_proc->nnz;
Expand Down Expand Up @@ -291,12 +287,12 @@ ParCSRMatrix* classical_strength(ParCSRMatrix* A, double theta, bool tap_amg, in

classical::init_strength(*A->on_proc, *S->on_proc);
classical::init_strength(*A->off_proc, *S->off_proc);
auto is_bsr = dynamic_cast<ParBSRMatrix*>(A);
if (!is_bsr) {
auto bsr = dynamic_cast<ParBSRMatrix*>(A);
if (!bsr) {
classical::hybrid_strength(*A, *S, theta, num_variables, variables, off_variables);
} else {
classical::norm_strength<true, strength_norm::abs>(*A, *S, theta, num_variables,
variables, off_variables);
classical::norm_strength<strength_norm::abs>(*bsr, *S, theta, num_variables,
variables, off_variables);
}

classical::finalize_strength(*A, *S);
Expand Down
25 changes: 1 addition & 24 deletions raptor/ruge_stuben/par_interpolation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// License: Simplified BSD, http://opensource.org/licenses/BSD-2-Clause
#include "assert.h"
#include "raptor/core/types.hpp"
#include "raptor/core/matrix_traits.hpp"
#include "raptor/core/par_matrix.hpp"
#include "raptor/ruge_stuben/par_interpolation.hpp"

Expand Down Expand Up @@ -1978,33 +1979,9 @@ ParBSRMatrix * one_point_interpolation(const ParBSRMatrix & A,
return ret;
}

template <class T>
using is_bsr_or_csr = std::enable_if_t<std::is_same_v<T, ParCSRMatrix> ||
std::is_same_v<T, ParBSRMatrix>, bool>;

template <class T> struct is_bsr : std::false_type {};
template <> struct is_bsr<ParBSRMatrix> : std::true_type {};
template <class T> inline constexpr bool is_bsr_v = is_bsr<T>::value;

BSRMatrix & bsr_cast(Matrix &mat) { return dynamic_cast<BSRMatrix &>(mat); }
const BSRMatrix & bsr_cast(const Matrix & mat) { return dynamic_cast<const BSRMatrix&>(mat); }

template <class T> struct matrix_value;
template <> struct matrix_value<ParCSRMatrix> { using type = double; };
template <> struct matrix_value<ParBSRMatrix> { using type = double*; };
template <class T>
using matrix_value_t = typename matrix_value<T>::type;

template <class T> struct sequential_matrix;
template<> struct sequential_matrix<ParBSRMatrix> { using type = BSRMatrix; };
template<> struct sequential_matrix<ParCSRMatrix> { using type = CSRMatrix; };
template <class T>
using sequential_matrix_t = typename sequential_matrix<T>::type;

namespace lair {
namespace {


/*
Helper type providing access to received rows based
on whether they are on_proc or off_proc
Expand Down

0 comments on commit 930dfee

Please sign in to comment.