Skip to content

Commit 90f7bd6

Browse files
authored
Added new pinv() operator (#740)
1 parent d9053d6 commit 90f7bd6

File tree

17 files changed

+728
-44
lines changed

17 files changed

+728
-44
lines changed

docs_input/api/linalg/decomp/pinv.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
.. _pinv_func:
2+
3+
pinv
4+
####
5+
6+
Compute the Moore-Penrose pseudo-inverse of a matrix.
7+
8+
.. doxygenfunction:: pinv(const OpA &a, float rcond = get_default_rcond<typename OpA::value_type>())
9+
10+
Examples
11+
~~~~~~~~
12+
13+
.. literalinclude:: ../../../../test/00_solver/Pinv.cu
14+
:language: cpp
15+
:start-after: example-begin pinv-test-1
16+
:end-before: example-end pinv-test-1
17+
:dedent:
18+
19+

docs_input/api/linalg/other/det.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
.. _det_func:
2+
3+
det
4+
=====
5+
6+
Compute the determinant of a tensor.
7+
8+
.. doxygenfunction:: det(const OpA &a)
9+
10+
Examples
11+
~~~~~~~~
12+
13+
.. literalinclude:: ../../../../test/00_solver/Det.cu
14+
:language: cpp
15+
:start-after: example-begin det-test-1
16+
:end-before: example-end det-test-1
17+
:dedent:
18+

include/matx/core/tensor.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,9 +658,10 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
658658
{
659659
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
660660

661+
[[maybe_unused]] stride_type prod = std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<stride_type>());
661662
// Ensure new shape's total size is not larger than the original
662663
MATX_ASSERT_STR(
663-
sizeof(M) * shape.TotalSize() <= storage_.Bytes(), matxInvalidSize,
664+
sizeof(M) * prod <= storage_.Bytes(), matxInvalidSize,
664665
"Total size of new tensor must not be larger than the original");
665666

666667
// This could be loosened up to make sure only the fastest changing dims
@@ -877,7 +878,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
877878
{
878879
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
879880

880-
static_assert(RANK >= 2, "Only tensors of rank 2 and higher can be permuted.");
881+
static_assert(RANK >= 1, "Only tensors of rank 1 and higher can be permuted.");
881882
cuda::std::array<shape_type, RANK> n;
882883
cuda::std::array<stride_type, RANK> s;
883884
[[maybe_unused]] bool done[RANK] = {0};

include/matx/operators/chol.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,6 @@ namespace detail {
104104
}
105105
}
106106

107-
// Size is not relevant in eig() since there are multiple return values and it
108-
// is not allowed to be called in larger expressions
109107
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
110108
{
111109
return a_.Size(dim);

include/matx/operators/det.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ namespace detail {
9696
}
9797
}
9898

99-
// Size is not relevant in det() since there are multiple return values and it
100-
// is not allowed to be called in larger expressions
10199
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
102100
{
103101
return a_.Size(dim);
@@ -106,6 +104,13 @@ namespace detail {
106104
};
107105
}
108106

107+
/**
108+
* Computes the determinant by performing an LU factorization of the input,
109+
* and then calculating the product of diagonal entries of the U factor.
110+
*
111+
* For tensors of rank > 2, batching is performed.
112+
*
113+
*/
109114
template<typename OpA>
110115
__MATX_INLINE__ auto det(const OpA &a) {
111116
return detail::DetOp(a);

include/matx/operators/operators.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
#include "matx/operators/outer.h"
8181
#include "matx/operators/overlap.h"
8282
#include "matx/operators/percentile.h"
83+
#include "matx/operators/pinv.h"
8384
#include "matx/operators/permute.h"
8485
#include "matx/operators/planar.h"
8586
#include "matx/operators/polyval.h"

include/matx/operators/pinv.h

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// BSD 3-Clause License
3+
//
4+
// COpBright (c) 2021, NVIDIA Corporation
5+
// All rights reserved.
6+
//
7+
// Redistribution and use in source and binary forms, with or without
8+
// modification, are permitted provided that the following conditions are met:
9+
//
10+
// 1. Redistributions of source code must retain the above cOpBright notice, this
11+
// list of conditions and the following disclaimer.
12+
//
13+
// 2. Redistributions in binary form must reproduce the above cOpBright notice,
14+
// this list of conditions and the following disclaimer in the documentation
15+
// and/or other materials provided with the distribution.
16+
//
17+
// 3. Neither the name of the cOpBright holder nor the names of its
18+
// contributors may be used to endorse or promote products derived from
19+
// this software without specific prior written permission.
20+
//
21+
// THIS SOFTWARE IS PROVIDED BY THE COpBRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
// DISCLAIMED. IN NO EVENT SHALL THE COpBRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
/////////////////////////////////////////////////////////////////////////////////
32+
33+
#pragma once
34+
35+
36+
#include "matx/core/type_utils.h"
37+
#include "matx/operators/base_operator.h"
38+
#include "matx/transforms/pinv.h"
39+
40+
namespace matx {
41+
namespace detail {
42+
template<typename OpA>
43+
class PinvOp : public BaseOp<PinvOp<OpA>>
44+
{
45+
private:
46+
OpA a_;
47+
float rcond_;
48+
cuda::std::array<index_t, OpA::Rank()> out_dims_;
49+
mutable detail::tensor_impl_t<typename remove_cvref_t<OpA>::value_type, OpA::Rank()> tmp_out_;
50+
mutable typename remove_cvref_t<OpA>::value_type *ptr;
51+
52+
public:
53+
using matxop = bool;
54+
using value_type = typename OpA::value_type;
55+
using matx_transform_op = bool;
56+
using pinv_xform_op = bool;
57+
58+
__MATX_INLINE__ std::string str() const { return "pinv()"; }
59+
__MATX_INLINE__ PinvOp(OpA a, float rcond) : a_(a), rcond_(rcond) {
60+
for (int r = 0; r < Rank(); r++) {
61+
if (r >= Rank() - 2) {
62+
out_dims_[r] = (r == Rank() - 1) ? a_.Size(Rank() - 2) : a_.Size(Rank() - 1);
63+
}
64+
else {
65+
out_dims_[r] = a_.Size(r);
66+
}
67+
}
68+
};
69+
70+
template <typename... Is>
71+
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const
72+
{
73+
return tmp_out_(indices...);
74+
}
75+
76+
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
77+
{
78+
return OpA::Rank();
79+
}
80+
81+
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
82+
{
83+
return out_dims_.Size(dim);
84+
}
85+
86+
template <typename Out, typename Executor>
87+
void Exec(Out &&out, Executor &&ex) const{
88+
pinv_impl(cuda::std::get<0>(out), a_, ex, rcond_);
89+
}
90+
91+
template <typename ShapeType, typename Executor>
92+
__MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
93+
{
94+
if constexpr (is_matx_op<OpA>()) {
95+
a_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
96+
}
97+
}
98+
99+
template <typename ShapeType, typename Executor>
100+
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
101+
{
102+
InnerPreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
103+
104+
detail::AllocateTempTensor(tmp_out_, std::forward<Executor>(ex), out_dims_, &ptr);
105+
106+
Exec(cuda::std::make_tuple(tmp_out_), std::forward<Executor>(ex));
107+
}
108+
109+
template <typename ShapeType, typename Executor>
110+
__MATX_INLINE__ void PostRun(ShapeType &&shape, Executor &&ex) const noexcept
111+
{
112+
if constexpr (is_matx_op<OpA>()) {
113+
a_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
114+
}
115+
}
116+
117+
};
118+
}
119+
120+
/**
121+
* Perfom a generalized inverse of a matrix using its singular-value decomposition (SVD).
122+
* It automatically removes small singular values for stability.
123+
*
124+
* For tensors of rank > 2, batching is performed.
125+
*
126+
* @tparam OpA
127+
* Tensor or operator type of input A
128+
*
129+
* @param a
130+
* Input tensor or operator of shape `... x m x n`
131+
* @param rcond
132+
* Cutoff for small singular values. For stability, singular values
133+
* smaller than `rcond * largest_singular_value` are set to 0 for each matrix
134+
* in the batch. By default, `rcond` is approximately the machine epsilon of the tensor dtype.
135+
*
136+
* @return
137+
* An operator that gives a tensor of size `... x n x m` representing the pseudo-inverse of the input
138+
*/
139+
template<typename OpA>
140+
__MATX_INLINE__ auto pinv(const OpA &a, float rcond = get_default_rcond<typename OpA::value_type>()) {
141+
return detail::PinvOp(a, rcond);
142+
}
143+
144+
}

include/matx/operators/svd.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ namespace detail {
6565
template <typename... Is>
6666
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const = delete;
6767

68+
// TODO: Handle SVDMode::NONE case better to not require U & VT
6869
template <typename Out, typename Executor>
6970
void Exec(Out &&out, Executor &&ex) const {
7071
static_assert(cuda::std::tuple_size_v<remove_cvref_t<Out>> == 4, "Must use mtie with 3 outputs on svd(). ie: (mtie(U, S, VT) = svd(A))");
@@ -99,6 +100,10 @@ namespace detail {
99100
/**
100101
* Perform a singular value decomposition (SVD) using cuSolver or a LAPACK host
101102
* library.
103+
*
104+
* The singular values within each vector are sorted in descending order.
105+
*
106+
* For tensors of Rank > 2, batching is performed.
102107
*
103108
* @tparam OpA
104109
* Operator input type

include/matx/operators/trace.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,6 @@ namespace detail {
110110
}
111111

112112
/**
113-
* Computes the trace of a tensor
114-
*
115113
* Computes the trace of a square matrix by summing the diagonal
116114
*
117115
* @tparam InputOperator

include/matx/transforms/det.h

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,8 @@ namespace matx {
5151
/**
5252
* Compute the determinant of a matrix
5353
*
54-
* Computes the terminant of a matrix by first computing the LU composition,
55-
* then reduces the product of the diagonal elements of U. The input and output
56-
* parameters may be the same tensor. In that case, the input is destroyed and
57-
* the output is stored in-place.
54+
* Computes the determinant of a matrix by first computing the LU decomposition,
55+
* then reduces the product of the diagonal elements of U.
5856
*
5957
* @tparam T1
6058
* Data type of matrix A
@@ -80,22 +78,16 @@ void det_impl(OutputTensor &out, const InputTensor &a,
8078
constexpr int RANK = InputTensor::Rank();
8179
using value_type = typename OutputTensor::value_type;
8280
using piv_value_type = std::conditional_t<is_cuda_executor_v<Executor>, int64_t, lapack_int_t>;
83-
84-
auto a_new = OpToTensor(a, exec);
85-
86-
if(!a_new.isSameView(a)) {
87-
(a_new = a).run(exec);
88-
}
8981

9082
// Get parameters required by these tensors
9183
cuda::std::array<index_t, RANK - 1> s;
9284

9385
// Set batching dimensions of piv
9486
for (int i = 0; i < RANK - 2; i++) {
95-
s[i] = a_new.Size(i);
87+
s[i] = a.Size(i);
9688
}
9789

98-
index_t piv_len = cuda::std::min(a_new.Size(RANK - 1), a_new.Size(RANK - 2));
90+
index_t piv_len = cuda::std::min(a.Size(RANK - 1), a.Size(RANK - 2));
9991
s[RANK - 2] = piv_len;
10092

10193
tensor_t<piv_value_type, RANK-1> piv;
@@ -104,13 +96,13 @@ void det_impl(OutputTensor &out, const InputTensor &a,
10496
if constexpr (is_cuda_executor_v<Executor>) {
10597
const auto stream = exec.getStream();
10698
make_tensor(piv, s, MATX_ASYNC_DEVICE_MEMORY, stream);
107-
make_tensor(ac, a_new.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream);
99+
make_tensor(ac, a.Shape(), MATX_ASYNC_DEVICE_MEMORY, stream);
108100
} else {
109101
make_tensor(piv, s, MATX_HOST_MALLOC_MEMORY);
110-
make_tensor(ac, a_new.Shape(), MATX_HOST_MALLOC_MEMORY);
102+
make_tensor(ac, a.Shape(), MATX_HOST_MALLOC_MEMORY);
111103
}
112104

113-
lu_impl(ac, piv, a_new, exec);
105+
lu_impl(ac, piv, a, exec);
114106

115107
// Determinant sign adjustment based on piv permutation
116108
// Create indices corresponding to no permutation to compare against

0 commit comments

Comments
 (0)