@@ -71,6 +71,21 @@ struct EinsumParams_t {
71
71
cudaStream_t stream;
72
72
};
73
73
74
+ template <typename Op>
75
+ __MATX_INLINE__ auto getEinsumSupportedTensor ( const Op &in, cudaStream_t stream) {
76
+ // This would be better as a templated lambda, but we don't have those in C++17 yet
77
+ const auto support_func = [&in]() {
78
+ if constexpr (is_tensor_view_v<Op>) {
79
+ return true ;
80
+ }
81
+ else {
82
+ return true ;
83
+ }
84
+ };
85
+
86
+ return GetSupportedTensor (in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream);
87
+ }
88
+
74
89
template <typename OutputTensor, typename ... InT>
75
90
class matxEinsumHandle_t {
76
91
public:
@@ -312,7 +327,7 @@ class matxEinsumHandle_t {
312
327
((params.nmodes_ [i++] = tensors.Rank ()), ...);
313
328
314
329
i = 0 ;
315
- MATX_ASSERT_STR (((tokens[i++].length () == static_cast <size_t >(tensors.Rank ())) && ...), matxInvalidDim,
330
+ MATX_ASSERT_STR (((tokens[i++].length () == static_cast <size_t >(tensors.Rank ())), ...), matxInvalidDim,
316
331
" Tensor rank must match number of einsum subscripts" );
317
332
318
333
auto set_sizes = [](auto &t, std::vector<int64_t > &sizes) {
@@ -460,7 +475,6 @@ struct EinsumParamsKeyEq {
460
475
461
476
namespace matx {
462
477
namespace cutensor {
463
-
464
478
/* *
465
479
* @brief Evaluates the Einstein summation on the operands
466
480
*
@@ -489,22 +503,44 @@ namespace cutensor {
489
503
#ifdef MATX_EN_CUTENSOR
490
504
MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_API)
491
505
506
+ auto out_n = detail::cutensor::getEinsumSupportedTensor (out, stream);
507
+ auto in_t = cuda::std::make_tuple (detail::cutensor::getEinsumSupportedTensor (tensors, stream)...);
508
+
509
+ using einsum_cache_t = std::unordered_map<
510
+ detail::cutensor::EinsumParams_t<decltype (detail::cutensor::getEinsumSupportedTensor (tensors, stream))...>,
511
+ std::any,
512
+ detail::cutensor::EinsumParamsKeyHash<decltype (detail::cutensor::getEinsumSupportedTensor (tensors, stream))...>,
513
+ detail::cutensor::EinsumParamsKeyEq<decltype (detail::cutensor::getEinsumSupportedTensor (tensors, stream))...>
514
+ >;
515
+
516
+ detail::assign_tuple_tensors<0 , cudaStream_t>(stream, in_t , tensors...);
517
+
518
+ using cache_val_type = matx::detail::cutensor::matxEinsumHandle_t<decltype (out_n),
519
+ decltype (detail::cutensor::getEinsumSupportedTensor (tensors, stream))...>;
520
+
492
521
// Get parameters required by these tensors
493
- auto params = matx::detail::cutensor::matxEinsumHandle_t<OutputType, InT...>::GetEinsumParams (out, subscripts, tensors...);
522
+ auto params = cuda::std::apply (
523
+ [&](auto &&... args) {
524
+ return cache_val_type::GetEinsumParams (out_n, subscripts, args...);
525
+ },
526
+ in_t
527
+ );
528
+
494
529
params.stream = stream;
495
530
496
- using einsum_cache_t = std::unordered_map<detail::cutensor::EinsumParams_t<InT...>, std::any, detail::cutensor::EinsumParamsKeyHash<InT...>, detail::cutensor::EinsumParamsKeyEq<InT...>>;
497
- using cache_val_type = matx::detail::cutensor::matxEinsumHandle_t<OutputType, InT...>;
498
531
detail::GetCache ().LookupAndExec <einsum_cache_t >(
499
- detail::GetCacheIdFromType<einsum_cache_t >(),
500
- params,
501
- [&]() {
502
- auto tmp = std::make_shared<cache_val_type>(out, subscripts, stream, tensors...);
503
- return tmp;
504
- },
505
- [&](std::shared_ptr<cache_val_type> ctype) {
506
- ctype->Exec (out, stream, tensors...);
507
- }
532
+ detail::GetCacheIdFromType<einsum_cache_t >(),
533
+ params,
534
+ [&]() {
535
+ return cuda::std::apply ([&](auto &&... args) {
536
+ return std::make_shared<cache_val_type>(out_n, subscripts, stream, args...);
537
+ }, in_t );
538
+ },
539
+ [&](std::shared_ptr<cache_val_type> ctype) {
540
+ cuda::std::apply ([&](auto &&... args) {
541
+ ctype->Exec (out_n, stream, args...);
542
+ }, in_t );
543
+ }
508
544
);
509
545
#else
510
546
MATX_THROW (matxNotSupported, " einsum() currently requires MATX_EN_CUTENSOR=ON but MATX_EN_CUTENSOR=OFF" );
0 commit comments