diff --git a/CMakeLists.txt b/CMakeLists.txt index 41a9303a3..f44678af6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -205,11 +205,12 @@ if (MATX_EN_NVPL OR MATX_EN_X86_FFTW OR MATX_EN_BLIS OR MATX_EN_OPENBLAS) if (MATX_EN_NVPL) message(STATUS "Enabling NVPL library support for ARM CPUs with ${INT_TYPE} interface") - find_package(nvpl REQUIRED COMPONENTS fft blas HINTS ${blas_DIR}) + find_package(nvpl REQUIRED COMPONENTS fft blas lapack HINTS ${blas_DIR}) if (NOT MATX_BUILD_32_BIT) target_compile_definitions(matx INTERFACE NVPL_ILP64) endif() - target_link_libraries(matx INTERFACE nvpl::fftw nvpl::blas_${INT_TYPE}_omp) + target_compile_definitions(matx INTERFACE NVPL_LAPACK_COMPLEX_STRUCTURE) + target_link_libraries(matx INTERFACE nvpl::fftw nvpl::blas_${INT_TYPE}_omp nvpl::lapack_${INT_TYPE}_omp) target_compile_definitions(matx INTERFACE MATX_EN_NVPL) else() # FFTW diff --git a/docs_input/api/linalg/decomp/qr.rst b/docs_input/api/linalg/decomp/qr.rst index dbb1d54b2..00477d8b4 100644 --- a/docs_input/api/linalg/decomp/qr.rst +++ b/docs_input/api/linalg/decomp/qr.rst @@ -16,13 +16,13 @@ Examples :end-before: example-end qr-test-1 :dedent: -.. doxygenfunction:: cusolver_qr +.. doxygenfunction:: qr_solver Examples ~~~~~~~~ .. literalinclude:: ../../../../test/00_solver/QR.cu :language: cpp - :start-after: example-begin cusolver_qr-test-1 - :end-before: example-end cusolver_qr-test-1 + :start-after: example-begin qr_solver-test-1 + :end-before: example-end qr_solver-test-1 :dedent: \ No newline at end of file diff --git a/docs_input/api/linalg/decomp/svd.rst b/docs_input/api/linalg/decomp/svd.rst index 68d6a1e18..9f5edd4ed 100644 --- a/docs_input/api/linalg/decomp/svd.rst +++ b/docs_input/api/linalg/decomp/svd.rst @@ -3,7 +3,7 @@ svd ### -Perform a singular value decomposition (SVD) using the power iteration method. +Perform a singular value decomposition (SVD). .. doxygenfunction:: svd diff --git a/docs_input/api/logic/truth/allclose.rst b/docs_input/api/logic/truth/allclose.rst index bca253e8f..a4414cf08 100644 --- a/docs_input/api/logic/truth/allclose.rst +++ b/docs_input/api/logic/truth/allclose.rst @@ -7,7 +7,7 @@ Reduce the closeness of two operators to a single scalar (0D) output. The output from allclose is an ``int`` value since boolean reductions are not available in hardware -.. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, HostExecutor &exec) +.. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, const HostExecutor &exec) .. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, cudaExecutor exec = 0) Examples diff --git a/docs_input/api/manipulation/rearranging/transpose_matrix.rst b/docs_input/api/manipulation/rearranging/transpose_matrix.rst index f9866024c..a637954b5 100644 --- a/docs_input/api/manipulation/rearranging/transpose_matrix.rst +++ b/docs_input/api/manipulation/rearranging/transpose_matrix.rst @@ -10,7 +10,7 @@ Transpose the last two dimensions of an operator Examples ~~~~~~~~ -.. literalinclude:: ../../../../include/matx/transforms/svd.h +.. literalinclude:: ../../../../include/matx/transforms/svd/svd_cuda.h :language: cpp :start-after: example-begin transpose_matrix-test-1 :end-before: example-end transpose_matrix-test-1 diff --git a/include/matx/core/tensor_utils.h b/include/matx/core/tensor_utils.h index e5177df39..a30d0c3f9 100644 --- a/include/matx/core/tensor_utils.h +++ b/include/matx/core/tensor_utils.h @@ -39,7 +39,7 @@ #include "matx/core/dlpack.h" #include "matx/core/make_tensor.h" #include "matx/kernels/utility.cuh" - +#include "matx/transforms/copy.h" namespace matx { @@ -1107,15 +1107,41 @@ void print(const Op &op) #endif // not DOXYGEN_ONLY -template -auto OpToTensor(Op &&op, [[maybe_unused]] cudaStream_t stream) { +template +auto OpToTensor(Op &&op, [[maybe_unused]] const Executor &exec) { if constexpr (!is_tensor_view_v) { - return make_tensor::value_type>(op.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream); + if constexpr (is_cuda_executor_v) { + return make_tensor::value_type>(op.Shape(), MATX_ASYNC_DEVICE_MEMORY, exec.getStream()); + } else { + return make_tensor::value_type>(op.Shape(), MATX_HOST_MALLOC_MEMORY); + } } else { return op; } } +/** + * Get a transposed view of a tensor into a user-supplied buffer + * + * @param tp + * Pointer to pre-allocated memory + * @param a + * Tensor to transpose + * @param exec + * Executor + */ +template +__MATX_INLINE__ auto +TransposeCopy(typename TensorType::value_type *tp, const TensorType &a, const Executor &exec) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + auto pa = a.PermuteMatrix(); + auto tv = make_tensor(tp, pa.Shape()); + matx::copy(tv, pa, exec); + return tv; +} + /** * @brief Set the print() precision for floating point values * diff --git a/include/matx/executors/support.h b/include/matx/executors/support.h index 1634af893..a70258d06 100644 --- a/include/matx/executors/support.h +++ b/include/matx/executors/support.h @@ -41,18 +41,25 @@ namespace matx { // FFT #if defined(MATX_EN_NVPL) || defined(MATX_EN_X86_FFTW) - #define MATX_EN_CPU_FFT 1 + #define MATX_EN_CPU_FFT 1 #else - #define MATX_EN_CPU_FFT 0 + #define MATX_EN_CPU_FFT 0 #endif // MatMul #if defined(MATX_EN_NVPL) || defined(MATX_EN_OPENBLAS) || defined(MATX_EN_BLIS) - #define MATX_EN_CPU_MATMUL 1 + #define MATX_EN_CPU_MATMUL 1 #else - #define MATX_EN_CPU_MATMUL 0 + #define MATX_EN_CPU_MATMUL 0 #endif +// Solver +#if defined(MATX_EN_NVPL) + #define MATX_EN_CPU_SOLVER 1 +#else + #define MATX_EN_CPU_SOLVER 0 +#endif + template constexpr bool CheckFFTSupport() { if constexpr (is_host_executor_v) { @@ -114,5 +121,14 @@ constexpr bool CheckMatMulSupport() { } } +template +constexpr bool CheckSolverSupport() { + if constexpr (is_host_executor_v) { + return MATX_EN_CPU_SOLVER; + } else { + return true; + } +} + }; // detail }; // matx \ No newline at end of file diff --git a/include/matx/operators/chol.h b/include/matx/operators/chol.h index 331a03cec..e2645dddf 100644 --- a/include/matx/operators/chol.h +++ b/include/matx/operators/chol.h @@ -35,7 +35,10 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" -#include "matx/transforms/solver.h" +#include "matx/transforms/chol/chol_cuda.h" +#ifdef MATX_EN_CPU_SOLVER + #include "matx/transforms/chol/chol_lapack.h" +#endif namespace matx { namespace detail { @@ -44,7 +47,7 @@ namespace detail { { private: OpA a_; - cublasFillMode_t uplo_; + SolverFillMode uplo_; mutable matx::tensor_t tmp_out_; public: @@ -54,7 +57,7 @@ namespace detail { using chol_xform_op = bool; __MATX_INLINE__ std::string str() const { return "chol()"; } - __MATX_INLINE__ CholOp(OpA a, cublasFillMode_t uplo) : a_(a), uplo_(uplo) { }; + __MATX_INLINE__ CholOp(OpA a, SolverFillMode uplo) : a_(a), uplo_(uplo) { }; // This should never be called template @@ -62,9 +65,7 @@ namespace detail { template void Exec(Out &&out, Executor &&ex) const { - static_assert(is_cuda_executor_v, "chol() only supports the CUDA executor currently"); - - chol_impl(cuda::std::get<0>(out), a_, ex.getStream(), uplo_); + chol_impl(cuda::std::get<0>(out), a_, ex, uplo_); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() @@ -114,7 +115,7 @@ namespace detail { } template -__MATX_INLINE__ auto chol(const OpA &a, cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) { +__MATX_INLINE__ auto chol(const OpA &a, SolverFillMode uplo = SolverFillMode::UPPER) { return detail::CholOp(a, uplo); } diff --git a/include/matx/operators/det.h b/include/matx/operators/det.h index 78a466858..f5b96a8f6 100644 --- a/include/matx/operators/det.h +++ b/include/matx/operators/det.h @@ -35,7 +35,7 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" -#include "matx/transforms/solver.h" +#include "matx/transforms/det.h" namespace matx { namespace detail { @@ -62,9 +62,7 @@ namespace detail { template void Exec(Out &&out, Executor &&ex) const{ - static_assert(is_cuda_executor_v, "det() only supports the CUDA executor currently"); - - det_impl(cuda::std::get<0>(out), a_, ex.getStream()); + det_impl(cuda::std::get<0>(out), a_, ex); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/eig.h b/include/matx/operators/eig.h index 676275aa7..8a100d08d 100644 --- a/include/matx/operators/eig.h +++ b/include/matx/operators/eig.h @@ -35,7 +35,10 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" -#include "matx/transforms/solver.h" +#include "matx/transforms/eig/eig_cuda.h" +#ifdef MATX_EN_CPU_SOLVER + #include "matx/transforms/eig/eig_lapack.h" +#endif namespace matx { @@ -47,8 +50,8 @@ namespace detail { { private: OpA a_; - cusolverEigMode_t jobz_; - cublasFillMode_t uplo_; + EigenMode jobz_; + SolverFillMode uplo_; public: using matxop = bool; @@ -57,7 +60,7 @@ namespace detail { using eig_xform_op = bool; __MATX_INLINE__ std::string str() const { return "eig()"; } - __MATX_INLINE__ EigOp(OpA a, cusolverEigMode_t jobz, cublasFillMode_t uplo) : a_(a), jobz_(jobz), uplo_(uplo) { }; + __MATX_INLINE__ EigOp(OpA a, EigenMode jobz, SolverFillMode uplo) : a_(a), jobz_(jobz), uplo_(uplo) { }; // This should never be called template @@ -65,10 +68,9 @@ namespace detail { template void Exec(Out &&out, Executor &&ex) const { - static_assert(is_cuda_executor_v, "eig () only supports the CUDA executor currently"); static_assert(cuda::std::tuple_size_v> == 3, "Must use mtie with 2 outputs on eig(). ie: (mtie(O, w) = eig(A))"); - eig_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex.getStream(), jobz_, uplo_); + eig_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex, jobz_, uplo_); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() @@ -94,8 +96,8 @@ namespace detail { template __MATX_INLINE__ auto eig(const OpA &a, - cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR, - cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) { + EigenMode jobz = EigenMode::VECTOR, + SolverFillMode uplo = SolverFillMode::UPPER) { return detail::EigOp(a, jobz, uplo); } diff --git a/include/matx/operators/frexp.h b/include/matx/operators/frexp.h index 538a1fdab..cfdb0604e 100644 --- a/include/matx/operators/frexp.h +++ b/include/matx/operators/frexp.h @@ -35,7 +35,6 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" -#include "matx/transforms/solver.h" namespace matx { diff --git a/include/matx/operators/lu.h b/include/matx/operators/lu.h index 37978aef2..a1b224ce7 100644 --- a/include/matx/operators/lu.h +++ b/include/matx/operators/lu.h @@ -35,7 +35,11 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" -#include "matx/transforms/solver.h" +#include "matx/transforms/lu/lu_cuda.h" +#ifdef MATX_EN_CPU_SOLVER + #include "matx/transforms/lu/lu_lapack.h" +#endif + namespace matx { namespace detail { @@ -62,10 +66,9 @@ namespace detail { template void Exec(Out &&out, Executor &&ex) const { - static_assert(is_cuda_executor_v, "lu() only supports the CUDA executor currently"); - static_assert(cuda::std::tuple_size_v> == 3, "Must use mtie with 2 outputs on cusolver_qr(). ie: (mtie(O, piv) = lu(A))"); + static_assert(cuda::std::tuple_size_v> == 3, "Must use mtie with 2 outputs on lu(). ie: (mtie(O, piv) = lu(A))"); - lu_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex.getStream()); + lu_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/prod.h b/include/matx/operators/prod.h index fee3f9bdd..ecbafc988 100644 --- a/include/matx/operators/prod.h +++ b/include/matx/operators/prod.h @@ -133,7 +133,7 @@ namespace detail { template __MATX_INLINE__ auto prod(const InType &in, const int (&dims)[D]) { - static_assert(D < InType::Rank(), "reduction dimensions must be <= Rank of input"); + static_assert(D <= InType::Rank(), "reduction dimensions must be <= Rank of input"); auto perm = detail::getPermuteDims(dims); auto permop = permute(in, perm); diff --git a/include/matx/operators/qr.h b/include/matx/operators/qr.h index 3c9dffc05..9b0de5098 100644 --- a/include/matx/operators/qr.h +++ b/include/matx/operators/qr.h @@ -35,7 +35,10 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" -#include "matx/transforms/qr.h" +#include "matx/transforms/qr/qr_cuda.h" +#ifdef MATX_EN_CPU_SOLVER + #include "matx/transforms/qr/qr_lapack.h" +#endif namespace matx { @@ -61,7 +64,7 @@ namespace detail { template void Exec(Out &&out, Executor &&ex) const { - static_assert(is_cuda_executor_v, "svd() only supports the CUDA executor currently"); + static_assert(is_cuda_executor_v, "qr() only supports the CUDA executor currently"); static_assert(cuda::std::tuple_size_v> == 3, "Must use mtie with 3 outputs on qr(). ie: (mtie(Q, R) = qr(A))"); qr_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex); @@ -107,7 +110,7 @@ __MATX_INLINE__ auto qr(AType A) { namespace detail { template - class CuSolverQROp : public BaseOp> + class SolverQROp : public BaseOp> { private: OpA a_; @@ -117,10 +120,10 @@ namespace detail { using matxop = bool; using value_type = typename OpA::value_type; using matx_transform_op = bool; - using cusolver_qr_xform_op = bool; + using qr_solver_xform_op = bool; - __MATX_INLINE__ std::string str() const { return "cusolver_qr()"; } - __MATX_INLINE__ CuSolverQROp(OpA a) : a_(a) { }; + __MATX_INLINE__ std::string str() const { return "qr_solver()"; } + __MATX_INLINE__ SolverQROp(OpA a) : a_(a) { }; // This should never be called template @@ -128,10 +131,9 @@ namespace detail { template void Exec(Out &&out, Executor &&ex) { - static_assert(is_cuda_executor_v, "cusolver_qr() only supports the CUDA executor currently"); - static_assert(cuda::std::tuple_size_v> == 3, "Must use mtie with 2 outputs on cusolver_qr(). ie: (mtie(A, tau) = eig(A))"); + static_assert(cuda::std::tuple_size_v> == 3, "Must use mtie with 2 outputs on qr_solver(). ie: (mtie(A, tau) = eig(A))"); - cusolver_qr_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex.getStream()); + qr_solver_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() @@ -142,10 +144,10 @@ namespace detail { template __MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) noexcept { - MATX_ASSERT_STR(false, matxNotSupported, "cusolver_qr() must only be called with a single assignment since it has multiple return types"); + MATX_ASSERT_STR(false, matxNotSupported, "qr_solver() must only be called with a single assignment since it has multiple return types"); } - // Size is not relevant in cusolver_qr() since there are multiple return values and it + // Size is not relevant in qr_solver() since there are multiple return values and it // is not allowed to be called in larger expressions constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const { @@ -155,9 +157,13 @@ namespace detail { }; } +/** + * Perform a QR decomposition on a matrix using cuSolver or a LAPACK host library. + * + */ template -__MATX_INLINE__ auto cusolver_qr(const OpA &a) { - return detail::CuSolverQROp(a); +__MATX_INLINE__ auto qr_solver(const OpA &a) { + return detail::SolverQROp(a); } } diff --git a/include/matx/operators/svd.h b/include/matx/operators/svd.h index 33c36202c..7b5df3c0f 100644 --- a/include/matx/operators/svd.h +++ b/include/matx/operators/svd.h @@ -35,8 +35,10 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" -#include "matx/transforms/solver.h" -#include "matx/transforms/svd.h" +#include "matx/transforms/svd/svd_cuda.h" +#ifdef MATX_EN_CPU_SOLVER + #include "matx/transforms/svd/svd_lapack.h" +#endif namespace matx { @@ -47,8 +49,8 @@ namespace detail { { private: OpA a_; - char jobu_; - char jobv_; + SVDJob jobu_; + SVDJob jobvt_; public: using matxop = bool; @@ -57,7 +59,7 @@ namespace detail { using svd_xform_op = bool; __MATX_INLINE__ std::string str() const { return "svd(" + get_type_str(a_) + ")"; } - __MATX_INLINE__ SVDOp(OpA a, const char jobu, const char jobvt) : a_(a), jobu_(jobu), jobv_(jobvt) { }; + __MATX_INLINE__ SVDOp(OpA a, const SVDJob jobu, const SVDJob jobvt) : a_(a), jobu_(jobu), jobvt_(jobvt) { }; // This should never be called template @@ -65,10 +67,9 @@ namespace detail { template void Exec(Out &&out, Executor &&ex) const { - static_assert(is_cuda_executor_v, "svd() only supports the CUDA executor currently"); - static_assert(cuda::std::tuple_size_v> == 4, "Must use mtie with 3 outputs on svd(). ie: (mtie(U, S, V) = svd(A))"); + static_assert(cuda::std::tuple_size_v> == 4, "Must use mtie with 3 outputs on svd(). ie: (mtie(U, S, VT) = svd(A))"); - svd_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), cuda::std::get<2>(out), a_, ex.getStream(), jobu_, jobv_); + svd_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), cuda::std::get<2>(out), a_, ex, jobu_, jobvt_); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() @@ -92,8 +93,13 @@ namespace detail { }; } +/** + * Perform a singular value decomposition (SVD) using cuSolver or a LAPACK host + * library. + * + */ template -__MATX_INLINE__ auto svd(const OpA &a, const char jobu = 'A', const char jobvt = 'A') { +__MATX_INLINE__ auto svd(const OpA &a, const SVDJob jobu = SVDJob::ALL, const SVDJob jobvt = SVDJob::ALL) { return detail::SVDOp(a, jobu, jobvt); } @@ -125,7 +131,7 @@ namespace detail { template void Exec(Out &&out, Executor &&ex) { static_assert(is_cuda_executor_v, "svdpi() only supports the CUDA executor currently"); - static_assert(cuda::std::tuple_size_v> == 4, "Must use mtie with 3 outputs on svdpi(). ie: (mtie(U, S, V) = svdpi(A))"); + static_assert(cuda::std::tuple_size_v> == 4, "Must use mtie with 3 outputs on svdpi(). ie: (mtie(U, S, VT) = svdpi(A))"); svdpi_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), cuda::std::get<2>(out), a_, x_, iterations_, ex, k_); } @@ -202,7 +208,7 @@ namespace detail { template void Exec(Out &&out, Executor &&ex) { static_assert(is_cuda_executor_v, "svdbpi() only supports the CUDA executor currently"); - static_assert(cuda::std::tuple_size_v> == 4, "Must use mtie with 3 outputs on svdbpi(). ie: (mtie(U, S, V) = svdbpi(A))"); + static_assert(cuda::std::tuple_size_v> == 4, "Must use mtie with 3 outputs on svdbpi(). ie: (mtie(U, S, VT) = svdbpi(A))"); svdbpi_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), cuda::std::get<2>(out), a_, max_iters_, tol_, ex); } diff --git a/include/matx/transforms/chol/chol_cuda.h b/include/matx/transforms/chol/chol_cuda.h new file mode 100644 index 000000000..c0854c8f6 --- /dev/null +++ b/include/matx/transforms/chol/chol_cuda.h @@ -0,0 +1,298 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "cublas_v2.h" +#include "cusolverDn.h" + +#include "matx/core/error.h" +#include "matx/core/nvtx.h" +#include "matx/core/tensor.h" +#include "matx/core/cache.h" +#include "matx/transforms/solver_common.h" + +#include +#include + +namespace matx { + +namespace detail { + +/** + * Parameters needed to execute a cholesky factorization. We distinguish unique + * factorizations mostly by the data pointer in A + */ +struct DnCholCUDAParams_t { + int64_t n; + void *A; + size_t batch_size; + cublasFillMode_t uplo; + MatXDataType_t dtype; +}; + +template +class matxDnCholCUDAPlan_t : matxDnCUDASolver_t { + using OutTensor_t = remove_cvref_t; + using T1 = typename remove_cvref_t::value_type; + static constexpr int RANK = OutTensor_t::Rank(); + static_assert(RANK >= 2, "Input/Output tensor must be rank 2 or higher"); + +public: + /** + * Plan for solving + * \f$\textbf{A} = \textbf{L} * \textbf{L^{H}}\f$ or \f$\textbf{A} = + * \textbf{U} * \textbf{U^{H}}\f$ using the Cholesky method + * + * Creates a handle for solving the factorization of A = M * M^H of a dense + * matrix using the Cholesky method, where M is either the upper or lower + * triangular portion of A. Input matrix A must be a square Hermitian matrix + * positive-definite where only the upper or lower triangle is used. + * + * @tparam T1 + * Data type of A matrix + * @tparam RANK + * Rank of A matrix + * + * @param a + * Input tensor view + * @param uplo + * Use upper or lower triangle for computation + * + */ + matxDnCholCUDAPlan_t(const ATensor &a, + cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Dim checks + MATX_STATIC_ASSERT_STR(RANK == remove_cvref_t::Rank(), matxInvalidDim, "Cholesky input/output tensor ranks must match"); + + // Type checks + MATX_STATIC_ASSERT_STR(!is_half_v, matxInvalidType, "Cholesky solver does not support half precision"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Input and Output types must match"); + + params = GetCholParams(a, uplo); + this->GetWorkspaceSize(); + this->AllocateWorkspace(params.batch_size); + } + + void GetWorkspaceSize() override + { + cusolverStatus_t ret = cusolverDnXpotrf_bufferSize(this->handle, this->dn_params, params.uplo, + params.n, MatXTypeToCudaType(), + params.A, params.n, + MatXTypeToCudaType(), &this->dspace, + &this->hspace); + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + } + + static DnCholCUDAParams_t GetCholParams(const ATensor &a, + cublasFillMode_t uplo) + { + DnCholCUDAParams_t params; + params.batch_size = GetNumBatches(a); + params.n = a.Size(RANK - 1); + params.A = a.Data(); + params.uplo = uplo; + params.dtype = TypeToInt(); + + return params; + } + + void Exec(OutputTensor &out, const ATensor &a, + const cudaExecutor &exec, cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + MATX_ASSERT_STR(a.Size(RANK - 1) == a.Size(RANK - 2), matxInvalidSize, "Input to Cholesky must be a square matrix"); + + // Ensure output size matches input + for (int i = 0; i < RANK; i++) { + MATX_ASSERT(out.Size(i) == a.Size(i), matxInvalidSize); + } + + SetBatchPointers(out, this->batch_a_ptrs); + + if (out.Data() != a.Data()) { + (out = a).run(exec); + } + + cusolverDnSetStream(this->handle, exec.getStream()); + int info; + + // At this time cuSolver does not have a batched 64-bit cholesky interface. + // Change this to use the batched version once available. + for (size_t i = 0; i < this->batch_a_ptrs.size(); i++) { + auto ret = cusolverDnXpotrf( + this->handle, this->dn_params, uplo, params.n, MatXTypeToCudaType(), + this->batch_a_ptrs[i], params.n, MatXTypeToCudaType(), + reinterpret_cast(this->d_workspace) + i * this->dspace, this->dspace, + reinterpret_cast(this->h_workspace) + i * this->hspace, this->hspace, + this->d_info + i); + + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + + // This will block. Figure this out later + cudaMemcpy(&info, this->d_info + i, sizeof(info), cudaMemcpyDeviceToHost); + MATX_ASSERT(info == 0, matxSolverError); + } + } + + /** + * Cholesky solver handle destructor + * + * Destroys any helper data used for provider type and any workspace memory + * created + * + */ + ~matxDnCholCUDAPlan_t() {} + +private: + DnCholCUDAParams_t params; +}; + +/** + * Crude hash to get a reasonably good delta for collisions. This doesn't need + * to be perfect, but fast enough to not slow down lookups, and different enough + * so the common solver parameters change + */ +struct DnCholCUDAParamsKeyHash { + std::size_t operator()(const DnCholCUDAParams_t &k) const noexcept + { + return (std::hash()(k.n)) + (std::hash()(k.batch_size)); + } +}; + +/** + * Test cholesky parameters for equality. Unlike the hash, all parameters must + * match. + */ +struct DnCholCUDAParamsKeyEq { + bool operator()(const DnCholCUDAParams_t &l, const DnCholCUDAParams_t &t) const + noexcept + { + return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype; + } +}; + +using chol_cuda_cache_t = std::unordered_map; + +} // end namespace detail + + +/** + * Perform a Cholesky decomposition using a cached plan + * + * See documentation of matxDnCholCUDAPlan_t for a description of how the + * algorithm works. This function provides a simple interface to the cuSolver + * library by deducing all parameters needed to perform a Cholesky decomposition + * from only the matrix A. The input and output parameters may be the same + * tensor. In that case, the input is destroyed and the output is stored + * in-place. Input must be a positive-definite Hermitian or real symmetric matrix. + * + * @tparam T1 + * Data type of matrix A + * @tparam RANK + * Rank of matrix A + * + * @param out + * Output tensor + * @param a + * Input tensor + * @param exec + * CUDA executor + * @param uplo + * Part of matrix to fill + */ +template +void chol_impl(OutputTensor &&out, const ATensor &a, + const cudaExecutor &exec, + SolverFillMode uplo = SolverFillMode::UPPER) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + + using OutputTensor_t = remove_cvref_t; + using T1 = typename OutputTensor_t::value_type; + constexpr int RANK = ATensor::Rank(); + + auto a_new = OpToTensor(a, exec); + + if(!a_new.isSameView(a)) { + (a_new = a).run(exec); + } + + // cuSolver assumes column-major matrices and MatX uses row-major matrices. + // One way to address this is to create a transposed copy of the input to + // use with the factorization, followed by transposing the output. However, + // for matrices with no additional padding, we can also change the value of + // uplo to effectively change the matrix to column-major. This allows us to + // compute the factorization without additional transposes. If we do not + // have contiguous input and output tensors, then we create a temporary + // contiguous tensor for use with cuSolver. + uplo = (uplo == SolverFillMode::UPPER) ? SolverFillMode::LOWER : SolverFillMode::UPPER; + + const bool allContiguous = a_new.IsContiguous() && out.IsContiguous(); + auto tv = [allContiguous, &a_new, &out, &exec]() -> auto { + if (allContiguous) { + (out = a_new).run(exec); + return out; + } else{ + auto t = make_tensor(a_new.Shape(), MATX_ASYNC_DEVICE_MEMORY, exec.getStream()); + (t = a_new).run(exec); + return t; + } + }(); + + cublasFillMode_t uplo_cusolver = (uplo == SolverFillMode::UPPER)? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + + // Get parameters required by these tensors + auto params = detail::matxDnCholCUDAPlan_t::GetCholParams(tv, uplo_cusolver); + + using cache_val_type = detail::matxDnCholCUDAPlan_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), + params, + [&]() { + return std::make_shared(tv, uplo_cusolver); + }, + [&](std::shared_ptr ctype) { + ctype->Exec(tv, tv, exec, uplo_cusolver); + } + ); + + if (! allContiguous) { + matx::copy(out, tv, exec); + } +} + +} // end namespace matx \ No newline at end of file diff --git a/include/matx/transforms/chol/chol_lapack.h b/include/matx/transforms/chol/chol_lapack.h new file mode 100644 index 000000000..1049384f6 --- /dev/null +++ b/include/matx/transforms/chol/chol_lapack.h @@ -0,0 +1,300 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "matx/core/error.h" +#include "matx/core/nvtx.h" +#include "matx/core/tensor.h" +#include "matx/core/cache.h" +#include "matx/executors/host.h" +#include "matx/executors/support.h" +#include "matx/transforms/solver_common.h" + +#include +#include + +namespace matx { + +namespace detail { + +#if MATX_EN_CPU_SOLVER +/** + * Parameters needed to execute a cholesky factorization. We distinguish unique + * factorizations mostly by the data pointer in A + */ +struct DnCholHostParams_t { + lapack_int_t n; + void *A; + size_t batch_size; + char uplo; + MatXDataType_t dtype; +}; + +template +class matxDnCholHostPlan_t : matxDnHostSolver_t::value_type> { + using OutTensor_t = remove_cvref_t; + using T1 = typename remove_cvref_t::value_type; + static constexpr int RANK = OutTensor_t::Rank(); + static_assert(RANK >= 2, "Input/Output tensor must be rank 2 or higher"); + using lapack_type = std::conditional_t>, lapack_scomplex_t, + std::conditional_t>, lapack_dcomplex_t, T1>>; + +public: + /** + * Plan for solving + * \f$\textbf{A} = \textbf{L} * \textbf{L^{H}}\f$ or \f$\textbf{A} = + * \textbf{U} * \textbf{U^{H}}\f$ using the Cholesky method + * + * Creates a handle for solving the factorization of A = M * M^H of a dense + * matrix using the Cholesky method, where M is either the upper or lower + * triangular portion of A. Input matrix A must be a square Hermitian matrix + * positive-definite where only the upper or lower triangle is used. This does + * require a workspace. + * + * @tparam T1 + * Data type of A matrix + * @tparam RANK + * Rank of A matrix + * + * @param a + * Input tensor view + * @param uplo + * Use upper or lower triangle for computation + * + */ + matxDnCholHostPlan_t(const ATensor &a, + const char uplo = 'U') + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Dim checks + MATX_STATIC_ASSERT_STR(RANK == remove_cvref_t::Rank(), matxInvalidDim, "Cholesky input/output tensor ranks must match"); + + // Type checks + MATX_STATIC_ASSERT_STR(!is_half_v, matxInvalidType, "Cholesky solver does not support half precision"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Input and Output types must match"); + + params = GetCholParams(a, uplo); + } + + static DnCholHostParams_t GetCholParams(const ATensor &a, + const char uplo) + { + DnCholHostParams_t params; + params.batch_size = GetNumBatches(a); + params.n = static_cast(a.Size(RANK - 1)); + params.A = a.Data(); + params.uplo = uplo; + params.dtype = TypeToInt(); + + return params; + } + + template + void Exec(OutputTensor &out, const ATensor &a, + const HostExecutor &exec, const char uplo = 'U') + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + MATX_ASSERT_STR(a.Size(RANK - 1) == a.Size(RANK - 2), matxInvalidSize, "Input to Cholesky must be a square matrix"); + + // Ensure output size matches input + for (int i = 0; i < RANK; i++) { + MATX_ASSERT(out.Size(i) == a.Size(i), matxInvalidSize); + } + + SetBatchPointers(out, this->batch_a_ptrs); + + if (out.Data() != a.Data()) { + (out = a).run(exec); + } + + lapack_int_t info; + + for (size_t i = 0; i < this->batch_a_ptrs.size(); i++) { + potrf_dispatch(&uplo, ¶ms.n, + reinterpret_cast(this->batch_a_ptrs[i]), + ¶ms.n, &info); + + MATX_ASSERT(info == 0, matxSolverError); + } + } + + /** + * Cholesky solver handle destructor + * + * Destroys any helper data used for provider type and any workspace memory + * created + * + */ + ~matxDnCholHostPlan_t() {} + +private: + void potrf_dispatch(const char* uplo, const lapack_int_t* n, lapack_type* a, + const lapack_int_t* lda, lapack_int_t* info) + { + if constexpr (std::is_same_v) { + spotrf_(uplo, n, a, lda, info); + } else if constexpr (std::is_same_v) { + dpotrf_(uplo, n, a, lda, info); + } else if constexpr (std::is_same_v) { + cpotrf_(uplo, n, a, lda, info); + } else if constexpr (std::is_same_v) { + zpotrf_(uplo, n, a, lda, info); + } + } + + DnCholHostParams_t params; +}; + +/** + * Crude hash to get a reasonably good delta for collisions. This doesn't need + * to be perfect, but fast enough to not slow down lookups, and different enough + * so the common solver parameters change + */ +struct DnCholHostParamsKeyHash { + std::size_t operator()(const DnCholHostParams_t &k) const noexcept + { + return (std::hash()(k.n)) + (std::hash()(k.batch_size)); + } +}; + +/** + * Test cholesky parameters for equality. Unlike the hash, all parameters must + * match. + */ +struct DnCholHostParamsKeyEq { + bool operator()(const DnCholHostParams_t &l, const DnCholHostParams_t &t) const + noexcept + { + return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype; + } +}; + +using chol_Host_cache_t = std::unordered_map; +#endif + +} // end namespace detail + + +/** + * Perform a Cholesky decomposition using a cached plan + * + * See documentation of matxDnCholHostPlan_t for a description of how the + * algorithm works. This function provides a simple interface to the LAPACK + * library by deducing all parameters needed to perform a Cholesky decomposition + * from only the matrix A. The input and output parameters may be the same + * tensor. In that case, the input is destroyed and the output is stored + * in-place. Input must be a positive-definite Hermitian or real symmetric matrix. + * + * @tparam T1 + * Data type of matrix A + * @tparam RANK + * Rank of matrix A + * + * @param out + * Output tensor + * @param a + * Input tensor + * @param exec + * Host executor + * @param uplo + * Part of matrix to fill + */ +template +void chol_impl([[maybe_unused]] OutputTensor &&out, + [[maybe_unused]] const ATensor &a, + [[maybe_unused]] const HostExecutor &exec, + [[maybe_unused]] SolverFillMode uplo = SolverFillMode::UPPER) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + MATX_ASSERT_STR(MATX_EN_CPU_SOLVER, matxInvalidExecutor, + "Trying to run a host Solver executor but host Solver support is not configured"); +#if MATX_EN_CPU_SOLVER + + using OutputTensor_t = remove_cvref_t; + using T1 = typename OutputTensor_t::value_type; + constexpr int RANK = ATensor::Rank(); + + auto a_new = OpToTensor(a, exec); + + if(!a_new.isSameView(a)) { + (a_new = a).run(exec); + } + + // LAPACK assumes column-major matrices and MatX uses row-major matrices. + // One way to address this is to create a transposed copy of the input to + // use with the factorization, followed by transposing the output. However, + // for matrices with no additional padding, we can also change the value of + // uplo to effectively change the matrix to column-major. This allows us to + // compute the factorization without additional transposes. If we do not + // have contiguous input and output tensors, then we create a temporary + // contiguous tensor for use with LAPACK. + uplo = (uplo == SolverFillMode::UPPER) ? SolverFillMode::LOWER : SolverFillMode::UPPER; + + const bool allContiguous = a_new.IsContiguous() && out.IsContiguous(); + auto tv = [allContiguous, &a_new, &out, &exec]() -> auto { + if (allContiguous) { + (out = a_new).run(exec); + return out; + } else{ + auto t = make_tensor(a_new.Shape(), MATX_HOST_MALLOC_MEMORY); + (t = a_new).run(exec); + return t; + } + }(); + + const char uplo_lapack = (uplo == SolverFillMode::UPPER)? 'U' : 'L'; + + // Get parameters required by these tensors + auto params = detail::matxDnCholHostPlan_t::GetCholParams(tv, uplo_lapack); + + using cache_val_type = detail::matxDnCholHostPlan_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), + params, + [&]() { + return std::make_shared(tv, uplo_lapack); + }, + [&](std::shared_ptr ctype) { + ctype->Exec(tv, tv, exec, uplo_lapack); + } + ); + + if (! allContiguous) { + matx::copy(out, tv, exec); + } +#endif +} + +} // end namespace matx \ No newline at end of file diff --git a/include/matx/transforms/cub.h b/include/matx/transforms/cub.h index c71be9afe..bc5ddcef6 100644 --- a/include/matx/transforms/cub.h +++ b/include/matx/transforms/cub.h @@ -1401,7 +1401,7 @@ void sort_impl(OutputTensor &a_out, const InputOperator &a, template void sort_impl(OutputTensor &a_out, const InputOperator &a, const SortDirection_t dir, - [[maybe_unused]] HostExecutor &exec) + [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) @@ -1496,7 +1496,7 @@ void cumsum_impl(OutputTensor &a_out, const InputOperator &a, template void cumsum_impl(OutputTensor &a_out, const InputOperator &a, - [[maybe_unused]] HostExecutor &exec) + [[maybe_unused]] const HostExecutor &exec) { #ifdef __CUDACC__ MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) @@ -1759,7 +1759,7 @@ void find_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator * Single-threaded host executor */ template -void find_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] HostExecutor &exec) +void find_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] const HostExecutor &exec) { static_assert(CountTensor::Rank() == 0, "Num found output tensor rank must be 0"); MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) @@ -1878,7 +1878,7 @@ void find_idx_impl(OutputTensor &a_out, CountTensor &num_found, const InputOpera * Single host executor */ template -void find_idx_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] HostExecutor &exec) +void find_idx_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] const HostExecutor &exec) { static_assert(CountTensor::Rank() == 0, "Num found output tensor rank must be 0"); MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) @@ -1989,7 +1989,7 @@ void unique_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperato * Single thread executor */ template -void unique_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, [[maybe_unused]] HostExecutor &exec) +void unique_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, [[maybe_unused]] const HostExecutor &exec) { #ifdef __CUDACC__ static_assert(CountTensor::Rank() == 0, "Num found output tensor rank must be 0"); diff --git a/include/matx/transforms/det.h b/include/matx/transforms/det.h new file mode 100644 index 000000000..f1ef7b32b --- /dev/null +++ b/include/matx/transforms/det.h @@ -0,0 +1,130 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "matx/core/error.h" +#include "matx/core/nvtx.h" +#include "matx/core/tensor.h" +#include "matx/core/cache.h" +#include "matx/executors/host.h" +#include "matx/executors/support.h" +#include "matx/transforms/lu/lu_cuda.h" +#ifdef MATX_EN_CPU_SOLVER + #include "matx/transforms/lu/lu_lapack.h" +#endif + +#include +#include + +namespace matx { + +/** + * Compute the determinant of a matrix + * + * Computes the terminant of a matrix by first computing the LU composition, + * then reduces the product of the diagonal elements of U. The input and output + * parameters may be the same tensor. In that case, the input is destroyed and + * the output is stored in-place. + * + * @tparam T1 + * Data type of matrix A + * @tparam RANK + * Rank of matrix A + * + * @param out + * Output tensor view + * @param a + * Input matrix A + * @param exec + * Executor + */ +template +void det_impl(OutputTensor &out, const InputTensor &a, + const Executor &exec) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + MATX_ASSERT_STR(!(is_host_executor_v && !MATX_EN_CPU_SOLVER), matxInvalidExecutor, + "Trying to run a host Solver executor but host Solver support is not configured"); + + static_assert(OutputTensor::Rank() == InputTensor::Rank() - 2, "Output tensor rank must be 2 less than input for det()"); + constexpr int RANK = InputTensor::Rank(); + using value_type = typename OutputTensor::value_type; + using piv_value_type = std::conditional_t, int64_t, lapack_int_t>; + + auto a_new = OpToTensor(a, exec); + + if(!a_new.isSameView(a)) { + (a_new = a).run(exec); + } + + // Get parameters required by these tensors + cuda::std::array s; + + // Set batching dimensions of piv + for (int i = 0; i < RANK - 2; i++) { + s[i] = a_new.Size(i); + } + + index_t piv_len = cuda::std::min(a_new.Size(RANK - 1), a_new.Size(RANK - 2)); + s[RANK - 2] = piv_len; + + tensor_t piv; + tensor_t ac; + + if constexpr (is_cuda_executor_v) { + const auto stream = exec.getStream(); + make_tensor(piv, s, MATX_ASYNC_DEVICE_MEMORY, stream); + make_tensor(ac, a_new.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream); + } else { + make_tensor(piv, s, MATX_HOST_MALLOC_MEMORY); + make_tensor(ac, a_new.Shape(), MATX_HOST_MALLOC_MEMORY); + } + + lu_impl(ac, piv, a_new, exec); + + // Determinant sign adjustment based on piv permutation + // Create indices corresponding to no permutation to compare against + auto pIdxShape = s; + pIdxShape[RANK-2] = matxKeepDim; + auto idx = range<0, 1, piv_value_type>({piv_len}, 1, 1); // piv has 1-based indexing + auto piv_idx = clone(idx, pIdxShape); + + // Calculate number of swaps for each matrix in the batch + auto swap_count = sum(as_type(piv != piv_idx), {RANK-2}); + + // Even number of swaps means positive and odd means negative + auto signs = as_type::type>((swap_count & 1) * -2 + 1); + (out = signs * prod(diag(ac), {RANK-2})).run(exec); +} + +} // end namespace matx \ No newline at end of file diff --git a/include/matx/transforms/eig/eig_cuda.h b/include/matx/transforms/eig/eig_cuda.h new file mode 100644 index 000000000..12d5c165e --- /dev/null +++ b/include/matx/transforms/eig/eig_cuda.h @@ -0,0 +1,321 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "cublas_v2.h" +#include "cusolverDn.h" + +#include "matx/core/error.h" +#include "matx/core/nvtx.h" +#include "matx/core/tensor.h" +#include "matx/core/cache.h" +#include "matx/transforms/solver_common.h" + +#include +#include + +namespace matx { + +namespace detail { + +/** + * Parameters needed to execute eigenvalue decomposition. We distinguish + * unique factorizations mostly by the data pointer in A. + */ +struct DnEigCUDAParams_t { + int64_t n; + cusolverEigMode_t jobz; + cublasFillMode_t uplo; + void *A; + void *out; + void *W; + size_t batch_size; + MatXDataType_t dtype; +}; + +template +class matxDnEigCUDAPlan_t : matxDnCUDASolver_t { +public: + using OutTensor_t = remove_cvref_t; + using T1 = typename ATensor::value_type; + using T2 = typename WTensor::value_type; + static constexpr int RANK = OutTensor_t::Rank(); + static_assert(RANK >= 2, "Input/Output tensor must be rank 2 or higher"); + + /** + * Plan computing eigenvalues/vectors on square Hermitian A such that: + * + * \f$\textbf{A} * textbf{V} = \textbf{V} * \textbf{\Lambda}\f$ + * + * + * @tparam T1 + * Data type of A matrix + * @tparam T2 + * Data type of W matrix + * @tparam RANK + * Rank of A matrix + * + * @param w + * Eigenvalues of A + * @param a + * Input tensor view + * @param jobz + * CUSOLVER_EIG_MODE_VECTOR to compute eigenvectors or + * CUSOLVER_EIG_MODE_NOVECTOR to not compute + * @param uplo + * Where to store data in A + * + */ + matxDnEigCUDAPlan_t(WTensor &w, + const ATensor &a, + cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR, + cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Dim checks + MATX_STATIC_ASSERT_STR(RANK == ATensor::Rank(), matxInvalidDim, "Output and A tensor ranks must match for eigen solver"); + MATX_STATIC_ASSERT_STR(RANK - 1 == WTensor::Rank(), matxInvalidDim, "W tensor must be one rank lower than output for eigen solver"); + + // Type checks + MATX_STATIC_ASSERT_STR(!is_half_v, matxInvalidType, "Eigen solver does not support half precision"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Input and output types must match"); + MATX_STATIC_ASSERT_STR(!is_complex_v, matxInvalidType, "W type must be real"); + MATX_STATIC_ASSERT_STR((std::is_same_v::type, T2>), matxInvalidType, "Out and W inner types must match"); + + params = GetEigParams(w, a, jobz, uplo); + this->GetWorkspaceSize(); + this->AllocateWorkspace(params.batch_size); + } + + void GetWorkspaceSize() override + { + // Use vector mode for a larger workspace size that works for both modes + cusolverStatus_t ret = cusolverDnXsyevd_bufferSize( + this->handle, this->dn_params, CUSOLVER_EIG_MODE_VECTOR, + params.uplo, params.n, MatXTypeToCudaType(), params.A, + params.n, MatXTypeToCudaType(), params.W, + MatXTypeToCudaType(), &this->dspace, + &this->hspace); + + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + } + + static DnEigCUDAParams_t GetEigParams(WTensor &w, + const ATensor &a, + cusolverEigMode_t jobz, + cublasFillMode_t uplo) + { + DnEigCUDAParams_t params; + params.batch_size = GetNumBatches(a); + params.n = a.Size(RANK - 1); + params.A = a.Data(); + params.W = w.Data(); + params.jobz = jobz; + params.uplo = uplo; + params.dtype = TypeToInt(); + + return params; + } + + void Exec(OutputTensor &out, WTensor &w, + const ATensor &a, + const cudaExecutor &exec, + cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR, + cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + MATX_ASSERT_STR(a.Size(RANK - 1) == a.Size(RANK - 2), matxInvalidSize, "Input to eigen must be a square matrix"); + + // Ensure output & w size matches input + for (int i = 0; i < RANK; i++) { + MATX_ASSERT(out.Size(i) == a.Size(i), matxInvalidSize); + if (i < RANK - 1) { + MATX_ASSERT(out.Size(i) == w.Size(i), matxInvalidSize); + } + } + + SetBatchPointers(out, this->batch_a_ptrs); + SetBatchPointers(w, this->batch_w_ptrs); + + if (out.Data() != a.Data()) { + (out = a).run(exec); + } + + cusolverDnSetStream(this->handle, exec.getStream()); + int info; + + // At this time cuSolver does not have a batched 64-bit LU interface. Change + // this to use the batched version once available. + for (size_t i = 0; i < this->batch_a_ptrs.size(); i++) { + auto ret = cusolverDnXsyevd( + this->handle, this->dn_params, jobz, uplo, params.n, MatXTypeToCudaType(), + this->batch_a_ptrs[i], params.n, MatXTypeToCudaType(), this->batch_w_ptrs[i], + MatXTypeToCudaType(), + reinterpret_cast(this->d_workspace) + i * this->dspace, this->dspace, + reinterpret_cast(this->h_workspace) + i * this->hspace, this->hspace, + this->d_info + i); + + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + + // This will block. Figure this out later + cudaMemcpy(&info, this->d_info + i, sizeof(info), cudaMemcpyDeviceToHost); + MATX_ASSERT(info == 0, matxSolverError); + } + } + + /** + * Eigen solver handle destructor + * + * Destroys any helper data used for provider type and any workspace memory + * created + * + */ + ~matxDnEigCUDAPlan_t() {} + +private: + std::vector batch_w_ptrs; + DnEigCUDAParams_t params; +}; + +/** + * Crude hash to get a reasonably good delta for collisions. This doesn't need + * to be perfect, but fast enough to not slow down lookups, and different enough + * so the common solver parameters change + */ +struct DnEigCUDAParamsKeyHash { + std::size_t operator()(const DnEigCUDAParams_t &k) const noexcept + { + return (std::hash()(k.n)) + (std::hash()(k.batch_size)); + } +}; + +/** + * Test Eigen parameters for equality. Unlike the hash, all parameters must + * match. + */ +struct DnEigCUDAParamsKeyEq { + bool operator()(const DnEigCUDAParams_t &l, const DnEigCUDAParams_t &t) const noexcept + { + return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype; + } +}; + +using eig_cuda_cache_t = std::unordered_map; + +} // end namespace detail + + +/** + * Perform a Eig decomposition using a cached plan + * + * See documentation of matxDnEigCUDAPlan_t for a description of how the + * algorithm works. This function provides a simple interface to the cuSolver + * library by deducing all parameters needed to perform a eigen decomposition + * from only the matrix A. The input and output parameters may be the same + * tensor. In that case, the input is destroyed and the output is stored + * in-place. Input must be a Hermitian or real symmetric matrix. + * + * @tparam T1 + * Data type of matrix A + * @tparam RANK + * Rank of matrix A + * + * @param out + * Output tensor view + * @param w + * Eigenvalues output + * @param a + * Input matrix A + * @param exec + * CUDA executor + * @param jobz + * EigenMode::VECTOR to compute eigenvectors or + * EigenMode::NO_VECTOR to not compute + * @param uplo + * Where to store data in A + */ +template +void eig_impl(OutputTensor &&out, WTensor &&w, + const ATensor &a, const cudaExecutor &exec, + EigenMode jobz = EigenMode::VECTOR, + SolverFillMode uplo = SolverFillMode::UPPER) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + using T1 = typename remove_cvref_t::value_type; + + auto w_new = OpToTensor(w, exec); + auto a_new = OpToTensor(a, exec); + + if(!a_new.isSameView(a)) { + (a_new = a).run(exec); + } + + /* Temporary WAR + cuSolver doesn't support row-major layouts. Since we want to make the + library appear as though everything is row-major, we take a performance hit + to transpose in and out of the function. Eventually this may be fixed in + cuSolver. + */ + T1 *tp; + matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY, + exec.getStream()); + auto tv = TransposeCopy(tp, a_new, exec); + + cusolverEigMode_t jobz_cusolver = (jobz == EigenMode::VECTOR) ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; + cublasFillMode_t uplo_cusolver = (uplo == SolverFillMode::UPPER) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + + // Get parameters required by these tensors + auto params = detail::matxDnEigCUDAPlan_t:: + GetEigParams(w_new, tv, jobz_cusolver, uplo_cusolver); + + // Get cache or new eigen plan if it doesn't exist + using cache_val_type = detail::matxDnEigCUDAPlan_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), + params, + [&]() { + return std::make_shared(w_new, tv, jobz_cusolver, uplo_cusolver); + }, + [&](std::shared_ptr ctype) { + ctype->Exec(tv, w_new, tv, exec, jobz_cusolver, uplo_cusolver); + } + ); + + /* Copy and free async buffer for transpose */ + matx::copy(out, tv.PermuteMatrix(), exec); + matxFree(tp); +} + +} // end namespace matx \ No newline at end of file diff --git a/include/matx/transforms/eig/eig_lapack.h b/include/matx/transforms/eig/eig_lapack.h new file mode 100644 index 000000000..b6bb1c6a9 --- /dev/null +++ b/include/matx/transforms/eig/eig_lapack.h @@ -0,0 +1,355 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "matx/core/error.h" +#include "matx/core/nvtx.h" +#include "matx/core/tensor.h" +#include "matx/core/cache.h" +#include "matx/executors/host.h" +#include "matx/executors/support.h" +#include "matx/transforms/solver_common.h" + +#include +#include + +namespace matx { + +namespace detail { + +#if MATX_EN_CPU_SOLVER +/** + * Parameters needed to execute eigenvalue decomposition. We distinguish + * unique factorizations mostly by the data pointer in A. + */ +struct DnEigHostParams_t { + lapack_int_t n; + char jobz; + char uplo; + void *A; + void *out; + void *W; + size_t batch_size; + MatXDataType_t dtype; +}; + +template +class matxDnEigHostPlan_t : matxDnHostSolver_t { +public: + using OutTensor_t = remove_cvref_t; + using T1 = typename ATensor::value_type; + using T2 = typename WTensor::value_type; + static constexpr int RANK = OutTensor_t::Rank(); + static_assert(RANK >= 2, "Input/Output tensor must be rank 2 or higher"); + using lapack_type = std::conditional_t>, lapack_scomplex_t, + std::conditional_t>, lapack_dcomplex_t, T1>>; + + /** + * Plan computing eigenvalues/vectors on square Hermitian A such that: + * + * \f$\textbf{A} * textbf{V} = \textbf{V} * \textbf{\Lambda}\f$ + * + * + * @tparam T1 + * Data type of A matrix + * @tparam T2 + * Data type of W matrix + * @tparam RANK + * Rank of A matrix + * + * @param w + * Eigenvalues of A + * @param a + * Input tensor view + * @param jobz + * 'V' to compute eigenvectors or + * 'N' to not compute + * @param uplo + * Where to store data in A: {'U' or 'L'} + * + */ + matxDnEigHostPlan_t(WTensor &w, + const ATensor &a, + const char jobz = 'V', + const char uplo = 'U') + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Dim checks + MATX_STATIC_ASSERT_STR(RANK == ATensor::Rank(), matxInvalidDim, "Output and A tensor ranks must match for eigen solver"); + MATX_STATIC_ASSERT_STR(RANK - 1 == WTensor::Rank(), matxInvalidDim, "W tensor must be one rank lower than output for eigen solver"); + + // Type checks + MATX_STATIC_ASSERT_STR(!is_half_v, matxInvalidType, "Eigen solver does not support half precision"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Input and output types must match"); + MATX_STATIC_ASSERT_STR(!is_complex_v, matxInvalidType, "W type must be real"); + MATX_STATIC_ASSERT_STR((std::is_same_v::type, T2>), matxInvalidType, "Out and W inner types must match"); + + params = GetEigParams(w, a, jobz, uplo); + this->GetWorkspaceSize(); + this->AllocateWorkspace(params.batch_size); + } + + void GetWorkspaceSize() override + { + // Perform a workspace query with lwork = -1. + lapack_int_t info; + lapack_type work_query; + T2 rwork_query; + lapack_int_t iwork_query; + + // Use vector mode for a larger workspace size that works for both modes + syevd_dispatch("V", ¶ms.uplo, ¶ms.n, nullptr, ¶ms.n, + nullptr, &work_query, &this->lwork, &rwork_query, + &this->lrwork, &iwork_query, &this->liwork, &info); + + MATX_ASSERT(info == 0, matxSolverError); + + // the real part of the first elem of work holds the optimal lwork. + if constexpr (is_complex_v) { + this->lwork = static_cast(work_query.real); + this->lrwork = static_cast(rwork_query); + } else { + this->lwork = static_cast(work_query); + this->lrwork = 0; // Complex variants do not use rwork. + } + this->liwork = static_cast(iwork_query); + } + + static DnEigHostParams_t GetEigParams(WTensor &w, + const ATensor &a, + char jobz, + char uplo) + { + DnEigHostParams_t params; + params.batch_size = GetNumBatches(a); + params.n = static_cast(a.Size(RANK - 1)); + params.A = a.Data(); + params.W = w.Data(); + params.jobz = jobz; + params.uplo = uplo; + params.dtype = TypeToInt(); + + return params; + } + + template + void Exec(OutputTensor &out, WTensor &w, + const ATensor &a, + const HostExecutor &exec, + const char jobz = 'V', + const char uplo = 'U') + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + MATX_ASSERT_STR(a.Size(RANK - 1) == a.Size(RANK - 2), matxInvalidSize, "Input to eigen must be a square matrix"); + + // Ensure output & w size matches input + for (int i = 0; i < RANK; i++) { + MATX_ASSERT(out.Size(i) == a.Size(i), matxInvalidSize); + if (i < RANK - 1) { + MATX_ASSERT(out.Size(i) == w.Size(i), matxInvalidSize); + } + } + + SetBatchPointers(out, this->batch_a_ptrs); + SetBatchPointers(w, this->batch_w_ptrs); + + if (out.Data() != a.Data()) { + (out = a).run(exec); + } + + lapack_int_t info; + for (size_t i = 0; i < this->batch_a_ptrs.size(); i++) { + syevd_dispatch(&jobz, &uplo, ¶ms.n, + reinterpret_cast(this->batch_a_ptrs[i]), + ¶ms.n, reinterpret_cast(this->batch_w_ptrs[i]), + reinterpret_cast(this->work), &this->lwork, + reinterpret_cast(this->rwork), &this->lrwork, + reinterpret_cast(this->iwork), &this->liwork, &info); + + MATX_ASSERT(info == 0, matxSolverError); + } + } + + /** + * Eigen solver handle destructor + * + * Destroys any helper data used for provider type and any workspace memory + * created + * + */ + ~matxDnEigHostPlan_t() {} + +private: + void syevd_dispatch(const char* jobz, const char* uplo, const lapack_int_t* n, + lapack_type* a, const lapack_int_t* lda, T2* w, lapack_type* work_in, + const lapack_int_t* lwork_in, [[maybe_unused]] T2* rwork_in, + [[maybe_unused]] const lapack_int_t* lrwork_in, lapack_int_t* iwork_in, + const lapack_int_t* liwork_in, lapack_int_t* info) + { + // TODO: remove warning suppression once syevd is optimized in NVPL LAPACK +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + if constexpr (std::is_same_v) { + ssyevd_(jobz, uplo, n, a, lda, w, work_in, lwork_in, iwork_in, liwork_in, info); + } else if constexpr (std::is_same_v) { + dsyevd_(jobz, uplo, n, a, lda, w, work_in, lwork_in, iwork_in, liwork_in, info); + } else if constexpr (std::is_same_v) { + cheevd_(jobz, uplo, n, a, lda, w, work_in, lwork_in, rwork_in, lrwork_in, iwork_in, liwork_in, info); + } else if constexpr (std::is_same_v) { + zheevd_(jobz, uplo, n, a, lda, w, work_in, lwork_in, rwork_in, lrwork_in, iwork_in, liwork_in, info); + } +#pragma GCC diagnostic pop + } + + std::vector batch_w_ptrs; + DnEigHostParams_t params; +}; + +/** + * Crude hash to get a reasonably good delta for collisions. This doesn't need + * to be perfect, but fast enough to not slow down lookups, and different enough + * so the common solver parameters change + */ +struct DnEigHostParamsKeyHash { + std::size_t operator()(const DnEigHostParams_t &k) const noexcept + { + return (std::hash()(k.n)) + (std::hash()(k.batch_size)); + } +}; + +/** + * Test Eigen parameters for equality. Unlike the hash, all parameters must + * match. + */ +struct DnEigHostParamsKeyEq { + bool operator()(const DnEigHostParams_t &l, const DnEigHostParams_t &t) const noexcept + { + return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype; + } +}; + +using eig_Host_cache_t = std::unordered_map; +#endif + +} // end namespace detail + + +/** + * Perform a Eig decomposition using a cached plan + * + * See documentation of matxDnEigHostPlan_t for a description of how the + * algorithm works. This function provides a simple interface to a LAPACK + * library by deducing all parameters needed to perform a eigen decomposition + * from only the matrix A. The input and output parameters may be the same + * tensor. In that case, the input is destroyed and the output is stored + * in-place. Input must be a Hermitian or real symmetric matrix. + * + * @tparam T1 + * Data type of matrix A + * @tparam RANK + * Rank of matrix A + * + * @param out + * Output tensor view + * @param w + * Eigenvalues output + * @param a + * Input matrix A + * @param exec + * Host executor + * @param jobz + * EigenMode::VECTOR to compute eigenvectors or + * EigenMode::NO_VECTOR to not compute + * @param uplo + * Where to store data in A + */ +template +void eig_impl([[maybe_unused]] OutputTensor &&out, + [[maybe_unused]] WTensor &&w, + [[maybe_unused]] const ATensor &a, + [[maybe_unused]] const HostExecutor &exec, + [[maybe_unused]] EigenMode jobz = EigenMode::VECTOR, + [[maybe_unused]] SolverFillMode uplo = SolverFillMode::UPPER) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + MATX_ASSERT_STR(MATX_EN_CPU_SOLVER, matxInvalidExecutor, + "Trying to run a host Solver executor but host Solver support is not configured"); +#if MATX_EN_CPU_SOLVER + + using T1 = typename remove_cvref_t::value_type; + + auto w_new = OpToTensor(w, exec); + auto a_new = OpToTensor(a, exec); + + if(!a_new.isSameView(a)) { + (a_new = a).run(exec); + } + + /* Temporary WAR + LAPACK doesn't support row-major layouts. Since we want to make the + library appear as though everything is row-major, we take a performance hit + to transpose in and out of the function. LAPACKE, however, supports both formats. + */ + T1 *tp; + matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_HOST_MALLOC_MEMORY); + auto tv = TransposeCopy(tp, a_new, exec); + + const char jobz_lapack = (jobz == EigenMode::VECTOR) ? 'V' : 'N'; + const char uplo_lapack = (uplo == SolverFillMode::UPPER) ? 'U': 'L'; + + // Get parameters required by these tensors + auto params = detail::matxDnEigHostPlan_t:: + GetEigParams(w_new, tv, jobz_lapack, uplo_lapack); + + // Get cache or new eigen plan if it doesn't exist + using cache_val_type = detail::matxDnEigHostPlan_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), + params, + [&]() { + return std::make_shared(w_new, tv, jobz_lapack, uplo_lapack); + }, + [&](std::shared_ptr ctype) { + ctype->Exec(tv, w_new, tv, exec, jobz_lapack, uplo_lapack); + } + ); + + /* Copy and free async buffer for transpose */ + matx::copy(out, tv.PermuteMatrix(), exec); + matxFree(tp); +#endif +} + +} // end namespace matx \ No newline at end of file diff --git a/include/matx/transforms/lu/lu_cuda.h b/include/matx/transforms/lu/lu_cuda.h new file mode 100644 index 000000000..791cab902 --- /dev/null +++ b/include/matx/transforms/lu/lu_cuda.h @@ -0,0 +1,296 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "cublas_v2.h" +#include "cusolverDn.h" + +#include "matx/core/error.h" +#include "matx/core/nvtx.h" +#include "matx/core/tensor.h" +#include "matx/core/cache.h" +#include "matx/transforms/solver_common.h" + +#include +#include + +namespace matx { + +namespace detail { + +/** + * Parameters needed to execute an LU factorization. We distinguish unique + * factorizations mostly by the data pointer in A + */ +struct DnLUCUDAParams_t { + int64_t m; + int64_t n; + void *A; + void *piv; + size_t batch_size; + MatXDataType_t dtype; +}; + +template +class matxDnLUCUDAPlan_t : matxDnCUDASolver_t { + using OutTensor_t = remove_cvref_t; + using T1 = typename ATensor::value_type; + using T2 = typename PivotTensor::value_type; + static constexpr int RANK = OutTensor_t::Rank(); + static_assert(RANK >= 2, "Input/Output tensor must be rank 2 or higher"); + +public: + /** + * Plan for factoring A such that \f$\textbf{P} * \textbf{A} = \textbf{L} * + * \textbf{U}\f$ + * + * Creates a handle for factoring matrix A into the format above. Matrix must + * not be singular. + * + * @tparam T1 + * Data type of A matrix + * @tparam T2 + * Data type of Pivot vector + * @tparam RANK + * Rank of A matrix + * + * @param piv + * Pivot indices + * @param a + * Input tensor view + * + */ + matxDnLUCUDAPlan_t(PivotTensor &piv, + const ATensor &a) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Dim checks + MATX_STATIC_ASSERT_STR(RANK-1 == PivotTensor::Rank(), matxInvalidDim, "Pivot tensor rank must be one less than output"); + MATX_STATIC_ASSERT_STR(RANK == ATensor::Rank(), matxInvalidDim, "Output tensor must match A tensor rank in LU"); + + // Type checks + MATX_STATIC_ASSERT_STR(!is_half_v, matxInvalidType, "LU solver does not support half precision"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Input and Output types must match"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Pivot tensor type must be int64_t"); + + params = GetLUParams(piv, a); + this->GetWorkspaceSize(); + this->AllocateWorkspace(params.batch_size); + } + + void GetWorkspaceSize() override + { + cusolverStatus_t ret = cusolverDnXgetrf_bufferSize(this->handle, this->dn_params, params.m, + params.n, MatXTypeToCudaType(), + params.A, params.m, + MatXTypeToCudaType(), &this->dspace, + &this->hspace); + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + } + + static DnLUCUDAParams_t GetLUParams(PivotTensor &piv, + const ATensor &a) noexcept + { + DnLUCUDAParams_t params; + params.batch_size = GetNumBatches(a); + params.m = a.Size(RANK - 2); + params.n = a.Size(RANK - 1); + params.A = a.Data(); + params.piv = piv.Data(); + params.dtype = TypeToInt(); + + return params; + } + + void Exec(OutputTensor &out, PivotTensor &piv, + const ATensor &a, const cudaExecutor &exec) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Batch size checks + for(int i = 0 ; i < RANK-2; i++) { + MATX_ASSERT_STR(out.Size(i) == a.Size(i), matxInvalidDim, "Out and A must have the same batch sizes"); + MATX_ASSERT_STR(piv.Size(i) == a.Size(i), matxInvalidDim, "Piv and A must have the same batch sizes"); + } + + // Inner size checks + MATX_ASSERT_STR((out.Size(RANK-2) == params.m) && (out.Size(RANK-1) == params.n), matxInvalidSize, "Out and A shapes do not match"); + MATX_ASSERT_STR(piv.Size(RANK-2) == cuda::std::min(params.m, params.n), matxInvalidSize, "Piv must be ... x min(m,n)"); + + SetBatchPointers(out, this->batch_a_ptrs); + SetBatchPointers(piv, this->batch_piv_ptrs); + + if (out.Data() != a.Data()) { + (out = a).run(exec); + } + + cusolverDnSetStream(this->handle, exec.getStream()); + int info; + + // At this time cuSolver does not have a batched 64-bit LU interface. Change + // this to use the batched version once available. + for (size_t i = 0; i < this->batch_a_ptrs.size(); i++) { + auto ret = cusolverDnXgetrf( + this->handle, this->dn_params, params.m, params.n, MatXTypeToCudaType(), + this->batch_a_ptrs[i], params.m, this->batch_piv_ptrs[i], + MatXTypeToCudaType(), + reinterpret_cast(this->d_workspace) + i * this->dspace, this->dspace, + reinterpret_cast(this->h_workspace) + i * this->hspace, this->hspace, + this->d_info + i); + + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + + // This will block. Figure this out later + cudaMemcpy(&info, this->d_info + i, sizeof(info), cudaMemcpyDeviceToHost); + MATX_ASSERT(info == 0, matxSolverError); + } + } + + /** + * LU solver handle destructor + * + * Destroys any helper data used for provider type and any workspace memory + * created + * + */ + ~matxDnLUCUDAPlan_t() {} + +private: + std::vector batch_piv_ptrs; + DnLUCUDAParams_t params; +}; + +/** + * Crude hash to get a reasonably good delta for collisions. This doesn't need + * to be perfect, but fast enough to not slow down lookups, and different enough + * so the common solver parameters change + */ +struct DnLUCUDAParamsKeyHash { + std::size_t operator()(const DnLUCUDAParams_t &k) const noexcept + { + return (std::hash()(k.m)) + (std::hash()(k.n)) + + (std::hash()(k.batch_size)); + } +}; + +/** + * Test LU parameters for equality. Unlike the hash, all parameters must match. + */ +struct DnLUCUDAParamsKeyEq { + bool operator()(const DnLUCUDAParams_t &l, const DnLUCUDAParams_t &t) const noexcept + { + return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size && + l.dtype == t.dtype; + } +}; + +// Static caches of LU this->handles +using lu_cuda_cache_t = std::unordered_map; + +} // end namespace detail + + +/** + * Perform an LU decomposition + * + * See documentation of matxDnLUCUDAPlan_t for a description of how the + * algorithm works. This function provides a simple interface to the cuSolver + * library by deducing all parameters needed to perform an LU decomposition from + * only the matrix A. The input and output parameters may be the same tensor. In + * that case, the input is destroyed and the output is stored in-place. + * + * @tparam T1 + * Data type of matrix A + * @tparam RANK + * Rank of matrix A + * + * @param out + * Output tensor view + * @param piv + * Output of pivot indices + * @param a + * Input matrix A + * @param exec + * CUDA executor + */ +template +void lu_impl(OutputTensor &&out, PivotTensor &&piv, + const ATensor &a, const cudaExecutor &exec) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + + using T1 = typename remove_cvref_t::value_type; + + auto piv_new = OpToTensor(piv, exec); + auto a_new = OpToTensor(a, exec); + + if(!a_new.isSameView(a)) { + (a_new = a).run(exec); + } + + /* Temporary WAR + cuSolver doesn't support row-major layouts. Since we want to make the + library appear as though everything is row-major, we take a performance hit + to transpose in and out of the function. Eventually this may be fixed in + cuSolver. + */ + T1 *tp; + matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY, + exec.getStream()); + auto tv = TransposeCopy(tp, a_new, exec); + auto tvt = tv.PermuteMatrix(); + + // Get parameters required by these tensors + auto params = detail::matxDnLUCUDAPlan_t::GetLUParams(piv_new, tvt); + + // Get cache or new LU plan if it doesn't exist + using cache_val_type = detail::matxDnLUCUDAPlan_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), + params, + [&]() { + return std::make_shared(piv_new, tvt); + }, + [&](std::shared_ptr ctype) { + ctype->Exec(tvt, piv_new, tvt, exec); + } + ); + + /* Temporary WAR + * Copy and free async buffer for transpose */ + matx::copy(out, tv.PermuteMatrix(), exec); + matxFree(tp); +} + +} // end namespace matx \ No newline at end of file diff --git a/include/matx/transforms/lu/lu_lapack.h b/include/matx/transforms/lu/lu_lapack.h new file mode 100644 index 000000000..1d0fce91c --- /dev/null +++ b/include/matx/transforms/lu/lu_lapack.h @@ -0,0 +1,294 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "matx/core/error.h" +#include "matx/core/nvtx.h" +#include "matx/core/tensor.h" +#include "matx/core/cache.h" +#include "matx/executors/host.h" +#include "matx/executors/support.h" +#include "matx/transforms/solver_common.h" + +#include +#include + +namespace matx { + +namespace detail { + +#if MATX_EN_CPU_SOLVER +/** + * Parameters needed to execute an LU factorization. We distinguish unique + * factorizations mostly by the data pointer in A + */ +struct DnLUHostParams_t { + lapack_int_t m; + lapack_int_t n; + void *A; + void *piv; + size_t batch_size; + MatXDataType_t dtype; +}; + +template +class matxDnLUHostPlan_t : matxDnHostSolver_t { + using OutTensor_t = remove_cvref_t; + using T1 = typename ATensor::value_type; + using T2 = typename PivotTensor::value_type; + static constexpr int RANK = OutTensor_t::Rank(); + static_assert(RANK >= 2, "Input/Output tensor must be rank 2 or higher"); + using lapack_type = std::conditional_t>, lapack_scomplex_t, + std::conditional_t>, lapack_dcomplex_t, T1>>; + +public: + /** + * Plan for factoring A such that \f$\textbf{P} * \textbf{A} = \textbf{L} * + * \textbf{U}\f$ + * + * Creates a handle for factoring matrix A into the format above. Matrix must + * not be singular. This does not require a workspace. + * + * @tparam T1 + * Data type of A matrix + * @tparam T2 + * Data type of Pivot vector + * @tparam RANK + * Rank of A matrix + * + * @param piv + * Pivot indices + * @param a + * Input tensor view + * + */ + matxDnLUHostPlan_t(PivotTensor &piv, + const ATensor &a) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Dim checks + MATX_STATIC_ASSERT_STR(RANK-1 == PivotTensor::Rank(), matxInvalidDim, "Pivot tensor rank must be one less than output"); + MATX_STATIC_ASSERT_STR(RANK == ATensor::Rank(), matxInvalidDim, "Output tensor must match A tensor rank in LU"); + + // Type checks + MATX_STATIC_ASSERT_STR(!is_half_v, matxInvalidType, "LU solver does not support half precision"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Input and Output types must match"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, + "Pivot tensor type must match the LAPACK host library integer type"); + + params = GetLUParams(piv, a); + } + + static DnLUHostParams_t GetLUParams(PivotTensor &piv, + const ATensor &a) noexcept + { + DnLUHostParams_t params; + params.batch_size = GetNumBatches(a); + params.m = static_cast(a.Size(RANK - 2)); + params.n = static_cast(a.Size(RANK - 1)); + params.A = a.Data(); + params.piv = piv.Data(); + params.dtype = TypeToInt(); + + return params; + } + + template + void Exec(OutputTensor &out, PivotTensor &piv, + const ATensor &a, const HostExecutor &exec) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Batch size checks + for(int i = 0 ; i < RANK-2; i++) { + MATX_ASSERT_STR(out.Size(i) == a.Size(i), matxInvalidDim, "Out and A must have the same batch sizes"); + MATX_ASSERT_STR(piv.Size(i) == a.Size(i), matxInvalidDim, "Piv and A must have the same batch sizes"); + } + + // Inner size checks + MATX_ASSERT_STR((out.Size(RANK-2) == params.m) && (out.Size(RANK-1) == params.n), matxInvalidSize, "Out and A shapes do not match"); + MATX_ASSERT_STR(piv.Size(RANK-2) == cuda::std::min(params.m, params.n), matxInvalidSize, "Piv must be ... x min(m,n)"); + + SetBatchPointers(out, this->batch_a_ptrs); + SetBatchPointers(piv, this->batch_piv_ptrs); + + if (out.Data() != a.Data()) { + (out = a).run(exec); + } + + lapack_int_t info; + for (size_t i = 0; i < this->batch_a_ptrs.size(); i++) { + getrf_dispatch(¶ms.m, ¶ms.n, reinterpret_cast(this->batch_a_ptrs[i]), + ¶ms.m, reinterpret_cast(this->batch_piv_ptrs[i]), &info); + + MATX_ASSERT(info == 0, matxSolverError); + } + } + + /** + * LU solver handle destructor + * + * Destroys any helper data used for provider type and any workspace memory + * created + * + */ + ~matxDnLUHostPlan_t() {} + +private: + void getrf_dispatch(const lapack_int_t* m, const lapack_int_t* n, lapack_type* a, + const lapack_int_t* lda, lapack_int_t* piv, lapack_int_t* info) + { + if constexpr (std::is_same_v) { + sgetrf_(m, n, a, lda, piv, info); + } else if constexpr (std::is_same_v) { + dgetrf_(m, n, a, lda, piv, info); + } else if constexpr (std::is_same_v) { + cgetrf_(m, n, a, lda, piv, info); + } else if constexpr (std::is_same_v) { + zgetrf_(m, n, a, lda, piv, info); + } + } + + std::vector batch_piv_ptrs; + DnLUHostParams_t params; +}; + +/** + * Crude hash to get a reasonably good delta for collisions. This doesn't need + * to be perfect, but fast enough to not slow down lookups, and different enough + * so the common solver parameters change + */ +struct DnLUHostParamsKeyHash { + std::size_t operator()(const DnLUHostParams_t &k) const noexcept + { + return (std::hash()(k.m)) + (std::hash()(k.n)) + + (std::hash()(k.batch_size)); + } +}; + +/** + * Test LU parameters for equality. Unlike the hash, all parameters must match. + */ +struct DnLUHostParamsKeyEq { + bool operator()(const DnLUHostParams_t &l, const DnLUHostParams_t &t) const noexcept + { + return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size && + l.dtype == t.dtype; + } +}; + +// Static caches of LU this->handles +using lu_Host_cache_t = std::unordered_map; +#endif + +} // end namespace detail + + +/** + * Perform an LU decomposition + * + * See documentation of matxDnLUHostPlan_t for a description of how the + * algorithm works. This function provides a simple interface to the LAPACK + * library by deducing all parameters needed to perform an LU decomposition from + * only the matrix A. The input and output parameters may be the same tensor. In + * that case, the input is destroyed and the output is stored in-place. + * + * @tparam T1 + * Data type of matrix A + * @tparam RANK + * Rank of matrix A + * + * @param out + * Output tensor view + * @param piv + * Output of pivot indices + * @param a + * Input matrix A + * @param exec + * Host Executor + */ +template +void lu_impl([[maybe_unused]] OutputTensor &&out, + [[maybe_unused]] PivotTensor &&piv, + [[maybe_unused]] const ATensor &a, + [[maybe_unused]] const HostExecutor &exec) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + MATX_ASSERT_STR(MATX_EN_CPU_SOLVER, matxInvalidExecutor, + "Trying to run a host Solver executor but host Solver support is not configured"); +#if MATX_EN_CPU_SOLVER + + using T1 = typename remove_cvref_t::value_type; + + auto piv_new = OpToTensor(piv, exec); + auto a_new = OpToTensor(a, exec); + + if(!a_new.isSameView(a)) { + (a_new = a).run(exec); + } + + /* Temporary WAR + LAPACK doesn't support row-major layouts. Since we want to make the + library appear as though everything is row-major, we take a performance hit + to transpose in and out of the function. LAPACKE, however, supports both formats. + */ + T1 *tp; + matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_HOST_MALLOC_MEMORY); + auto tv = TransposeCopy(tp, a_new, exec); + auto tvt = tv.PermuteMatrix(); + + // Get parameters required by these tensors + auto params = detail::matxDnLUHostPlan_t::GetLUParams(piv_new, tvt); + + // Get cache or new LU plan if it doesn't exist + using cache_val_type = detail::matxDnLUHostPlan_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), + params, + [&]() { + return std::make_shared(piv_new, tvt); + }, + [&](std::shared_ptr ctype) { + ctype->Exec(tvt, piv_new, tvt, exec); + } + ); + + /* Temporary WAR + * Copy and free async buffer for transpose */ + matx::copy(out, tv.PermuteMatrix(), exec); + matxFree(tp); +#endif +} + +} // end namespace matx \ No newline at end of file diff --git a/include/matx/transforms/qr.h b/include/matx/transforms/qr.h deleted file mode 100644 index 7fce7bd92..000000000 --- a/include/matx/transforms/qr.h +++ /dev/null @@ -1,224 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// BSD 3-Clause License -// -// Copyright (c) 2021, NVIDIA Corporation -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -///////////////////////////////////////////////////////////////////////////////// - -#pragma once - -#include "matx/core/error.h" -#include "matx/core/nvtx.h" -#include "matx/core/tensor.h" -#include "matx/operators/slice.h" -#include -#include - -namespace matx { - -namespace detail { - - template - inline auto qr_internal_workspace(const AType &A, cudaStream_t stream) { - using ATypeS = typename AType::value_type; - const int RANK = AType::Rank(); - - index_t m = A.Size(RANK-2); - index_t n = A.Size(RANK-1); - - cuda::std::array uShape; - for(int i = 0; i < RANK-2; i++) { - uShape[i] = A.Size(i); - } - uShape[RANK-2] = A.Size(RANK-1); - - auto QShape = A.Shape(); - QShape[RANK-1] = m; - - auto Qin = make_tensor(QShape, MATX_ASYNC_DEVICE_MEMORY, stream); - auto wwt = make_tensor(QShape, MATX_ASYNC_DEVICE_MEMORY, stream); - auto u = make_tensor(uShape, MATX_ASYNC_DEVICE_MEMORY, stream); - - return cuda::std::make_tuple(Qin, wwt, u); - } - - template - inline void qr_internal(QType &Q, RType &R, const AType &A, WType workspace, const cudaExecutor &exec) { - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - const auto stream = exec.getStream(); - - static_assert(AType::Rank() >= 2); - static_assert(QType::Rank() == AType::Rank()); - static_assert(RType::Rank() == AType::Rank()); - - MATX_ASSERT_STR(AType::Rank() == QType::Rank(), matxInvalidDim, "qr: A and Q must have the same rank"); - MATX_ASSERT_STR(AType::Rank() == RType::Rank(), matxInvalidDim, "qr: A and R must have the same rank"); - - using ATypeS = typename AType::value_type; - using NTypeS = typename inner_op_type_t::type; - const int RANK = AType::Rank(); - - index_t m = A.Size(RANK-2); - index_t n = A.Size(RANK-1); - index_t k = cuda::std::min(m,n); - if(m<=n) k--; // these matrices have one less update since the diagonal ends on the bottom of the matrix - - auto Qin = cuda::std::get<0>(workspace); - auto wwt = cuda::std::get<1>(workspace); - auto u = cuda::std::get<2>(workspace); - - static_assert(decltype(Qin)::Rank() == QType::Rank()); - static_assert(decltype(wwt)::Rank() == QType::Rank()); - static_assert(decltype(u)::Rank() == QType::Rank()-1); - - // Create Identity matrix - auto E = eye({m, m}); - - // Clone over batch Dims - auto ECShape = Q.Shape(); - ECShape[RANK-1] = matxKeepDim; - ECShape[RANK-2] = matxKeepDim; - - auto I = clone(E, ECShape); - - // Inititalize Q - (Q = I).run(stream); - (R = A).run(stream); - - // we will slice X directly from R. - cuda::std::array xSliceB, xSliceE; - xSliceB.fill(0); xSliceE.fill(matxEnd); - xSliceE[RANK-1] = matxDropDim; // drop last dim to make a vector - - - // v is of size m x 1. Instead of allocating additional memory we will just reuse a row of Qin - cuda::std::array vSliceB, vSliceE; - vSliceB.fill(0); vSliceE.fill(matxEnd); - // select a single row of Q to alias as v - vSliceE[RANK-2] = matxDropDim; - auto v = slice(Qin, vSliceB, vSliceE); - auto xz = v; // alias - - - // N is of size 1. Instead of allocating additional memory we will just reuse an entry of Qin - cuda::std::array nSliceB, nSliceE; - nSliceB.fill(0); nSliceE.fill(matxEnd); - // select a single row of Q to alias as v - nSliceE[RANK-2] = matxDropDim; - nSliceE[RANK-1] = matxDropDim; - - auto N = slice(wwt, nSliceB, nSliceE); - - // N cloned with RANK-2 of size m. - cuda::std::array ncShape; - ncShape.fill(matxKeepDim); - ncShape[RANK-2] = m; - auto nc = clone(N,ncShape); - - // aliasing some memory here to share storage and provide clarity in the code below - auto s = N; // alias - auto sc = nc; // alias - auto w = v; // alias - - for(int i = 0 ; i < k ; i++) { - - // slice off a column of R and alias as x - xSliceB[RANK-1] = i; - auto x = slice(R, xSliceB, xSliceE); - - // operator which zeros out values above current index in matrix - (xz = (index(x.Rank()-1) >= i) * x).run(stream); - - // compute L2 norm without sqrt. - (N = sum(abs2(xz))).run(stream); - //(N = sqrt(N)).run(stream); // sqrt folded into next op - - (v = xz + (index(v.Rank()-1) == i) * sign(xz) * sqrt(nc)).run(stream); - - auto r = x; // alias column of R happens to be the same as x - - (s = sum(abs2(v))).run(stream); - //(s = sqrt(s)).run(stream); // sqrt folded into next op - - // IFELSE to avoid nans when dividing by zero - (IFELSE(sc != NTypeS(0), - w = (v / sqrt(sc)), - w = NTypeS(0))).run(stream); - - (u = matvec(conj(transpose_matrix(R)), w, 2 , 0)).run(stream); - - (R = outer(w, conj(u), -1, 1)).run(stream); - - // entries below diagonal should be numerical zero. Zero them out to avoid additional FP error. - (IF(index(x.Rank()-1) > i, r = ATypeS(0)) ).run(stream); - - (wwt = outer(w, conj(w))).run(stream); - - (Qin = Q).run(stream); // save input - matmul_impl(Q, Qin, wwt, exec, -2, 1); - - } - } -} // end namespace detail - -/** - * Perform QR decomposition on a matrix using housholders reflections. If rank > 2 operations are batched. - * - * @tparam QType - * Tensor or operator type for output of Q matrix or tensor output. - * @tparam RType - * Tensor or operator type for output of R matrix - * @tparam AType - * Tensor or operator type for output of A input tensors. - * - * @param Q - * Q output tensor or operator. - * @param R - * R output tensor or operator. - * @param A - * Input tensor or operator for tensor A input. - * @param exec - * CUDA executor - */ -template -inline void qr_impl(QType &Q, RType &R, const AType &A, const cudaExecutor &exec) { - const auto stream = exec.getStream(); - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - static_assert(AType::Rank() >= 2); - static_assert(QType::Rank() == AType::Rank()); - static_assert(RType::Rank() == AType::Rank()); - - MATX_ASSERT_STR(AType::Rank() == QType::Rank(), matxInvalidDim, "qr: A and Q must have the same rank"); - MATX_ASSERT_STR(AType::Rank() == RType::Rank(), matxInvalidDim, "qr: A and R must have the same rank"); - - auto workspace = qr_internal_workspace(A, stream); - qr_internal(Q,R,A,workspace,exec); -} - -} // end namespace matx diff --git a/include/matx/transforms/qr/qr_cuda.h b/include/matx/transforms/qr/qr_cuda.h new file mode 100644 index 000000000..c05ffa533 --- /dev/null +++ b/include/matx/transforms/qr/qr_cuda.h @@ -0,0 +1,482 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "cublas_v2.h" +#include "cusolverDn.h" + +#include "matx/core/error.h" +#include "matx/core/nvtx.h" +#include "matx/core/tensor.h" +#include "matx/core/cache.h" +#include "matx/operators/slice.h" +#include "matx/transforms/solver_common.h" + +#include +#include + +namespace matx { + +namespace detail { + + template + inline auto qr_internal_workspace(const AType &A, cudaStream_t stream) { + using ATypeS = typename AType::value_type; + const int RANK = AType::Rank(); + + index_t m = A.Size(RANK-2); + index_t n = A.Size(RANK-1); + + cuda::std::array uShape; + for(int i = 0; i < RANK-2; i++) { + uShape[i] = A.Size(i); + } + uShape[RANK-2] = A.Size(RANK-1); + + auto QShape = A.Shape(); + QShape[RANK-1] = m; + + auto Qin = make_tensor(QShape, MATX_ASYNC_DEVICE_MEMORY, stream); + auto wwt = make_tensor(QShape, MATX_ASYNC_DEVICE_MEMORY, stream); + auto u = make_tensor(uShape, MATX_ASYNC_DEVICE_MEMORY, stream); + + return cuda::std::make_tuple(Qin, wwt, u); + } + + template + inline void qr_internal(QType &Q, RType &R, const AType &A, WType workspace, const cudaExecutor &exec) { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + const auto stream = exec.getStream(); + + static_assert(AType::Rank() >= 2); + static_assert(QType::Rank() == AType::Rank()); + static_assert(RType::Rank() == AType::Rank()); + + MATX_ASSERT_STR(AType::Rank() == QType::Rank(), matxInvalidDim, "qr: A and Q must have the same rank"); + MATX_ASSERT_STR(AType::Rank() == RType::Rank(), matxInvalidDim, "qr: A and R must have the same rank"); + + using ATypeS = typename AType::value_type; + using NTypeS = typename inner_op_type_t::type; + const int RANK = AType::Rank(); + + index_t m = A.Size(RANK-2); + index_t n = A.Size(RANK-1); + index_t k = cuda::std::min(m,n); + if(m<=n) k--; // these matrices have one less update since the diagonal ends on the bottom of the matrix + + auto Qin = cuda::std::get<0>(workspace); + auto wwt = cuda::std::get<1>(workspace); + auto u = cuda::std::get<2>(workspace); + + static_assert(decltype(Qin)::Rank() == QType::Rank()); + static_assert(decltype(wwt)::Rank() == QType::Rank()); + static_assert(decltype(u)::Rank() == QType::Rank()-1); + + // Create Identity matrix + auto E = eye({m, m}); + + // Clone over batch Dims + auto ECShape = Q.Shape(); + ECShape[RANK-1] = matxKeepDim; + ECShape[RANK-2] = matxKeepDim; + + auto I = clone(E, ECShape); + + // Inititalize Q + (Q = I).run(stream); + (R = A).run(stream); + + // we will slice X directly from R. + cuda::std::array xSliceB, xSliceE; + xSliceB.fill(0); xSliceE.fill(matxEnd); + xSliceE[RANK-1] = matxDropDim; // drop last dim to make a vector + + + // v is of size m x 1. Instead of allocating additional memory we will just reuse a row of Qin + cuda::std::array vSliceB, vSliceE; + vSliceB.fill(0); vSliceE.fill(matxEnd); + // select a single row of Q to alias as v + vSliceE[RANK-2] = matxDropDim; + auto v = slice(Qin, vSliceB, vSliceE); + auto xz = v; // alias + + + // N is of size 1. Instead of allocating additional memory we will just reuse an entry of Qin + cuda::std::array nSliceB, nSliceE; + nSliceB.fill(0); nSliceE.fill(matxEnd); + // select a single row of Q to alias as v + nSliceE[RANK-2] = matxDropDim; + nSliceE[RANK-1] = matxDropDim; + + auto N = slice(wwt, nSliceB, nSliceE); + + // N cloned with RANK-2 of size m. + cuda::std::array ncShape; + ncShape.fill(matxKeepDim); + ncShape[RANK-2] = m; + auto nc = clone(N,ncShape); + + // aliasing some memory here to share storage and provide clarity in the code below + auto s = N; // alias + auto sc = nc; // alias + auto w = v; // alias + + for(int i = 0 ; i < k ; i++) { + + // slice off a column of R and alias as x + xSliceB[RANK-1] = i; + auto x = slice(R, xSliceB, xSliceE); + + // operator which zeros out values above current index in matrix + (xz = (index(x.Rank()-1) >= i) * x).run(stream); + + // compute L2 norm without sqrt. + (N = sum(abs2(xz))).run(stream); + //(N = sqrt(N)).run(stream); // sqrt folded into next op + + (v = xz + (index(v.Rank()-1) == i) * sign(xz) * sqrt(nc)).run(stream); + + auto r = x; // alias column of R happens to be the same as x + + (s = sum(abs2(v))).run(stream); + //(s = sqrt(s)).run(stream); // sqrt folded into next op + + // IFELSE to avoid nans when dividing by zero + (IFELSE(sc != NTypeS(0), + w = (v / sqrt(sc)), + w = NTypeS(0))).run(stream); + + (u = matvec(conj(transpose_matrix(R)), w, 2 , 0)).run(stream); + + (R = outer(w, conj(u), -1, 1)).run(stream); + + // entries below diagonal should be numerical zero. Zero them out to avoid additional FP error. + (IF(index(x.Rank()-1) > i, r = ATypeS(0)) ).run(stream); + + (wwt = outer(w, conj(w))).run(stream); + + (Qin = Q).run(stream); // save input + matmul_impl(Q, Qin, wwt, exec, -2, 1); + + } + } +} // end namespace detail + + +/** + * Perform QR decomposition on a matrix using housholders reflections. If rank > 2, operations are batched. + * + * @tparam QType + * Tensor or operator type for output of Q matrix or tensor output. + * @tparam RType + * Tensor or operator type for output of R matrix + * @tparam AType + * Tensor or operator type for output of A input tensors. + * + * @param Q + * Q output tensor or operator. + * @param R + * R output tensor or operator. + * @param A + * Input tensor or operator for tensor A input. + * @param exec + * CUDA executor + */ +template +inline void qr_impl(QType &Q, RType &R, const AType &A, const cudaExecutor &exec) { + const auto stream = exec.getStream(); + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + static_assert(AType::Rank() >= 2); + static_assert(QType::Rank() == AType::Rank()); + static_assert(RType::Rank() == AType::Rank()); + + MATX_ASSERT_STR(AType::Rank() == QType::Rank(), matxInvalidDim, "qr: A and Q must have the same rank"); + MATX_ASSERT_STR(AType::Rank() == RType::Rank(), matxInvalidDim, "qr: A and R must have the same rank"); + + auto workspace = detail::qr_internal_workspace(A, stream); + detail::qr_internal(Q,R,A,workspace,exec); +} + + +/********************************************** SOLVER QR + * *********************************************/ + +namespace detail { + +/** + * Parameters needed to execute a QR factorization. We distinguish unique + * factorizations mostly by the data pointer in A + */ +struct DnQRCUDAParams_t { + int64_t m; + int64_t n; + void *A; + void *tau; + size_t batch_size; + MatXDataType_t dtype; +}; + +template +class matxDnQRCUDAPlan_t : matxDnCUDASolver_t { + using OutTensor_t = remove_cvref_t; + using T1 = typename ATensor::value_type; + using T2 = typename TauTensor::value_type; + static constexpr int RANK = OutTensor_t::Rank(); + static_assert(RANK >= 2, "Input/Output tensor must be rank 2 or higher"); + +public: + /** + * Plan for factoring A such that \f$\textbf{A} = \textbf{Q} * \textbf{R}\f$ + * + * Creates a handle for factoring matrix A into the format above. QR + * decomposition in cuBLAS/cuSolver does not return the Q matrix directly, and + * it must be computed separately used the Householder reflections in the tau + * output, along with the overwritten A matrix input. The input and output + * parameters may be the same tensor. In that case, the input is destroyed and + * the output is stored in-place. + * + * @tparam T1 + * Data type of A matrix + * @tparam T2 + * Data type of Tau vector + * @tparam RANK + * Rank of A matrix + * + * @param tau + * Scaling factors for reflections + * @param a + * Input tensor view + * + */ + matxDnQRCUDAPlan_t(TauTensor &tau, + const ATensor &a) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Dim checks + MATX_STATIC_ASSERT_STR(RANK-1 == TauTensor::Rank(), matxInvalidDim, "Tau tensor must be one rank less than output tensor"); + MATX_STATIC_ASSERT_STR(RANK == ATensor::Rank(), matxInvalidDim, "Output tensor must match A tensor rank in QR"); + + // Type checks + MATX_STATIC_ASSERT_STR(!is_half_v, matxInvalidType, "QR solver does not support half precision"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Input and Output types must match"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "A and Tau types must match"); + + params = GetQRParams(tau, a); + this->GetWorkspaceSize(); + this->AllocateWorkspace(params.batch_size); + } + + void GetWorkspaceSize() override + { + cusolverStatus_t ret = cusolverDnXgeqrf_bufferSize( + this->handle, this->dn_params, params.m, params.n, MatXTypeToCudaType(), + params.A, params.m, MatXTypeToCudaType(), params.tau, + MatXTypeToCudaType(), &this->dspace, &this->hspace); + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + } + + static DnQRCUDAParams_t GetQRParams(TauTensor &tau, + const ATensor &a) + { + DnQRCUDAParams_t params; + + params.batch_size = GetNumBatches(a); + params.m = a.Size(RANK - 2); + params.n = a.Size(RANK - 1); + params.A = a.Data(); + params.tau = tau.Data(); + params.dtype = TypeToInt(); + + return params; + } + + void Exec(OutTensor &out, TauTensor &tau, + const ATensor &a, const cudaExecutor &exec) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Batch size checks + for(int i = 0 ; i < RANK-2; i++) { + MATX_ASSERT_STR(out.Size(i) == a.Size(i), matxInvalidDim, "Out and A must have the same batch sizes"); + MATX_ASSERT_STR(tau.Size(i) == a.Size(i), matxInvalidDim, "Tau and A must have the same batch sizes"); + } + + // Inner size checks + MATX_ASSERT_STR((out.Size(RANK-2) == params.m) && (out.Size(RANK-1) == params.n), matxInvalidSize, "Out and A shapes do not match"); + MATX_ASSERT_STR(tau.Size(RANK-2) == cuda::std::min(params.m, params.n), matxInvalidSize, "Tau must be ... x min(m,n)"); + + SetBatchPointers(out, this->batch_a_ptrs); + SetBatchPointers(tau, this->batch_tau_ptrs); + + if (out.Data() != a.Data()) { + (out = a).run(exec); + } + + cusolverDnSetStream(this->handle, exec.getStream()); + int info; + + // At this time cuSolver does not have a batched 64-bit LU interface. Change + // this to use the batched version once available. + for (size_t i = 0; i < this->batch_a_ptrs.size(); i++) { + auto ret = cusolverDnXgeqrf( + this->handle, this->dn_params, params.m, params.n, MatXTypeToCudaType(), + this->batch_a_ptrs[i], params.m, MatXTypeToCudaType(), + this->batch_tau_ptrs[i], MatXTypeToCudaType(), + reinterpret_cast(this->d_workspace) + i * this->dspace, this->dspace, + reinterpret_cast(this->h_workspace) + i * this->hspace, this->hspace, + this->d_info + i); + + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + + // This will block. Figure this out later + cudaMemcpy(&info, this->d_info + i, sizeof(info), cudaMemcpyDeviceToHost); + MATX_ASSERT(info == 0, matxSolverError); + } + } + + /** + * QR solver handle destructor + * + * Destroys any helper data used for provider type and any workspace memory + * created + * + */ + ~matxDnQRCUDAPlan_t() {} + +private: + std::vector batch_tau_ptrs; + DnQRCUDAParams_t params; +}; + +/** + * Crude hash to get a reasonably good delta for collisions. This doesn't need + * to be perfect, but fast enough to not slow down lookups, and different enough + * so the common solver parameters change + */ +struct DnQRCUDAParamsKeyHash { + std::size_t operator()(const DnQRCUDAParams_t &k) const noexcept + { + return (std::hash()(k.m)) + (std::hash()(k.n)) + + (std::hash()(k.batch_size)); + } +}; + +/** + * Test QR parameters for equality. Unlike the hash, all parameters must match. + */ +struct DnQRCUDAParamsKeyEq { + bool operator()(const DnQRCUDAParams_t &l, const DnQRCUDAParams_t &t) const noexcept + { + return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size && + l.dtype == t.dtype; + } +}; + +using qr_cuda_cache_t = std::unordered_map; + +} // end namespace detail + +/** + * Perform a QR decomposition using a cached plan + * + * See documentation of matxDnQRCUDAPlan_t for a description of how the + * algorithm works. This function provides a simple interface to the cuSolver + * library by deducing all parameters needed to perform a QR decomposition from + * only the matrix A. The input and output parameters may be the same tensor. In + * that case, the input is destroyed and the output is stored in-place. + * + * @tparam T1 + * Data type of matrix A + * @tparam RANK + * Rank of matrix A + * + * @param out + * Output tensor view + * @param tau + * Output of reflection scalar values + * @param a + * Input tensor A + * @param exec + * CUDA executor + */ +template +void qr_solver_impl(OutTensor &&out, TauTensor &&tau, + const ATensor &a, const cudaExecutor &exec) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + using T1 = typename remove_cvref_t::value_type; + + auto tau_new = OpToTensor(tau, exec); + auto a_new = OpToTensor(a, exec); + + if(!a_new.isSameView(a)) { + (a_new = a).run(exec); + } + + /* Temporary WAR + cuSolver doesn't support row-major layouts. Since we want to make the + library appear as though everything is row-major, we take a performance hit + to transpose in and out of the function. Eventually this may be fixed in + cuSolver. + */ + T1 *tp; + matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY, + exec.getStream()); + auto tv = TransposeCopy(tp, a_new, exec); + auto tvt = tv.PermuteMatrix(); + + // Get parameters required by these tensors + auto params = detail::matxDnQRCUDAPlan_t::GetQRParams(tau_new, tvt); + + // Get cache or new QR plan if it doesn't exist + using cache_val_type = detail::matxDnQRCUDAPlan_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), + params, + [&]() { + return std::make_shared(tau_new, tvt); + }, + [&](std::shared_ptr ctype) { + ctype->Exec(tvt, tau_new, tvt, exec); + } + ); + + /* Temporary WAR + * Copy and free async buffer for transpose */ + matx::copy(out, tv.PermuteMatrix(), exec); + matxFree(tp); +} + +} // end namespace matx diff --git a/include/matx/transforms/qr/qr_lapack.h b/include/matx/transforms/qr/qr_lapack.h new file mode 100644 index 000000000..16284a702 --- /dev/null +++ b/include/matx/transforms/qr/qr_lapack.h @@ -0,0 +1,318 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "matx/core/error.h" +#include "matx/core/nvtx.h" +#include "matx/core/tensor.h" +#include "matx/core/cache.h" +#include "matx/operators/slice.h" +#include "matx/executors/host.h" +#include "matx/executors/support.h" +#include "matx/transforms/solver_common.h" + +#include +#include + +namespace matx { + +namespace detail { + +#if MATX_EN_CPU_SOLVER +/** + * Parameters needed to execute a QR factorization. We distinguish unique + * factorizations mostly by the data pointer in A + */ +struct DnQRHostParams_t { + lapack_int_t m; + lapack_int_t n; + void *A; + void *tau; + size_t batch_size; + MatXDataType_t dtype; +}; + +template +class matxDnQRHostPlan_t : matxDnHostSolver_t { + using OutTensor_t = remove_cvref_t; + using T1 = typename ATensor::value_type; + using T2 = typename TauTensor::value_type; + static constexpr int RANK = OutTensor_t::Rank(); + static_assert(RANK >= 2, "Input/Output tensor must be rank 2 or higher"); + using lapack_type = std::conditional_t>, lapack_scomplex_t, + std::conditional_t>, lapack_dcomplex_t, T1>>; +public: + /** + * Plan for factoring A such that \f$\textbf{A} = \textbf{Q} * \textbf{R}\f$ + * + * Creates a handle for factoring matrix A into the format above. QR + * decomposition in LAPACK does not return the Q matrix directly, and + * it must be computed separately used the Householder reflections in the tau + * output, along with the overwritten A matrix input. The input and output + * parameters may be the same tensor. In that case, the input is destroyed and + * the output is stored in-place. + * + * @tparam T1 + * Data type of A matrix + * @tparam T2 + * Data type of Tau vector + * @tparam RANK + * Rank of A matrix + * + * @param tau + * Scaling factors for reflections + * @param a + * Input tensor view + * + */ + matxDnQRHostPlan_t(TauTensor &tau, + const ATensor &a) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Dim checks + MATX_STATIC_ASSERT_STR(RANK-1 == TauTensor::Rank(), matxInvalidDim, "Tau tensor must be one rank less than output tensor"); + MATX_STATIC_ASSERT_STR(RANK == ATensor::Rank(), matxInvalidDim, "Output tensor must match A tensor rank in QR"); + + // Type checks + MATX_STATIC_ASSERT_STR(!is_half_v, matxInvalidType, "QR solver does not support half precision"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Input and Output types must match"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "A and Tau types must match"); + + params = GetQRParams(tau, a); + this->GetWorkspaceSize(); + this->AllocateWorkspace(params.batch_size); + } + + void GetWorkspaceSize() override + { + // perform workspace query with lwork = -1 + lapack_int_t info; + lapack_type work_query; + + geqrf_dispatch(¶ms.m, ¶ms.n, nullptr, + ¶ms.m, nullptr, + &work_query, &this->lwork, &info); + MATX_ASSERT(info == 0, matxSolverError); + + // the real part of the first elem of work holds the optimal lwork + if constexpr (is_complex_v) { + this->lwork = static_cast(work_query.real); + } else { + this->lwork = static_cast(work_query); + } + } + + static DnQRHostParams_t GetQRParams(TauTensor &tau, + const ATensor &a) + { + DnQRHostParams_t params; + + params.batch_size = GetNumBatches(a); + params.m = static_cast(a.Size(RANK - 2)); + params.n = static_cast(a.Size(RANK - 1)); + params.A = a.Data(); + params.tau = tau.Data(); + params.dtype = TypeToInt(); + + return params; + } + + template + void Exec(OutTensor &out, TauTensor &tau, + const ATensor &a, const HostExecutor &exec) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Batch size checks + for(int i = 0 ; i < RANK-2; i++) { + MATX_ASSERT_STR(out.Size(i) == a.Size(i), matxInvalidDim, "Out and A must have the same batch sizes"); + MATX_ASSERT_STR(tau.Size(i) == a.Size(i), matxInvalidDim, "Tau and A must have the same batch sizes"); + } + + // Inner size checks + MATX_ASSERT_STR((out.Size(RANK-2) == params.m) && (out.Size(RANK-1) == params.n), matxInvalidSize, "Out and A shapes do not match"); + MATX_ASSERT_STR(tau.Size(RANK-2) == cuda::std::min(params.m, params.n), matxInvalidSize, "Tau must be ... x min(m,n)"); + + SetBatchPointers(out, this->batch_a_ptrs); + SetBatchPointers(tau, this->batch_tau_ptrs); + + if (out.Data() != a.Data()) { + (out = a).run(exec); + } + + lapack_int_t info; + for (size_t i = 0; i < this->batch_a_ptrs.size(); i++) { + geqrf_dispatch(¶ms.m, ¶ms.n, reinterpret_cast(this->batch_a_ptrs[i]), + ¶ms.m, reinterpret_cast(this->batch_tau_ptrs[i]), + reinterpret_cast(this->work), &this->lwork, &info); + + MATX_ASSERT(info == 0, matxSolverError); + } + } + + /** + * QR solver handle destructor + * + * Destroys any helper data used for provider type and any workspace memory + * created + * + */ + ~matxDnQRHostPlan_t() {} + +private: + void geqrf_dispatch(const lapack_int_t* m, const lapack_int_t* n, lapack_type* a, + const lapack_int_t* lda, lapack_type* tau, lapack_type* work_in, + const lapack_int_t* lwork_in, lapack_int_t* info) + { + if constexpr (std::is_same_v) { + sgeqrf_(m, n, a, lda, tau, work_in, lwork_in, info); + } else if constexpr (std::is_same_v) { + dgeqrf_(m, n, a, lda, tau, work_in, lwork_in, info); + } else if constexpr (std::is_same_v) { + cgeqrf_(m, n, a, lda, tau, work_in, lwork_in, info); + } else if constexpr (std::is_same_v) { + zgeqrf_(m, n, a, lda, tau, work_in, lwork_in, info); + } + } + + std::vector batch_tau_ptrs; + DnQRHostParams_t params; +}; + +/** + * Crude hash to get a reasonably good delta for collisions. This doesn't need + * to be perfect, but fast enough to not slow down lookups, and different enough + * so the common solver parameters change + */ +struct DnQRHostParamsKeyHash { + std::size_t operator()(const DnQRHostParams_t &k) const noexcept + { + return (std::hash()(k.m)) + (std::hash()(k.n)) + + (std::hash()(k.batch_size)); + } +}; + +/** + * Test QR parameters for equality. Unlike the hash, all parameters must match. + */ +struct DnQRHostParamsKeyEq { + bool operator()(const DnQRHostParams_t &l, const DnQRHostParams_t &t) const noexcept + { + return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size && + l.dtype == t.dtype; + } +}; + +using qr_Host_cache_t = std::unordered_map; +#endif + +} // end namespace detail + +/** + * Perform a QR decomposition using a cached plan + * + * See documentation of matxDnQRHostPlan_t for a description of how the + * algorithm works. This function provides a simple interface to a LAPACK + * library by deducing all parameters needed to perform a QR decomposition from + * only the matrix A. The input and output parameters may be the same tensor. In + * that case, the input is destroyed and the output is stored in-place. + * + * @tparam T1 + * Data type of matrix A + * @tparam RANK + * Rank of matrix A + * + * @param out + * Output tensor view + * @param tau + * Output of reflection scalar values + * @param a + * Input tensor A + * @param exec + * Host executor + */ +template +void qr_solver_impl([[maybe_unused]] OutTensor &&out, + [[maybe_unused]] TauTensor &&tau, + [[maybe_unused]] const ATensor &a, + [[maybe_unused]] const HostExecutor &exec) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + MATX_ASSERT_STR(MATX_EN_CPU_SOLVER, matxInvalidExecutor, + "Trying to run a host Solver executor but host Solver support is not configured"); +#if MATX_EN_CPU_SOLVER + + using T1 = typename remove_cvref_t::value_type; + + auto tau_new = OpToTensor(tau, exec); + auto a_new = OpToTensor(a, exec); + + if(!a_new.isSameView(a)) { + (a_new = a).run(exec); + } + + /* Temporary WAR + LAPACK doesn't support row-major layouts. Since we want to make the + library appear as though everything is row-major, we take a performance hit + to transpose in and out of the function. LAPACKE, however, supports both formats. + */ + T1 *tp; + matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_HOST_MALLOC_MEMORY); + auto tv = TransposeCopy(tp, a_new, exec); + auto tvt = tv.PermuteMatrix(); + + // Get parameters required by these tensors + auto params = detail::matxDnQRHostPlan_t::GetQRParams(tau_new, tvt); + + // Get cache or new QR plan if it doesn't exist + using cache_val_type = detail::matxDnQRHostPlan_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), + params, + [&]() { + return std::make_shared(tau_new, tvt); + }, + [&](std::shared_ptr ctype) { + ctype->Exec(tvt, tau_new, tvt, exec); + } + ); + + /* Temporary WAR + * Copy and free async buffer for transpose */ + matx::copy(out, tv.PermuteMatrix(), exec); + matxFree(tp); +#endif +} + +} // end namespace matx diff --git a/include/matx/transforms/reduce.h b/include/matx/transforms/reduce.h index 431a25612..b143b8fdc 100644 --- a/include/matx/transforms/reduce.h +++ b/include/matx/transforms/reduce.h @@ -1530,7 +1530,7 @@ void __MATX_INLINE__ mean_impl(OutType dest, const InType &in, * Single thread host executor */ template -void __MATX_INLINE__ mean_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) +void __MATX_INLINE__ mean_impl(OutType dest, const InType &in, [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("mean_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -1796,7 +1796,7 @@ void __MATX_INLINE__ median_impl(OutType dest, * Single thread host executor */ template -void __MATX_INLINE__ median_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) +void __MATX_INLINE__ median_impl(OutType dest, const InType &in, [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("median_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) { @@ -1888,7 +1888,7 @@ void __MATX_INLINE__ sum_impl(OutType dest, const InType &in, cudaExecutor exec * Single thread host executor */ template -void __MATX_INLINE__ sum_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) +void __MATX_INLINE__ sum_impl(OutType dest, const InType &in, [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("sum_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) { @@ -1957,7 +1957,7 @@ void __MATX_INLINE__ prod_impl(OutType dest, const InType &in, cudaExecutor exec * Single thread host executor */ template -void __MATX_INLINE__ prod_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) +void __MATX_INLINE__ prod_impl(OutType dest, const InType &in, [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("prod_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) { @@ -2034,7 +2034,7 @@ void __MATX_INLINE__ max_impl(OutType dest, const InType &in, cudaExecutor exec * Single threaded host executor */ template -void __MATX_INLINE__ max_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) +void __MATX_INLINE__ max_impl(OutType dest, const InType &in, [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("max_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -2112,7 +2112,7 @@ void __MATX_INLINE__ argmax_impl(OutType dest, TensorIndexType &idest, const InT * Single threaded host executor */ template -void __MATX_INLINE__ argmax_impl(OutType dest, TensorIndexType &idest, const InType &in, [[maybe_unused]] HostExecutor &exec) +void __MATX_INLINE__ argmax_impl(OutType dest, TensorIndexType &idest, const InType &in, [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("argmax_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -2185,7 +2185,7 @@ void __MATX_INLINE__ min_impl(OutType dest, const InType &in, cudaExecutor exec * Single threaded host executor */ template -void __MATX_INLINE__ min_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) +void __MATX_INLINE__ min_impl(OutType dest, const InType &in, [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("min_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) { @@ -2259,7 +2259,7 @@ void __MATX_INLINE__ argmin_impl(OutType dest, TensorIndexType &idest, const InT * SIngle host executor */ template -void __MATX_INLINE__ argmin_impl(OutType dest, TensorIndexType &idest, const InType &in, [[maybe_unused]] HostExecutor &exec) +void __MATX_INLINE__ argmin_impl(OutType dest, TensorIndexType &idest, const InType &in, [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("argmin_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -2332,7 +2332,7 @@ void __MATX_INLINE__ any_impl(OutType dest, const InType &in, cudaExecutor exec * Single threaded host executor */ template -void __MATX_INLINE__ any_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) +void __MATX_INLINE__ any_impl(OutType dest, const InType &in, [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("any_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -2405,7 +2405,7 @@ void __MATX_INLINE__ all_impl(OutType dest, const InType &in, cudaExecutor exec * Single threaded host executor */ template -void __MATX_INLINE__ all_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) +void __MATX_INLINE__ all_impl(OutType dest, const InType &in, [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("all_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -2491,7 +2491,7 @@ void __MATX_INLINE__ allclose(OutType dest, const InType1 &in1, const InType2 &i * Single threaded host executor */ template -void __MATX_INLINE__ allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, [[maybe_unused]] HostExecutor &exec) +void __MATX_INLINE__ allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, [[maybe_unused]] const HostExecutor &exec) { MATX_NVTX_START("allclose(" + get_type_str(in1) + ", " + get_type_str(in2) + ")", matx::MATX_NVTX_LOG_API) static_assert(OutType::Rank() == 0, "allclose output must be rank 0"); diff --git a/include/matx/transforms/solver.h b/include/matx/transforms/solver.h deleted file mode 100644 index 06cfcd079..000000000 --- a/include/matx/transforms/solver.h +++ /dev/null @@ -1,1551 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// BSD 3-Clause License -// -// Copyright (c) 2021, NVIDIA Corporation -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -///////////////////////////////////////////////////////////////////////////////// - -#pragma once - -#include "cublas_v2.h" -#include "cusolverDn.h" -#include "matx/core/error.h" -#include "matx/core/nvtx.h" -#include "matx/core/tensor.h" -#include "matx/core/cache.h" -#include -#include - -namespace matx { -namespace detail { -/** - * Dense solver base class that all dense solver types inherit common methods - * and structures from. The dense solvers used in the 64-bit cuSolver API all - * use host and device workspace, as well as an "info" allocation to point to - * issues during solving. - */ -class matxDnSolver_t { -public: - matxDnSolver_t() - { - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - [[maybe_unused]] cusolverStatus_t ret; - ret = cusolverDnCreate(&handle); - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - - ret = cusolverDnCreateParams(&dn_params); - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - } - - matxError_t SetAdvancedOptions(cusolverDnFunction_t function, - cusolverAlgMode_t algo) - { - [[maybe_unused]] cusolverStatus_t ret = cusolverDnSetAdvOptions(dn_params, function, algo); - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - - return matxSuccess; - } - - virtual ~matxDnSolver_t() - { - matxFree(d_workspace, cudaStreamDefault); - matxFree(h_workspace, cudaStreamDefault); - matxFree(d_info, cudaStreamDefault); - cusolverDnDestroy(handle); - } - - template - void SetBatchPointers(TensorType &a) - { - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - batch_a_ptrs.clear(); - - if constexpr (TensorType::Rank() == 2) { - batch_a_ptrs.push_back(&a(0, 0)); - } - else { - using shape_type = typename TensorType::desc_type::shape_type; - int batch_offset = 2; - cuda::std::array idx{0}; - auto a_shape = a.Shape(); - // Get total number of batches - size_t total_iter = std::accumulate(a_shape.begin(), a_shape.begin() + TensorType::Rank() - batch_offset, 1, std::multiplies()); - for (size_t iter = 0; iter < total_iter; iter++) { - auto ap = cuda::std::apply([&a](auto... param) { return a.GetPointer(param...); }, idx); - batch_a_ptrs.push_back(ap); - - // Update all but the last 2 indices - UpdateIndices(a, idx, batch_offset); - } - } - - } - - /** - * Get a transposed view of a tensor into a user-supplied buffer - * - * @param tp - * Pointer to pre-allocated memory - * @param a - * Tensor to transpose - * @param stream - * CUDA stream - */ - template - static inline auto - TransposeCopy(typename TensorType::value_type *tp, const TensorType &a, cudaStream_t stream = 0) - { - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - auto pa = a.PermuteMatrix(); - auto tv = make_tensor(tp, pa.Shape()); - matx::copy(tv, pa, stream); - return tv; - } - - template - static inline uint32_t GetNumBatches(const TensorType &a) - { - uint32_t cnt = 1; - for (int i = 3; i <= TensorType::Rank(); i++) { - cnt *= static_cast(a.Size(TensorType::Rank() - i)); - } - - return cnt; - } - - void AllocateWorkspace(size_t batches) - { - if (dspace > 0) { - matxAlloc(&d_workspace, batches * dspace, MATX_DEVICE_MEMORY); - } - - matxAlloc((void **)&d_info, batches * sizeof(*d_info), MATX_DEVICE_MEMORY); - - if (hspace > 0) { - matxAlloc(&h_workspace, batches * hspace, MATX_HOST_MEMORY); - } - } - - virtual void GetWorkspaceSize(size_t *host, size_t *device) = 0; - -protected: - cusolverDnHandle_t handle; - cusolverDnParams_t dn_params; - std::vector batch_a_ptrs; - int *d_info; - void *d_workspace = nullptr; - void *h_workspace = nullptr; - size_t hspace; - size_t dspace; -}; - -/** - * Parameters needed to execute a cholesky factorization. We distinguish unique - * factorizations mostly by the data pointer in A - */ -struct DnCholParams_t { - int64_t n; - void *A; - size_t batch_size; - cublasFillMode_t uplo; - MatXDataType_t dtype; -}; - -template -class matxDnCholSolverPlan_t : public matxDnSolver_t { - using OutTensor_t = remove_cvref_t; - static_assert(OutTensor_t::Rank() == remove_cvref_t::Rank(), "Cholesky input/output tensor ranks must match"); - using T1 = typename OutTensor_t::value_type; - static constexpr int RANK = OutTensor_t::Rank(); - -public: - /** - * Plan for solving - * \f$\textbf{A} = \textbf{L} * \textbf{L^{H}}\f$ or \f$\textbf{A} = - * \textbf{U} * \textbf{U^{H}}\f$ using the Cholesky method - * - * Creates a handle for solving the factorization of A = M * M^H of a dense - * matrix using the Cholesky method, where M is either the upper or lower - * triangular portion of A. Input matrix A must be a square Hermitian matrix - * positive-definite where only the upper or lower triangle is used. - * - * @tparam T1 - * Data type of A matrix - * @tparam RANK - * Rank of A matrix - * - * @param a - * Input tensor view - * @param uplo - * Use upper or lower triangle for computation - * - */ - matxDnCholSolverPlan_t(const ATensor &a, - cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) - { - static_assert(RANK >= 2); - - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - params = GetCholParams(a, uplo); - GetWorkspaceSize(&hspace, &dspace); - AllocateWorkspace(params.batch_size); - } - - void GetWorkspaceSize(size_t *host, size_t *device) override - { - cusolverStatus_t ret = cusolverDnXpotrf_bufferSize(handle, dn_params, params.uplo, - params.n, MatXTypeToCudaType(), - params.A, params.n, - MatXTypeToCudaType(), device, - host); - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - } - - static DnCholParams_t GetCholParams(const ATensor &a, - cublasFillMode_t uplo) - { - DnCholParams_t params; - params.batch_size = matxDnSolver_t::GetNumBatches(a); - params.n = a.Size(RANK - 1); - params.A = a.Data(); - params.uplo = uplo; - params.dtype = TypeToInt(); - - return params; - } - - void Exec(OutputTensor &out, const ATensor &a, - cudaStream_t stream, cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) - { - // Ensure matrix is square - MATX_ASSERT(a.Size(RANK - 1) == a.Size(RANK - 2), matxInvalidSize); - - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - // Ensure output size matches input - for (int i = 0; i < RANK; i++) { - MATX_ASSERT(out.Size(i) == a.Size(i), matxInvalidSize); - } - - cusolverDnSetStream(handle, stream); - - SetBatchPointers(out); - if (out.Data() != a.Data()) { - matx::copy(out, a, stream); - } - - // At this time cuSolver does not have a batched 64-bit cholesky interface. - // Change this to use the batched version once available. - for (size_t i = 0; i < batch_a_ptrs.size(); i++) { - auto ret = cusolverDnXpotrf( - handle, dn_params, uplo, params.n, MatXTypeToCudaType(), - batch_a_ptrs[i], params.n, MatXTypeToCudaType(), - reinterpret_cast(d_workspace) + i * dspace, dspace, - reinterpret_cast(h_workspace) + i * hspace, hspace, - d_info + i); - - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - } - } - - /** - * Cholesky solver handle destructor - * - * Destroys any helper data used for provider type and any workspace memory - * created - * - */ - ~matxDnCholSolverPlan_t() {} - -private: - DnCholParams_t params; -}; - -/** - * Crude hash to get a reasonably good delta for collisions. This doesn't need - * to be perfect, but fast enough to not slow down lookups, and different enough - * so the common solver parameters change - */ -struct DnCholParamsKeyHash { - std::size_t operator()(const DnCholParams_t &k) const noexcept - { - return (std::hash()(k.n)) + (std::hash()(k.batch_size)); - } -}; - -/** - * Test cholesky parameters for equality. Unlike the hash, all parameters must - * match. - */ -struct DnCholParamsKeyEq { - bool operator()(const DnCholParams_t &l, const DnCholParams_t &t) const - noexcept - { - return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype; - } -}; - -using chol_cache_t = std::unordered_map; - - -/***************************************** LU FACTORIZATION - * *********************************************/ - -/** - * Parameters needed to execute an LU factorization. We distinguish unique - * factorizations mostly by the data pointer in A - */ -struct DnLUParams_t { - int64_t m; - int64_t n; - void *A; - void *piv; - size_t batch_size; - MatXDataType_t dtype; -}; - -template -class matxDnLUSolverPlan_t : public matxDnSolver_t { - using OutTensor_t = remove_cvref_t; - static constexpr int RANK = OutTensor_t::Rank(); - using T1 = typename OutTensor_t::value_type; - static_assert(RANK-1 == PivotTensor::Rank(), "Pivot tensor rank must be one less than output"); - static_assert(std::is_same_v, "Pivot tensor type must be int64_t"); - -public: - /** - * Plan for factoring A such that \f$\textbf{P} * \textbf{A} = \textbf{L} * - * \textbf{U}\f$ - * - * Creates a handle for factoring matrix A into the format above. Matrix must - * not be singular. - * - * @tparam T1 - * Data type of A matrix - * @tparam RANK - * Rank of A matrix - * - * @param piv - * Pivot indices - * @param a - * Input tensor view - * - */ - matxDnLUSolverPlan_t(PivotTensor &piv, - const ATensor &a) - { - static_assert(RANK >= 2); - - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - params = GetLUParams(piv, a); - GetWorkspaceSize(&hspace, &dspace); - AllocateWorkspace(params.batch_size); - } - - void GetWorkspaceSize(size_t *host, size_t *device) override - { - cusolverStatus_t ret = cusolverDnXgetrf_bufferSize(handle, dn_params, params.m, - params.n, MatXTypeToCudaType(), - params.A, params.m, - MatXTypeToCudaType(), device, - host); - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - } - - static DnLUParams_t GetLUParams(PivotTensor &piv, - const ATensor &a) noexcept - { - DnLUParams_t params; - params.batch_size = matxDnSolver_t::GetNumBatches(a); - params.m = a.Size(RANK - 2); - params.n = a.Size(RANK - 1); - params.A = a.Data(); - params.piv = piv.Data(); - params.dtype = TypeToInt(); - - return params; - } - - void Exec(OutputTensor &out, PivotTensor &piv, - const ATensor &a, const cudaStream_t stream = 0) - { - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - cusolverDnSetStream(handle, stream); - int info; - - batch_piv_ptrs.clear(); - - if constexpr (RANK == 2) { - batch_piv_ptrs.push_back(&piv(0)); - } - else if constexpr (RANK == 3) { - for (int i = 0; i < piv.Size(0); i++) { - batch_piv_ptrs.push_back(&piv(i, 0)); - } - } - else { - for (int i = 0; i < piv.Size(0); i++) { - for (int j = 0; j < piv.Size(1); j++) { - batch_piv_ptrs.push_back(&piv(i, j, 0)); - } - } - } - - SetBatchPointers(out); - - if (out.Data() != a.Data()) { - matx::copy(out, a, stream); - } - - // At this time cuSolver does not have a batched 64-bit LU interface. Change - // this to use the batched version once available. - for (size_t i = 0; i < batch_a_ptrs.size(); i++) { - auto ret = cusolverDnXgetrf( - handle, dn_params, params.m, params.n, MatXTypeToCudaType(), - batch_a_ptrs[i], params.m, batch_piv_ptrs[i], - MatXTypeToCudaType(), - reinterpret_cast(d_workspace) + i * dspace, dspace, - reinterpret_cast(h_workspace) + i * hspace, hspace, - d_info + i); - - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - - // This will block. Figure this out later - cudaMemcpy(&info, d_info + i, sizeof(info), cudaMemcpyDeviceToHost); - MATX_ASSERT(info == 0, matxSolverError); - } - } - - /** - * LU solver handle destructor - * - * Destroys any helper data used for provider type and any workspace memory - * created - * - */ - ~matxDnLUSolverPlan_t() {} - -private: - std::vector batch_piv_ptrs; - DnLUParams_t params; -}; - -/** - * Crude hash to get a reasonably good delta for collisions. This doesn't need - * to be perfect, but fast enough to not slow down lookups, and different enough - * so the common solver parameters change - */ -struct DnLUParamsKeyHash { - std::size_t operator()(const DnLUParams_t &k) const noexcept - { - return (std::hash()(k.m)) + (std::hash()(k.n)) + - (std::hash()(k.batch_size)); - } -}; - -/** - * Test LU parameters for equality. Unlike the hash, all parameters must match. - */ -struct DnLUParamsKeyEq { - bool operator()(const DnLUParams_t &l, const DnLUParams_t &t) const noexcept - { - return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size && - l.dtype == t.dtype; - } -}; - -// Static caches of LU handles -using lu_cache_t = std::unordered_map; - - -/***************************************** QR FACTORIZATION - * *********************************************/ - -/** - * Parameters needed to execute a QR factorization. We distinguish unique - * factorizations mostly by the data pointer in A - */ -struct DnQRParams_t { - int64_t m; - int64_t n; - void *A; - void *tau; - size_t batch_size; - MatXDataType_t dtype; -}; - -template -class matxDnQRSolverPlan_t : public matxDnSolver_t { - using out_type_t = remove_cvref_t; - using T1 = typename out_type_t::value_type; - static constexpr int RANK = out_type_t::Rank(); - static_assert(out_type_t::Rank()-1 == TauTensor::Rank(), "Tau tensor must be one rank less than output tensor"); - static_assert(out_type_t::Rank() == ATensor::Rank(), "Output tensor must match A tensor rank in SVD"); - -public: - /** - * Plan for factoring A such that \f$\textbf{A} = \textbf{Q} * \textbf{R}\f$ - * - * Creates a handle for factoring matrix A into the format above. QR - * decomposition in cuBLAS/cuSolver does not return the Q matrix directly, and - * it must be computed separately used the Householder reflections in the tau - * output, along with the overwritten A matrix input. The input and output - * parameters may be the same tensor. In that case, the input is destroyed and - * the output is stored in-place. - * - * @tparam T1 - * Data type of A matrix - * @tparam RANK - * Rank of A matrix - * - * @param tau - * Scaling factors for reflections - * @param a - * Input tensor view - * - */ - matxDnQRSolverPlan_t(TauTensor &tau, - const ATensor &a) - { - static_assert(RANK >= 2); - - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - params = GetQRParams(tau, a); - GetWorkspaceSize(&hspace, &dspace); - AllocateWorkspace(params.batch_size); - } - - void GetWorkspaceSize(size_t *host, size_t *device) override - { - cusolverStatus_t ret = cusolverDnXgeqrf_bufferSize( - handle, dn_params, params.m, params.n, MatXTypeToCudaType(), - params.A, params.m, MatXTypeToCudaType(), params.tau, - MatXTypeToCudaType(), device, host); - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - } - - static DnQRParams_t GetQRParams(TauTensor &tau, - const ATensor &a) - { - DnQRParams_t params; - - params.batch_size = matxDnSolver_t::GetNumBatches(a); - params.m = a.Size(RANK - 2); - params.n = a.Size(RANK - 1); - params.A = a.Data(); - params.tau = tau.Data(); - params.dtype = TypeToInt(); - - return params; - } - - void Exec(OutTensor &out, TauTensor &tau, - const ATensor &a, cudaStream_t stream = 0) - { - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - batch_tau_ptrs.clear(); - - // Ensure output size matches input - for (int i = 0; i < RANK; i++) { - MATX_ASSERT(out.Size(i) == a.Size(i), matxInvalidSize); - } - - SetBatchPointers(out); - - if constexpr (RANK == 2) { - batch_tau_ptrs.push_back(&tau(0)); - } - else if constexpr (RANK == 3) { - for (int i = 0; i < tau.Size(0); i++) { - batch_tau_ptrs.push_back(&tau(i, 0)); - } - } - else { - for (int i = 0; i < tau.Size(0); i++) { - for (int j = 0; j < tau.Size(1); j++) { - batch_tau_ptrs.push_back(&tau(i, j, 0)); - } - } - } - - if (out.Data() != a.Data()) { - matx::copy(out, a, stream); - } - - cusolverDnSetStream(handle, stream); - int info; - - // At this time cuSolver does not have a batched 64-bit LU interface. Change - // this to use the batched version once available. - for (size_t i = 0; i < batch_a_ptrs.size(); i++) { - auto ret = cusolverDnXgeqrf( - handle, dn_params, params.m, params.n, MatXTypeToCudaType(), - batch_a_ptrs[i], params.m, MatXTypeToCudaType(), - batch_tau_ptrs[i], MatXTypeToCudaType(), - reinterpret_cast(d_workspace) + i * dspace, dspace, - reinterpret_cast(h_workspace) + i * hspace, hspace, - d_info + i); - - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - - // This will block. Figure this out later - cudaMemcpy(&info, d_info + i, sizeof(info), cudaMemcpyDeviceToHost); - MATX_ASSERT(info == 0, matxSolverError); - } - } - - /** - * QR solver handle destructor - * - * Destroys any helper data used for provider type and any workspace memory - * created - * - */ - ~matxDnQRSolverPlan_t() {} - -private: - std::vector batch_tau_ptrs; - DnQRParams_t params; -}; - -/** - * Crude hash to get a reasonably good delta for collisions. This doesn't need - * to be perfect, but fast enough to not slow down lookups, and different enough - * so the common solver parameters change - */ -struct DnQRParamsKeyHash { - std::size_t operator()(const DnQRParams_t &k) const noexcept - { - return (std::hash()(k.m)) + (std::hash()(k.n)) + - (std::hash()(k.batch_size)); - } -}; - -/** - * Test QR parameters for equality. Unlike the hash, all parameters must match. - */ -struct DnQRParamsKeyEq { - bool operator()(const DnQRParams_t &l, const DnQRParams_t &t) const noexcept - { - return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size && - l.dtype == t.dtype; - } -}; - -using qr_cache_t = std::unordered_map; - - -/********************************************** SVD - * *********************************************/ - -/** - * Parameters needed to execute singular value decomposition. We distinguish - * unique factorizations mostly by the data pointer in A. - */ -struct DnSVDParams_t { - int64_t m; - int64_t n; - char jobu; - char jobvt; - void *A; - void *U; - void *V; - void *S; - size_t batch_size; - MatXDataType_t dtype; -}; - -template -class matxDnSVDSolverPlan_t : public matxDnSolver_t { - using T1 = typename ATensor::value_type; - using T2 = typename UTensor::value_type; - using T3 = typename STensor::value_type; - using T4 = typename VTensor::value_type; - static constexpr int RANK = UTensor::Rank(); - static_assert(UTensor::Rank()-1 == STensor::Rank(), "S tensor must be 1 rank lower than U tensor in SVD"); - static_assert(UTensor::Rank() == ATensor::Rank(), "U tensor must match A tensor rank in SVD"); - static_assert(UTensor::Rank() == VTensor::Rank(), "U tensor must match V tensor rank in SVD"); - static_assert(!is_complex_v, "S type must be real"); - -public: - /** - * Plan for factoring A such that \f$\textbf{A} = \textbf{U} * \textbf{\Sigma} - * * \textbf{V^{H}}\f$ - * - * Creates a handle for decomposing matrix A into the format above. - * - * @tparam T1 - * Data type of A matrix - * @tparam T2 - * Data type of U matrix - * @tparam T3 - * Data type of S vector - * @tparam T4 - * Data type of V matrix - * @tparam RANK - * Rank of A, U, and V matrices, and RANK-1 of S - * - * @param u - * Output tensor view for U matrix - * @param s - * Output tensor view for S matrix - * @param v - * Output tensor view for V matrix - * @param a - * Input tensor view for A matrix - * @param jobu - * Specifies options for computing all or part of the matrix U: = 'A'. See - * cuSolver documentation for more info - * @param jobvt - * specifies options for computing all or part of the matrix V**T. See - * cuSolver documentation for more info - * - */ - matxDnSVDSolverPlan_t(UTensor &u, - STensor &s, - VTensor &v, - const ATensor &a, const char jobu = 'A', - const char jobvt = 'A') - { - static_assert(RANK >= 2); - - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - make_tensor(scratch, a.Shape(), MATX_DEVICE_MEMORY); - params = GetSVDParams(u, s, v, scratch, jobu, jobvt); - - GetWorkspaceSize(&hspace, &dspace); - - SetBatchPointers(scratch); - AllocateWorkspace(params.batch_size); - } - - void GetWorkspaceSize(size_t *host, size_t *device) override - { - cusolverStatus_t ret = - cusolverDnXgesvd_bufferSize( - handle, dn_params, params.jobu, params.jobvt, params.m, params.n, - MatXTypeToCudaType(), params.A, params.m, - MatXTypeToCudaType(), params.S, MatXTypeToCudaType(), - params.U, params.m, MatXTypeToCudaType(), params.V, params.n, - MatXTypeToCudaType(), device, host); - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - } - - static DnSVDParams_t - GetSVDParams(UTensor &u, STensor &s, - VTensor &v, const ATensor &a, - const char jobu = 'A', const char jobvt = 'A') - { - DnSVDParams_t params; - params.batch_size = matxDnSolver_t::GetNumBatches(a); - params.m = a.Size(RANK - 2); - params.n = a.Size(RANK - 1); - params.A = a.Data(); - params.U = u.Data(); - params.V = v.Data(); - params.S = s.Data(); - params.jobu = jobu; - params.jobvt = jobvt; - params.dtype = TypeToInt(); - - return params; - } - - void Exec(UTensor &u, STensor &s, - VTensor &v, const ATensor &a, - const char jobu = 'A', const char jobvt = 'A', - cudaStream_t stream = 0) - { - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - batch_s_ptrs.clear(); - batch_v_ptrs.clear(); - batch_u_ptrs.clear(); - - if constexpr (RANK == 2) { - batch_s_ptrs.push_back(&s(0)); - batch_u_ptrs.push_back(&u(0, 0)); - batch_v_ptrs.push_back(&v(0, 0)); - } - else if constexpr (RANK == 3) { - for (int i = 0; i < a.Size(0); i++) { - batch_s_ptrs.push_back(&s(i, 0)); - batch_u_ptrs.push_back(&u(i, 0, 0)); - batch_v_ptrs.push_back(&v(i, 0, 0)); - } - } - else { - for (int i = 0; i < a.Size(0); i++) { - for (int j = 0; j < a.Size(1); j++) { - batch_s_ptrs.push_back(&s(i, j, 0)); - batch_u_ptrs.push_back(&u(i, j, 0, 0)); - batch_v_ptrs.push_back(&v(i, j, 0, 0)); - } - } - } - - cusolverDnSetStream(handle, stream); - matx::copy(scratch, a, stream); - int info; - - // At this time cuSolver does not have a batched 64-bit SVD interface. Change - // this to use the batched version once available. - for (size_t i = 0; i < batch_a_ptrs.size(); i++) { - - auto ret = cusolverDnXgesvd( - handle, dn_params, jobu, jobvt, params.m, params.n, - MatXTypeToCudaType(), batch_a_ptrs[i], params.m, - MatXTypeToCudaType(), batch_s_ptrs[i], MatXTypeToCudaType(), - batch_u_ptrs[i], params.m, MatXTypeToCudaType(), batch_v_ptrs[i], - params.n, MatXTypeToCudaType(), - reinterpret_cast(d_workspace) + i * dspace, dspace, - reinterpret_cast(h_workspace) + i * hspace, hspace, - d_info + i); - - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - - // This will block. Figure this out later - cudaMemcpy(&info, d_info + i, sizeof(info), cudaMemcpyDeviceToHost); - MATX_ASSERT(info == 0, matxSolverError); - } - } - - /** - * SVD solver handle destructor - * - * Destroys any helper data used for provider type and any workspace memory - * created - * - */ - ~matxDnSVDSolverPlan_t() {} - -private: - matx::tensor_t scratch; - std::vector batch_s_ptrs; - std::vector batch_v_ptrs; - std::vector batch_u_ptrs; - DnSVDParams_t params; -}; - -/** - * Crude hash to get a reasonably good delta for collisions. This doesn't need - * to be perfect, but fast enough to not slow down lookups, and different enough - * so the common solver parameters change - */ -struct DnSVDParamsKeyHash { - std::size_t operator()(const DnSVDParams_t &k) const noexcept - { - return (std::hash()(k.m)) + (std::hash()(k.n)) + - (std::hash()(k.batch_size)); - } -}; - -/** - * Test SVD parameters for equality. Unlike the hash, all parameters must match. - */ -struct DnSVDParamsKeyEq { - bool operator()(const DnSVDParams_t &l, const DnSVDParams_t &t) const noexcept - { - return l.n == t.n && l.m == t.m && l.jobu == t.jobu && l.jobvt == t.jobvt && - l.batch_size == t.batch_size && l.dtype == t.dtype; - } -}; - -using svd_cache_t = std::unordered_map; - -/*************************************** Eigenvalues and eigenvectors - * *************************************/ - -/** - * Parameters needed to execute singular value decomposition. We distinguish - * unique factorizations mostly by the data pointer in A. - */ -struct DnEigParams_t { - int64_t m; - cusolverEigMode_t jobz; - cublasFillMode_t uplo; - void *A; - void *out; - void *W; - size_t batch_size; - MatXDataType_t dtype; -}; - -template -class matxDnEigSolverPlan_t : public matxDnSolver_t { -public: - using T2 = typename WTensor::value_type; - using T1 = typename ATensor::value_type; - static constexpr int RANK = remove_cvref_t::Rank(); - static_assert(RANK == ATensor::Rank(), "Output and A tensor ranks must match for eigen solver"); - static_assert(RANK-1 == WTensor::Rank(), "W tensor must be one rank lower than output for eigen solver"); - - /** - * Plan computing eigenvalues/vectors on A such that: - * - * \f$\textbf{A} * textbf{V} = \textbf{V} * \textbf{\Lambda}\f$ - * - * - * @tparam T1 - * Data type of A matrix - * @tparam T2 - * Data type of W matrix - * @tparam RANK - * Rank of A matrix - * - * @param w - * Eigenvalues of A - * @param a - * Input tensor view - * @param jobz - * CUSOLVER_EIG_MODE_VECTOR to compute eigenvectors or - * CUSOLVER_EIG_MODE_NOVECTOR to not compute - * @param uplo - * Where to store data in A - * - */ - matxDnEigSolverPlan_t(WTensor &w, - const ATensor &a, - cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR, - cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) - { - static_assert(RANK >= 2); - - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - params = GetEigParams(w, a, jobz, uplo); - GetWorkspaceSize(&hspace, &dspace); - AllocateWorkspace(params.batch_size); - } - - void GetWorkspaceSize(size_t *host, size_t *device) override - { - cusolverStatus_t ret = cusolverDnXsyevd_bufferSize( - handle, dn_params, params.jobz, params.uplo, params.m, - MatXTypeToCudaType(), params.A, params.m, - MatXTypeToCudaType(), params.W, - MatXTypeToCudaType(), device, - host); - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - } - - static DnEigParams_t GetEigParams(WTensor &w, - const ATensor &a, - cusolverEigMode_t jobz, - cublasFillMode_t uplo) - { - DnEigParams_t params; - params.batch_size = matxDnSolver_t::GetNumBatches(a); - params.m = a.Size(RANK - 1); - params.A = a.Data(); - params.W = w.Data(); - params.jobz = jobz; - params.uplo = uplo; - params.dtype = TypeToInt(); - - return params; - } - - void Exec(OutputTensor &out, WTensor &w, - const ATensor &a, - cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR, - cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER, - cudaStream_t stream = 0) - { - // Ensure matrix is square - MATX_ASSERT(a.Size(RANK - 1) == a.Size(RANK - 2), matxInvalidSize); - - MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - batch_w_ptrs.clear(); - - // Ensure output size matches input - for (int i = 0; i < RANK; i++) { - MATX_ASSERT(out.Size(i) == a.Size(i), matxInvalidSize); - } - - if constexpr (RANK == 2) { - batch_w_ptrs.push_back(&w(0)); - } - else if constexpr (RANK == 3) { - for (int i = 0; i < a.Size(0); i++) { - batch_w_ptrs.push_back(&w(i, 0)); - } - } - else { - for (int i = 0; i < a.Size(0); i++) { - for (int j = 0; j < a.Size(1); j++) { - batch_w_ptrs.push_back(&w(i, j, 0)); - } - } - } - - SetBatchPointers(out); - - if (out.Data() != a.Data()) { - matx::copy(out, a, stream); - } - - cusolverDnSetStream(handle, stream); - int info; - - // At this time cuSolver does not have a batched 64-bit LU interface. Change - // this to use the batched version once available. - for (size_t i = 0; i < batch_a_ptrs.size(); i++) { - auto ret = cusolverDnXsyevd( - handle, dn_params, jobz, uplo, params.m, MatXTypeToCudaType(), - batch_a_ptrs[i], params.m, MatXTypeToCudaType(), batch_w_ptrs[i], - MatXTypeToCudaType(), - reinterpret_cast(d_workspace) + i * dspace, dspace, - reinterpret_cast(h_workspace) + i * hspace, hspace, - d_info + i); - - MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); - - // This will block. Figure this out later - cudaMemcpy(&info, d_info + i, sizeof(info), cudaMemcpyDeviceToHost); - MATX_ASSERT(info == 0, matxSolverError); - } - } - - /** - * Eigen solver handle destructor - * - * Destroys any helper data used for provider type and any workspace memory - * created - * - */ - ~matxDnEigSolverPlan_t() {} - -private: - std::vector batch_w_ptrs; - DnEigParams_t params; -}; - -/** - * Crude hash to get a reasonably good delta for collisions. This doesn't need - * to be perfect, but fast enough to not slow down lookups, and different enough - * so the common solver parameters change - */ -struct DnEigParamsKeyHash { - std::size_t operator()(const DnEigParams_t &k) const noexcept - { - return (std::hash()(k.m)) + (std::hash()(k.batch_size)); - } -}; - -/** - * Test Eigen parameters for equality. Unlike the hash, all parameters must - * match. - */ -struct DnEigParamsKeyEq { - bool operator()(const DnEigParams_t &l, const DnEigParams_t &t) const noexcept - { - return l.m == t.m && l.batch_size == t.batch_size && l.dtype == t.dtype; - } -}; - -using eig_cache_t = std::unordered_map; - -} - - -/** - * Perform a Cholesky decomposition using a cached plan - * - * See documentation of matxDnCholSolverPlan_t for a description of how the - * algorithm works. This function provides a simple interface to the cuSolver - * library by deducing all parameters needed to perform a Cholesky decomposition - * from only the matrix A. The input and output parameters may be the same - * tensor. In that case, the input is destroyed and the output is stored - * in-place. - * - * @tparam T1 - * Data type of matrix A - * @tparam RANK - * Rank of matrix A - * - * @param out - * Output tensor - * @param a - * Input tensor - * @param stream - * CUDA stream - * @param uplo - * Part of matrix to fill - */ -template -void chol_impl(OutputTensor &&out, const ATensor &a, - cudaStream_t stream = 0, - cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) -{ - MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) - - using OutputTensor_t = remove_cvref_t; - using T1 = typename OutputTensor_t::value_type; - - auto a_new = OpToTensor(a, stream); - - if(!a_new.isSameView(a)) { - (a_new = a).run(stream); - } - - // cuSolver assumes column-major matrices and MatX uses row-major matrices. - // One way to address this is to create a transposed copy of the input to - // use with the factorization, followed by transposing the output. However, - // for matrices with no additional padding, we can also change the value of - // uplo to effectively change the matrix to column-major. This allows us to - // compute the factorization without additional transposes. If we do not - // have contiguous input and output tensors, then we create a temporary - // contiguous tensor for use with cuSolver. - uplo = (uplo == CUBLAS_FILL_MODE_UPPER) ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; - - const bool allContiguous = a_new.IsContiguous() && out.IsContiguous(); - auto tv = [allContiguous, &a_new, &out, &stream]() -> auto { - if (allContiguous) { - (out = a_new).run(stream); - return out; - } else{ - auto t = make_tensor(a_new.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream); - matx::copy(t, a_new, stream); - return t; - } - }(); - - // Get parameters required by these tensors - auto params = detail::matxDnCholSolverPlan_t::GetCholParams(tv, uplo); - params.uplo = uplo; - - using cache_val_type = detail::matxDnCholSolverPlan_t; - detail::GetCache().LookupAndExec( - detail::GetCacheIdFromType(), - params, - [&]() { - return std::make_shared(tv, uplo); - }, - [&](std::shared_ptr ctype) { - ctype->Exec(tv, tv, stream, uplo); - } - ); - - if (! allContiguous) { - matx::copy(out, tv, stream); - } -} - - - -/** - * Perform an LU decomposition - * - * See documentation of matxDnLUSolverPlan_t for a description of how the - * algorithm works. This function provides a simple interface to the cuSolver - * library by deducing all parameters needed to perform an LU decomposition from - * only the matrix A. The input and output parameters may be the same tensor. In - * that case, the input is destroyed and the output is stored in-place. - * - * @tparam T1 - * Data type of matrix A - * @tparam RANK - * Rank of matrix A - * - * @param out - * Output tensor view - * @param piv - * Output of pivot indices - * @param a - * Input matrix A - * @param stream - * CUDA stream - */ -template -void lu_impl(OutputTensor &&out, PivotTensor &&piv, - const ATensor &a, const cudaStream_t stream = 0) -{ - MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) - - using T1 = typename remove_cvref_t::value_type; - - auto piv_new = OpToTensor(piv, stream); - auto a_new = OpToTensor(a, stream); - - if(!piv_new.isSameView(piv)) { - (piv_new = piv).run(stream); - } - if(!a_new.isSameView(a)) { - (a_new = a).run(stream); - } - - /* Temporary WAR - cuSolver doesn't support row-major layouts. Since we want to make the - library appear as though everything is row-major, we take a performance hit - to transpose in and out of the function. Eventually this may be fixed in - cuSolver. - */ - T1 *tp; - matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY, - stream); - auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a_new, stream); - auto tvt = tv.PermuteMatrix(); - - // Get parameters required by these tensors - auto params = detail::matxDnLUSolverPlan_t::GetLUParams(piv_new, tvt); - - // Get cache or new LU plan if it doesn't exist - using cache_val_type = detail::matxDnLUSolverPlan_t; - detail::GetCache().LookupAndExec( - detail::GetCacheIdFromType(), - params, - [&]() { - return std::make_shared(piv_new, tvt); - }, - [&](std::shared_ptr ctype) { - ctype->Exec(tvt, piv_new, tvt, stream); - } - ); - - /* Temporary WAR - * Copy and free async buffer for transpose */ - matx::copy(out, tv.PermuteMatrix(), stream); - matxFree(tp); -} - - - -/** - * Compute the determinant of a matrix - * - * Computes the terminant of a matrix by first computing the LU composition, - * then reduces the product of the diagonal elements of U. The input and output - * parameters may be the same tensor. In that case, the input is destroyed and - * the output is stored in-place. - * - * @tparam T1 - * Data type of matrix A - * @tparam RANK - * Rank of matrix A - * - * @param out - * Output tensor view - * @param a - * Input matrix A - * @param stream - * CUDA stream - */ -template -void det_impl(OutputTensor &out, const InputTensor &a, - const cudaStream_t stream = 0) -{ - MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) - - static_assert(OutputTensor::Rank() == InputTensor::Rank() - 2, "Output tensor rank must be 2 less than input for det()"); - constexpr int RANK = InputTensor::Rank(); - - auto a_new = OpToTensor(a, stream); - - if(!a_new.isSameView(a)) { - (a_new = a).run(stream); - } - - // Get parameters required by these tensors - cuda::std::array s; - - // Set batching dimensions of piv - for (int i = 0; i < RANK - 2; i++) { - s[i] = a_new.Size(i); - } - - s[RANK - 2] = cuda::std::min(a_new.Size(RANK - 1), a_new.Size(RANK - 2)); - - auto piv = make_tensor(s, MATX_ASYNC_DEVICE_MEMORY, stream); - auto ac = make_tensor(a_new.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream); - - lu_impl(ac, piv, a_new, stream); - (out = prod(diag(ac))).run(stream); -} - - -/** - * Perform a QR decomposition using a cached plan - * - * See documentation of matxDnQRSolverPlan_t for a description of how the - * algorithm works. This function provides a simple interface to the cuSolver - * library by deducing all parameters needed to perform a QR decomposition from - * only the matrix A. The input and output parameters may be the same tensor. In - * that case, the input is destroyed and the output is stored in-place. - * - * @tparam T1 - * Data type of matrix A - * @tparam RANK - * Rank of matrix A - * - * @param out - * Output tensor view - * @param tau - * Output of reflection scalar values - * @param a - * Input tensor A - * @param stream - * CUDA stream - */ -template -void cusolver_qr_impl(OutTensor &&out, TauTensor &&tau, - const ATensor &a, cudaStream_t stream = 0) -{ - MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) - - using T1 = typename remove_cvref_t::value_type; - - auto tau_new = OpToTensor(tau, stream); - auto a_new = OpToTensor(a, stream); - - if(!tau_new.isSameView(tau)) { - (tau_new = tau).run(stream); - } - if(!a_new.isSameView(a)) { - (a_new = a).run(stream); - } - - /* Temporary WAR - cuSolver doesn't support row-major layouts. Since we want to make the - library appear as though everything is row-major, we take a performance hit - to transpose in and out of the function. Eventually this may be fixed in - cuSolver. - */ - T1 *tp; - matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY, - stream); - auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a_new, stream); - auto tvt = tv.PermuteMatrix(); - - // Get parameters required by these tensors - auto params = detail::matxDnQRSolverPlan_t::GetQRParams(tau_new, tvt); - - // Get cache or new QR plan if it doesn't exist - using cache_val_type = detail::matxDnQRSolverPlan_t; - detail::GetCache().LookupAndExec( - detail::GetCacheIdFromType(), - params, - [&]() { - return std::make_shared(tau_new, tvt); - }, - [&](std::shared_ptr ctype) { - ctype->Exec(tvt, tau_new, tvt, stream); - } - ); - - /* Temporary WAR - * Copy and free async buffer for transpose */ - matx::copy(out, tv.PermuteMatrix(), stream); - matxFree(tp); -} - - - -/** - * Perform a SVD decomposition using a cached plan - * - * See documentation of matxDnSVDSolverPlan_t for a description of how the - * algorithm works. This function provides a simple interface to the cuSolver - * library by deducing all parameters needed to perform a SVD decomposition from - * only the matrix A. - * - * @tparam T1 - * Data type of matrix A - * @tparam RANK - * Rank of matrix A - * - * @param u - * U matrix output - * @param s - * Sigma matrix output - * @param v - * V matrix output - * @param a - * Input matrix A - * @param stream - * CUDA stream - * @param jobu - * Specifies options for computing all or part of the matrix U: = 'A'. See - * cuSolver documentation for more info - * @param jobvt - * specifies options for computing all or part of the matrix V**T. See - * cuSolver documentation for more info - * - */ -template -void svd_impl(UTensor &&u, STensor &&s, - VTensor &&v, const ATensor &a, - cudaStream_t stream = 0, const char jobu = 'A', const char jobvt = 'A') -{ - MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) - - using T1 = typename ATensor::value_type; - - auto u_new = OpToTensor(u, stream); - auto s_new = OpToTensor(s, stream); - auto v_new = OpToTensor(v, stream); - auto a_new = OpToTensor(a, stream); - - if(!u_new.isSameView(u)) { - (u_new = u).run(stream); - } - if(!s_new.isSameView(s)) { - (s_new = s).run(stream); - } - if(!v_new.isSameView(v)) { - (v_new = v).run(stream); - } - if(!a_new.isSameView(a)) { - (a_new = a).run(stream); - } - - /* Temporary WAR - cuSolver doesn't support row-major layouts. Since we want to make the - library appear as though everything is row-major, we take a performance hit - to transpose in and out of the function. Eventually this may be fixed in - cuSolver. - */ - T1 *tp; - matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream); - auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a_new, stream); - auto tvt = tv.PermuteMatrix(); - - // Get parameters required by these tensors - auto params = detail::matxDnSVDSolverPlan_t::GetSVDParams( - u_new, s_new, v_new, tvt, jobu, jobvt); - - // Get cache or new QR plan if it doesn't exist - using cache_val_type = detail::matxDnSVDSolverPlan_t; - detail::GetCache().LookupAndExec( - detail::GetCacheIdFromType(), - params, - [&]() { - return std::make_shared(u_new, s_new, v_new, tvt, jobu, jobvt); - }, - [&](std::shared_ptr ctype) { - ctype->Exec(u_new, s_new, v_new, tvt, jobu, jobvt, stream); - } - ); - - matxFree(tp); -} - - -/** - * Perform a Eig decomposition using a cached plan - * - * See documentation of matxDnEigSolverPlan_t for a description of how the - * algorithm works. This function provides a simple interface to the cuSolver - * library by deducing all parameters needed to perform a eigen decomposition - * from only the matrix A. The input and output parameters may be the same - * tensor. In that case, the input is destroyed and the output is stored - * in-place. - * - * @tparam T1 - * Data type of matrix A - * @tparam RANK - * Rank of matrix A - * - * @param out - * Output tensor view - * @param w - * Eigenvalues output - * @param a - * Input matrix A - * @param stream - * CUDA stream - * @param jobz - * CUSOLVER_EIG_MODE_VECTOR to compute eigenvectors or - * CUSOLVER_EIG_MODE_NOVECTOR to not compute - * @param uplo - * Where to store data in A - */ -template -void eig_impl(OutputTensor &&out, WTensor &&w, - const ATensor &a, cudaStream_t stream = 0, - cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR, - cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) -{ - MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) - - /* Temporary WAR - cuSolver doesn't support row-major layouts. Since we want to make the - library appear as though everything is row-major, we take a performance hit - to transpose in and out of the function. Eventually this may be fixed in - cuSolver. - */ - using T1 = typename remove_cvref_t::value_type; - - auto w_new = OpToTensor(w, stream); - auto a_new = OpToTensor(a, stream); - - if(!w_new.isSameView(w)) { - (w_new = w).run(stream); - } - if(!a_new.isSameView(a)) { - (a_new = a).run(stream); - } - - T1 *tp; - matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY, - stream); - auto tv = detail::matxDnSolver_t::TransposeCopy(tp, a_new, stream); - - // Get parameters required by these tensors - auto params = - detail::matxDnEigSolverPlan_t::GetEigParams(w_new, tv, jobz, uplo); - - // Get cache or new eigen plan if it doesn't exist - using cache_val_type = detail::matxDnEigSolverPlan_t; - detail::GetCache().LookupAndExec( - detail::GetCacheIdFromType(), - params, - [&]() { - return std::make_shared(w_new, tv, jobz, uplo); - }, - [&](std::shared_ptr ctype) { - ctype->Exec(tv, w_new, tv, jobz, uplo, stream); - } - ); - - /* Copy and free async buffer for transpose */ - matx::copy(out, tv.PermuteMatrix(), stream); - matxFree(tp); -} - - -} // end namespace matx diff --git a/include/matx/transforms/solver_common.h b/include/matx/transforms/solver_common.h new file mode 100644 index 000000000..37e17e42f --- /dev/null +++ b/include/matx/transforms/solver_common.h @@ -0,0 +1,275 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#ifdef MATX_EN_NVPL + #include + using lapack_int_t = nvpl_int_t; + using lapack_scomplex_t = nvpl_scomplex_t; + using lapack_dcomplex_t = nvpl_dcomplex_t; +#else + using lapack_int_t = int64_t; +#endif + +namespace matx { + +/* Parameter enums */ + +// Which part (lower or upper) of the dense matrix was filled +// and should be used by the function +enum class SolverFillMode { + UPPER, + LOWER +}; + +enum class EigenMode { + NO_VECTOR, // Only eigenvalues are computed + VECTOR // Both eigenvalues and eigenvectors are computed +}; + +// SVD job options for computing columns of U (jobu) and rows of VT (jobvt) +enum class SVDJob { + ALL, // For jobu: All M columns of U are computed + // For jobvt: All N rows of V^T are computed + REDUCED, // For jobu: The first min(m,n) columns of U are computed + // For jobvt: The first min(m,n) rows of V^T are computed + NONE // For jobu: No columns of U are computed + // For jobvt: No rows of V^T are computed +}; + +namespace detail { + +__MATX_INLINE__ char SVDJobToChar(SVDJob job) { + switch (job) { + case SVDJob::ALL: + return 'A'; + case SVDJob::REDUCED: + return 'S'; + case SVDJob::NONE: + return 'N'; + default: + MATX_ASSERT_STR(false, matxInvalidParameter, "Job for SVD not supported"); + return '\0'; + } +} + +/* Solver utility functions */ + +enum class BatchType { + VECTOR = 1, + MATRIX = 2 +}; + +/** + * @brief Sets batch pointers for a batched tensor of arbitrary rank. + * + * Clears the given batch pointers vector and then populates it + * with pointers to the data of the tensor for batched operations. + * Handles both batched matrices and vectors. + * + * @tparam BTYPE + * Whether the input is a batch of matrices or vectors + * @tparam TensorType + * Type of input tensor a + * @tparam PointerType + * Tensor value type + * + * @param a + * The tensor for which batch pointers are to be set. + * @param batch_ptrs + * The vector to be filled with pointers + */ +template +__MATX_INLINE__ void SetBatchPointers(const TensorType &a, std::vector &batch_ptrs) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + batch_ptrs.clear(); + + if constexpr (BTYPE == BatchType::VECTOR && TensorType::Rank() == 1) { + // single vector case + batch_ptrs.push_back(&a(0)); + } + else if constexpr (BTYPE == BatchType::MATRIX && TensorType::Rank() == 2) { + // single matrix case + batch_ptrs.push_back(&a(0, 0)); + } + else { + // batched vectors or matrices + using shape_type = typename TensorType::desc_type::shape_type; + int batch_offset = static_cast(BTYPE); + cuda::std::array idx{0}; + auto a_shape = a.Shape(); + size_t total_iter = std::accumulate(a_shape.begin(), a_shape.begin() + TensorType::Rank() - batch_offset, 1, std::multiplies()); + for (size_t iter = 0; iter < total_iter; iter++) { + auto ap = cuda::std::apply([&a](auto... param) { return a.GetPointer(param...); }, idx); + batch_ptrs.push_back(ap); + UpdateIndices(a, idx, batch_offset); + } + } +} + +template +__MATX_INLINE__ uint32_t GetNumBatches(const TensorType &a) +{ + uint32_t cnt = 1; + for (int i = 3; i <= TensorType::Rank(); i++) { + cnt *= static_cast(a.Size(TensorType::Rank() - i)); + } + + return cnt; +} + + +/** + * Dense cuSolver base class that all dense cuda solver types inherit common methods + * and structures from. The dense solvers used in the 64-bit cuSolver API all + * use host and device workspace, as well as an "info" allocation to point to + * issues during solving. + */ +class matxDnCUDASolver_t { +public: + matxDnCUDASolver_t() + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + [[maybe_unused]] cusolverStatus_t ret; + ret = cusolverDnCreate(&handle); + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + + ret = cusolverDnCreateParams(&dn_params); + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + } + + matxError_t SetAdvancedOptions(cusolverDnFunction_t function, + cusolverAlgMode_t algo) + { + [[maybe_unused]] cusolverStatus_t ret = cusolverDnSetAdvOptions(dn_params, function, algo); + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + + return matxSuccess; + } + + virtual ~matxDnCUDASolver_t() + { + matxFree(d_workspace, cudaStreamDefault); + matxFree(h_workspace, cudaStreamDefault); + matxFree(d_info, cudaStreamDefault); + cusolverDnDestroy(handle); + } + + void AllocateWorkspace(size_t batches) + { + if (dspace > 0) { + matxAlloc(&d_workspace, batches * dspace, MATX_DEVICE_MEMORY); + } + + matxAlloc((void **)&d_info, batches * sizeof(*d_info), MATX_DEVICE_MEMORY); + + if (hspace > 0) { + matxAlloc(&h_workspace, batches * hspace, MATX_HOST_MEMORY); + } + } + + virtual void GetWorkspaceSize() = 0; + +protected: + cusolverDnHandle_t handle; + cusolverDnParams_t dn_params; + std::vector batch_a_ptrs; + int *d_info; + void *d_workspace = nullptr; + void *h_workspace = nullptr; + size_t hspace; + size_t dspace; +}; + +#if MATX_EN_CPU_SOLVER +/** + * Dense LAPACK base class that all dense host solver types inherit common methods + * and structures from. Depending on the decomposition, it may require different + * types of workspace arrays. + * + * @tparam ValueType + * Input tensor type + * + */ +template +class matxDnHostSolver_t { + +public: + matxDnHostSolver_t() + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + } + + virtual ~matxDnHostSolver_t() + { + matxFree(work); + matxFree(rwork); + matxFree(iwork); + } + + void AllocateWorkspace([[maybe_unused]] size_t batches) + { + if (lwork > 0) { + matxAlloc(&work, lwork * sizeof(ValueType), MATX_HOST_MALLOC_MEMORY); + } + + // used for eig and svd complex types + if (lrwork > 0) { + matxAlloc(&rwork, lrwork * sizeof(typename inner_op_type_t::type), MATX_HOST_MALLOC_MEMORY); + } + + // used for all eig types + if (liwork > 0) { + matxAlloc(&iwork, liwork * sizeof(lapack_int_t), MATX_HOST_MALLOC_MEMORY); + } + } + + virtual void GetWorkspaceSize() {}; + +protected: + std::vector batch_a_ptrs; + void *work = nullptr; // work array of input type + void *rwork = nullptr; // real valued work array + void *iwork = nullptr; // integer valued work array + lapack_int_t lwork = -1; + lapack_int_t lrwork = -1; + lapack_int_t liwork = -1; +}; +#endif + +} // end namespace detail + +} // end namespace matx \ No newline at end of file diff --git a/include/matx/transforms/svd.h b/include/matx/transforms/svd/svd_cuda.h similarity index 60% rename from include/matx/transforms/svd.h rename to include/matx/transforms/svd/svd_cuda.h index f8fc05f53..6bd68513d 100644 --- a/include/matx/transforms/svd.h +++ b/include/matx/transforms/svd/svd_cuda.h @@ -32,14 +32,63 @@ #pragma once +#include "cublas_v2.h" +#include "cusolverDn.h" + #include "matx/core/error.h" #include "matx/core/nvtx.h" #include "matx/core/tensor.h" +#include "matx/core/cache.h" +#include "matx/operators/slice.h" +#include "matx/transforms/solver_common.h" + #include #include namespace matx { +namespace detail { + +template +inline auto svdbpi_impl_workspace(const AType &A, cudaStream_t stream) { + using ATypeS = typename AType::value_type; + const int RANK = AType::Rank(); + + auto m = A.Size(RANK-2); // rows + auto n = A.Size(RANK-1); // cols + auto d = cuda::std::min(n,m); // dim for AAT or ATA + + auto ATShape = A.Shape(); + ATShape[RANK-2] = d; + ATShape[RANK-1] = d; + + auto QShape = A.Shape(); + QShape[RANK-1] = d; + QShape[RANK-2] = d; + + auto RShape = A.Shape(); + RShape[RANK-1] = d; + RShape[RANK-2] = d; + + cuda::std::array l2NormShape; + for(int i=0;i(ATShape, MATX_ASYNC_DEVICE_MEMORY, stream); + auto Q = make_tensor(QShape, MATX_ASYNC_DEVICE_MEMORY, stream); + auto Qold = make_tensor(QShape, MATX_ASYNC_DEVICE_MEMORY, stream); + auto R = make_tensor(RShape, MATX_ASYNC_DEVICE_MEMORY, stream); + auto Z = make_tensor(QShape, MATX_ASYNC_DEVICE_MEMORY, stream); + auto l2Norm = make_tensor(l2NormShape, MATX_ASYNC_DEVICE_MEMORY, stream); + auto converged = make_tensor({}, MATX_ASYNC_DEVICE_MEMORY, stream); + return cuda::std::tuple(AT, Q, Qold, R, Z, l2Norm, converged); +} + +} // end namespace detail + + /** * Perform a SVD decomposition using the power iteration. This version of * SVD works well on small n/m with large batch. @@ -303,43 +352,6 @@ void svdpi_impl(UType &U, SType &S, VTType &VT, AType &A, X0Type &x0, int iterat } } -template -inline auto svdbpi_impl_workspace(const AType &A, cudaStream_t stream) { - using ATypeS = typename AType::value_type; - const int RANK = AType::Rank(); - - auto m = A.Size(RANK-2); // rows - auto n = A.Size(RANK-1); // cols - auto d = cuda::std::min(n,m); // dim for AAT or ATA - - auto ATShape = A.Shape(); - ATShape[RANK-2] = d; - ATShape[RANK-1] = d; - - auto QShape = A.Shape(); - QShape[RANK-1] = d; - QShape[RANK-2] = d; - - auto RShape = A.Shape(); - RShape[RANK-1] = d; - RShape[RANK-2] = d; - - cuda::std::array l2NormShape; - for(int i=0;i(ATShape, MATX_ASYNC_DEVICE_MEMORY, stream); - auto Q = make_tensor(QShape, MATX_ASYNC_DEVICE_MEMORY, stream); - auto Qold = make_tensor(QShape, MATX_ASYNC_DEVICE_MEMORY, stream); - auto R = make_tensor(RShape, MATX_ASYNC_DEVICE_MEMORY, stream); - auto Z = make_tensor(QShape, MATX_ASYNC_DEVICE_MEMORY, stream); - auto l2Norm = make_tensor(l2NormShape, MATX_ASYNC_DEVICE_MEMORY, stream); - auto converged = make_tensor({}, MATX_ASYNC_DEVICE_MEMORY, stream); - return cuda::std::tuple(AT, Q, Qold, R, Z, l2Norm, converged); -} - /** * Perform a SVD decomposition using the block power iteration. This version of * SVD works well on small n/m with large batch. @@ -408,8 +420,8 @@ inline void svdbpi_impl(UType &U, SType &S, VTType &VT, const AType &A, int max_ cudaEventCreate(&event); } - auto [AT, Q, Qold, R, Z, l2Norm, converged] = svdbpi_impl_workspace(A, stream); - auto qr_workspace = qr_internal_workspace(Z, stream); + auto [AT, Q, Qold, R, Z, l2Norm, converged] = detail::svdbpi_impl_workspace(A, stream); + auto qr_workspace = detail::qr_internal_workspace(Z, stream); // create spd matrix if ( m >= n ) { @@ -432,10 +444,10 @@ inline void svdbpi_impl(UType &U, SType &S, VTType &VT, const AType &A, int max_ // double pump this iteration so we get Qold and Q for tolerance checking. // We might take an extra iteration but it will overheads associated with checking concergence. matmul_impl(Z, AT, Q, exec); - qr_internal(Qold, R, Z, qr_workspace, exec); + detail::qr_internal(Qold, R, Z, qr_workspace, exec); matmul_impl(Z, AT, Qold, exec); - qr_internal(Q, R, Z, qr_workspace, exec); + detail::qr_internal(Q, R, Z, qr_workspace, exec); if(tol!=0.0f) { @@ -503,5 +515,314 @@ inline void svdbpi_impl(UType &U, SType &S, VTType &VT, const AType &A, int max_ } } -} // end namespace matx +/********************************************** SOLVER SVD + * *********************************************/ + +namespace detail { + +/** + * Parameters needed to execute singular value decomposition. We distinguish + * unique factorizations mostly by the data pointer in A. + */ +struct DnSVDCUDAParams_t { + int64_t m; + int64_t n; + char jobu; + char jobvt; + void *A; + void *U; + void *VT; + void *S; + size_t batch_size; + MatXDataType_t dtype; +}; + +template +class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t { + using T1 = typename ATensor::value_type; + using T2 = typename UTensor::value_type; + using T3 = typename STensor::value_type; + using T4 = typename VtTensor::value_type; + static constexpr int RANK = UTensor::Rank(); + static_assert(RANK >= 2, "Input/Output tensor must be rank 2 or higher"); + +public: + /** + * Plan for factoring A such that \f$\textbf{A} = \textbf{U} * \textbf{\Sigma} + * * \textbf{V^{H}}\f$ + * + * Creates a handle for decomposing matrix A into the format above. cuSolver destroys + * the contents of A, so a copy of the user input should be passed here. + * + * @tparam T1 + * Data type of A matrix + * @tparam T2 + * Data type of U matrix + * @tparam T3 + * Data type of S vector + * @tparam T4 + * Data type of VT matrix + * @tparam RANK + * Rank of A, U, and VT matrices, and RANK-1 of S + * + * @param u + * Output tensor view for U matrix + * @param s + * Output tensor view for S matrix + * @param vt + * Output tensor view for VT matrix + * @param a + * Input tensor view for A matrix + * @param jobu + * Specifies options for computing all or part of the matrix U: = 'A'. See + * SVDJob documentation for more info + * @param jobvt + * specifies options for computing all or part of the matrix V**T. See + * SVDJob documentation for more info + * + */ + matxDnSVDCUDAPlan_t(UTensor &u, + STensor &s, + VtTensor &vt, + const ATensor &a, const char jobu = 'A', + const char jobvt = 'A') + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Dim checks + MATX_STATIC_ASSERT_STR(UTensor::Rank()-1 == STensor::Rank(), matxInvalidDim, "S tensor must be 1 rank lower than U tensor in SVD"); + MATX_STATIC_ASSERT_STR(UTensor::Rank() == ATensor::Rank(), matxInvalidDim, "U tensor must match A tensor rank in SVD"); + MATX_STATIC_ASSERT_STR(UTensor::Rank() == VtTensor::Rank(), matxInvalidDim, "U tensor must match VT tensor rank in SVD"); + + // Type checks + MATX_STATIC_ASSERT_STR(!is_half_v, matxInvalidType, "SVD solver does not support half precision"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "A and U types must match"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "A and VT types must match"); + MATX_STATIC_ASSERT_STR(!is_complex_v, matxInvalidType, "S type must be real"); + MATX_STATIC_ASSERT_STR((std::is_same_v::type, T3>), matxInvalidType, "A and S inner types must match"); + + params = GetSVDParams(u, s, vt, a, jobu, jobvt); + this->GetWorkspaceSize(); + this->AllocateWorkspace(params.batch_size); + } + + void GetWorkspaceSize() override + { + // Use all mode for a larger workspace size that works for all modes + cusolverStatus_t ret = + cusolverDnXgesvd_bufferSize( + this->handle, this->dn_params, 'A', 'A', params.m, params.n, + MatXTypeToCudaType(), params.A, params.m, + MatXTypeToCudaType(), params.S, MatXTypeToCudaType(), + params.U, params.m, MatXTypeToCudaType(), params.VT, params.n, + MatXTypeToCudaType(), &this->dspace, &this->hspace); + + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + } + + static DnSVDCUDAParams_t + GetSVDParams(UTensor &u, STensor &s, + VtTensor &vt, const ATensor &a, + const char jobu = 'A', const char jobvt = 'A') + { + DnSVDCUDAParams_t params; + params.batch_size = GetNumBatches(a); + params.m = a.Size(RANK - 2); + params.n = a.Size(RANK - 1); + params.A = a.Data(); + params.U = u.Data(); + params.VT = vt.Data(); + params.S = s.Data(); + params.jobu = jobu; + params.jobvt = jobvt; + params.dtype = TypeToInt(); + + return params; + } + + void Exec(UTensor &u, STensor &s, VtTensor &vt, + const ATensor &a, const cudaExecutor &exec, + const char jobu = 'A', const char jobvt = 'A') + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Batch size checks + for(int i = 0 ; i < RANK-2; i++) { + MATX_ASSERT_STR(u.Size(i) == a.Size(i), matxInvalidDim, "U and A must have the same batch sizes"); + MATX_ASSERT_STR(vt.Size(i) == a.Size(i), matxInvalidDim, "VT and A must have the same batch sizes"); + MATX_ASSERT_STR(s.Size(i) == a.Size(i), matxInvalidDim, "S and A must have the same batch sizes"); + } + + // Inner size checks + MATX_ASSERT_STR((u.Size(RANK-1) == params.m) && (u.Size(RANK-2) == u.Size(RANK-1)), matxInvalidSize, "U must be ... x m x m"); + MATX_ASSERT_STR((vt.Size(RANK-1) == params.n) && (vt.Size(RANK-2) == vt.Size(RANK-1)), matxInvalidSize, "VT must be ... x n x n"); + MATX_ASSERT_STR(s.Size(RANK-2) == cuda::std::min(params.m, params.n), matxInvalidSize, "S must be ... x min(m,n)"); + + SetBatchPointers(a, this->batch_a_ptrs); + SetBatchPointers(u, this->batch_u_ptrs); + SetBatchPointers(vt, this->batch_vt_ptrs); + SetBatchPointers(s, this->batch_s_ptrs); + + cusolverDnSetStream(this->handle, exec.getStream()); + int info; + + // At this time cuSolver does not have a batched 64-bit SVD interface. Change + // this to use the batched version once available. + for (size_t i = 0; i < this->batch_a_ptrs.size(); i++) { + + auto ret = cusolverDnXgesvd( + this->handle, this->dn_params, jobu, jobvt, params.m, params.n, + MatXTypeToCudaType(), this->batch_a_ptrs[i], params.m, + MatXTypeToCudaType(), this->batch_s_ptrs[i], MatXTypeToCudaType(), + this->batch_u_ptrs[i], params.m, MatXTypeToCudaType(), this->batch_vt_ptrs[i], + params.n, MatXTypeToCudaType(), + reinterpret_cast(this->d_workspace) + i * this->dspace, this->dspace, + reinterpret_cast(this->h_workspace) + i * this->hspace, this->hspace, + this->d_info + i); + + MATX_ASSERT(ret == CUSOLVER_STATUS_SUCCESS, matxSolverError); + + // This will block. Figure this out later + cudaMemcpy(&info, this->d_info + i, sizeof(info), cudaMemcpyDeviceToHost); + MATX_ASSERT(info == 0, matxSolverError); + } + } + + /** + * SVD solver handle destructor + * + * Destroys any helper data used for provider type and any workspace memory + * created + * + */ + ~matxDnSVDCUDAPlan_t() {} + +private: + std::vector batch_u_ptrs; + std::vector batch_s_ptrs; + std::vector batch_vt_ptrs; + DnSVDCUDAParams_t params; +}; + +/** + * Crude hash to get a reasonably good delta for collisions. This doesn't need + * to be perfect, but fast enough to not slow down lookups, and different enough + * so the common solver parameters change + */ +struct DnSVDCUDAParamsKeyHash { + std::size_t operator()(const DnSVDCUDAParams_t &k) const noexcept + { + return (std::hash()(k.m)) + (std::hash()(k.n)) + + (std::hash()(k.batch_size)); + } +}; + +/** + * Test SVD parameters for equality. Unlike the hash, all parameters must match. + */ +struct DnSVDCUDAParamsKeyEq { + bool operator()(const DnSVDCUDAParams_t &l, const DnSVDCUDAParams_t &t) const noexcept + { + return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size && l.dtype == t.dtype; + } +}; + +using svd_cuda_cache_t = std::unordered_map; + +} + +/** + * Perform a SVD decomposition using a cached plan + * + * See documentation of matxDnSVDCUDAPlan_t for a description of how the + * algorithm works. This function provides a simple interface to the cuSolver + * library by deducing all parameters needed to perform a SVD decomposition from + * only the matrix A. cuSolver only support m >= n. + * + * @tparam T1 + * Data type of matrix A + * @tparam RANK + * Rank of matrix A + * + * @param u + * U matrix output + * @param s + * Sigma matrix output + * @param vt + * VT matrix output + * @param a + * Input matrix A + * @param exec + * CUDA Executor + * @param jobu + * Specifies options for computing all or part of the matrix U: = 'A'. See + * SVDJob documentation for more info + * @param jobvt + * specifies options for computing all or part of the matrix V**T. See + * SVDJob documentation for more info + * + */ +template +void svd_impl(UTensor &&u, STensor &&s, + VtTensor &&vt, const ATensor &a, + const cudaExecutor &exec, const SVDJob jobu = SVDJob::ALL, + const SVDJob jobvt = SVDJob::ALL) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + using T1 = typename ATensor::value_type; + constexpr int RANK = ATensor::Rank(); + const auto stream = exec.getStream(); + + auto u_new = OpToTensor(u, exec); + auto s_new = OpToTensor(s, exec); + auto vt_new = OpToTensor(vt, exec); + auto a_new = OpToTensor(a, exec); + + if(!a_new.isSameView(a)) { + (a_new = a).run(exec); + } + + /* Temporary WAR + cuSolver doesn't support row-major layouts. Since we want to make the + library appear as though everything is row-major, we take a performance hit + to transpose in and out of the function. Eventually this may be fixed in + cuSolver. + */ + + T1 *tp; + matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream); + auto tv = TransposeCopy(tp, a_new, exec); + auto tvt = tv.PermuteMatrix(); + + auto u_col_maj = make_tensor(u_new.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream); + auto vt_col_maj = make_tensor(vt_new.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream); + + const char jobu_cusolver = detail::SVDJobToChar(jobu); + const char jobvt_cusolver = detail::SVDJobToChar(jobvt); + + // Get parameters required by these tensors + auto params = detail::matxDnSVDCUDAPlan_t:: + GetSVDParams(u_col_maj, s_new, vt_col_maj, tvt, jobu_cusolver, jobvt_cusolver); + + // Get cache or new QR plan if it doesn't exist + using cache_val_type = detail::matxDnSVDCUDAPlan_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), + params, + [&]() { + return std::make_shared(u_col_maj, s_new, vt_col_maj, tvt, jobu_cusolver, jobvt_cusolver); + }, + [&](std::shared_ptr ctype) { + ctype->Exec(u_col_maj, s_new, vt_col_maj, tvt, exec, jobu_cusolver, jobvt_cusolver); + } + ); + + // cuSolver writes to them in col-major format, so we need to transpose them back. + (u = transpose_matrix(u_col_maj)).run(exec); + (vt = transpose_matrix(vt_col_maj)).run(exec); + + matxFree(tp); +} + +} // end namespace matx \ No newline at end of file diff --git a/include/matx/transforms/svd/svd_lapack.h b/include/matx/transforms/svd/svd_lapack.h new file mode 100644 index 000000000..cd0939691 --- /dev/null +++ b/include/matx/transforms/svd/svd_lapack.h @@ -0,0 +1,388 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "matx/core/error.h" +#include "matx/core/nvtx.h" +#include "matx/core/tensor.h" +#include "matx/core/cache.h" +#include "matx/operators/slice.h" +#include "matx/executors/host.h" +#include "matx/executors/support.h" +#include "matx/transforms/solver_common.h" + +#include +#include + +namespace matx { + +namespace detail { + +#if MATX_EN_CPU_SOLVER +/** + * Parameters needed to execute singular value decomposition. We distinguish + * unique factorizations mostly by the data pointer in A. + */ +struct DnSVDHostParams_t { + lapack_int_t m; + lapack_int_t n; + char jobu; + char jobvt; + void *A; + void *U; + void *VT; + void *S; + size_t batch_size; + MatXDataType_t dtype; +}; + +template +class matxDnSVDHostPlan_t : matxDnHostSolver_t { + using T1 = typename ATensor::value_type; + using T2 = typename UTensor::value_type; + using T3 = typename STensor::value_type; + using T4 = typename VtTensor::value_type; + static constexpr int RANK = UTensor::Rank(); + static_assert(RANK >= 2, "Input/Output tensor must be rank 2 or higher"); + using lapack_type = std::conditional_t>, lapack_scomplex_t, + std::conditional_t>, lapack_dcomplex_t, T1>>; + +public: + /** + * Plan for factoring A such that \f$\textbf{A} = \textbf{U} * \textbf{\Sigma} + * * \textbf{V^{H}}\f$ + * + * Creates a handle for decomposing matrix A into the format above. LAPACK destroys + * the contents of A, so a copy of the user input should be passed here. + * + * @tparam T1 + * Data type of A matrix + * @tparam T2 + * Data type of U matrix + * @tparam T3 + * Data type of S vector + * @tparam T4 + * Data type of VT matrix + * @tparam RANK + * Rank of A, U, and VT matrices, and RANK-1 of S + * + * @param u + * Output tensor view for U matrix + * @param s + * Output tensor view for S matrix + * @param vt + * Output tensor view for VT matrix + * @param a + * Input tensor view for A matrix + * @param jobu + * Specifies options for computing all or part of the matrix U: = 'A'. See + * SVDJob documentation for more info + * @param jobvt + * specifies options for computing all or part of the matrix V**T. See + * SVDJob documentation for more info + * + */ + matxDnSVDHostPlan_t(UTensor &u, + STensor &s, + VtTensor &vt, + const ATensor &a, const char jobu = 'A', + const char jobvt = 'A') + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Dim checks + MATX_STATIC_ASSERT_STR(UTensor::Rank()-1 == STensor::Rank(), matxInvalidDim, "S tensor must be 1 rank lower than U tensor in SVD"); + MATX_STATIC_ASSERT_STR(UTensor::Rank() == ATensor::Rank(), matxInvalidDim, "U tensor must match A tensor rank in SVD"); + MATX_STATIC_ASSERT_STR(UTensor::Rank() == VtTensor::Rank(), matxInvalidDim, "U tensor must match VT tensor rank in SVD"); + + // Type checks + MATX_STATIC_ASSERT_STR(!is_half_v, matxInvalidType, "SVD solver does not support half precision"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "A and U types must match"); + MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "A and VT types must match"); + MATX_STATIC_ASSERT_STR(!is_complex_v, matxInvalidType, "S type must be real"); + MATX_STATIC_ASSERT_STR((std::is_same_v::type, T3>), matxInvalidType, "A and S inner types must match"); + + params = GetSVDParams(u, s, vt, a, jobu, jobvt); + this->GetWorkspaceSize(); + this->AllocateWorkspace(params.batch_size); + } + + void GetWorkspaceSize() override + { + // Perform a workspace query with lwork = -1. + + lapack_int_t info; + lapack_type work_query; + // Use all mode for a larger workspace size that works for all modes + gesvd_dispatch("A", "A", ¶ms.m, ¶ms.n, nullptr, + ¶ms.m, nullptr, nullptr, ¶ms.m, nullptr, ¶ms.n, + &work_query, &this->lwork, nullptr, &info); + + MATX_ASSERT(info == 0, matxSolverError); + + // the real part of the first elem of work holds the optimal lwork. + // rwork has size 5*min(M,N) and is only used for complex types + if constexpr (is_complex_v) { + this->lwork = static_cast(work_query.real); + this->lrwork = 5 * cuda::std::min(params.m, params.n); + } else { + this->lwork = static_cast(work_query); + this->lrwork = 0; // rwork is not used for real types + } + } + + static DnSVDHostParams_t + GetSVDParams(UTensor &u, STensor &s, + VtTensor &vt, const ATensor &a, + const char jobu = 'A', const char jobvt = 'A') + { + DnSVDHostParams_t params; + params.batch_size = GetNumBatches(a); + params.m = static_cast(a.Size(RANK - 2)); + params.n = static_cast(a.Size(RANK - 1)); + params.A = a.Data(); + params.U = u.Data(); + params.VT = vt.Data(); + params.S = s.Data(); + params.jobu = jobu; + params.jobvt = jobvt; + params.dtype = TypeToInt(); + + return params; + } + + template + void Exec(UTensor &u, STensor &s, VtTensor &vt, + const ATensor &a, [[maybe_unused]] const HostExecutor &exec, + const char jobu = 'A', const char jobvt = 'A') + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + // Batch size checks + for(int i = 0 ; i < RANK-2; i++) { + MATX_ASSERT_STR(u.Size(i) == a.Size(i), matxInvalidDim, "U and A must have the same batch sizes"); + MATX_ASSERT_STR(vt.Size(i) == a.Size(i), matxInvalidDim, "VT and A must have the same batch sizes"); + MATX_ASSERT_STR(s.Size(i) == a.Size(i), matxInvalidDim, "S and A must have the same batch sizes"); + } + + // Inner size checks + MATX_ASSERT_STR((u.Size(RANK-1) == params.m) && (u.Size(RANK-2) == u.Size(RANK-1)), matxInvalidSize, "U must be ... x m x m"); + MATX_ASSERT_STR((vt.Size(RANK-1) == params.n) && (vt.Size(RANK-2) == vt.Size(RANK-1)), matxInvalidSize, "VT must be ... x n x n"); + MATX_ASSERT_STR(s.Size(RANK-2) == cuda::std::min(params.m, params.n), matxInvalidSize, "S must be ... x min(m,n)"); + + SetBatchPointers(a, this->batch_a_ptrs); + SetBatchPointers(u, this->batch_u_ptrs); + SetBatchPointers(vt, this->batch_vt_ptrs); + SetBatchPointers(s, this->batch_s_ptrs); + + lapack_int_t info; + for (size_t i = 0; i < this->batch_a_ptrs.size(); i++) { + gesvd_dispatch(&jobu, &jobvt, ¶ms.m, ¶ms.n, + reinterpret_cast(this->batch_a_ptrs[i]), + ¶ms.m, reinterpret_cast(this->batch_s_ptrs[i]), + reinterpret_cast(this->batch_u_ptrs[i]), ¶ms.m, + reinterpret_cast(this->batch_vt_ptrs[i]), ¶ms.n, + reinterpret_cast(this->work), &this->lwork, + reinterpret_cast(this->rwork), &info); + + MATX_ASSERT(info == 0, matxSolverError); + } + } + + /** + * SVD solver handle destructor + * + * Destroys any helper data used for provider type and any workspace memory + * created + * + */ + ~matxDnSVDHostPlan_t() {} + +private: + void gesvd_dispatch(const char *jobu, const char *jobvt, const lapack_int_t *m, + const lapack_int_t *n, lapack_type *a, + const lapack_int_t *lda, T3 *s, lapack_type *u, + const lapack_int_t *ldu, lapack_type *vt, + const lapack_int_t *ldvt, lapack_type *work_in, + const lapack_int_t *lwork_in, [[maybe_unused]] T3 *rwork_in, lapack_int_t *info) + { + // TODO: remove warning suppression once gesvd is optimized in NVPL LAPACK +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + if constexpr (std::is_same_v) { + sgesvd_(jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work_in, lwork_in, info); + } else if constexpr (std::is_same_v) { + dgesvd_(jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work_in, lwork_in, info); + } else if constexpr (std::is_same_v) { + cgesvd_(jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work_in, lwork_in, rwork_in, info); + } else if constexpr (std::is_same_v) { + zgesvd_(jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work_in, lwork_in, rwork_in, info); + } +#pragma GCC diagnostic pop + } + + std::vector batch_u_ptrs; + std::vector batch_s_ptrs; + std::vector batch_vt_ptrs; + DnSVDHostParams_t params; +}; + +/** + * Crude hash to get a reasonably good delta for collisions. This doesn't need + * to be perfect, but fast enough to not slow down lookups, and different enough + * so the common solver parameters change + */ +struct DnSVDHostParamsKeyHash { + std::size_t operator()(const DnSVDHostParams_t &k) const noexcept + { + return (std::hash()(k.m)) + (std::hash()(k.n)) + + (std::hash()(k.batch_size)); + } +}; + +/** + * Test SVD parameters for equality. Unlike the hash, all parameters must match. + */ +struct DnSVDHostParamsKeyEq { + bool operator()(const DnSVDHostParams_t &l, const DnSVDHostParams_t &t) const noexcept + { + return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size && l.dtype == t.dtype; + } +}; + +using svd_Host_cache_t = std::unordered_map; +#endif + +} + +/** + * Perform a SVD decomposition using a cached plan + * + * See documentation of matxDnSVDHostPlan_t for a description of how the + * algorithm works. This function provides a simple interface to the LAPACK + * library by deducing all parameters needed to perform a SVD decomposition from + * only the matrix A. + * + * @tparam T1 + * Data type of matrix A + * @tparam RANK + * Rank of matrix A + * + * @param u + * U matrix output + * @param s + * Sigma matrix output + * @param vt + * VT matrix output + * @param a + * Input matrix A + * @param exec + * Host Executor + * @param jobu + * Specifies options for computing all or part of the matrix U: = 'A'. See + * SVDJob documentation for more info + * @param jobvt + * specifies options for computing all or part of the matrix V**T. See + * SVDJob documentation for more info + * + */ +template +void svd_impl([[maybe_unused]] UTensor &&u, + [[maybe_unused]] STensor &&s, + [[maybe_unused]] VtTensor &&vt, + [[maybe_unused]] const ATensor &a, + [[maybe_unused]] const HostExecutor &exec, + [[maybe_unused]] const SVDJob jobu = SVDJob::ALL, + [[maybe_unused]] const SVDJob jobvt = SVDJob::ALL) +{ + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + MATX_ASSERT_STR(MATX_EN_CPU_SOLVER, matxInvalidExecutor, + "Trying to run a host Solver executor but host Solver support is not configured"); +#if MATX_EN_CPU_SOLVER + + using T1 = typename ATensor::value_type; + constexpr int RANK = ATensor::Rank(); + + auto u_new = OpToTensor(u, exec); + auto s_new = OpToTensor(s, exec); + auto vt_new = OpToTensor(vt, exec); + auto a_new = OpToTensor(a, exec); + + if(!a_new.isSameView(a)) { + (a_new = a).run(exec); + } + + /* Temporary WAR + LAPACK doesn't support row-major layouts. Since we want to make the + library appear as though everything is row-major, we take a performance hit + to transpose in and out of the function. LAPACKE, however, supports both formats. + */ + + T1 *tp; + matxAlloc(reinterpret_cast(&tp), a_new.Bytes(), MATX_HOST_MALLOC_MEMORY); + auto tv = TransposeCopy(tp, a_new, exec); + auto tvt = tv.PermuteMatrix(); + + auto u_col_maj = make_tensor(u_new.Shape(), MATX_HOST_MALLOC_MEMORY); + auto vt_col_maj = make_tensor(vt_new.Shape(), MATX_HOST_MALLOC_MEMORY); + + const char jobu_lapack = detail::SVDJobToChar(jobu); + const char jobvt_lapack = detail::SVDJobToChar(jobvt); + + // Get parameters required by these tensors + auto params = detail::matxDnSVDHostPlan_t:: + GetSVDParams(u_col_maj, s_new, vt_col_maj, tvt, jobu_lapack, jobvt_lapack); + + // Get cache or new QR plan if it doesn't exist + using cache_val_type = detail::matxDnSVDHostPlan_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), + params, + [&]() { + return std::make_shared(u_col_maj, s_new, vt_col_maj, tvt, jobu_lapack, jobvt_lapack); + }, + [&](std::shared_ptr ctype) { + ctype->Exec(u_col_maj, s_new, vt_col_maj, tvt, exec, jobu_lapack, jobvt_lapack); + } + ); + + // LAPACK writes to them in col-major format, so we need to transpose them back. + (u = transpose_matrix(u_col_maj)).run(exec); + (vt = transpose_matrix(vt_col_maj)).run(exec); + + matxFree(tp); +#endif +} + +} // end namespace matx + diff --git a/include/matx/transforms/transpose.h b/include/matx/transforms/transpose.h index 7940bcba3..23fb96dc7 100644 --- a/include/matx/transforms/transpose.h +++ b/include/matx/transforms/transpose.h @@ -104,7 +104,7 @@ namespace matx template __MATX_INLINE__ void transpose_matrix_impl([[maybe_unused]] OutputTensor &out, - const InputTensor &in, HostExecutor &exec) + const InputTensor &in, const HostExecutor &exec) { static_assert(InputTensor::Rank() >= 2, "transpose_matrix operator must be on rank 2 or greater"); diff --git a/test/00_solver/Cholesky.cu b/test/00_solver/Cholesky.cu index 8c8b6cc4e..94edc35b0 100644 --- a/test/00_solver/Cholesky.cu +++ b/test/00_solver/Cholesky.cu @@ -40,14 +40,28 @@ using namespace matx; template class CholSolverTest : public ::testing::Test { protected: + using GTestType = cuda::std::tuple_element_t<0, T>; + using GExecType = cuda::std::tuple_element_t<1, T>; void SetUp() override { + if constexpr (!detail::CheckSolverSupport()) { + GTEST_SKIP(); + } + + // Use an arbitrary number of threads for the select threads host exec. + if constexpr (is_select_threads_host_executor_v) { + HostExecParams params{4}; + exec = SelectThreadsHostExecutor{params}; + } + pb = std::make_unique(); } void TearDown() override { pb.reset(); } std::unique_ptr pb; + GExecType exec{}; + float thresh = 0.001f; }; template @@ -55,14 +69,13 @@ class CholSolverTestNonHalfFloatTypes : public CholSolverTest { }; TYPED_TEST_SUITE(CholSolverTestNonHalfFloatTypes, - MatXFloatNonHalfTypesCUDAExec); + MatXFloatNonHalfTypesAllExecs); TYPED_TEST(CholSolverTestNonHalfFloatTypes, CholeskyBasic) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; - using ExecType = cuda::std::tuple_element_t<1, TypeParam>; - ExecType exec; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; const cuda::std::array dims { 16, @@ -81,23 +94,23 @@ TYPED_TEST(CholSolverTestNonHalfFloatTypes, CholeskyBasic) this->pb->NumpyToTensorView(Lv, "L"); // example-begin chol-test-1 - (Bv = chol(Bv, CUBLAS_FILL_MODE_LOWER)).run(exec); + (Bv = chol(Bv, SolverFillMode::LOWER)).run(this->exec); // example-end chol-test-1 - exec.sync(); + this->exec.sync(); - // Cholesky fills the lower triangular portion (due to CUBLAS_FILL_MODE_LOWER) + // Cholesky fills the lower triangular portion (due to SolverFillMode::LOWER) // and destroys the upper triangular portion. if constexpr (is_complex_v) { for (index_t i = 0; i < dims[k]; i++) { for (index_t j = 0; j <= i; j++) { - ASSERT_NEAR(Bv(i, j).real(), Lv(i, j).real(), 0.001); - ASSERT_NEAR(Bv(i, j).imag(), Lv(i, j).imag(), 0.001); + ASSERT_NEAR(Bv(i, j).real(), Lv(i, j).real(), this->thresh); + ASSERT_NEAR(Bv(i, j).imag(), Lv(i, j).imag(), this->thresh); } } } else { for (index_t i = 0; i < dims[k]; i++) { for (index_t j = 0; j <= i; j++) { - ASSERT_NEAR(Bv(i, j), Lv(i, j), 0.001); + ASSERT_NEAR(Bv(i, j), Lv(i, j), this->thresh); } } } @@ -106,12 +119,61 @@ TYPED_TEST(CholSolverTestNonHalfFloatTypes, CholeskyBasic) MATX_EXIT_HANDLER(); } +TYPED_TEST(CholSolverTestNonHalfFloatTypes, CholeskyBasicBatched) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + const cuda::std::array dims { + 16, + 50, + 100, + 130, + 200, + 1000 + }; + constexpr index_t batches = 10; + + for (size_t k = 0; k < dims.size(); k++) { + this->pb->template InitAndRunTVGenerator("00_solver", "cholesky", "run", {batches, dims[k]}); + auto Bv = make_tensor({batches, dims[k], dims[k]}); + auto Lv = make_tensor({batches, dims[k], dims[k]}); + this->pb->NumpyToTensorView(Bv, "B"); + this->pb->NumpyToTensorView(Lv, "L"); + + (Bv = chol(Bv, SolverFillMode::LOWER)).run(this->exec); + this->exec.sync(); + + // Cholesky fills the lower triangular portion (due to SolverFillMode::LOWER) + // and destroys the upper triangular portion. + for (index_t b = 0; b < batches; b++) { + if constexpr (is_complex_v) { + for (index_t i = 0; i < dims[k]; i++) { + for (index_t j = 0; j <= i; j++) { + ASSERT_NEAR(Bv(b, i, j).real(), Lv(b, i, j).real(), this->thresh); + ASSERT_NEAR(Bv(b, i, j).imag(), Lv(b, i, j).imag(), this->thresh); + } + } + } else { + for (index_t i = 0; i < dims[k]; i++) { + for (index_t j = 0; j <= i; j++) { + ASSERT_NEAR(Bv(b, i, j), Lv(b, i, j), this->thresh); + } + } + } + } + } + + MATX_EXIT_HANDLER(); +} + + TYPED_TEST(CholSolverTestNonHalfFloatTypes, CholeskyWindowed) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; - using ExecType = cuda::std::tuple_element_t<1, TypeParam>; - ExecType exec; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; const cuda::std::array dims { 50, @@ -129,25 +191,25 @@ TYPED_TEST(CholSolverTestNonHalfFloatTypes, CholeskyWindowed) auto Lv = make_tensor({dims[k], dims[k]}); this->pb->NumpyToTensorView(Cv, "B"); this->pb->NumpyToTensorView(Lv, "L"); - (Bslice = Cv).run(exec); - exec.sync(); + (Bslice = Cv).run(this->exec); + this->exec.sync(); - (Bslice = chol(Bslice, CUBLAS_FILL_MODE_LOWER)).run(exec); - exec.sync(); + (Bslice = chol(Bslice, SolverFillMode::LOWER)).run(this->exec); + this->exec.sync(); - // Cholesky fills the lower triangular portion (due to CUBLAS_FILL_MODE_LOWER) + // Cholesky fills the lower triangular portion (due to SolverFillMode::LOWER) // and destroys the upper triangular portion. if constexpr (is_complex_v) { for (index_t i = 0; i < dims[k]; i++) { for (index_t j = 0; j <= i; j++) { - ASSERT_NEAR(Bslice(i, j).real(), Lv(i, j).real(), 0.001); - ASSERT_NEAR(Bslice(i, j).imag(), Lv(i, j).imag(), 0.001); + ASSERT_NEAR(Bslice(i, j).real(), Lv(i, j).real(), this->thresh); + ASSERT_NEAR(Bslice(i, j).imag(), Lv(i, j).imag(), this->thresh); } } } else { for (index_t i = 0; i < dims[k]; i++) { for (index_t j = 0; j <= i; j++) { - ASSERT_NEAR(Bslice(i, j), Lv(i, j), 0.001); + ASSERT_NEAR(Bslice(i, j), Lv(i, j), this->thresh); } } } diff --git a/test/00_solver/Det.cu b/test/00_solver/Det.cu index a1bda8eff..7d4b53a89 100644 --- a/test/00_solver/Det.cu +++ b/test/00_solver/Det.cu @@ -46,39 +46,100 @@ template class DetSolverTest : public ::testing::Test { protected: void SetUp() override { + if constexpr (!detail::CheckSolverSupport()) { + GTEST_SKIP(); + } + + // Use an arbitrary number of threads for the select threads host exec. + if constexpr (is_select_threads_host_executor_v) { + HostExecParams params{4}; + exec = SelectThreadsHostExecutor{params}; + } + pb = std::make_unique(); - pb->InitAndRunTVGenerator("00_solver", "det", "run", {m}); - pb->NumpyToTensorView(Av, "A"); } void TearDown() override { pb.reset(); } GExecType exec{}; std::unique_ptr pb; - tensor_t Av{{m, m}}; - tensor_t Atv{{m, m}}; - tensor_t detv{{}}; + float relTol = 2e-5f; }; template -class DetSolverTestNonComplexFloatTypes : public DetSolverTest { +class DetSolverTestFloatTypes : public DetSolverTest { }; -TYPED_TEST_SUITE(DetSolverTestNonComplexFloatTypes, - MatXFloatNonComplexNonHalfTypesCUDAExec); +template +double getMaxMagnitude(const T& value) { + if constexpr (is_complex_v) { + return cuda::std::max(cuda::std::fabs(value.real()), cuda::std::fabs(value.imag())); + } else { + return cuda::std::fabs(value); + } +} + +TYPED_TEST_SUITE(DetSolverTestFloatTypes, + MatXFloatNonHalfTypesAllExecs); -TYPED_TEST(DetSolverTestNonComplexFloatTypes, Determinant) +TYPED_TEST(DetSolverTestFloatTypes, DeterminantBasic) { MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using inner_type = typename inner_op_type_t::type; + + auto Av = make_tensor({m, m}); + auto detv = make_tensor({}); + auto detv_ref = make_tensor({}); - // cuSolver only supports col-major solving today, so we need to transpose, - // solve, then transpose again to compare to Python - (this->Atv = transpose(this->Av)).run(this->exec); + this->pb->template InitAndRunTVGenerator("00_solver", "det", "run", {m}); + this->pb->NumpyToTensorView(Av, "A"); - (this->detv = det(this->Atv)).run(this->exec); - (this->Av = transpose(this->Atv)).run(this->exec); // Transpose back to row-major + // example-begin det-test-1 + (detv = det(Av)).run(this->exec); + // example-end det-test-1 this->exec.sync(); - MATX_TEST_ASSERT_COMPARE(this->pb, this->detv, "det", 0.1); + this->pb->NumpyToTensorView(detv_ref, "det"); + + // The relative error is on the order of 2e-5 compared to the Python result + // for float types + auto thresh = this->relTol * getMaxMagnitude(detv_ref()); + + MATX_TEST_ASSERT_COMPARE(this->pb, detv, "det", thresh); MATX_EXIT_HANDLER(); } + +TYPED_TEST(DetSolverTestFloatTypes, DeterminantBasicBatched) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using inner_type = typename inner_op_type_t::type; + + constexpr int batches = 10; + + auto Av = make_tensor({batches, m, m}); + auto detv = make_tensor({batches}); + auto detv_ref = make_tensor({batches}); + + this->pb->template InitAndRunTVGenerator("00_solver", "det", "run", {batches, m}); + this->pb->NumpyToTensorView(Av, "A"); + this->pb->NumpyToTensorView(detv_ref, "det"); + + (detv = det(Av)).run(this->exec); + this->exec.sync(); + + for (index_t b = 0; b < batches; b++) { + auto detv_ref_b = detv_ref(b); + auto thresh = this->relTol * getMaxMagnitude(detv_ref_b); + + if constexpr (is_complex_v) { + ASSERT_NEAR(detv(b).real(), detv_ref_b.real(), thresh); + ASSERT_NEAR(detv(b).imag(), detv_ref_b.imag(), thresh); + } else { + ASSERT_NEAR(detv(b), detv_ref_b, thresh); + } + } + + MATX_EXIT_HANDLER(); +} \ No newline at end of file diff --git a/test/00_solver/Eigen.cu b/test/00_solver/Eigen.cu index b3ce3648c..f8e4b7130 100644 --- a/test/00_solver/Eigen.cu +++ b/test/00_solver/Eigen.cu @@ -45,61 +45,132 @@ protected: using GExecType = cuda::std::tuple_element_t<1, T>; void SetUp() override { + if constexpr (!detail::CheckSolverSupport()) { + GTEST_SKIP(); + } + + // Use an arbitrary number of threads for the select threads host exec. + if constexpr (is_select_threads_host_executor_v) { + HostExecParams params{4}; + exec = SelectThreadsHostExecutor{params}; + } + pb = std::make_unique(); - pb->InitAndRunTVGenerator("00_solver", "eig", "run", {dim_size}); - pb->NumpyToTensorView(Bv, "B"); } void TearDown() override { pb.reset(); } std::unique_ptr pb; GExecType exec{}; - tensor_t Bv{{dim_size, dim_size}}; - tensor_t Btv{{dim_size, dim_size}}; - tensor_t Evv{{dim_size, dim_size}}; - - tensor_t Wv{{dim_size, 1}}; - tensor_t Wov{{dim_size}}; - - tensor_t Gtv{{dim_size, 1}}; - tensor_t Lvv{{dim_size, 1}}; + float thresh = 0.001f; }; template -class EigenSolverTestNonComplexFloatTypes : public EigenSolverTest { +class EigenSolverTestFloatTypes : public EigenSolverTest { }; -TYPED_TEST_SUITE(EigenSolverTestNonComplexFloatTypes, - MatXFloatNonComplexNonHalfTypesCUDAExec); +TYPED_TEST_SUITE(EigenSolverTestFloatTypes, + MatXFloatNonHalfTypesAllExecs); -TYPED_TEST(EigenSolverTestNonComplexFloatTypes, EigenBasic) +TYPED_TEST(EigenSolverTestFloatTypes, EigenBasic) { MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using value_type = typename inner_op_type_t::type; + + auto Bv = make_tensor({dim_size, dim_size}); + auto Evv = make_tensor({dim_size, dim_size}); + auto Wov = make_tensor({dim_size}); + + auto Gv = make_tensor({dim_size, 1}); + auto Lvv = make_tensor({dim_size, 1}); + + this->pb->template InitAndRunTVGenerator("00_solver", "eig", "run", {dim_size}); + this->pb->NumpyToTensorView(Bv, "B"); + // example-begin eig-test-1 // Note that eigenvalue/vector solutions are not necessarily ordered in the same way other libraries // may order them. When comparing against other libraries it's best to check A*v = lambda * v - (mtie(this->Evv, this->Wov) = eig(this->Bv)).run(this->exec); + (mtie(Evv, Wov) = eig(Bv)).run(this->exec); // example-end eig-test-1 // Now we need to go through all the eigenvectors and eigenvalues and make // sure the results match the equation A*v = lambda*v, where v are the // eigenvectors corresponding to the eigenvalue lambda. for (index_t i = 0; i < dim_size; i++) { - auto v = this->Evv.template Slice<2>({0, i}, {matxEnd, i + 1}); - matx::copy(this->Wv, v, 0); + auto v = slice<2>(Evv, {0, i}, {matxEnd, i + 1}); // Compute lambda*v - auto b = v * this->Wov(i); - (this->Lvv = b).run(this->exec); - // Compute A*v + (Lvv = Wov(i) * v).run(this->exec); - (this->Gtv = matmul(this->Bv, this->Wv)).run(this->exec); + // Compute A*v + (Gv = matmul(Bv, v)).run(this->exec); this->exec.sync(); + // Compare for (index_t j = 0; j < dim_size; j++) { - ASSERT_NEAR(this->Gtv(j, 0), this->Lvv(j, 0), 0.001); + if constexpr (is_complex_v) { + ASSERT_NEAR(Gv(j, 0).real(), Lvv(j, 0).real(), this->thresh); + ASSERT_NEAR(Gv(j, 0).imag(), Lvv(j, 0).imag(), this->thresh); + } + else { + ASSERT_NEAR(Gv(j, 0), Lvv(j, 0), this->thresh); + } } } MATX_EXIT_HANDLER(); } + +TYPED_TEST(EigenSolverTestFloatTypes, EigenBasicBatched) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using value_type = typename inner_op_type_t::type; + + constexpr int batches = 10; + auto Bv = make_tensor({batches, dim_size, dim_size}); + auto Evv = make_tensor({batches, dim_size, dim_size}); + auto Wov = make_tensor({batches, dim_size}); + + auto Gv = make_tensor({dim_size, 1}); + auto Lvv = make_tensor({dim_size, 1}); + + this->pb->template InitAndRunTVGenerator("00_solver", "eig", "run", {batches, dim_size}); + this->pb->NumpyToTensorView(Bv, "B"); + + // Note that eigenvalue/vector solutions are not necessarily ordered in the same way other libraries + // may order them. When comparing against other libraries it's best to check A*v = lambda * v + (mtie(Evv, Wov) = eig(Bv)).run(this->exec); + + // Now we need to go through all the eigenvectors and eigenvalues and make + // sure the results match the equation A*v = lambda*v, where v are the + // eigenvectors corresponding to the eigenvalue lambda. + for (index_t b = 0; b < batches; b++) { + for (index_t i = 0; i < dim_size; i++) { + // ith col vector from bth batch + auto v = slice<2>(Evv, {b, 0, i}, {matxDropDim, matxEnd, i + 1}); + + // Compute lambda*v + (Lvv = Wov(b, i) * v).run(this->exec); + + // Compute A*v + auto Bv_b = slice<2>(Bv, {b, 0, 0}, {matxDropDim, matxEnd, matxEnd}); + (Gv = matmul(Bv_b, v)).run(this->exec); + this->exec.sync(); + + // Compare + for (index_t j = 0; j < dim_size; j++) { + if constexpr (is_complex_v) { + ASSERT_NEAR(Gv(j, 0).real(), Lvv(j, 0).real(), this->thresh); + ASSERT_NEAR(Gv(j, 0).imag(), Lvv(j, 0).imag(), this->thresh); + } + else { + ASSERT_NEAR(Gv(j, 0), Lvv(j, 0), this->thresh); + } + } + } + } + + MATX_EXIT_HANDLER(); +} \ No newline at end of file diff --git a/test/00_solver/Inverse.cu b/test/00_solver/Inverse.cu index fb8728a43..78315b6fa 100644 --- a/test/00_solver/Inverse.cu +++ b/test/00_solver/Inverse.cu @@ -51,6 +51,7 @@ protected: void TearDown() override { pb.reset(); } GExecType exec{}; std::unique_ptr pb; + float thresh = 0.001f; }; template @@ -69,7 +70,7 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv4x4) auto Ainv = make_tensor({4, 4}); auto Ainv_ref = make_tensor({4, 4}); - this->pb->template InitAndRunTVGenerator("00_solver", "inv", "run", {4, 1}); + this->pb->template InitAndRunTVGenerator("00_solver", "inv", "run", {4}); this->pb->NumpyToTensorView(A, "A"); this->pb->NumpyToTensorView(Ainv_ref, "A_inv"); @@ -82,11 +83,11 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv4x4) for (index_t i = 0; i < A.Size(0); i++) { for (index_t j = 0; j <= i; j++) { if constexpr (is_complex_v) { - ASSERT_NEAR(Ainv_ref(i, j).real(), Ainv(i, j).real(), 0.001); - ASSERT_NEAR(Ainv_ref(i, j).imag(), Ainv(i, j).imag(), 0.001); + ASSERT_NEAR(Ainv_ref(i, j).real(), Ainv(i, j).real(), this->thresh); + ASSERT_NEAR(Ainv_ref(i, j).imag(), Ainv(i, j).imag(), this->thresh); } else { - ASSERT_NEAR(Ainv_ref(i, j), Ainv(i, j), 0.001); + ASSERT_NEAR(Ainv_ref(i, j), Ainv(i, j), this->thresh); } } } @@ -103,7 +104,7 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv4x4Batched) auto Ainv = make_tensor({100, 4, 4}); auto Ainv_ref = make_tensor({100, 4, 4}); - this->pb->template InitAndRunTVGenerator("00_solver", "inv", "run", {4, 100}); + this->pb->template InitAndRunTVGenerator("00_solver", "inv", "run", {100, 4}); this->pb->NumpyToTensorView(A, "A"); this->pb->NumpyToTensorView(Ainv_ref, "A_inv"); @@ -114,11 +115,11 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv4x4Batched) for (index_t i = 0; i < A.Size(1); i++) { for (index_t j = 0; j <= i; j++) { if constexpr (is_complex_v) { - ASSERT_NEAR(Ainv_ref(b, i, j).real(), Ainv(b, i, j).real(), 0.001); - ASSERT_NEAR(Ainv_ref(b, i, j).imag(), Ainv(b, i, j).imag(), 0.001); + ASSERT_NEAR(Ainv_ref(b, i, j).real(), Ainv(b, i, j).real(), this->thresh); + ASSERT_NEAR(Ainv_ref(b, i, j).imag(), Ainv(b, i, j).imag(), this->thresh); } else { - ASSERT_NEAR(Ainv_ref(b, i, j), Ainv(b, i, j), 0.001); + ASSERT_NEAR(Ainv_ref(b, i, j), Ainv(b, i, j), this->thresh); } } } @@ -136,7 +137,7 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv8x8) auto Ainv = make_tensor({8, 8}); auto Ainv_ref = make_tensor({8, 8}); - this->pb->template InitAndRunTVGenerator("00_solver", "inv", "run", {8, 1}); + this->pb->template InitAndRunTVGenerator("00_solver", "inv", "run", {8}); this->pb->NumpyToTensorView(A, "A"); this->pb->NumpyToTensorView(Ainv_ref, "A_inv"); @@ -146,11 +147,11 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv8x8) for (index_t i = 0; i < A.Size(0); i++) { for (index_t j = 0; j <= i; j++) { if constexpr (is_complex_v) { - ASSERT_NEAR(Ainv_ref(i, j).real(), Ainv(i, j).real(), 0.001); - ASSERT_NEAR(Ainv_ref(i, j).imag(), Ainv(i, j).imag(), 0.001); + ASSERT_NEAR(Ainv_ref(i, j).real(), Ainv(i, j).real(), this->thresh); + ASSERT_NEAR(Ainv_ref(i, j).imag(), Ainv(i, j).imag(), this->thresh); } else { - ASSERT_NEAR(Ainv_ref(i, j), Ainv(i, j), 0.001); + ASSERT_NEAR(Ainv_ref(i, j), Ainv(i, j), this->thresh); } } } @@ -167,7 +168,7 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv8x8Batched) auto Ainv = make_tensor({100, 8, 8}); auto Ainv_ref = make_tensor({100, 8, 8}); - this->pb->template InitAndRunTVGenerator("00_solver", "inv", "run", {8, 100}); + this->pb->template InitAndRunTVGenerator("00_solver", "inv", "run", {100, 8}); this->pb->NumpyToTensorView(A, "A"); this->pb->NumpyToTensorView(Ainv_ref, "A_inv"); @@ -178,11 +179,11 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv8x8Batched) for (index_t i = 0; i < A.Size(1); i++) { for (index_t j = 0; j <= i; j++) { if constexpr (is_complex_v) { - ASSERT_NEAR(Ainv_ref(b, i, j).real(), Ainv(b, i, j).real(), 0.001); - ASSERT_NEAR(Ainv_ref(b, i, j).imag(), Ainv(b, i, j).imag(), 0.001); + ASSERT_NEAR(Ainv_ref(b, i, j).real(), Ainv(b, i, j).real(), this->thresh); + ASSERT_NEAR(Ainv_ref(b, i, j).imag(), Ainv(b, i, j).imag(), this->thresh); } else { - ASSERT_NEAR(Ainv_ref(b, i, j), Ainv(b, i, j), 0.001); + ASSERT_NEAR(Ainv_ref(b, i, j), Ainv(b, i, j), this->thresh); } } } @@ -201,7 +202,7 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv256x256) auto Ainv = make_tensor({256, 256}); auto Ainv_ref = make_tensor({256, 256}); - this->pb->template InitAndRunTVGenerator("00_solver", "inv", "run", {256, 1}); + this->pb->template InitAndRunTVGenerator("00_solver", "inv", "run", {256}); this->pb->NumpyToTensorView(A, "A"); this->pb->NumpyToTensorView(Ainv_ref, "A_inv"); @@ -211,11 +212,11 @@ TYPED_TEST(InvSolverTestFloatTypes, Inv256x256) for (index_t i = 0; i < A.Size(0); i++) { for (index_t j = 0; j <= i; j++) { if constexpr (is_complex_v) { - ASSERT_NEAR(Ainv_ref(i, j).real(), Ainv(i, j).real(), 0.001); - ASSERT_NEAR(Ainv_ref(i, j).imag(), Ainv(i, j).imag(), 0.001); + ASSERT_NEAR(Ainv_ref(i, j).real(), Ainv(i, j).real(), this->thresh); + ASSERT_NEAR(Ainv_ref(i, j).imag(), Ainv(i, j).imag(), this->thresh); } else { - ASSERT_NEAR(Ainv_ref(i, j), Ainv(i, j), 0.001); + ASSERT_NEAR(Ainv_ref(i, j), Ainv(i, j), this->thresh); } } } diff --git a/test/00_solver/LU.cu b/test/00_solver/LU.cu index 9955e5aa0..83f6b8d12 100644 --- a/test/00_solver/LU.cu +++ b/test/00_solver/LU.cu @@ -46,51 +46,111 @@ protected: using GExecType = cuda::std::tuple_element_t<1, T>; void SetUp() override { + if constexpr (!detail::CheckSolverSupport()) { + GTEST_SKIP(); + } + + // Use an arbitrary number of threads for the select threads host exec. + if constexpr (is_select_threads_host_executor_v) { + HostExecParams params{4}; + exec = SelectThreadsHostExecutor{params}; + } + pb = std::make_unique(); - pb->InitAndRunTVGenerator("00_solver", "lu", "run", {m, n}); - pb->NumpyToTensorView(Av, "A"); - pb->NumpyToTensorView(Lv, "L"); - pb->NumpyToTensorView(Uv, "U"); } void TearDown() override { pb.reset(); } GExecType exec{}; std::unique_ptr pb; - tensor_t Av{{m, n}}; - tensor_t Atv{{n, m}}; - tensor_t PivV{{std::min(m, n)}}; - tensor_t Lv{{m, std::min(m, n)}}; - tensor_t Uv{{std::min(m, n), n}}; + float thresh = 0.001f; }; template -class LUSolverTestNonComplexFloatTypes : public LUSolverTest { +class LUSolverTestFloatTypes : public LUSolverTest { }; -TYPED_TEST_SUITE(LUSolverTestNonComplexFloatTypes, - MatXFloatNonComplexNonHalfTypesCUDAExec); +TYPED_TEST_SUITE(LUSolverTestFloatTypes, + MatXFloatNonHalfTypesAllExecs); -TYPED_TEST(LUSolverTestNonComplexFloatTypes, LUBasic) +TYPED_TEST(LUSolverTestFloatTypes, LUBasic) { MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + using piv_value_type = std::conditional_t, int64_t, lapack_int_t>; + + auto Av = make_tensor({m, n}); + auto PivV = make_tensor({std::min(m, n)}); + auto Lv = make_tensor({m, std::min(m, n)}); + auto Uv = make_tensor({std::min(m, n), n}); + + this->pb->template InitAndRunTVGenerator("00_solver", "lu", "run", {m, n}); + this->pb->NumpyToTensorView(Av, "A"); + this->pb->NumpyToTensorView(Lv, "L"); + this->pb->NumpyToTensorView(Uv, "U"); + // example-begin lu-test-1 - (mtie(this->Av, this->PivV) = lu(this->Av)).run(this->exec); + (mtie(Av, PivV) = lu(Av)).run(this->exec); // example-end lu-test-1 this->exec.sync(); // The upper and lower triangle components are saved in Av. Python saves them // as separate matrices with the diagonal of the lower matrix set to 0 - for (index_t i = 0; i < this->Av.Size(0); i++) { - for (index_t j = 0; j < this->Av.Size(1); j++) { - if (i > j) { // Lower triangle - ASSERT_NEAR(this->Av(i, j), this->Lv(i, j), 0.001); + for (index_t i = 0; i < Av.Size(0); i++) { + for (index_t j = 0; j < Av.Size(1); j++) { + TestType refv = i > j ? Lv(i, j) : Uv(i, j); + if constexpr (is_complex_v) { + ASSERT_NEAR(Av(i, j).real(), refv.real(), this->thresh); + ASSERT_NEAR(Av(i, j).imag(), refv.imag(), this->thresh); } else { - ASSERT_NEAR(this->Av(i, j), this->Uv(i, j), 0.001); + ASSERT_NEAR(Av(i, j), refv, this->thresh); } } } MATX_EXIT_HANDLER(); } + +TYPED_TEST(LUSolverTestFloatTypes, LUBasicBatched) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + using piv_value_type = std::conditional_t, int64_t, lapack_int_t>; + constexpr int batches = 10; + + auto Av = make_tensor({batches, m, n}); + auto PivV = make_tensor({batches, std::min(m, n)}); + auto Lv = make_tensor({batches, m, std::min(m, n)}); + auto Uv = make_tensor({batches, std::min(m, n), n}); + + this->pb->template InitAndRunTVGenerator("00_solver", "lu", "run", {batches, m, n}); + this->pb->NumpyToTensorView(Av, "A"); + this->pb->NumpyToTensorView(Lv, "L"); + this->pb->NumpyToTensorView(Uv, "U"); + + (mtie(Av, PivV) = lu(Av)).run(this->exec); + this->exec.sync(); + + // The upper and lower triangle components are saved in Av. Python saves them + // as separate matrices with the diagonal of the lower matrix set to 0 + for (index_t b = 0; b < Av.Size(0); b++) { + for (index_t i = 0; i < Av.Size(1); i++) { + for (index_t j = 0; j < Av.Size(2); j++) { + TestType act = i > j ? Lv(b, i, j) : Uv(b, i, j); + if constexpr (is_complex_v) { + ASSERT_NEAR(Av(b, i, j).real(), act.real(), this->thresh); + ASSERT_NEAR(Av(b, i, j).imag(), act.imag(), this->thresh); + } + else { + ASSERT_NEAR(Av(b, i, j), act, this->thresh); + } + } + } + } + + MATX_EXIT_HANDLER(); +} \ No newline at end of file diff --git a/test/00_solver/QR.cu b/test/00_solver/QR.cu index 3c807ab10..b27714ce1 100644 --- a/test/00_solver/QR.cu +++ b/test/00_solver/QR.cu @@ -42,56 +42,114 @@ constexpr int n = 50; template class QRSolverTest : public ::testing::Test { protected: - using dtype = float; using GTestType = cuda::std::tuple_element_t<0, T>; using GExecType = cuda::std::tuple_element_t<1, T>; void SetUp() override { + if constexpr (!detail::CheckSolverSupport()) { + GTEST_SKIP(); + } + + // Use an arbitrary number of threads for the select threads host exec. + if constexpr (is_select_threads_host_executor_v) { + HostExecParams params{4}; + exec = SelectThreadsHostExecutor{params}; + } + pb = std::make_unique(); - pb->InitAndRunTVGenerator("00_solver", "qr", "run", {m, n}); - pb->NumpyToTensorView(Av, "A"); - pb->NumpyToTensorView(Qv, "Q"); - pb->NumpyToTensorView(Rv, "R"); } void TearDown() override { pb.reset(); } std::unique_ptr pb; - GExecType exec{}; - tensor_t Av{{m, n}}; - tensor_t Atv{{n, m}}; - tensor_t TauV{{std::min(m, n)}}; - tensor_t Qv{{m, std::min(m, n)}}; - tensor_t Rv{{std::min(m, n), n}}; + GExecType exec{}; + float thresh = 0.001f; }; template -class QRSolverTestNonComplexFloatTypes : public QRSolverTest { +class QRSolverTestFloatTypes : public QRSolverTest { }; -TYPED_TEST_SUITE(QRSolverTestNonComplexFloatTypes, - MatXFloatNonComplexNonHalfTypesCUDAExec); +TYPED_TEST_SUITE(QRSolverTestFloatTypes, + MatXFloatNonHalfTypesAllExecs); -TYPED_TEST(QRSolverTestNonComplexFloatTypes, QRBasic) +TYPED_TEST(QRSolverTestFloatTypes, QRBasic) { MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + + auto Av = make_tensor({m, n}); + auto TauV = make_tensor({std::min(m,n)}); + auto Qv = make_tensor({m, std::min(m, n)}); + auto Rv = make_tensor({std::min(m, n), n}); + + this->pb->template InitAndRunTVGenerator("00_solver", "qr", "run", {m, n}); + this->pb->NumpyToTensorView(Av, "A"); + this->pb->NumpyToTensorView(Qv, "Q"); + this->pb->NumpyToTensorView(Rv, "R"); - // example-begin cusolver_qr-test-1 - // cuSolver only supports col-major solving today, so we need to transpose, - // solve, then transpose again to compare to Python - (mtie(this->Av, this->TauV) = cusolver_qr(this->Av)).run(this->exec); - // example-end cusolver_qr-test-1 + // example-begin qr_solver-test-1 + (mtie(Av, TauV) = qr_solver(Av)).run(this->exec); + // example-end qr_solver-test-1 this->exec.sync(); // For now we're only verifying R. Q is a bit more complex to compute since // cuSolver/BLAS don't return Q, and instead return Householder reflections // that are used to compute Q. Eventually compute Q here and verify - for (index_t i = 0; i < this->Av.Size(0); i++) { - for (index_t j = 0; j < this->Av.Size(1); j++) { + for (index_t i = 0; i < Av.Size(0); i++) { + for (index_t j = 0; j < Av.Size(1); j++) { // R is stored only in the top triangle of A if (i <= j) { - ASSERT_NEAR(this->Av(i, j), this->Rv(i, j), 0.001); + if constexpr (is_complex_v) { + ASSERT_NEAR(Av(i, j).real(), Rv(i, j).real(), this->thresh); + ASSERT_NEAR(Av(i, j).imag(), Rv(i, j).imag(), this->thresh); + } + else { + ASSERT_NEAR(Av(i, j), Rv(i, j), this->thresh); + } + } + } + } + + MATX_EXIT_HANDLER(); +} + +TYPED_TEST(QRSolverTestFloatTypes, QRBasicBatched) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + + constexpr int batches = 10; + auto Av = make_tensor({batches, m, n}); + auto TauV = make_tensor({batches, std::min(m,n)}); + auto Qv = make_tensor({batches, m, std::min(m, n)}); + auto Rv = make_tensor({batches, std::min(m, n), n}); + + this->pb->template InitAndRunTVGenerator("00_solver", "qr", "run", {batches, m, n}); + this->pb->NumpyToTensorView(Av, "A"); + this->pb->NumpyToTensorView(Qv, "Q"); + this->pb->NumpyToTensorView(Rv, "R"); + + (mtie(Av, TauV) = qr_solver(Av)).run(this->exec); + this->exec.sync(); + + // For now we're only verifying R. Q is a bit more complex to compute since + // cuSolver/BLAS don't return Q, and instead return Householder reflections + // that are used to compute Q. Eventually compute Q here and verify + for (index_t b = 0; b < Av.Size(0); b++) { + for (index_t i = 0; i < Av.Size(1); i++) { + for (index_t j = 0; j < Av.Size(2); j++) { + // R is stored only in the top triangle of A + if (i <= j) { + if constexpr (is_complex_v) { + ASSERT_NEAR(Av(b, i, j).real(), Rv(b, i, j).real(), this->thresh); + ASSERT_NEAR(Av(b, i, j).imag(), Rv(b, i, j).imag(), this->thresh); + } + else { + ASSERT_NEAR(Av(b, i, j), Rv(b, i, j), this->thresh); + } + } } } } diff --git a/test/00_solver/SVD.cu b/test/00_solver/SVD.cu index 1f82c5bd1..8ed05eed0 100644 --- a/test/00_solver/SVD.cu +++ b/test/00_solver/SVD.cu @@ -40,100 +40,98 @@ using namespace matx; constexpr index_t m = 100; constexpr index_t n = 50; -template class SVDSolverTest : public ::testing::Test { +template class SVDTest : public ::testing::Test { protected: using GTestType = cuda::std::tuple_element_t<0, T>; using GExecType = cuda::std::tuple_element_t<1, T>; void SetUp() override { pb = std::make_unique(); - pb->InitAndRunTVGenerator("00_solver", "svd", "run", {m, n}); } void TearDown() override { pb.reset(); } GExecType exec{}; std::unique_ptr pb; + float thresh = 0.001f; +}; + +template class SVDSolverTest : public SVDTest { +protected: + using GTestType = cuda::std::tuple_element_t<0, T>; + using GExecType = cuda::std::tuple_element_t<1, T>; + void SetUp() override + { + if constexpr (!detail::CheckSolverSupport()) { + GTEST_SKIP(); + } + + // Use an arbitrary number of threads for the select threads host exec. + if constexpr (is_select_threads_host_executor_v) { + HostExecParams params{4}; + this->exec = SelectThreadsHostExecutor{params}; + } + + this->pb = std::make_unique(); + } }; template class SVDSolverTestNonHalfTypes : public SVDSolverTest { }; -TYPED_TEST_SUITE(SVDSolverTestNonHalfTypes, - MatXFloatNonHalfTypesCUDAExec); +template +class SVDPISolverTestNonHalfTypes : public SVDTest { +}; + +TYPED_TEST_SUITE(SVDSolverTestNonHalfTypes, MatXFloatNonHalfTypesAllExecs); +TYPED_TEST_SUITE(SVDPISolverTestNonHalfTypes, MatXFloatNonHalfTypesCUDAExec); TYPED_TEST(SVDSolverTestNonHalfTypes, SVDBasic) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; using ExecType = cuda::std::tuple_element_t<1, TypeParam>; - using value_type = typename inner_op_type_t::type; + tensor_t Av{{m, n}}; - tensor_t Atv{{n, m}}; tensor_t Sv{{std::min(m, n)}}; tensor_t Uv{{m, m}}; - tensor_t Vv{{n, n}}; + tensor_t VTv{{n, n}}; - tensor_t Sav{{m, n}}; - tensor_t SSolav{{m, n}}; - tensor_t Uav{{m, m}}; - tensor_t Vav{{n, n}}; + tensor_t Dv{{m, n}}; + tensor_t UDVTv{{m, n}}; + this->pb->template InitAndRunTVGenerator("00_solver", "svd", "run", {m, n}); this->pb->NumpyToTensorView(Av, "A"); - // Used only for validation - auto tmpV = make_tensor({m, n}); - // example-begin svd-test-1 - // cuSolver only supports col-major solving today, so we need to transpose, - // solve, then transpose again to compare to Python - (Atv = transpose(Av)).run(this->exec); - - auto Atv2 = Atv.View({m, n}); - (mtie(Uv, Sv, Vv) = svd(Atv2)).run(this->exec); + (mtie(Uv, Sv, VTv) = svd(Av)).run(this->exec); // example-end svd-test-1 this->exec.sync(); // Since SVD produces a solution that's not necessarily unique, we cannot // compare against Python output. Instead, we just make sure that A = U*S*V'. - // However, U and V are in column-major format, so we have to transpose them - // back to verify the identity. - (Uav = transpose(Uv)).run(this->exec); - (Vav = transpose(Vv)).run(this->exec); - // Zero out s - (Sav = zeros::type>({m, n})).run(this->exec); + // Construct diagonal matrix D from the vector of singular values S + (Dv = zeros({m, n})).run(this->exec); this->exec.sync(); - // Construct S matrix since it's just a vector from cuSolver for (index_t i = 0; i < n; i++) { - Sav(i, i) = Sv(i); + Dv(i, i) = Sv(i); } - this->exec.sync(); - - (SSolav = 0).run(this->exec); - if constexpr (is_complex_v) { - (SSolav.RealView() = Sav).run(this->exec); - } - else { - (SSolav = Sav).run(this->exec); - } - - (tmpV = matmul(Uav, SSolav)).run(this->exec); // U * S - (SSolav = matmul(tmpV, Vav)).run(this->exec); // (U * S) * V' + (UDVTv = matmul(matmul(Uv, Dv), VTv)).run(this->exec); // (U * S) * V' this->exec.sync(); for (index_t i = 0; i < Av.Size(0); i++) { for (index_t j = 0; j < Av.Size(1); j++) { if constexpr (is_complex_v) { - ASSERT_NEAR(Av(i, j).real(), SSolav(i, j).real(), 0.001) << i << " " << j; - ASSERT_NEAR(Av(i, j).imag(), SSolav(i, j).imag(), 0.001) << i << " " << j; + ASSERT_NEAR(Av(i, j).real(), UDVTv(i, j).real(), this->thresh) << i << " " << j; + ASSERT_NEAR(Av(i, j).imag(), UDVTv(i, j).imag(), this->thresh) << i << " " << j; } else { - ASSERT_NEAR(Av(i, j), SSolav(i, j), 0.001) << i << " " << j; + ASSERT_NEAR(Av(i, j), UDVTv(i, j), this->thresh) << i << " " << j; } } } @@ -146,78 +144,50 @@ TYPED_TEST(SVDSolverTestNonHalfTypes, SVDBasicBatched) MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + using value_type = typename inner_op_type_t::type; constexpr index_t batches = 10; - using value_type = typename inner_op_type_t::type; - auto Av1 = make_tensor({m, n}); - this->pb->NumpyToTensorView(Av1, "A"); auto Av = make_tensor({batches, m, n}); - auto Atv = make_tensor({batches, n, m}); - (Av = Av1).run(this->exec); - auto Sv = make_tensor({batches, std::min(m, n)}); auto Uv = make_tensor({batches, m, m}); - auto Vv = make_tensor({batches, n, n}); - - auto Sav = make_tensor({batches, m, n}); - auto SSolav = make_tensor({batches, m, n}); - auto Uav = make_tensor({batches, m, m}); - auto Vav = make_tensor({batches, n, n}); + auto VTv = make_tensor({batches, n, n}); - // Used only for validation - auto tmpV = make_tensor({batches, m, n}); + auto Dv = make_tensor({batches, m, n}); + auto UDVTv = make_tensor({batches, m, n}); - // cuSolver only supports col-major solving today, so we need to transpose, - // solve, then transpose again to compare to Python - (Atv = transpose_matrix(Av)).run(this->exec); + this->pb->template InitAndRunTVGenerator("00_solver", "svd", "run", {batches, m, n}); + this->pb->NumpyToTensorView(Av, "A"); - auto Atv2 = Atv.View({batches, m, n}); - (mtie(Uv, Sv, Vv) = svd(Atv2)).run(this->exec); + (mtie(Uv, Sv, VTv) = svd(Av)).run(this->exec); this->exec.sync(); // Since SVD produces a solution that's not necessarily unique, we cannot // compare against Python output. Instead, we just make sure that A = U*S*V'. - // However, U and V are in column-major format, so we have to transpose them - // back to verify the identity. - (Uav = transpose_matrix(Uv)).run(this->exec); - (Vav = transpose_matrix(Vv)).run(this->exec); - // Zero out s - (Sav = zeros::type>({batches, m, n})).run(this->exec); + // Construct batched diagonal matrix D from the vector of singular values S + (Dv = zeros({m, n})).run(this->exec); this->exec.sync(); - // Construct S matrix since it's just a vector from cuSolver for (index_t b = 0; b < batches; b++) { for (index_t i = 0; i < n; i++) { - Sav(b, i, i) = Sv(b, i); + Dv(b, i, i) = Sv(b, i); } } - this->exec.sync(); - - (SSolav = 0).run(this->exec); - if constexpr (is_complex_v) { - (SSolav.RealView() = Sav).run(this->exec); - } - else { - (SSolav = Sav).run(this->exec); - } - - (tmpV = matmul(Uav, SSolav)).run(this->exec); // U * S - (SSolav = matmul(tmpV, Vav)).run(this->exec); // (U * S) * V' + (UDVTv = matmul(matmul(Uv, Dv), VTv)).run(this->exec); // (U * S) * V' this->exec.sync(); for (index_t b = 0; b < batches; b++) { - for (index_t i = 0; i < Av.Size(0); i++) { - for (index_t j = 0; j < Av.Size(1); j++) { + for (index_t i = 0; i < Av.Size(1); i++) { + for (index_t j = 0; j < Av.Size(2); j++) { if constexpr (is_complex_v) { - ASSERT_NEAR(Av(b, i, j).real(), SSolav(b, i, j).real(), 0.001) << i << " " << j; - ASSERT_NEAR(Av(b, i, j).imag(), SSolav(b, i, j).imag(), 0.001) << i << " " << j; + ASSERT_NEAR(Av(b, i, j).real(), UDVTv(b, i, j).real(), this->thresh) << i << " " << j; + ASSERT_NEAR(Av(b, i, j).imag(), UDVTv(b, i, j).imag(), this->thresh) << i << " " << j; } else { - ASSERT_NEAR(Av(b, i, j), SSolav(b, i, j), 0.001) << i << " " << j; + ASSERT_NEAR(Av(b, i, j), UDVTv(b, i, j), this->thresh) << i << " " << j; } } } @@ -336,7 +306,7 @@ void svdpi_test( const index_t (&AshapeA)[RANK], Executor exec) { ASSERT_NEAR( mdiffA(), SType(0), .00001); } -TYPED_TEST(SVDSolverTestNonHalfTypes, SVDPI) +TYPED_TEST(SVDPISolverTestNonHalfTypes, SVDPI) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; @@ -470,7 +440,7 @@ void svdbpi_test( const index_t (&AshapeA)[RANK], Executor exec) { exec.sync(); } -TYPED_TEST(SVDSolverTestNonHalfTypes, SVDBPI) +TYPED_TEST(SVDPISolverTestNonHalfTypes, SVDBPI) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; diff --git a/test/00_transform/MatMul.cu b/test/00_transform/MatMul.cu index fda7eba81..f6b3b16c1 100644 --- a/test/00_transform/MatMul.cu +++ b/test/00_transform/MatMul.cu @@ -52,6 +52,12 @@ protected: GTEST_SKIP(); } + // Use an arbitrary number of threads for the select threads host exec. + if constexpr (is_select_threads_host_executor_v) { + HostExecParams params{4}; + exec = SelectThreadsHostExecutor{params}; + } + pb = std::make_unique(); // Half precision needs a bit more // tolerance when compared to fp32 if constexpr (is_complex_half_v || is_matx_half_v) { diff --git a/test/test_vectors/generators/00_solver.py b/test/test_vectors/generators/00_solver.py index 9ecc2efc4..4061a117c 100644 --- a/test/test_vectors/generators/00_solver.py +++ b/test/test_vectors/generators/00_solver.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 import numpy as np -from scipy import linalg as slinalg +import cupy as cp +from cupyx.scipy import linalg as cplinalg from numpy import random import math import matx_common @@ -15,14 +16,15 @@ def __init__(self, dtype: str, size: List[int]): np.random.seed(1234) def run(self): - n = self.size[0] - batches = self.size[1] - - # Create a positive-definite matrix - if batches > 1: - A = matx_common.randn_ndarray((batches, n,n), self.dtype) + n = self.size[-1] + if len(self.size) == 1: + shape = (n, n) else: - A = matx_common.randn_ndarray((n,n), self.dtype) + batch_size = self.size[0] + shape = (batch_size, n, n) + + # Create an invertible matrix + A = matx_common.randn_ndarray(shape, self.dtype) A_inv = np.linalg.inv(A) return { @@ -38,13 +40,16 @@ def __init__(self, dtype: str, size: List[int]): np.random.seed(1234) def run(self): - n = self.size[0] - - if self.dtype in ('complex64', 'complex128'): - A = np.random.randn(n, n) + 1j*np.random.randn(n, n) + n = self.size[-1] + if len(self.size) == 1: + shape = (n, n) else: - A = np.random.randn(n, n) - B = np.matmul(A, A.conj().T) + batch_size = self.size[0] + shape = (batch_size, n, n) + + # Create a positive-definite matrix + A = matx_common.randn_ndarray(shape, self.dtype) + B = np.matmul(A, np.conj(A).swapaxes(-2, -1)) B = B + n*np.eye(n) L = np.linalg.cholesky(B) @@ -62,10 +67,35 @@ def __init__(self, dtype: str, size: List[int]): np.random.seed(1234) def run(self): - m, n = self.size[0], self.size[1] + m, n = self.size[-2:] + if len(self.size) == 2: + shape = (m, n) + else: + batch_size = self.size[0] + shape = (batch_size, m, n) - A = np.random.randn(m, n) - P, L, U = slinalg.lu(A) + A = matx_common.randn_ndarray(shape, self.dtype) + A_cp = cp.asarray(A) + + if len(self.size) == 2: + P_cp, L_cp, U_cp = cplinalg.lu(A_cp) + else: + P_list = [] + L_list = [] + U_list = [] + + for i in range(batch_size): + P_i, L_i, U_i = cplinalg.lu(A_cp[i]) + P_list.append(P_i) + L_list.append(L_i) + U_list.append(U_i) + + P_cp = cp.stack(P_list) + L_cp = cp.stack(L_list) + U_cp = cp.stack(U_list) + + cp.cuda.Stream.null.synchronize() + P, L, U = cp.asnumpy(P_cp), cp.asnumpy(L_cp), cp.asnumpy(U_cp) return { 'A': A, @@ -82,9 +112,14 @@ def __init__(self, dtype: str, size: List[int]): np.random.seed(1234) def run(self): - m, n = self.size[0], self.size[1] + m, n = self.size[-2:] + if len(self.size) == 2: + shape = (m, n) + else: + batch_size = self.size[0] + shape = (batch_size, m, n) - A = np.random.randn(m, n) + A = matx_common.randn_ndarray(shape, self.dtype) Q, R = np.linalg.qr(A) return { @@ -101,9 +136,14 @@ def __init__(self, dtype: str, size: List[int]): np.random.seed(1234) def run(self): - m, n = self.size[0], self.size[1] + m, n = self.size[-2:] + if len(self.size) == 2: + shape = (m, n) + else: + batch_size = self.size[0] + shape = (batch_size, m, n) - A = matx_common.randn_ndarray((m,n), self.dtype) + A = matx_common.randn_ndarray(shape, self.dtype) U, S, V = np.linalg.svd(A) return { @@ -121,10 +161,16 @@ def __init__(self, dtype: str, size: List[int]): np.random.seed(1234) def run(self): - n = self.size[0] + n = self.size[-1] + if len(self.size) == 1: + shape = (n, n) + else: + batch_size = self.size[0] + shape = (batch_size, n, n) + # Create a positive-definite matrix - A = np.random.randn(n, n) - B = np.matmul(A, A.conj().T) + A = matx_common.randn_ndarray(shape, self.dtype) + B = np.matmul(A, np.conj(A).swapaxes(-2, -1)) B = B + n*np.eye(n) W, V = np.linalg.eig(B) @@ -142,10 +188,15 @@ def __init__(self, dtype: str, size: List[int]): np.random.seed(1234) def run(self): - n = self.size[0] + n = self.size[-1] + if len(self.size) == 1: + shape = (n, n) + else: + batch_size = self.size[0] + shape = (batch_size, n, n) - # Create a positive-definite matrix - A = np.random.randn(n, n) + # Create an invertible matrix + A = matx_common.randn_ndarray(shape, self.dtype) det = np.linalg.det(A) return { diff --git a/test/test_vectors/generators/matx_common.py b/test/test_vectors/generators/matx_common.py index 2ddd9e944..96d1ce7ca 100755 --- a/test/test_vectors/generators/matx_common.py +++ b/test/test_vectors/generators/matx_common.py @@ -18,6 +18,6 @@ def to_file(var, name): def randn_ndarray(tshape, dtype): if np.issubdtype(dtype, np.floating): - return np.random.randn(*tshape) + return np.random.randn(*tshape).astype(dtype) else: - return np.random.randn(*tshape) + 1j*np.random.randn(*tshape) + return (np.random.randn(*tshape) + 1j*np.random.randn(*tshape)).astype(dtype)