Skip to content

Commit

Permalink
NVPL LAPACK Solver Support on ARM (#701)
Browse files Browse the repository at this point in the history
* NVPL LAPACK support for all decompositions on ARM

* Simplify SVD interface to hide column-major logic

* More comprehensive type and shape checks for all decompositions

* Batched tests and complex tests for remaining decompositions

* Fixed sign bug in determinant

* Split solver into separate files
  • Loading branch information
aayushg55 authored Aug 9, 2024
1 parent 1627463 commit 9070391
Show file tree
Hide file tree
Showing 42 changed files with 4,537 additions and 2,136 deletions.
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,12 @@ if (MATX_EN_NVPL OR MATX_EN_X86_FFTW OR MATX_EN_BLIS OR MATX_EN_OPENBLAS)

if (MATX_EN_NVPL)
message(STATUS "Enabling NVPL library support for ARM CPUs with ${INT_TYPE} interface")
find_package(nvpl REQUIRED COMPONENTS fft blas HINTS ${blas_DIR})
find_package(nvpl REQUIRED COMPONENTS fft blas lapack HINTS ${blas_DIR})
if (NOT MATX_BUILD_32_BIT)
target_compile_definitions(matx INTERFACE NVPL_ILP64)
endif()
target_link_libraries(matx INTERFACE nvpl::fftw nvpl::blas_${INT_TYPE}_omp)
target_compile_definitions(matx INTERFACE NVPL_LAPACK_COMPLEX_STRUCTURE)
target_link_libraries(matx INTERFACE nvpl::fftw nvpl::blas_${INT_TYPE}_omp nvpl::lapack_${INT_TYPE}_omp)
target_compile_definitions(matx INTERFACE MATX_EN_NVPL)
else()
# FFTW
Expand Down
6 changes: 3 additions & 3 deletions docs_input/api/linalg/decomp/qr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ Examples
:end-before: example-end qr-test-1
:dedent:

.. doxygenfunction:: cusolver_qr
.. doxygenfunction:: qr_solver

Examples
~~~~~~~~

.. literalinclude:: ../../../../test/00_solver/QR.cu
:language: cpp
:start-after: example-begin cusolver_qr-test-1
:end-before: example-end cusolver_qr-test-1
:start-after: example-begin qr_solver-test-1
:end-before: example-end qr_solver-test-1
:dedent:
2 changes: 1 addition & 1 deletion docs_input/api/linalg/decomp/svd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
svd
###

Perform a singular value decomposition (SVD) using the power iteration method.
Perform a singular value decomposition (SVD).

.. doxygenfunction:: svd

Expand Down
2 changes: 1 addition & 1 deletion docs_input/api/logic/truth/allclose.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Reduce the closeness of two operators to a single scalar (0D) output. The output
from allclose is an ``int`` value since boolean reductions are not available in hardware


.. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, HostExecutor<MODE> &exec)
.. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, const HostExecutor<MODE> &exec)
.. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, cudaExecutor exec = 0)

Examples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Transpose the last two dimensions of an operator
Examples
~~~~~~~~

.. literalinclude:: ../../../../include/matx/transforms/svd.h
.. literalinclude:: ../../../../include/matx/transforms/svd/svd_cuda.h
:language: cpp
:start-after: example-begin transpose_matrix-test-1
:end-before: example-end transpose_matrix-test-1
Expand Down
34 changes: 30 additions & 4 deletions include/matx/core/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#include "matx/core/dlpack.h"
#include "matx/core/make_tensor.h"
#include "matx/kernels/utility.cuh"

#include "matx/transforms/copy.h"

namespace matx
{
Expand Down Expand Up @@ -1107,15 +1107,41 @@ void print(const Op &op)

#endif // not DOXYGEN_ONLY

template <typename Op>
auto OpToTensor(Op &&op, [[maybe_unused]] cudaStream_t stream) {
template <typename Op, typename Executor>
auto OpToTensor(Op &&op, [[maybe_unused]] const Executor &exec) {
if constexpr (!is_tensor_view_v<Op>) {
return make_tensor<typename remove_cvref<Op>::value_type>(op.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream);
if constexpr (is_cuda_executor_v<Executor>) {
return make_tensor<typename remove_cvref<Op>::value_type>(op.Shape(), MATX_ASYNC_DEVICE_MEMORY, exec.getStream());
} else {
return make_tensor<typename remove_cvref<Op>::value_type>(op.Shape(), MATX_HOST_MALLOC_MEMORY);
}
} else {
return op;
}
}

/**
* Get a transposed view of a tensor into a user-supplied buffer
*
* @param tp
* Pointer to pre-allocated memory
* @param a
* Tensor to transpose
* @param exec
* Executor
*/
template <typename TensorType, typename Executor>
__MATX_INLINE__ auto
TransposeCopy(typename TensorType::value_type *tp, const TensorType &a, const Executor &exec)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

auto pa = a.PermuteMatrix();
auto tv = make_tensor(tp, pa.Shape());
matx::copy(tv, pa, exec);
return tv;
}

/**
* @brief Set the print() precision for floating point values
*
Expand Down
24 changes: 20 additions & 4 deletions include/matx/executors/support.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ namespace matx {

// FFT
#if defined(MATX_EN_NVPL) || defined(MATX_EN_X86_FFTW)
#define MATX_EN_CPU_FFT 1
#define MATX_EN_CPU_FFT 1
#else
#define MATX_EN_CPU_FFT 0
#define MATX_EN_CPU_FFT 0
#endif

// MatMul
#if defined(MATX_EN_NVPL) || defined(MATX_EN_OPENBLAS) || defined(MATX_EN_BLIS)
#define MATX_EN_CPU_MATMUL 1
#define MATX_EN_CPU_MATMUL 1
#else
#define MATX_EN_CPU_MATMUL 0
#define MATX_EN_CPU_MATMUL 0
#endif

// Solver
#if defined(MATX_EN_NVPL)
#define MATX_EN_CPU_SOLVER 1
#else
#define MATX_EN_CPU_SOLVER 0
#endif

template <typename Exec, typename T>
constexpr bool CheckFFTSupport() {
if constexpr (is_host_executor_v<Exec>) {
Expand Down Expand Up @@ -114,5 +121,14 @@ constexpr bool CheckMatMulSupport() {
}
}

template <typename Exec>
constexpr bool CheckSolverSupport() {
if constexpr (is_host_executor_v<Exec>) {
return MATX_EN_CPU_SOLVER;
} else {
return true;
}
}

}; // detail
}; // matx
15 changes: 8 additions & 7 deletions include/matx/operators/chol.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@

#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
#include "matx/transforms/solver.h"
#include "matx/transforms/chol/chol_cuda.h"
#ifdef MATX_EN_CPU_SOLVER
#include "matx/transforms/chol/chol_lapack.h"
#endif

namespace matx {
namespace detail {
Expand All @@ -44,7 +47,7 @@ namespace detail {
{
private:
OpA a_;
cublasFillMode_t uplo_;
SolverFillMode uplo_;
mutable matx::tensor_t<typename OpA::value_type, OpA::Rank()> tmp_out_;

public:
Expand All @@ -54,17 +57,15 @@ namespace detail {
using chol_xform_op = bool;

__MATX_INLINE__ std::string str() const { return "chol()"; }
__MATX_INLINE__ CholOp(OpA a, cublasFillMode_t uplo) : a_(a), uplo_(uplo) { };
__MATX_INLINE__ CholOp(OpA a, SolverFillMode uplo) : a_(a), uplo_(uplo) { };

// This should never be called
template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const = delete;

template <typename Out, typename Executor>
void Exec(Out &&out, Executor &&ex) const {
static_assert(is_cuda_executor_v<Executor>, "chol() only supports the CUDA executor currently");

chol_impl(cuda::std::get<0>(out), a_, ex.getStream(), uplo_);
chol_impl(cuda::std::get<0>(out), a_, ex, uplo_);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
Expand Down Expand Up @@ -114,7 +115,7 @@ namespace detail {
}

template<typename OpA>
__MATX_INLINE__ auto chol(const OpA &a, cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) {
__MATX_INLINE__ auto chol(const OpA &a, SolverFillMode uplo = SolverFillMode::UPPER) {
return detail::CholOp(a, uplo);
}

Expand Down
6 changes: 2 additions & 4 deletions include/matx/operators/det.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
#include "matx/transforms/solver.h"
#include "matx/transforms/det.h"

namespace matx {
namespace detail {
Expand All @@ -62,9 +62,7 @@ namespace detail {

template <typename Out, typename Executor>
void Exec(Out &&out, Executor &&ex) const{
static_assert(is_cuda_executor_v<Executor>, "det() only supports the CUDA executor currently");

det_impl(cuda::std::get<0>(out), a_, ex.getStream());
det_impl(cuda::std::get<0>(out), a_, ex);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
Expand Down
18 changes: 10 additions & 8 deletions include/matx/operators/eig.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@

#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
#include "matx/transforms/solver.h"
#include "matx/transforms/eig/eig_cuda.h"
#ifdef MATX_EN_CPU_SOLVER
#include "matx/transforms/eig/eig_lapack.h"
#endif

namespace matx {

Expand All @@ -47,8 +50,8 @@ namespace detail {
{
private:
OpA a_;
cusolverEigMode_t jobz_;
cublasFillMode_t uplo_;
EigenMode jobz_;
SolverFillMode uplo_;

public:
using matxop = bool;
Expand All @@ -57,18 +60,17 @@ namespace detail {
using eig_xform_op = bool;

__MATX_INLINE__ std::string str() const { return "eig()"; }
__MATX_INLINE__ EigOp(OpA a, cusolverEigMode_t jobz, cublasFillMode_t uplo) : a_(a), jobz_(jobz), uplo_(uplo) { };
__MATX_INLINE__ EigOp(OpA a, EigenMode jobz, SolverFillMode uplo) : a_(a), jobz_(jobz), uplo_(uplo) { };

// This should never be called
template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const = delete;

template <typename Out, typename Executor>
void Exec(Out &&out, Executor &&ex) const {
static_assert(is_cuda_executor_v<Executor>, "eig () only supports the CUDA executor currently");
static_assert(cuda::std::tuple_size_v<remove_cvref_t<Out>> == 3, "Must use mtie with 2 outputs on eig(). ie: (mtie(O, w) = eig(A))");

eig_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex.getStream(), jobz_, uplo_);
eig_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex, jobz_, uplo_);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
Expand All @@ -94,8 +96,8 @@ namespace detail {

template<typename OpA>
__MATX_INLINE__ auto eig(const OpA &a,
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR,
cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) {
EigenMode jobz = EigenMode::VECTOR,
SolverFillMode uplo = SolverFillMode::UPPER) {
return detail::EigOp(a, jobz, uplo);
}

Expand Down
1 change: 0 additions & 1 deletion include/matx/operators/frexp.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
#include "matx/transforms/solver.h"

namespace matx {

Expand Down
11 changes: 7 additions & 4 deletions include/matx/operators/lu.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@

#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
#include "matx/transforms/solver.h"
#include "matx/transforms/lu/lu_cuda.h"
#ifdef MATX_EN_CPU_SOLVER
#include "matx/transforms/lu/lu_lapack.h"
#endif


namespace matx {
namespace detail {
Expand All @@ -62,10 +66,9 @@ namespace detail {

template <typename Out, typename Executor>
void Exec(Out &&out, Executor &&ex) const {
static_assert(is_cuda_executor_v<Executor>, "lu() only supports the CUDA executor currently");
static_assert(cuda::std::tuple_size_v<remove_cvref_t<Out>> == 3, "Must use mtie with 2 outputs on cusolver_qr(). ie: (mtie(O, piv) = lu(A))");
static_assert(cuda::std::tuple_size_v<remove_cvref_t<Out>> == 3, "Must use mtie with 2 outputs on lu(). ie: (mtie(O, piv) = lu(A))");

lu_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex.getStream());
lu_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
Expand Down
2 changes: 1 addition & 1 deletion include/matx/operators/prod.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ namespace detail {
template <typename InType, int D>
__MATX_INLINE__ auto prod(const InType &in, const int (&dims)[D])
{
static_assert(D < InType::Rank(), "reduction dimensions must be <= Rank of input");
static_assert(D <= InType::Rank(), "reduction dimensions must be <= Rank of input");
auto perm = detail::getPermuteDims<InType::Rank()>(dims);
auto permop = permute(in, perm);

Expand Down
Loading

0 comments on commit 9070391

Please sign in to comment.