Skip to content

Commit

Permalink
First version of MATX Sparse-Direct-Solve (using dispatch to cuDSS) (#…
Browse files Browse the repository at this point in the history
…849)

* First version of MATX Sparse-Direct-Solve (using dispatch to cuDSS)
  • Loading branch information
aartbik authored Feb 4, 2025
1 parent b3ca482 commit 55dd664
Show file tree
Hide file tree
Showing 5 changed files with 513 additions and 36 deletions.
74 changes: 38 additions & 36 deletions examples/sparse_tensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

#include "matx.h"

// Note that sparse tensor support in MatX is still experimental.

using namespace matx;

int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
Expand All @@ -42,7 +44,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
cudaExecutor exec{stream};

//
// Print some formats that are used for the versatile sparse tensor
// Print some formats that are used for the universal sparse tensor
// type. Note that common formats like COO and CSR have good library
// support in e.g. cuSPARSE, but MatX provides a much more general
// way to define the sparse tensor storage through a DSL (see doc).
Expand All @@ -68,25 +70,6 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
// | 0, 0, 0, 0, 0, 0, 0, 0 |
// | 0, 0, 3, 4, 0, 5, 0, 0 |
//

constexpr index_t m = 4;
constexpr index_t n = 8;
constexpr index_t nse = 5;

tensor_t<float, 1> values{{nse}};
tensor_t<int, 1> row_idx{{nse}};
tensor_t<int, 1> col_idx{{nse}};

values.SetVals({ 1, 2, 3, 4, 5 });
row_idx.SetVals({ 0, 0, 3, 3, 3 });
col_idx.SetVals({ 0, 1, 2, 3, 5 });

// Note that sparse tensor support in MatX is still experimental.
auto Acoo = experimental::make_tensor_coo(values, row_idx, col_idx, {m, n});

//
// This shows:
//
// tensor_impl_2_f32: SparseTensor{float} Rank: 2, Sizes:[4, 8], Levels:[4, 8]
// nse = 5
// format = ( d0, d1 ) -> ( d0 : compressed(non-unique), d1 : singleton )
Expand All @@ -95,6 +78,13 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
// values = ( 1.0000e+00 2.0000e+00 3.0000e+00 4.0000e+00 5.0000e+00 )
// space = CUDA managed memory
//
auto vals = make_tensor<float>({5});
auto idxi = make_tensor<int>({5});
auto idxj = make_tensor<int>({5});
vals.SetVals({1, 2, 3, 4, 5});
idxi.SetVals({0, 0, 3, 3, 3});
idxj.SetVals({0, 1, 2, 3, 5});
auto Acoo = experimental::make_tensor_coo(vals, idxi, idxj, {4, 8});
print(Acoo);

//
Expand All @@ -107,9 +97,9 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
// use sparse operations that are tailored for the sparse data
// structure (such as scanning by row for CSR).
//
tensor_t<float, 2> A{{m, n}};
for (index_t i = 0; i < m; i++) {
for (index_t j = 0; j < n; j++) {
auto A = make_tensor<float>({4, 8});
for (index_t i = 0; i < 4; i++) {
for (index_t j = 0; j < 8; j++) {
A(i, j) = Acoo(i, j);
}
}
Expand All @@ -119,24 +109,36 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
// SpMM is implemented on COO through cuSPARSE. This is the
// correct way of performing an efficient sparse operation.
//
tensor_t<float, 2> B{{8, 4}};
tensor_t<float, 2> C{{4, 4}};
B.SetVals({ { 0, 1, 2, 3 },
{ 4, 5, 6, 7 },
{ 8, 9, 10, 11 },
{ 12, 13, 14, 15 },
{ 16, 17, 18, 19 },
{ 20, 21, 22, 23 },
{ 24, 25, 26, 27 },
{ 28, 29, 30, 31 } });
auto B = make_tensor<float, 2>({8, 4});
auto C = make_tensor<float>({4, 4});
B.SetVals({
{ 0, 1, 2, 3}, { 4, 5, 6, 7}, { 8, 9, 10, 11}, {12, 13, 14, 15},
{16, 17, 18, 19}, {20, 21, 22, 23}, {24, 25, 26, 27}, {28, 29, 30, 31} });
(C = matmul(Acoo, B)).run(exec);
print(C);

//
// Verify by computing the equivelent dense GEMM.
// Creates a CSR matrix which is used to solve the following
// system of equations AX=Y, where X is the unknown.
//
(C = matmul(A, B)).run(exec);
print(C);
// | 1 2 0 0 | | 1 5 | | 5 17 |
// | 0 3 0 0 | x | 2 6 | = | 6 18 |
// | 0 0 4 0 | | 3 7 | | 12 28 |
// | 0 0 0 5 | | 4 8 | | 20 40 |
//
auto coeffs = make_tensor<float>({5});
auto rowptr = make_tensor<int>({5});
auto colidx = make_tensor<int>({5});
coeffs.SetVals({1, 2, 3, 4, 5});
rowptr.SetVals({0, 2, 3, 4, 5});
colidx.SetVals({0, 1, 1, 2, 3});
auto Acsr = experimental::make_tensor_csr(coeffs, rowptr, colidx, {4, 4});
print(Acsr);
auto X = make_tensor<float>({4, 2});
auto Y = make_tensor<float>({4, 2});
Y.SetVals({ {5, 17}, {6, 18}, {12, 28}, {20, 40} });
(X = solve(Acsr, Y)).run(exec);
print(X);

MATX_EXIT_HANDLER();
}
3 changes: 3 additions & 0 deletions include/matx/core/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,9 @@ template <typename T> constexpr cudaDataType_t MatXTypeToCudaType()
if constexpr (std::is_same_v<T, int8_t>) {
return CUDA_R_8I;
}
if constexpr (std::is_same_v<T, int>) {
return CUDA_R_32I;
}
if constexpr (std::is_same_v<T, float>) {
return CUDA_R_32F;
}
Expand Down
1 change: 1 addition & 0 deletions include/matx/operators/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
#include "matx/operators/shift.h"
#include "matx/operators/sign.h"
#include "matx/operators/slice.h"
#include "matx/operators/solve.h"
#include "matx/operators/sort.h"
#include "matx/operators/sph2cart.h"
#include "matx/operators/stack.h"
Expand Down
161 changes: 161 additions & 0 deletions include/matx/operators/solve.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// Copyright (c) 2025, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////

#pragma once

#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
#ifdef MATX_EN_CUDSS
#include "matx/transforms/solve/solve_cudss.h"
#endif

namespace matx {
namespace detail {

template <typename OpA, typename OpB>
class SolveOp : public BaseOp<SolveOp<OpA, OpB>> {
private:
typename detail::base_type_t<OpA> a_;
typename detail::base_type_t<OpB> b_;

static constexpr int out_rank = OpB::Rank();
cuda::std::array<index_t, out_rank> out_dims_;
mutable detail::tensor_impl_t<typename OpA::value_type, out_rank> tmp_out_;
mutable typename OpA::value_type *ptr = nullptr;

public:
using matxop = bool;
using matx_transform_op = bool;
using solve_xform_op = bool;
using value_type = typename OpA::value_type;

__MATX_INLINE__ SolveOp(const OpA &a, const OpB &b) : a_(a), b_(b) {
for (int r = 0, rank = Rank(); r < rank; r++) {
out_dims_[r] = b_.Size(r);
}
}

__MATX_INLINE__ std::string str() const {
return "solve(" + get_type_str(a_) + "," + get_type_str(b_) + ")";
}

__MATX_HOST__ __MATX_INLINE__ auto Data() const noexcept { return ptr; }

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto)
operator()(Is... indices) const {
return tmp_out_(indices...);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t
Rank() {
return remove_cvref_t<OpB>::Rank();
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t
Size(int dim) const {
return out_dims_[dim];
}

template <typename Out, typename Executor>
void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const {
static_assert(!is_sparse_tensor_v<OpB>, "sparse rhs not implemented");
if constexpr (is_sparse_tensor_v<OpA>) {
#ifdef MATX_EN_CUDSS
sparse_solve_impl(cuda::std::get<0>(out), a_, b_, ex);
#else
MATX_THROW(matxNotSupported, "Sparse direct solver requires cuDSS");
#endif
} else {
MATX_THROW(matxNotSupported,
"Direct solver currently only supports sparse system");
}
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape,
[[maybe_unused]] Executor &&ex) const noexcept {
static_assert(is_sparse_tensor_v<OpA>,
"Direct solver currently only supports sparse system");
if constexpr (is_matx_op<OpB>()) {
b_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape,
[[maybe_unused]] Executor &&ex) const noexcept {
InnerPreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
detail::AllocateTempTensor(tmp_out_, std::forward<Executor>(ex), out_dims_,
&ptr);
Exec(cuda::std::make_tuple(tmp_out_), std::forward<Executor>(ex));
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PostRun([[maybe_unused]] ShapeType &&shape,
[[maybe_unused]]Executor &&ex) const noexcept {
static_assert(is_sparse_tensor_v<OpA>,
"Direct solver currently only supports sparse system");
if constexpr (is_matx_op<OpB>()) {
b_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
matxFree(ptr);
}
};

} // end namespace detail

/**
* Run a direct SOLVE (viz. X = solve(A, B) solves system AX=B for unknown X).
*
* Note that currently, this operation is only implemented for solving
* a linear system with a very **sparse** matrix A.
*
* @tparam OpA
* Data type of A tensor (sparse)
* @tparam OpB
* Data type of B tensor
*
* @param A
* A Sparse tensor with system coefficients
* @param B
* B Dense tensor of known values
*
* @return
* Operator that produces the output tensor X with the solution
*/
template <typename OpA, typename OpB>
__MATX_INLINE__ auto solve(const OpA &A, const OpB &B) {
return detail::SolveOp(A, B);
}

} // end namespace matx
Loading

0 comments on commit 55dd664

Please sign in to comment.