Skip to content

Commit

Permalink
First version of MATX SpMM (using dispatch to cuSPARSE) (#843)
Browse files Browse the repository at this point in the history
* First version of MATX SpMM (using dispatch to cuSPARSE)
  • Loading branch information
aartbik authored Jan 29, 2025
1 parent cda2122 commit 5446cbc
Show file tree
Hide file tree
Showing 4 changed files with 362 additions and 6 deletions.
29 changes: 25 additions & 4 deletions examples/sparse_tensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,36 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
// use sparse operations that are tailored for the sparse data
// structure (such as scanning by row for CSR).
//
tensor_t<float, 2> Dense{{m, n}};
tensor_t<float, 2> A{{m, n}};
for (index_t i = 0; i < m; i++) {
for (index_t j = 0; j < n; j++) {
Dense(i, j) = Acoo(i, j);
A(i, j) = Acoo(i, j);
}
}
print(Dense);
print(A);

// TODO: operations on Acoo
//
// SpMM is implemented on COO through cuSPARSE. This is the
// correct way of performing an efficient sparse operation.
//
tensor_t<float, 2> B{{8, 4}};
tensor_t<float, 2> C{{4, 4}};
B.SetVals({ { 0, 1, 2, 3 },
{ 4, 5, 6, 7 },
{ 8, 9, 10, 11 },
{ 12, 13, 14, 15 },
{ 16, 17, 18, 19 },
{ 20, 21, 22, 23 },
{ 24, 25, 26, 27 },
{ 28, 29, 30, 31 } });
(C = matmul(Acoo, B)).run(exec);
print(C);

//
// Verify by computing the equivelent dense GEMM.
//
(C = matmul(A, B)).run(exec);
print(C);

MATX_EXIT_HANDLER();
}
5 changes: 4 additions & 1 deletion include/matx/core/sparse_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ class sparse_tensor_t
VAL, TF::DIM, DimDesc, detail::SparseTensorData<VAL, CRD, POS, TF>> {
public:
using sparse_tensor = bool;
using val_type = VAL;
using crd_type = CRD;
using pos_type = POS;
using Format = TF;
static constexpr int DIM = TF::DIM;
static constexpr int LVL = TF::LVL;
using Format = TF;

//
// Constructs a sparse tensor with given shape and contents.
Expand Down
13 changes: 12 additions & 1 deletion include/matx/operators/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
#include "matx/transforms/matmul/matmul_cuda.h"
#include "matx/transforms/matmul/matmul_cusparse.h"
#ifdef MATX_EN_CPU_MATMUL
#include "matx/transforms/matmul/matmul_cblas.h"
#endif
Expand Down Expand Up @@ -113,7 +114,17 @@ namespace matx

template <typename Out, typename Executor>
void Exec(Out &&out, Executor &&ex) const {
if constexpr (!std::is_same_v<PermDims, no_permute_t>) {
// Perform SpMM or otherwise GEMM.
static_assert(!is_sparse_tensor_v<OpB>, "sparse rhs not implemented");
if constexpr (is_sparse_tensor_v<OpA>) {
if constexpr (!std::is_same_v<PermDims, no_permute_t>) {
sparse_matmul_impl(permute(cuda::std::get<0>(out), perm_), a_, b_, ex, alpha_, beta_);
}
else {
sparse_matmul_impl(cuda::std::get<0>(out), a_, b_, ex, alpha_, beta_);
}
}
else if constexpr (!std::is_same_v<PermDims, no_permute_t>) {
matmul_impl(permute(cuda::std::get<0>(out), perm_), a_, b_, ex, alpha_, beta_);
}
else {
Expand Down
Loading

0 comments on commit 5446cbc

Please sign in to comment.