Skip to content

Commit 7d7d101

Browse files
committed
add the get_await_completion_adaptor query and use it in as_awaitable
1 parent ba86938 commit 7d7d101

File tree

3 files changed

+92
-43
lines changed

3 files changed

+92
-43
lines changed

include/stdexec/__detail/__as_awaitable.hpp

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
#include "__concepts.hpp"
2323
#include "__connect.hpp"
2424
#include "__meta.hpp"
25+
#include "__queries.hpp"
2526
#include "__tag_invoke.hpp"
2627
#include "__type_traits.hpp"
28+
#include "__variant.hpp"
2729

2830
#include <exception>
31+
#include <functional> // for std::identity
2932
#include <system_error>
30-
#include <variant>
3133

3234
namespace STDEXEC
3335
{
@@ -49,6 +51,23 @@ namespace STDEXEC
4951
template <class _Sender, class _Promise>
5052
using __value_t = __decay_t<
5153
__value_types_of_t<_Sender, env_of_t<_Promise&>, __q<__single_value>, __msingle_or<void>>>;
54+
55+
inline constexpr auto __get_await_completion_adaptor =
56+
__with_default{get_await_completion_adaptor, std::identity{}};
57+
58+
template <class _Sender>
59+
using __adapt_completion_t = __result_of<__get_await_completion_adaptor, env_of_t<_Sender>>;
60+
61+
template <class _Sender>
62+
constexpr auto __adapt_sender_for_await(_Sender&& __sndr)
63+
noexcept(__nothrow_callable<__adapt_completion_t<_Sender>, _Sender>) -> decltype(auto)
64+
{
65+
return __get_await_completion_adaptor(get_env(__sndr))(static_cast<_Sender&&>(__sndr));
66+
}
67+
68+
template <class _Sender>
69+
using __adapted_sender_t =
70+
__remove_rvalue_reference_t<__call_result_t<__adapt_completion_t<_Sender>, _Sender>>;
5271
} // namespace __detail
5372

5473
/////////////////////////////////////////////////////////////////////////////
@@ -62,8 +81,7 @@ namespace STDEXEC
6281
using __value_or_void_t = __if_c<__same_as<_Value, void>, __void, _Value>;
6382

6483
template <class _Value>
65-
using __expected_t =
66-
std::variant<std::monostate, __value_or_void_t<_Value>, std::exception_ptr>;
84+
using __expected_t = __variant<__value_or_void_t<_Value>, std::exception_ptr>;
6785

6886
// Helper to cast a coroutine_handle<void> to coroutine_handle<_Promise>
6987
template <class _Promise>
@@ -84,7 +102,7 @@ namespace STDEXEC
84102
{
85103
STDEXEC_TRY
86104
{
87-
__result_->template emplace<1>(static_cast<_Us&&>(__us)...);
105+
__result_->template emplace<0>(static_cast<_Us&&>(__us)...);
88106
__continuation_.resume();
89107
}
90108
STDEXEC_CATCH_ALL
@@ -97,11 +115,11 @@ namespace STDEXEC
97115
void set_error(_Error&& __err) noexcept
98116
{
99117
if constexpr (__decays_to<_Error, std::exception_ptr>)
100-
__result_->template emplace<2>(static_cast<_Error&&>(__err));
118+
__result_->template emplace<1>(static_cast<_Error&&>(__err));
101119
else if constexpr (__decays_to<_Error, std::error_code>)
102-
__result_->template emplace<2>(std::make_exception_ptr(std::system_error(__err)));
120+
__result_->template emplace<1>(std::make_exception_ptr(std::system_error(__err)));
103121
else
104-
__result_->template emplace<2>(std::make_exception_ptr(static_cast<_Error&&>(__err)));
122+
__result_->template emplace<1>(std::make_exception_ptr(static_cast<_Error&&>(__err)));
105123
__continuation_.resume();
106124
}
107125

@@ -145,22 +163,22 @@ namespace STDEXEC
145163
{
146164
switch (__result_.index())
147165
{
148-
case 0: // receiver contract not satisfied
166+
case __variant_npos: // receiver contract not satisfied
149167
STDEXEC_ASSERT(false && +"_Should never get here" == nullptr);
150168
break;
151-
case 1: // set_value
169+
case 0: // set_value
152170
if constexpr (!__same_as<_Value, void>)
153-
return static_cast<_Value&&>(std::get<1>(__result_));
171+
return static_cast<_Value&&>(__var::__get<0>(__result_));
154172
else
155173
return;
156-
case 2: // set_error
157-
std::rethrow_exception(std::get<2>(__result_));
174+
case 1: // set_error
175+
std::rethrow_exception(__var::__get<1>(__result_));
158176
}
159177
std::terminate();
160178
}
161179

162180
protected:
163-
__expected_t<_Value> __result_;
181+
__expected_t<_Value> __result_{__no_init};
164182
};
165183

166184
template <class _Promise, class _Sender>
@@ -185,14 +203,23 @@ namespace STDEXEC
185203
};
186204

187205
template <class _Sender, class _Promise>
188-
concept __awaitable_sender = sender_in<_Sender, env_of_t<_Promise&>>
189-
&& __minvocable_q<__detail::__value_t, _Sender, _Promise>
190-
&& sender_to<_Sender, __receiver_t<_Sender, _Promise>>
191-
&& requires(_Promise& __promise) {
192-
{
193-
__promise.unhandled_stopped()
194-
} -> __std::convertible_to<__std::coroutine_handle<>>;
195-
};
206+
STDEXEC_HOST_DEVICE_DEDUCTION_GUIDE
207+
__sender_awaitable(_Sender&&, __std::coroutine_handle<_Promise>)
208+
-> __sender_awaitable<_Promise, _Sender>;
209+
210+
template <class _Sender, class _Promise>
211+
concept __awaitable_adapted_sender = sender_in<_Sender, env_of_t<_Promise&>>
212+
&& __minvocable_q<__detail::__value_t, _Sender, _Promise>
213+
&& sender_to<_Sender, __receiver_t<_Sender, _Promise>>
214+
&& requires(_Promise& __promise) {
215+
{
216+
__promise.unhandled_stopped()
217+
} -> __std::convertible_to<__std::coroutine_handle<>>;
218+
};
219+
220+
template <class _Sender, class _Promise>
221+
concept __awaitable_sender =
222+
__awaitable_adapted_sender<__detail::__adapted_sender_t<_Sender>, _Promise>;
196223

197224
struct __unspecified
198225
{
@@ -214,32 +241,33 @@ namespace STDEXEC
214241
template <class _Tp, class _Promise>
215242
static consteval auto __get_declfn() noexcept
216243
{
217-
using __as_awaitable::__unspecified;
244+
using namespace __as_awaitable;
218245
if constexpr (__connect_await::__has_as_awaitable_member<_Tp, _Promise>)
219246
{
220247
using __result_t = decltype(__declval<_Tp>().as_awaitable(__declval<_Promise&>()));
221248
constexpr bool __is_nothrow = noexcept(
222249
__declval<_Tp>().as_awaitable(__declval<_Promise&>()));
223250
return __declfn<__result_t, __is_nothrow>();
224-
// NOLINTNEXTLINE(bugprone-branch-clone)
225251
}
226-
else if constexpr (__awaitable<_Tp, __unspecified>)
227-
{ // NOT __awaitable<_Tp, _Promise> !!
252+
else if constexpr (__awaitable<_Tp, __unspecified>) // NOT __awaitable<_Tp, _Promise> !!
253+
{ // NOLINT(bugprone-branch-clone)
228254
return __declfn<_Tp&&>();
229255
}
230-
else if constexpr (__as_awaitable::__awaitable_sender<_Tp, _Promise>)
256+
else if constexpr (__awaitable_sender<_Tp, _Promise>)
231257
{
232-
using __result_t = __as_awaitable::__sender_awaitable<_Promise, _Tp>;
233-
constexpr bool __is_nothrow =
234-
__nothrow_constructible_from<__result_t, _Tp, __std::coroutine_handle<_Promise>>;
258+
using __result_t = decltype( //
259+
__sender_awaitable{__detail::__adapt_sender_for_await(__declval<_Tp>()),
260+
__std::coroutine_handle<_Promise>()});
261+
constexpr bool __is_nothrow = noexcept(
262+
__sender_awaitable{__detail::__adapt_sender_for_await(__declval<_Tp>()),
263+
__std::coroutine_handle<_Promise>()});
235264
return __declfn<__result_t, __is_nothrow>();
236-
// NOT TO SPEC
237265
}
238-
else if constexpr (__as_awaitable::__incompatible_sender<_Tp, _Promise>)
266+
else if constexpr (__incompatible_sender<_Tp, _Promise>)
239267
{
240-
// It's a sender, but it isn't a sender in the current promise's environment, so
241-
// we can return the error type that results from trying to compute the sender's
242-
// value type:
268+
// NOT TO SPEC: It's a sender, but it isn't a sender in the current promise's
269+
// environment, so we can return the error type that results from trying to
270+
// compute the sender's value type:
243271
return __declfn<__detail::__value_t<_Tp, _Promise>>();
244272
}
245273
else
@@ -253,24 +281,24 @@ namespace STDEXEC
253281
auto operator()(_Tp&& __t, _Promise& __promise) const noexcept(noexcept(_DeclFn()))
254282
-> decltype(_DeclFn())
255283
{
256-
using __as_awaitable::__unspecified;
284+
using namespace __as_awaitable;
257285
if constexpr (__connect_await::__has_as_awaitable_member<_Tp, _Promise>)
258286
{
259287
using __result_t = decltype(static_cast<_Tp&&>(__t).as_awaitable(__promise));
260288
static_assert(__awaitable<__result_t, _Promise>);
261289
return static_cast<_Tp&&>(__t).as_awaitable(__promise);
262-
// NOLINTNEXTLINE(bugprone-branch-clone)
263290
}
264-
else if constexpr (__awaitable<_Tp, __unspecified>)
265-
{ // NOT __awaitable<_Tp, _Promise> !!
291+
else if constexpr (__awaitable<_Tp, __unspecified>) // NOT __awaitable<_Tp, _Promise> !!
292+
{ // NOLINT(bugprone-branch-clone)
266293
return static_cast<_Tp&&>(__t);
267294
}
268-
else if constexpr (__as_awaitable::__awaitable_sender<_Tp, _Promise>)
295+
else if constexpr (__awaitable_sender<_Tp, _Promise>)
269296
{
270297
auto __hcoro = __std::coroutine_handle<_Promise>::from_promise(__promise);
271-
return __as_awaitable::__sender_awaitable<_Promise, _Tp>{static_cast<_Tp&&>(__t), __hcoro};
298+
return __sender_awaitable{__detail::__adapt_sender_for_await(static_cast<_Tp&&>(__t)),
299+
__hcoro};
272300
}
273-
else if constexpr (__as_awaitable::__incompatible_sender<_Tp, _Promise>)
301+
else if constexpr (__incompatible_sender<_Tp, _Promise>)
274302
{
275303
return __detail::__value_t<_Tp, _Promise>();
276304
}

include/stdexec/__detail/__execution_fwd.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ namespace STDEXEC
117117
template <__completion_tag _CPO>
118118
struct get_completion_behavior_t;
119119
struct get_domain_t;
120+
struct get_await_completion_adaptor_t;
120121

121122
struct __debug_env_t;
122123

@@ -127,8 +128,9 @@ namespace STDEXEC
127128
template <__completion_tag _CPO>
128129
extern get_completion_scheduler_t<_CPO> const get_completion_scheduler;
129130
template <class _CPO = void>
130-
extern get_completion_domain_t<_CPO> const get_completion_domain;
131-
extern get_domain_t const get_domain;
131+
extern get_completion_domain_t<_CPO> const get_completion_domain;
132+
extern get_domain_t const get_domain;
133+
extern get_await_completion_adaptor_t const get_await_completion_adaptor;
132134

133135
template <class _Env>
134136
concept __is_debug_env = __callable<__debug_env_t, _Env>;

include/stdexec/__detail/__queries.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,25 @@ namespace STDEXEC
2828
//////////////////////////////////////////////////////////////////////////////////////////////////
2929
// [exec.queries]
3030

31+
// [exec.get.await.adapt], see https://eel.is/c++draft/exec#get.await.adapt
32+
struct get_await_completion_adaptor_t : __query<get_await_completion_adaptor_t>
33+
{
34+
template <class _Env>
35+
STDEXEC_ATTRIBUTE(always_inline, host, device)
36+
static constexpr void __validate() noexcept
37+
{
38+
static_assert(STDEXEC::__nothrow_callable<get_await_completion_adaptor_t, _Env const &>);
39+
}
40+
41+
STDEXEC_ATTRIBUTE(nodiscard, always_inline, host, device)
42+
static consteval auto query(forwarding_query_t) noexcept -> bool
43+
{
44+
return true;
45+
}
46+
};
47+
48+
inline constexpr get_await_completion_adaptor_t get_await_completion_adaptor{};
49+
3150
// NOT TO SPEC:
3251
struct __is_scheduler_affine_t
3352
{

0 commit comments

Comments
 (0)