Skip to content

Commit

Permalink
Merge pull request #295 from iburyl/in-place-triangular-matrix-vector
Browse files Browse the repository at this point in the history
In-place triangular matrix vector product + tests
  • Loading branch information
mhoemmen authored Oct 18, 2024
2 parents 1c85df1 + eb910eb commit 7988045
Show file tree
Hide file tree
Showing 3 changed files with 386 additions and 0 deletions.
134 changes: 134 additions & 0 deletions include/experimental/__p1673_bits/blas2_matrix_vector_product.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,29 @@ struct is_custom_tri_mat_vec_product_with_update_avail<
>
: std::true_type{};

// In-place triangular matrix-vector product with update
template <class Exec, class A_t, class Tri_t, class D_t, class Y_t, class = void>
struct is_custom_tri_mat_vec_product_with_inplace_avail : std::false_type {};

template <class Exec, class A_t, class Tri_t, class D_t, class Y_t>
struct is_custom_tri_mat_vec_product_with_inplace_avail<
Exec, A_t, Tri_t, D_t, Y_t,
std::enable_if_t<
std::is_void_v<
decltype(triangular_matrix_vector_product
(std::declval<Exec>(),
std::declval<A_t>(),
std::declval<Tri_t>(),
std::declval<D_t>(),
std::declval<Y_t>()
)
)
>
&& !linalg::impl::is_inline_exec_v<Exec>
>
>
: std::true_type{};

} // end anonymous namespace

namespace impl {
Expand Down Expand Up @@ -1278,6 +1301,117 @@ void triangular_matrix_vector_product(
triangular_matrix_vector_product(impl::default_exec_t{}, A, t, d, x, y, z);
}

// In-place triangular matrix-vector product: y := A * y

MDSPAN_TEMPLATE_REQUIRES(
class ElementType_A,
class SizeType_A,
::std::size_t numRows_A,
::std::size_t numCols_A,
class Layout_A,
class Accessor_A,
class Triangle,
class DiagonalStorage,
class ElementType_y,
class SizeType_y,
::std::size_t ext_y,
class Layout_y,
class Accessor_y,
/* requires */ (impl::always_unique_mapping_v<Layout_A, extents<SizeType_A, numRows_A, numCols_A>>)
)
void triangular_matrix_vector_product(
linalg::impl::inline_exec_t&& /* exec */,
mdspan<ElementType_A, extents<SizeType_A, numRows_A, numCols_A> , Layout_A, Accessor_A> A,
Triangle t,
DiagonalStorage d,
mdspan<ElementType_y, extents<SizeType_y, ext_y>, Layout_y, Accessor_y> y)
{
using size_type = std::common_type_t<SizeType_A, SizeType_y>;
constexpr bool explicitDiagonal =
std::is_same_v<DiagonalStorage, explicit_diagonal_t>;

if constexpr (std::is_same_v<Triangle, lower_triangle_t>) {
for (size_type k = 0; k < A.extent(1); ++k) {
size_type j = A.extent(1) - (k + size_type(1));
ElementType_y tmp = y(j);
for (size_type i = j + size_type(1); i < A.extent(0); ++i) {
y(i) += A(i,j) * tmp;
}
if constexpr (explicitDiagonal) {
y(j) = y(j) * A(j,j);
}
}
}
else {
for (size_type j = 0; j < A.extent(1); ++j) {
ElementType_y tmp = y(j);
for (size_type i = 0; i < j; ++i) {
y(i) += A(i,j) * tmp;
}
if constexpr (explicitDiagonal) {
y(j) = y(j) * A(j,j);
}
}
}
}

template<class ExecutionPolicy,
class ElementType_A,
class SizeType_A,
::std::size_t numRows_A,
::std::size_t numCols_A,
class Layout_A,
class Accessor_A,
class Triangle,
class DiagonalStorage,
class ElementType_y,
class SizeType_y, ::std::size_t ext_y,
class Layout_y,
class Accessor_y>
void triangular_matrix_vector_product(
ExecutionPolicy&& exec,
mdspan<ElementType_A, extents<SizeType_A, numRows_A, numCols_A>, Layout_A, Accessor_A> A,
Triangle t,
DiagonalStorage d,
mdspan<ElementType_y, extents<SizeType_y, ext_y>, Layout_y, Accessor_y> y)
{
constexpr bool use_custom = is_custom_tri_mat_vec_product_with_inplace_avail<
decltype(execpolicy_mapper(exec)),
decltype(A), decltype(t), decltype(d), decltype(y)
>::value;

if constexpr(use_custom) {
triangular_matrix_vector_product(execpolicy_mapper(exec), A, t, d, y);
} else {
triangular_matrix_vector_product(linalg::impl::inline_exec_t(), A, t, d, y);
}
}

MDSPAN_TEMPLATE_REQUIRES(
class ElementType_A,
class SizeType_A,
::std::size_t numRows_A,
::std::size_t numCols_A,
class Layout_A,
class Accessor_A,
class Triangle,
class DiagonalStorage,
class ElementType_y,
class SizeType_y,
::std::size_t ext_y,
class Layout_y,
class Accessor_y,
/* requires */ (impl::always_unique_mapping_v<Layout_A, extents<SizeType_A, numRows_A, numCols_A>>)
)
void triangular_matrix_vector_product(
mdspan<ElementType_A, extents<SizeType_A, numRows_A, numCols_A>, Layout_A, Accessor_A> A,
Triangle t,
DiagonalStorage d,
mdspan<ElementType_y, extents<SizeType_y, ext_y>, Layout_y, Accessor_y> y)
{
triangular_matrix_vector_product(linalg::impl::default_exec_t(), A, t, d, y);
}

} // end namespace linalg
} // end inline namespace __p1673_version_0
} // end namespace MDSPAN_IMPL_PROPOSED_NAMESPACE
Expand Down
1 change: 1 addition & 0 deletions tests/native/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ linalg_add_test(syr)
linalg_add_test(syrk)
linalg_add_test(transposed)
linalg_add_test(trmm)
linalg_add_test(trmv)
linalg_add_test(trsm)
Loading

0 comments on commit 7988045

Please sign in to comment.