Skip to content

Commit 7434f4c

Browse files
authored
Support operators into einsum interface (#845)
1 parent 68a8c12 commit 7434f4c

File tree

3 files changed

+90
-14
lines changed

3 files changed

+90
-14
lines changed

include/matx/core/type_utils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,20 @@
5353

5454
namespace matx {
5555

56+
namespace detail {
57+
template <int N, typename Executor, typename TupleType, typename... Ops>
58+
void assign_tuple_tensors(const Executor &exec, TupleType &t, Ops... ops)
59+
{
60+
if constexpr (N < sizeof...(Ops)) {
61+
auto in_tup = cuda::std::make_tuple(ops...);
62+
if (!cuda::std::get<N>(t).isSameView(cuda::std::get<N>(in_tup))) {
63+
(cuda::std::get<N>(t) = cuda::std::get<N>(in_tup)).run(exec);
64+
assign_tuple_tensors<N + 1>(exec, t, ops...);
65+
}
66+
}
67+
}
68+
};
69+
5670
enum {
5771
matxNoRank = -1
5872
};

include/matx/transforms/einsum.h

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,21 @@ struct EinsumParams_t {
7171
cudaStream_t stream;
7272
};
7373

74+
template <typename Op>
75+
__MATX_INLINE__ auto getEinsumSupportedTensor( const Op &in, cudaStream_t stream) {
76+
// This would be better as a templated lambda, but we don't have those in C++17 yet
77+
const auto support_func = [&in]() {
78+
if constexpr (is_tensor_view_v<Op>) {
79+
return true;
80+
}
81+
else {
82+
return true;
83+
}
84+
};
85+
86+
return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream);
87+
}
88+
7489
template <typename OutputTensor, typename... InT>
7590
class matxEinsumHandle_t {
7691
public:
@@ -312,7 +327,7 @@ class matxEinsumHandle_t {
312327
((params.nmodes_[i++] = tensors.Rank()), ...);
313328

314329
i = 0;
315-
MATX_ASSERT_STR(((tokens[i++].length() == static_cast<size_t>(tensors.Rank())) && ...), matxInvalidDim,
330+
MATX_ASSERT_STR(((tokens[i++].length() == static_cast<size_t>(tensors.Rank())), ...), matxInvalidDim,
316331
"Tensor rank must match number of einsum subscripts");
317332

318333
auto set_sizes = [](auto &t, std::vector<int64_t> &sizes) {
@@ -460,7 +475,6 @@ struct EinsumParamsKeyEq {
460475

461476
namespace matx {
462477
namespace cutensor {
463-
464478
/**
465479
* @brief Evaluates the Einstein summation on the operands
466480
*
@@ -489,22 +503,44 @@ namespace cutensor {
489503
#ifdef MATX_EN_CUTENSOR
490504
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
491505

506+
auto out_n = detail::cutensor::getEinsumSupportedTensor(out, stream);
507+
auto in_t = cuda::std::make_tuple(detail::cutensor::getEinsumSupportedTensor(tensors, stream)...);
508+
509+
using einsum_cache_t = std::unordered_map<
510+
detail::cutensor::EinsumParams_t<decltype(detail::cutensor::getEinsumSupportedTensor(tensors, stream))...>,
511+
std::any,
512+
detail::cutensor::EinsumParamsKeyHash<decltype(detail::cutensor::getEinsumSupportedTensor(tensors, stream))...>,
513+
detail::cutensor::EinsumParamsKeyEq<decltype(detail::cutensor::getEinsumSupportedTensor(tensors, stream))...>
514+
>;
515+
516+
detail::assign_tuple_tensors<0, cudaStream_t>(stream, in_t, tensors...);
517+
518+
using cache_val_type = matx::detail::cutensor::matxEinsumHandle_t<decltype(out_n),
519+
decltype(detail::cutensor::getEinsumSupportedTensor(tensors, stream))...>;
520+
492521
// Get parameters required by these tensors
493-
auto params = matx::detail::cutensor::matxEinsumHandle_t<OutputType, InT...>::GetEinsumParams(out, subscripts, tensors...);
522+
auto params = cuda::std::apply(
523+
[&](auto&&... args) {
524+
return cache_val_type::GetEinsumParams(out_n, subscripts, args...);
525+
},
526+
in_t
527+
);
528+
494529
params.stream = stream;
495530

496-
using einsum_cache_t = std::unordered_map<detail::cutensor::EinsumParams_t<InT...>, std::any, detail::cutensor::EinsumParamsKeyHash<InT...>, detail::cutensor::EinsumParamsKeyEq<InT...>>;
497-
using cache_val_type = matx::detail::cutensor::matxEinsumHandle_t<OutputType, InT...>;
498531
detail::GetCache().LookupAndExec<einsum_cache_t>(
499-
detail::GetCacheIdFromType<einsum_cache_t>(),
500-
params,
501-
[&]() {
502-
auto tmp = std::make_shared<cache_val_type>(out, subscripts, stream, tensors...);
503-
return tmp;
504-
},
505-
[&](std::shared_ptr<cache_val_type> ctype) {
506-
ctype->Exec(out, stream, tensors...);
507-
}
532+
detail::GetCacheIdFromType<einsum_cache_t>(),
533+
params,
534+
[&]() {
535+
return cuda::std::apply([&](auto&&... args) {
536+
return std::make_shared<cache_val_type>(out_n, subscripts, stream, args...);
537+
}, in_t);
538+
},
539+
[&](std::shared_ptr<cache_val_type> ctype) {
540+
cuda::std::apply([&](auto&&... args) {
541+
ctype->Exec(out_n, stream, args...);
542+
}, in_t);
543+
}
508544
);
509545
#else
510546
MATX_THROW(matxNotSupported, "einsum() currently requires MATX_EN_CUTENSOR=ON but MATX_EN_CUTENSOR=OFF");

test/00_tensor/EinsumTests.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,32 @@ TYPED_TEST(EinsumTestsFloatNonComplexNonHalfTypes, Contraction3D)
127127
MATX_EXIT_HANDLER();
128128
}
129129

130+
TYPED_TEST(EinsumTestsFloatNonComplexNonHalfTypes, Contraction3DOperator)
131+
{
132+
MATX_ENTER_HANDLER();
133+
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
134+
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;
135+
136+
ExecType exec{};
137+
138+
this->pb->template InitAndRunTVGenerator<TestType>(
139+
"00_operators", "contraction", "run", {});
140+
141+
auto a1 = make_tensor<TestType>({60});
142+
auto b1 = make_tensor<TestType>({24});
143+
auto c2 = make_tensor<TestType>({5,2});
144+
145+
// Perform a 3D tensor contraction
146+
(c2 = cutensor::einsum("ijk,jil->kl",
147+
reshape(linspace<0>(a1.Shape(), (TestType)0, static_cast<TestType>(a1.Size(0) - 1)), {3,4,5}),
148+
reshape(linspace<0>(b1.Shape(), (TestType)0, static_cast<TestType>(b1.Size(0) - 1)), {4,3,2}))).run(exec);
149+
150+
exec.sync();
151+
MATX_TEST_ASSERT_COMPARE(this->pb, c2, "c_float3d", 0.01);
152+
153+
MATX_EXIT_HANDLER();
154+
}
155+
130156
TYPED_TEST(EinsumTestsFloatNonComplexNonHalfTypes, Dot)
131157
{
132158
MATX_ENTER_HANDLER();

0 commit comments

Comments
 (0)