diff --git a/examples/sparse_tensor.cu b/examples/sparse_tensor.cu index 2ac5cb14..241ee8e3 100644 --- a/examples/sparse_tensor.cu +++ b/examples/sparse_tensor.cu @@ -107,15 +107,36 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) // use sparse operations that are tailored for the sparse data // structure (such as scanning by row for CSR). // - tensor_t Dense{{m, n}}; + tensor_t A{{m, n}}; for (index_t i = 0; i < m; i++) { for (index_t j = 0; j < n; j++) { - Dense(i, j) = Acoo(i, j); + A(i, j) = Acoo(i, j); } } - print(Dense); + print(A); - // TODO: operations on Acoo + // + // SpMM is implemented on COO through cuSPARSE. This is the + // correct way of performing an efficient sparse operation. + // + tensor_t B{{8, 4}}; + tensor_t C{{4, 4}}; + B.SetVals({ { 0, 1, 2, 3 }, + { 4, 5, 6, 7 }, + { 8, 9, 10, 11 }, + { 12, 13, 14, 15 }, + { 16, 17, 18, 19 }, + { 20, 21, 22, 23 }, + { 24, 25, 26, 27 }, + { 28, 29, 30, 31 } }); + (C = matmul(Acoo, B)).run(exec); + print(C); + + // + // Verify by computing the equivelent dense GEMM. + // + (C = matmul(A, B)).run(exec); + print(C); MATX_EXIT_HANDLER(); } diff --git a/include/matx/core/sparse_tensor.h b/include/matx/core/sparse_tensor.h index e0a1e03d..f351fb97 100644 --- a/include/matx/core/sparse_tensor.h +++ b/include/matx/core/sparse_tensor.h @@ -57,9 +57,12 @@ class sparse_tensor_t VAL, TF::DIM, DimDesc, detail::SparseTensorData> { public: using sparse_tensor = bool; + using val_type = VAL; + using crd_type = CRD; + using pos_type = POS; + using Format = TF; static constexpr int DIM = TF::DIM; static constexpr int LVL = TF::LVL; - using Format = TF; // // Constructs a sparse tensor with given shape and contents. diff --git a/include/matx/operators/matmul.h b/include/matx/operators/matmul.h index add2a556..7d1b70f3 100644 --- a/include/matx/operators/matmul.h +++ b/include/matx/operators/matmul.h @@ -36,6 +36,7 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" #include "matx/transforms/matmul/matmul_cuda.h" +#include "matx/transforms/matmul/matmul_cusparse.h" #ifdef MATX_EN_CPU_MATMUL #include "matx/transforms/matmul/matmul_cblas.h" #endif @@ -113,7 +114,17 @@ namespace matx template void Exec(Out &&out, Executor &&ex) const { - if constexpr (!std::is_same_v) { + // Perform SpMM or otherwise GEMM. + static_assert(!is_sparse_tensor_v, "sparse rhs not implemented"); + if constexpr (is_sparse_tensor_v) { + if constexpr (!std::is_same_v) { + sparse_matmul_impl(permute(cuda::std::get<0>(out), perm_), a_, b_, ex, alpha_, beta_); + } + else { + sparse_matmul_impl(cuda::std::get<0>(out), a_, b_, ex, alpha_, beta_); + } + } + else if constexpr (!std::is_same_v) { matmul_impl(permute(cuda::std::get<0>(out), perm_), a_, b_, ex, alpha_, beta_); } else { diff --git a/include/matx/transforms/matmul/matmul_cusparse.h b/include/matx/transforms/matmul/matmul_cusparse.h new file mode 100644 index 00000000..e66c4f23 --- /dev/null +++ b/include/matx/transforms/matmul/matmul_cusparse.h @@ -0,0 +1,321 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2025, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include + +#include + +#include "matx/core/cache.h" +#include "matx/core/sparse_tensor.h" +#include "matx/core/tensor.h" + +namespace matx { + +namespace detail { + +// Translate MatXType for indices to cuSPARSE index type. +template +constexpr cusparseIndexType_t MatXTypeToCuSparseIndexType() { + if constexpr (std::is_same_v) { + return CUSPARSE_INDEX_16U; + } + if constexpr (std::is_same_v) { + return CUSPARSE_INDEX_32I; + } + if constexpr (std::is_same_v) { + return CUSPARSE_INDEX_64I; + } + if constexpr (std::is_same_v) { + return CUSPARSE_INDEX_64I; + } +} + +/** + * Parameters needed to execute a cuSPARSE GEMM. + */ +struct MatMulCUSPARSEParams_t { + MatXDataType_t dtype; + MatXDataType_t ptype; + MatXDataType_t ctype; + int rank; + cudaStream_t stream; + float alpha; + float beta; + index_t nse; + index_t m; + index_t n; + index_t k; + cusparseOperation_t opA; + cusparseOperation_t opB; + // Matrix handles in cuSPARSE are data specific (unlike e.g. cuBLAS + // where the same plan can be shared between different data buffers). + void *ptrA0; + void *ptrA1; + void *ptrA2; + void *ptrA3; + void *ptrA4; + void *ptrB; + void *ptrC; +}; + +template +class MatMulCUSPARSEHandle_t { +public: + using TA = typename TensorTypeA::value_type; + using TB = typename TensorTypeB::value_type; + using TC = typename TensorTypeC::value_type; + + static constexpr int RANKA = TensorTypeC::Rank(); + static constexpr int RANKB = TensorTypeC::Rank(); + static constexpr int RANKC = TensorTypeC::Rank(); + + /** + * Construct a sparse GEMM handle + * SpMV + * SpMM <- for now + * SpGEMM + * + */ + MatMulCUSPARSEHandle_t(TensorTypeC &c, const TensorTypeA &a, + const TensorTypeB &b, cudaStream_t stream, float alpha, + float beta) { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + static_assert(RANKA == 2); + static_assert(RANKB == 2); + static_assert(RANKC == 2); + + MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 2), matxInvalidSize); + MATX_ASSERT(c.Size(RANKC - 1) == b.Size(RANKB - 1), matxInvalidSize); + MATX_ASSERT(c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize); + + params_ = GetGemmParams(c, a, b, stream, alpha, beta); + + [[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_); + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError); + + // Create cuSPARSE handle for sparse matrix A. + static_assert(is_sparse_tensor_v); + cusparseIndexType_t pt = + MatXTypeToCuSparseIndexType(); + cusparseIndexType_t ct = + MatXTypeToCuSparseIndexType(); + cusparseIndexBase_t zb = CUSPARSE_INDEX_BASE_ZERO; + cudaDataType dta = MatXTypeToCudaType(); + if constexpr (TensorTypeA::Format::isCOO()) { + ret = cusparseCreateCoo(&matA_, params_.m, params_.k, params_.nse, + params_.ptrA3, params_.ptrA4, params_.ptrA0, ct, + zb, dta); + } else if constexpr (TensorTypeA::Format::isCSR()) { + ret = cusparseCreateCsr(&matA_, params_.m, params_.k, params_.nse, + params_.ptrA2, params_.ptrA4, params_.ptrA0, pt, + ct, zb, dta); + } else if constexpr (TensorTypeA::Format::isCSC()) { + ret = cusparseCreateCsc(&matA_, params_.m, params_.k, params_.nse, + params_.ptrA2, params_.ptrA4, params_.ptrA0, pt, + ct, zb, dta); + } else { + MATX_THROW(matxNotSupported, "SpMM currently only supports COO/CSR/CSC"); + } + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError); + + // Create cuSPARSE handle for dense matrices B and C. + static_assert(is_tensor_view_v); + static_assert(is_tensor_view_v); + cudaDataType dtb = MatXTypeToCudaType(); + cudaDataType dtc = MatXTypeToCudaType(); + const cusparseOrder_t order = CUSPARSE_ORDER_ROW; // TODO: support col B,C? + ret = cusparseCreateDnMat(&matB_, params_.k, params_.n, /*ld=*/params_.n, + params_.ptrB, dtb, order); + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError); + ret = cusparseCreateDnMat(&matC_, params_.m, params_.n, /*ld=*/params_.n, + params_.ptrC, dtc, order); + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError); + + // Allocate a workspace for SpMM. + const cusparseSpMMAlg_t algo = CUSPARSE_SPMM_ALG_DEFAULT; + const cudaDataType comptp = dtc; // TODO: support separate comp type?! + ret = cusparseSpMM_bufferSize(handle_, params_.opA, params_.opB, + ¶ms_.alpha, matA_, matB_, ¶ms_.beta, + matC_, comptp, algo, &workspaceSize_); + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError); + if (workspaceSize_) { + matxAlloc((void **)&workspace_, workspaceSize_, MATX_DEVICE_MEMORY); + } + } + + ~MatMulCUSPARSEHandle_t() { + if (workspaceSize_) { + matxFree(workspace_); + } + cusparseDestroy(handle_); + } + + static detail::MatMulCUSPARSEParams_t + GetGemmParams(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b, + cudaStream_t stream, float alpha, float beta) { + detail::MatMulCUSPARSEParams_t params; + params.dtype = TypeToInt(); + params.ptype = TypeToInt(); + params.ctype = TypeToInt(); + params.rank = c.Rank(); + params.stream = stream; + params.alpha = alpha; + params.beta = beta; + // TODO: simple no-batch, row-wise, no-transpose for now + params.nse = a.Nse(); + params.m = a.Size(TensorTypeA::Rank() - 2); + params.n = b.Size(TensorTypeB::Rank() - 1); + params.k = a.Size(TensorTypeB::Rank() - 1); + params.opA = CUSPARSE_OPERATION_NON_TRANSPOSE; + params.opB = CUSPARSE_OPERATION_NON_TRANSPOSE; + // Matrix handles in cuSPARSE are data specific. Therefore, the pointers + // to the underlying buffers are part of the GEMM parameters. + params.ptrA0 = a.Data(); + params.ptrA1 = a.POSData(0); + params.ptrA2 = a.POSData(1); + params.ptrA3 = a.CRDData(0); + params.ptrA4 = a.CRDData(1); + params.ptrB = b.Data(); + params.ptrC = c.Data(); + return params; + } + + __MATX_INLINE__ void Exec([[maybe_unused]] TensorTypeC &c, + [[maybe_unused]] const TensorTypeA &a, + [[maybe_unused]] const TensorTypeB &b) { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL); + const cusparseSpMMAlg_t algo = CUSPARSE_SPMM_ALG_DEFAULT; + const cudaDataType comptp = MatXTypeToCudaType(); // TODO: see above + [[maybe_unused]] cusparseStatus_t ret = + cusparseSpMM(handle_, params_.opA, params_.opB, ¶ms_.alpha, matA_, + matB_, ¶ms_.beta, matC_, comptp, algo, workspace_); + MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError); + } + +private: + cusparseHandle_t handle_ = nullptr; // TODO: share handle globally? + cusparseSpMatDescr_t matA_ = nullptr; + cusparseDnMatDescr_t matB_ = nullptr; + cusparseDnMatDescr_t matC_ = nullptr; + size_t workspaceSize_ = 0; + void *workspace_ = nullptr; + detail::MatMulCUSPARSEParams_t params_; +}; + +/** + * Crude hash on GEMM to get a reasonably good delta for collisions. This + * doesn't need to be perfect, but fast enough to not slow down lookups, and + * different enough so the common GEMM parameters change. + */ +struct MatMulCUSPARSEParamsKeyHash { + std::size_t operator()(const MatMulCUSPARSEParams_t &k) const noexcept { + return std::hash()(reinterpret_cast(k.ptrA0)) + + std::hash()(reinterpret_cast(k.ptrB)) + + std::hash()(reinterpret_cast(k.ptrC)) + + std::hash()(reinterpret_cast(k.stream)); + } +}; + +/** + * Test GEMM parameters for equality. Unlike the hash, all parameters must + * match exactly to ensure the hashed kernel can be reused for the computation. + */ +struct MatMulCUSPARSEParamsKeyEq { + bool operator()(const MatMulCUSPARSEParams_t &l, + const MatMulCUSPARSEParams_t &t) const noexcept { + return l.dtype == t.dtype && l.ptype == t.ptype && l.ctype == t.ctype && + l.rank == t.rank && l.stream == t.stream && l.alpha == t.alpha && + l.beta == t.beta && l.nse == t.nse && l.m == t.m && l.n == t.n && + l.k == t.k && l.opA == t.opA && l.opB == t.opB && + l.ptrA0 == t.ptrA0 && l.ptrA1 == t.ptrA1 && l.ptrA2 == t.ptrA2 && + l.ptrA3 == t.ptrA3 && l.ptrA4 == t.ptrA4 && l.ptrB == t.ptrB && + l.ptrC == t.ptrC; + } +}; + +using gemm_cusparse_cache_t = + std::unordered_map; + +} // end namespace detail + +template +__MATX_INLINE__ auto getCUSPARSESupportedTensor(const Op &in, + cudaStream_t stream) { + const auto support_func = [&in]() { + if constexpr (is_tensor_view_v) { + return in.Stride(Op::Rank() - 1) == 1; // TODO: more than row-wise + } else { + return true; + } + }; + return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream); +} + +template +void sparse_matmul_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B, + const cudaExecutor &exec, float alpha = 1.0, + float beta = 0.0) { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + const auto stream = exec.getStream(); + + auto a = A; // always sparse + auto b = getCUSPARSESupportedTensor(B, stream); + auto c = getCUSPARSESupportedTensor(C, stream); + + // TODO: some more checking, supported type? on device? etc. + + typedef decltype(c) ctype; + typedef decltype(a) atype; + typedef decltype(b) btype; + + // Get parameters required by these tensors (for caching). + auto params = + detail::MatMulCUSPARSEHandle_t::GetGemmParams( + c, a, b, stream, alpha, beta); + + // Lookup and cache. + using cache_val_type = detail::MatMulCUSPARSEHandle_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), params, + [&]() { + return std::make_shared(c, a, b, stream, alpha, beta); + }, + [&](std::shared_ptr cache_type) { + cache_type->Exec(c, a, b); + }); +} + +} // end namespace matx