Skip to content

Commit

Permalink
[WIP] late-bound stream scheduler algorithm customizations
Browse files Browse the repository at this point in the history
  • Loading branch information
ericniebler committed Sep 28, 2023
1 parent 9af4974 commit f8c27c5
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 38 deletions.
10 changes: 6 additions & 4 deletions examples/nvexec/maxwell/snr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,12 @@ auto maxwell_eqs_snr(
return ex::just()
| exec::on(
computer,
repeat_n(
n_iterations,
ex::bulk(accessor.cells, update_h(accessor))
| ex::bulk(accessor.cells, update_e(time, dt, accessor))))
ex::bulk(accessor.cells, update_h(accessor))
| ex::bulk(accessor.cells, update_e(time, dt, accessor)))
// repeat_n(
// n_iterations,
// ex::bulk(accessor.cells, update_h(accessor))
// | ex::bulk(accessor.cells, update_e(time, dt, accessor))))
| ex::then(dump_vtk(write_results, accessor));
}

Expand Down
27 changes: 23 additions & 4 deletions include/nvexec/stream/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ namespace nvexec {
}
};

struct stream_scheduler;

struct context_state_t {
std::pmr::memory_resource* pinned_resource_{nullptr};
std::pmr::memory_resource* managed_resource_{nullptr};
Expand Down Expand Up @@ -195,14 +197,22 @@ namespace nvexec {
void return_stream(cudaStream_t stream) {
stream_pools_->return_stream(stream, priority_);
}

stream_scheduler make_stream_scheduler() const noexcept;
};

struct stream_scheduler;
template <class = stream_scheduler>
struct stream_domain;

struct stream_sender_base {
using is_sender = void;
};

// needed for subsumption purposes
template <class Sender, class Env>
concept _non_stream_sender = //
!derived_from<__decay_t<Sender>, stream_sender_base>;

struct stream_receiver_base : __receiver_base {
constexpr static std::size_t memory_allocation_size = 0;
};
Expand Down Expand Up @@ -265,6 +275,10 @@ namespace nvexec {
stream_provider_t* operator()(const Env& env) const noexcept {
return tag_invoke(get_stream_provider_t{}, env);
}

friend constexpr bool tag_invoke(forwarding_query_t, const get_stream_provider_t&) noexcept {
return true;
}
};

template <class... Ts>
Expand Down Expand Up @@ -308,7 +322,10 @@ namespace nvexec {
using variant_storage_t = //
__minvoke< __minvoke<
__mfold_right<
__mbind_front_q<stream_storage_impl::variant, ::cuda::std::tuple<set_noop>>,
__mbind_front_q<
stream_storage_impl::variant,
::cuda::std::tuple<set_noop>,
::cuda::std::tuple<set_error_t, cudaError_t>>,
__mbind_front_q<stream_storage_impl::__bind_completions_t, _Sender, _Env>>,
set_value_t,
set_error_t,
Expand Down Expand Up @@ -570,7 +587,8 @@ namespace nvexec {

template <__decays_to<cudaError_t> Error>
void propagate_completion_signal(set_error_t, Error&& status) noexcept {
if constexpr (stream_receiver<outer_receiver_t>) {
using Domain = __env_domain_of_t<env_of_t<outer_receiver_t>>;
if constexpr (stream_receiver<outer_receiver_t> || same_as<Domain, stream_domain<>>) {
set_error((outer_receiver_t&&) rcvr_, (cudaError_t&&) status);
} else {
// pass a cudaError_t by value:
Expand All @@ -581,7 +599,8 @@ namespace nvexec {

template <class Tag, class... As>
void propagate_completion_signal(Tag, As&&... as) noexcept {
if constexpr (stream_receiver<outer_receiver_t>) {
using Domain = __env_domain_of_t<env_of_t<outer_receiver_t>>;
if constexpr (stream_receiver<outer_receiver_t> || same_as<Domain, stream_domain<>>) {
Tag()((outer_receiver_t&&) rcvr_, (As&&) as...);
} else {
continuation_kernel<outer_receiver_t, As&&...> // by reference
Expand Down
7 changes: 7 additions & 0 deletions include/nvexec/stream/then.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,10 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
};
};
}

namespace stdexec::__detail {
template <class SenderId, class Fun>
inline constexpr __mconst<
nvexec::STDEXEC_STREAM_DETAIL_NS::then_sender_t<__name_of<__t<SenderId>>, Fun>>
__name_of_v<nvexec::STDEXEC_STREAM_DETAIL_NS::then_sender_t<SenderId, Fun>>{};
}
98 changes: 98 additions & 0 deletions include/nvexec/stream/wrap.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright (c) 2022 NVIDIA Corporation
*
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://llvm.org/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "../../stdexec/execution.hpp"
#include <type_traits>

#include "common.cuh"

namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
namespace _wrap {
template <class SenderId>
struct sender : stream_sender_base {
using is_sender = void;
using Sender = stdexec::__t<SenderId>;
using __t = sender;
using __id = sender;

sender(Sender sndr, context_state_t context_state)
: sndr_(std::move(sndr))
, env_{context_state} {
}

struct environment {
context_state_t context_state_;

template <same_as<environment> Self>
friend auto tag_invoke(get_completion_scheduler_t<set_value_t>, const Self& env) noexcept {
return env.context_state_.make_stream_scheduler();
}
};

// BUGBUG this doesn't handle the case where the sender has a nested
// type alias named completion_signatures.
template <class Self, class Env>
using completions_t =
tag_invoke_result_t<get_completion_signatures_t, __copy_cvref_t<Self, Sender>, Env>;

// test for tag_invocable instead of sender_to because the connect customization
// point would convert the stdexec::just sender back into this nvexec::just sender,
// causing recursion.
template <__decays_to<sender> Self, receiver Receiver>
requires receiver_of<Receiver, completions_t<Self, env_of_t<Receiver>>> &&
tag_invocable<connect_t, __copy_cvref_t<Self, Sender>, Receiver>
friend auto tag_invoke(connect_t, Self&& self, Receiver rcvr) //
noexcept(nothrow_tag_invocable<connect_t, __copy_cvref_t<Self, Sender>, Receiver>)
-> tag_invoke_result_t<connect_t, __copy_cvref_t<Self, Sender>, Receiver> {
return tag_invoke(connect, ((Self&&) self).sndr_, (Receiver&&) rcvr);
}

template <__decays_to<sender> Self, class Env>
friend auto tag_invoke(get_completion_signatures_t, Self&& self, Env&& env) noexcept
-> completions_t<Self, Env> {
return {};
}

template <same_as<sender> Self>
friend const environment& tag_invoke(get_env_t, const Self& self) noexcept {
return self.env_;
}

Sender sndr_;
environment env_;
};
} // namespace _wrap

template <class Env, class Sender>
auto as_stream_sender(Sender sndr, const context_state_t&) -> Sender {
return sndr;
}

template <class Env, class Sender>
requires _non_stream_sender<Sender, Env>
auto as_stream_sender(Sender sndr, const context_state_t& context_state) //
-> _wrap::sender<__id<Sender>> {
return {std::move(sndr), context_state};
}
}

namespace stdexec::__detail {
template <class SenderId>
inline constexpr __mconst<
nvexec::STDEXEC_STREAM_DETAIL_NS::_wrap::sender<__name_of<__t<SenderId>>>>
__name_of_v<nvexec::STDEXEC_STREAM_DETAIL_NS::_wrap::sender<SenderId>>{};
}
105 changes: 89 additions & 16 deletions include/nvexec/stream_context.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "stream/when_all.cuh"
#include "stream/reduce.cuh"
#include "stream/ensure_started.cuh"
#include "stream/wrap.cuh"

#include "stream/common.cuh"
#include "detail/queue.cuh"
Expand Down Expand Up @@ -87,41 +88,109 @@ namespace nvexec {
template <sender Sender>
using ensure_started_th = __t<ensure_started_sender_t<__id<Sender>>>;

// needed for subsumption purposes
template <class Sender, class Env>
concept _non_stream_sender = //
!derived_from<__decay_t<Sender>, stream_sender_base>;

struct stream_scheduler;

template <class = stream_scheduler>
// template <class = stream_scheduler>
// struct stream_domain;

// template <class Tag>
// struct _just_t : Tag {
// static __prop<get_domain_t, stream_domain<>> get_env(auto&&) noexcept {
// return __mkprop
// }
// }

template <class /*= stream_scheduler*/>
struct stream_domain : private __default_domain<context_state_t> {
using __default_domain::__default_domain;
using __default_domain::transform_sender;
//using __default_domain::transform_sender;

// Lazy algorithm customizations require a recursive tree transformation
template <sender_expr Sender, class Env>
requires _non_stream_sender<Sender, Env> // no need to transform it a second time
auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
return stdexec::apply_sender(
//print<__name_of<Sender>>();
auto result = stdexec::apply_sender(
(Sender&&) sndr,
[&]<class Tag, class Data, class... Children>(Tag, Data&& data, Children&&... children) {
return make_sender_expr<Tag, stream_domain>(
(Data&&) data, transform_sender((Children&&) children, env)...);
return //as_stream_sender<Env>(
make_sender_expr<Tag, stream_domain>(
(Data&&) data,
stdexec::transform_sender(*this, (Children&&) children, env)...); //,
//base());
});
//print<__name_of<decltype(result)>>();
return result;
}

// reduce senders get a special transformation
template <sender_expr_for<reduce_t> Sender, class Env>
template <sender_expr_for<schedule_from_t> Sender, class Env>
requires _non_stream_sender<Sender, Env> // no need to transform it a second time
auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
return stdexec::apply_sender(
(Sender&&) sndr,
[&]<class Tag, class Data, class Child>(Tag, Data&& data, Child&& child) {
auto [init, fun] = (Data&&) data;
auto next = transform_sender((Child&&) child, env);
return reduce_sender_t<decltype(next), decltype(init), decltype(fun)>(
std::move(next), init, fun);
auto sched = get_scheduler(env);
auto next = stdexec::transform_sender(*this, (Child&&) child, env);
return stdexec::__t<
schedule_from_sender_t<stream_scheduler, stdexec::__id<decltype(next)>>>{
sched.context_state_, std::move(next)};
});
}

// // reduce senders get a special transformation
// template <sender_expr_for<reduce_t> Sender, class Env>
// requires _non_stream_sender<Sender, Env> // no need to transform it a second time
// auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
// return stdexec::apply_sender(
// (Sender&&) sndr,
// [&]<class Tag, class Data, class Child>(Tag, Data&& data, Child&& child) {
// auto [init, fun] = (Data&&) data;
// auto next = stdexec::transform_sender(*this, (Child&&) child, env);
// return reduce_sender_t<decltype(next), decltype(init), decltype(fun)>(
// std::move(next), init, fun);
// });
// }

// transform senders get a special transformation
template <sender_expr_for<transfer_t> Sender, class Env>
requires _non_stream_sender<Sender, Env> // no need to transform it a second time
auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
return stdexec::apply_sender(
(Sender&&) sndr, [&]<class Data, class Child>(__ignore, Data&& data, Child&& child) {
auto from = get_scheduler(env);
auto to = get_completion_scheduler<set_value_t>(data);
auto next = stdexec::transform_sender(*this, (Child&&) child, env);
auto transfer = __t<transfer_sender_t<decltype(next)>>(
from.context_state_, std::move(next));
return __t< schedule_from_sender_t<stream_scheduler, __id<decltype(transfer)>>>{
from.context_state_, std::move(transfer)};
});
}

// template <sender_expr_for<just_t, just_error_t, just_stopped_t> Sender, class Env>
// requires _non_stream_sender<Sender, Env> // no need to transform it a second time
// auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
// return stdexec::apply_sender(
// (Sender&&) sndr, [&]<class Tag, class Data>(Tag, Data&& data) {
// return make_sender_expr<Tag, stream_domain>(
// (Data&&) data, get_completion_scheduler<Tag>(data));
// });
// }
// template <sender_expr_for<just_t> Sender, class Env>
// requires _non_stream_sender<Sender, Env> // no need to transform it a second time
// auto transform_sender(Sender&& sndr, const Env&) const noexcept {
// return just_sender<__decay_t<Sender>>{(Sender&&) sndr, base()};
// }

template <sender_expr_for<bulk_t> Sender, class Env>
requires _non_stream_sender<Sender, Env> // no need to transform it a second time
auto transform_sender(Sender&& sndr, const Env& env) const noexcept {
return stdexec::apply_sender(
(Sender&&) sndr, [&]<class Data, class Child>(__ignore, Data&& data, Child&& child) {
auto&& [shape, fun] = (Data&&) data;
auto next = stdexec::transform_sender(*this, (Child&&) child, env);
return bulk_sender_th<decltype(next), decltype(shape), decltype(fun)>{
{}, std::move(next), shape, fun};
});
}

Expand Down Expand Up @@ -338,6 +407,10 @@ namespace nvexec {
return {base()};
}

stream_scheduler context_state_t::make_stream_scheduler() const noexcept {
return {*this};
}

template <stream_completing_sender Sender>
void tag_invoke(start_detached_t, Sender&& sndr) noexcept(false) {
_submit::submit_t{}((Sender&&) sndr, _start_detached::detached_receiver_t{});
Expand Down
4 changes: 2 additions & 2 deletions include/stdexec/__detail/__basic_sender.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ namespace stdexec {
concept sender_expr = //
__mvalid<__tag_of, _Sender>;

template <class _Sender, class _Tag>
template <class _Sender, class... _Tags>
concept sender_expr_for = //
sender_expr<_Sender> && same_as<__tag_of<_Sender>, _Tag>;
sender_expr<_Sender> && __one_of<__tag_of<_Sender>, _Tags...>;

// The __name_of utility defined below is used to pretty-print the type names of
// senders in compiler diagnostics.
Expand Down
4 changes: 4 additions & 0 deletions include/stdexec/__detail/__execution_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,7 @@ namespace stdexec {
using __on_v2::on_t;
}
}

template <class...> [[deprecated]] void print() {}

[[deprecated]] __global__ void kernel(auto&&...) {}
Loading

0 comments on commit f8c27c5

Please sign in to comment.