From 088c5602721d20abcf03b8e4c75784323e4acc5f Mon Sep 17 00:00:00 2001 From: aurianer Date: Sat, 12 Aug 2023 00:04:38 +0200 Subject: [PATCH] Create consume_rvalues separate from unwrapping --- include/dlaf/common/consume_rvalues.h | 43 +++++++++++++++++++++++++++ include/dlaf/common/unwrap.h | 13 +------- include/dlaf/sender/transform.h | 15 ++++++---- include/dlaf/sender/transform_mpi.h | 3 +- 4 files changed, 55 insertions(+), 19 deletions(-) create mode 100644 include/dlaf/common/consume_rvalues.h diff --git a/include/dlaf/common/consume_rvalues.h b/include/dlaf/common/consume_rvalues.h new file mode 100644 index 0000000000..38c15b01f2 --- /dev/null +++ b/include/dlaf/common/consume_rvalues.h @@ -0,0 +1,43 @@ +// +// Distributed Linear Algebra with Future (DLAF) +// +// Copyright (c) 2018-2023, ETH Zurich +// All rights reserved. +// +// Please, refer to the LICENSE file in the root directory. +// SPDX-License-Identifier: BSD-3-Clause +// +#pragma once + +/// @file + +#include +#include +#include + +namespace dlaf::common::internal { +/// ConsumeRvalues is a callable object wrapper that consumes rvalues passed as arguments +/// after calling the wrapped callable. +template +struct ConsumeRvalues { + std::decay_t f; + + template + auto operator()(Ts&&... ts) -> decltype(std::move(f)(std::forward(ts)...)) { + using result_type = decltype(std::move(f)(std::forward(ts)...)); + if constexpr (std::is_void_v) { + std::move(f)(std::forward(ts)...); + std::tuple(std::forward(ts)...); + } + else { + auto r = std::move(f)(std::forward(ts)...); + std::tuple(std::forward(ts)...); + return r; + } + } +}; + +template +ConsumeRvalues(F&&) -> ConsumeRvalues>; + +} diff --git a/include/dlaf/common/unwrap.h b/include/dlaf/common/unwrap.h index 47a7fcd6fd..dd92b1444a 100644 --- a/include/dlaf/common/unwrap.h +++ b/include/dlaf/common/unwrap.h @@ -59,18 +59,7 @@ struct Unwrapping { template auto operator()(Ts&&... ts) -> decltype(std::move(f)(Unwrapper>::unwrap(std::forward(ts))...)) { - using result_type = decltype(std::move(f)(Unwrapper>::unwrap(std::forward(ts))...)); - if constexpr(std::is_void_v) - { - std::move(f)(Unwrapper>::unwrap(std::forward(ts))...); - std::tuple(std::forward(ts)...); - } - else - { - auto r = std::move(f)(Unwrapper>::unwrap(std::forward(ts))...); - std::tuple(std::forward(ts)...); - return r; - } + return std::move(f)(Unwrapper>::unwrap(std::forward(ts))...); } }; diff --git a/include/dlaf/sender/transform.h b/include/dlaf/sender/transform.h index a114edeb41..8b184ed02c 100644 --- a/include/dlaf/sender/transform.h +++ b/include/dlaf/sender/transform.h @@ -11,6 +11,7 @@ #include +#include #include #include #include @@ -43,7 +44,7 @@ enum class TransformDispatchType { Plain, Blas, Lapack }; // allows choosing the priority. // // At its core, transform is a convenience wrapper around -// sender | transfer(with_priority(scheduler, priority)) | then(unwrapping(f)). +// sender | transfer(with_priority(scheduler, priority)) | then(ConsumeRvalues(unwrapping(f))). /// Lazy transform. This does not submit the work and returns a sender. template (policy.priority()); auto transfer_sender = transfer(std::forward(sender), std::move(scheduler)); + using dlaf::common::internal::ConsumeRvalues; + using dlaf::common::internal::Unwrapping; + if constexpr (B == Backend::MC) { - return then(std::move(transfer_sender), dlaf::common::internal::Unwrapping{std::forward(f)}); + return then(std::move(transfer_sender), ConsumeRvalues{Unwrapping{std::forward(f)}}); } else if constexpr (B == Backend::GPU) { #if defined(DLAF_WITH_GPU) @@ -67,16 +71,15 @@ template (f)}); + ConsumeRvalues{Unwrapping{std::forward(f)}}); } else if constexpr (Tag == TransformDispatchType::Blas) { return then_with_cublas(std::move(transfer_sender), - dlaf::common::internal::Unwrapping{std::forward(f)}, - CUBLAS_POINTER_MODE_HOST); + ConsumeRvalues{Unwrapping{std::forward(f)}}, CUBLAS_POINTER_MODE_HOST); } else if constexpr (Tag == TransformDispatchType::Lapack) { return then_with_cusolver(std::move(transfer_sender), - dlaf::common::internal::Unwrapping{std::forward(f)}); + ConsumeRvalues{Unwrapping{std::forward(f)}}); } else { DLAF_STATIC_FAIL( diff --git a/include/dlaf/sender/transform_mpi.h b/include/dlaf/sender/transform_mpi.h index 50f23cb0b7..91d6048ab9 100644 --- a/include/dlaf/sender/transform_mpi.h +++ b/include/dlaf/sender/transform_mpi.h @@ -11,6 +11,7 @@ #include +#include #include #include #include @@ -88,7 +89,7 @@ template (sender), dlaf::internal::getMPIScheduler()) | - ex::then(MPICallHelper{std::forward(f)}); + ex::then(dlaf::common::internal::ConsumeRvalues{MPICallHelper{std::forward(f)}}); } /// Fire-and-forget transformMPI. This submits the work and returns void.