Skip to content

Commit e01c4c6

Browse files
authored
Merge pull request #343 from crtrott/submdspan-padded-layouts-rebase2
Adds submdspan_mapping for padded layouts
2 parents 97c8eef + 451cce6 commit e01c4c6

File tree

4 files changed

+286
-48
lines changed

4 files changed

+286
-48
lines changed

include/experimental/__p2630_bits/submdspan_mapping.hpp

Lines changed: 184 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,21 @@ struct deduce_layout_left_submapping<
182182
}
183183
};
184184

185+
// We are reusing the same thing for layout_left and layout_left_padded
186+
// For layout_left as source StaticStride is static_extent(0)
187+
template<class Extents, size_t NumGaps, size_t StaticStride>
188+
struct compute_s_static_layout_left {
189+
// Neither StaticStride nor any of the provided extents can be zero.
190+
// StaticStride can never be zero, the static_extents we are looking at are associated with
191+
// integral slice specifiers - which wouldn't be valid for zero extent
192+
template<size_t ... Idx>
193+
MDSPAN_INLINE_FUNCTION
194+
static constexpr size_t value(std::index_sequence<Idx...>) {
195+
size_t val = ((Idx>0 && Idx<=NumGaps ? (Extents::static_extent(Idx) == dynamic_extent?0:Extents::static_extent(Idx)) : 1) * ... * (StaticStride == dynamic_extent?0:StaticStride));
196+
return val == 0?dynamic_extent:val;
197+
}
198+
};
199+
185200
} // namespace detail
186201

187202
// Actual submdspan mapping call
@@ -202,14 +217,6 @@ layout_left::mapping<Extents>::submdspan_mapping_impl(
202217
std::make_index_sequence<src_ext_t::rank()>,
203218
SliceSpecifiers...>;
204219

205-
using dst_layout_t = std::conditional_t<
206-
deduce_layout::layout_left_value(), layout_left,
207-
std::conditional_t<
208-
deduce_layout::layout_left_padded_value(),
209-
MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_left_padded<dynamic_extent>,
210-
layout_stride>>;
211-
using dst_mapping_t = typename dst_layout_t::template mapping<dst_ext_t>;
212-
213220
// Figure out if any slice's lower bound equals the corresponding extent.
214221
// If so, bypass evaluating the layout mapping. This fixes LWG Issue 4060.
215222
const bool out_of_bounds =
@@ -218,17 +225,19 @@ layout_left::mapping<Extents>::submdspan_mapping_impl(
218225
out_of_bounds ? this->required_span_size()
219226
: this->operator()(detail::first_of(slices)...));
220227

221-
if constexpr (std::is_same_v<dst_layout_t, layout_left>) {
228+
if constexpr (deduce_layout::layout_left_value()) {
222229
// layout_left case
230+
using dst_mapping_t = typename layout_left::template mapping<dst_ext_t>;
223231
return submdspan_mapping_result<dst_mapping_t>{dst_mapping_t(dst_ext),
224232
offset};
225-
} else if constexpr (std::is_same_v<dst_layout_t,
226-
MDSPAN_IMPL_PROPOSED_NAMESPACE::
227-
layout_left_padded<dynamic_extent>>) {
233+
} else if constexpr (deduce_layout::layout_left_padded_value()) {
234+
constexpr size_t S_static = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::compute_s_static_layout_left<Extents, deduce_layout::gap_len, Extents::static_extent(0)>::value(std::make_index_sequence<Extents::rank()>());
235+
using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_left_padded<S_static>::template mapping<dst_ext_t>;
228236
return submdspan_mapping_result<dst_mapping_t>{
229237
dst_mapping_t(dst_ext, stride(1 + deduce_layout::gap_len)), offset};
230238
} else {
231239
// layout_stride case
240+
using dst_mapping_t = typename layout_stride::mapping<dst_ext_t>;
232241
auto inv_map = detail::inv_map_rank(std::integral_constant<size_t, 0>(),
233242
std::index_sequence<>(), slices...);
234243
return submdspan_mapping_result<dst_mapping_t> {
@@ -253,6 +262,77 @@ layout_left::mapping<Extents>::submdspan_mapping_impl(
253262
#endif
254263
}
255264

265+
template <size_t PaddingValue>
266+
template <class Extents>
267+
template <class... SliceSpecifiers>
268+
MDSPAN_INLINE_FUNCTION constexpr auto
269+
MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_left_padded<PaddingValue>::mapping<Extents>::submdspan_mapping_impl(
270+
SliceSpecifiers... slices) const {
271+
272+
// compute sub extents
273+
using src_ext_t = Extents;
274+
auto dst_ext = submdspan_extents(extents(), slices...);
275+
using dst_ext_t = decltype(dst_ext);
276+
277+
if constexpr (Extents::rank() == 0) { // rank-0 case
278+
using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_left_padded<PaddingValue>::template mapping<Extents>;
279+
return submdspan_mapping_result<dst_mapping_t>{*this, 0};
280+
} else {
281+
const bool out_of_bounds =
282+
MDSPAN_IMPL_STANDARD_NAMESPACE::detail::any_slice_out_of_bounds(this->extents(), slices...);
283+
auto offset = static_cast<size_t>(
284+
out_of_bounds ? this->required_span_size()
285+
: this->operator()(MDSPAN_IMPL_STANDARD_NAMESPACE::detail::first_of(slices)...));
286+
if constexpr (dst_ext_t::rank() == 0) { // result rank-0
287+
using dst_mapping_t = typename layout_left::template mapping<dst_ext_t>;
288+
return submdspan_mapping_result<dst_mapping_t>{dst_mapping_t{dst_ext}, offset};
289+
} else { // general case
290+
// Figure out if any slice's lower bound equals the corresponding extent.
291+
// If so, bypass evaluating the layout mapping. This fixes LWG Issue 4060.
292+
// figure out sub layout type
293+
using deduce_layout = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::deduce_layout_left_submapping<
294+
typename dst_ext_t::index_type, dst_ext_t::rank(),
295+
decltype(std::make_index_sequence<src_ext_t::rank()>()),
296+
SliceSpecifiers...>;
297+
298+
if constexpr (deduce_layout::layout_left_value() && dst_ext_t::rank() == 1) { // getting rank-1 from leftmost
299+
using dst_mapping_t = typename layout_left::template mapping<dst_ext_t>;
300+
return submdspan_mapping_result<dst_mapping_t>{dst_mapping_t{dst_ext}, offset};
301+
} else if constexpr (deduce_layout::layout_left_padded_value()) { // can keep layout_left_padded
302+
constexpr size_t S_static = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::compute_s_static_layout_left<Extents, deduce_layout::gap_len, static_padding_stride>::value(std::make_index_sequence<Extents::rank()>());
303+
using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_left_padded<S_static>::template mapping<dst_ext_t>;
304+
return submdspan_mapping_result<dst_mapping_t>{
305+
dst_mapping_t(dst_ext, stride(1 + deduce_layout::gap_len)), offset};
306+
} else { // layout_stride
307+
auto inv_map = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::inv_map_rank(std::integral_constant<size_t, 0>(),
308+
std::index_sequence<>(), slices...);
309+
using dst_mapping_t = typename layout_stride::template mapping<dst_ext_t>;
310+
return submdspan_mapping_result<dst_mapping_t> {
311+
dst_mapping_t(dst_ext,
312+
MDSPAN_IMPL_STANDARD_NAMESPACE::detail::construct_sub_strides(
313+
*this, inv_map,
314+
// HIP needs deduction guides to have markups so we need to be explicit
315+
// NVCC 11.0 has a bug with deduction guide here, tested that 11.2 does not have
316+
// the issue But Clang-CUDA also doesn't accept the use of deduction guide so
317+
// disable it for CUDA alltogether
318+
#if defined(_MDSPAN_HAS_HIP) || defined(_MDSPAN_HAS_CUDA)
319+
std::tuple<decltype(MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices))...>{
320+
MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices)...})),
321+
#else
322+
std::tuple{MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices)...})),
323+
#endif
324+
offset
325+
};
326+
}
327+
}
328+
}
329+
330+
331+
#if defined(__NVCC__) && !defined(__CUDA_ARCH__) && defined(__GNUC__)
332+
__builtin_unreachable();
333+
#endif
334+
}
335+
256336
//**********************************
257337
// layout_right submdspan_mapping
258338
//*********************************
@@ -322,6 +402,21 @@ struct deduce_layout_right_submapping<
322402
}
323403
};
324404

405+
// We are reusing the same thing for layout_right and layout_right_padded
406+
// For layout_right as source StaticStride is static_extent(Rank-1)
407+
template<class Extents, size_t NumGaps, size_t StaticStride>
408+
struct compute_s_static_layout_right {
409+
// Neither StaticStride nor any of the provided extents can be zero.
410+
// StaticStride can never be zero, the static_extents we are looking at are associated with
411+
// integral slice specifiers - which wouldn't be valid for zero extent
412+
template<size_t ... Idx>
413+
MDSPAN_INLINE_FUNCTION
414+
static constexpr size_t value(std::index_sequence<Idx...>) {
415+
size_t val = ((Idx >= Extents::rank() - 1 - NumGaps && Idx < Extents::rank() - 1 ? (Extents::static_extent(Idx) == dynamic_extent?0:Extents::static_extent(Idx)) : 1) * ... * (StaticStride == dynamic_extent?0:StaticStride));
416+
return val == 0?dynamic_extent:val;
417+
}
418+
};
419+
325420
} // namespace detail
326421

327422
// Actual submdspan mapping call
@@ -342,14 +437,6 @@ layout_right::mapping<Extents>::submdspan_mapping_impl(
342437
std::make_index_sequence<src_ext_t::rank()>,
343438
SliceSpecifiers...>;
344439

345-
using dst_layout_t = std::conditional_t<
346-
deduce_layout::layout_right_value(), layout_right,
347-
std::conditional_t<
348-
deduce_layout::layout_right_padded_value(),
349-
MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<dynamic_extent>,
350-
layout_stride>>;
351-
using dst_mapping_t = typename dst_layout_t::template mapping<dst_ext_t>;
352-
353440
// Figure out if any slice's lower bound equals the corresponding extent.
354441
// If so, bypass evaluating the layout mapping. This fixes LWG Issue 4060.
355442
const bool out_of_bounds =
@@ -358,20 +445,21 @@ layout_right::mapping<Extents>::submdspan_mapping_impl(
358445
out_of_bounds ? this->required_span_size()
359446
: this->operator()(detail::first_of(slices)...));
360447

361-
if constexpr (std::is_same_v<dst_layout_t, layout_right>) {
448+
if constexpr (deduce_layout::layout_right_value()) {
362449
// layout_right case
450+
using dst_mapping_t = typename layout_right::mapping<dst_ext_t>;
363451
return submdspan_mapping_result<dst_mapping_t>{dst_mapping_t(dst_ext),
364452
offset};
365-
} else if constexpr (std::is_same_v<
366-
dst_layout_t,
367-
MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<
368-
dynamic_extent>>) {
453+
} else if constexpr (deduce_layout::layout_right_padded_value()) {
454+
constexpr size_t S_static = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::compute_s_static_layout_left<Extents, deduce_layout::gap_len, Extents::static_extent(Extents::rank() - 1)>::value(std::make_index_sequence<Extents::rank()>());
455+
using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<S_static>::template mapping<dst_ext_t>;
369456
return submdspan_mapping_result<dst_mapping_t>{
370457
dst_mapping_t(dst_ext,
371458
stride(src_ext_t::rank() - 2 - deduce_layout::gap_len)),
372459
offset};
373460
} else {
374461
// layout_stride case
462+
using dst_mapping_t = typename layout_stride::mapping<dst_ext_t>;
375463
auto inv_map = detail::inv_map_rank(std::integral_constant<size_t, 0>(),
376464
std::index_sequence<>(), slices...);
377465
return submdspan_mapping_result<dst_mapping_t> {
@@ -396,6 +484,77 @@ layout_right::mapping<Extents>::submdspan_mapping_impl(
396484
#endif
397485
}
398486

487+
template <size_t PaddingValue>
488+
template <class Extents>
489+
template <class... SliceSpecifiers>
490+
MDSPAN_INLINE_FUNCTION constexpr auto
491+
MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<PaddingValue>::mapping<Extents>::submdspan_mapping_impl(
492+
SliceSpecifiers... slices) const {
493+
494+
// compute sub extents
495+
using src_ext_t = Extents;
496+
auto dst_ext = submdspan_extents(extents(), slices...);
497+
using dst_ext_t = decltype(dst_ext);
498+
499+
if constexpr (Extents::rank() == 0) { // rank-0 case
500+
using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<PaddingValue>::template mapping<Extents>;
501+
return submdspan_mapping_result<dst_mapping_t>{*this, 0};
502+
} else {
503+
// Figure out if any slice's lower bound equals the corresponding extent.
504+
// If so, bypass evaluating the layout mapping. This fixes LWG Issue 4060.
505+
// figure out sub layout type
506+
const bool out_of_bounds =
507+
MDSPAN_IMPL_STANDARD_NAMESPACE::detail::any_slice_out_of_bounds(this->extents(), slices...);
508+
auto offset = static_cast<size_t>(
509+
out_of_bounds ? this->required_span_size()
510+
: this->operator()(MDSPAN_IMPL_STANDARD_NAMESPACE::detail::first_of(slices)...));
511+
if constexpr (dst_ext_t::rank() == 0) { // result rank-0
512+
using dst_mapping_t = typename layout_right::template mapping<dst_ext_t>;
513+
return submdspan_mapping_result<dst_mapping_t>{dst_mapping_t{dst_ext}, offset};
514+
} else { // general case
515+
using deduce_layout = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::deduce_layout_right_submapping<
516+
typename dst_ext_t::index_type, dst_ext_t::rank(),
517+
decltype(std::make_index_sequence<src_ext_t::rank()>()),
518+
SliceSpecifiers...>;
519+
520+
if constexpr (deduce_layout::layout_right_value() && dst_ext_t::rank() == 1) { // getting rank-1 from rightmost
521+
using dst_mapping_t = typename layout_right::template mapping<dst_ext_t>;
522+
return submdspan_mapping_result<dst_mapping_t>{dst_mapping_t{dst_ext}, offset};
523+
} else if constexpr (deduce_layout::layout_right_padded_value()) { // can keep layout_right_padded
524+
constexpr size_t S_static = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::compute_s_static_layout_right<Extents, deduce_layout::gap_len, static_padding_stride>::value(std::make_index_sequence<Extents::rank()>());
525+
using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<S_static>::template mapping<dst_ext_t>;
526+
return submdspan_mapping_result<dst_mapping_t>{
527+
dst_mapping_t(dst_ext, stride(Extents::rank() - 2 - deduce_layout::gap_len)), offset};
528+
} else { // layout_stride
529+
auto inv_map = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::inv_map_rank(std::integral_constant<size_t, 0>(),
530+
std::index_sequence<>(), slices...);
531+
using dst_mapping_t = typename layout_stride::template mapping<dst_ext_t>;
532+
return submdspan_mapping_result<dst_mapping_t> {
533+
dst_mapping_t(dst_ext,
534+
MDSPAN_IMPL_STANDARD_NAMESPACE::detail::construct_sub_strides(
535+
*this, inv_map,
536+
// HIP needs deduction guides to have markups so we need to be explicit
537+
// NVCC 11.0 has a bug with deduction guide here, tested that 11.2 does not have
538+
// the issue But Clang-CUDA also doesn't accept the use of deduction guide so
539+
// disable it for CUDA alltogether
540+
#if defined(_MDSPAN_HAS_HIP) || defined(_MDSPAN_HAS_CUDA)
541+
std::tuple<decltype(MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices))...>{
542+
MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices)...})),
543+
#else
544+
std::tuple{MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices)...})),
545+
#endif
546+
offset
547+
};
548+
}
549+
}
550+
}
551+
552+
553+
#if defined(__NVCC__) && !defined(__CUDA_ARCH__) && defined(__GNUC__)
554+
__builtin_unreachable();
555+
#endif
556+
}
557+
399558
//**********************************
400559
// layout_stride submdspan_mapping
401560
//*********************************

include/experimental/__p2642_bits/layout_padded.hpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class layout_left_padded<PaddingValue>::mapping {
221221
#endif
222222

223223
MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping(const mapping&) noexcept = default;
224-
MDSPAN_INLINE_FUNCTION_DEFAULTED mapping& operator=(const mapping&) noexcept = default;
224+
MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping& operator=(const mapping&) noexcept = default;
225225

226226
/**
227227
* Initializes the mapping with the given extents.
@@ -497,10 +497,12 @@ class layout_left_padded<PaddingValue>::mapping {
497497

498498
// [mdspan.submdspan.mapping], submdspan mapping specialization
499499
template<class... SliceSpecifiers>
500+
MDSPAN_INLINE_FUNCTION
500501
constexpr auto submdspan_mapping_impl(
501502
SliceSpecifiers... slices) const;
502503

503504
template<class... SliceSpecifiers>
505+
MDSPAN_INLINE_FUNCTION
504506
friend constexpr auto submdspan_mapping(
505507
const mapping& src, SliceSpecifiers... slices) {
506508
return src.submdspan_mapping_impl(slices...);
@@ -582,7 +584,7 @@ class layout_right_padded<PaddingValue>::mapping {
582584
#endif
583585

584586
MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping(const mapping&) noexcept = default;
585-
MDSPAN_INLINE_FUNCTION_DEFAULTED mapping& operator=(const mapping&) noexcept = default;
587+
MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping& operator=(const mapping&) noexcept = default;
586588

587589
/**
588590
* Initializes the mapping with the given extents.
@@ -847,6 +849,19 @@ class layout_right_padded<PaddingValue>::mapping {
847849
return !(left == right);
848850
}
849851
#endif
852+
853+
// [mdspan.submdspan.mapping], submdspan mapping specialization
854+
template<class... SliceSpecifiers>
855+
MDSPAN_INLINE_FUNCTION
856+
constexpr auto submdspan_mapping_impl(
857+
SliceSpecifiers... slices) const;
858+
859+
template<class... SliceSpecifiers>
860+
MDSPAN_INLINE_FUNCTION
861+
friend constexpr auto submdspan_mapping(
862+
const mapping& src, SliceSpecifiers... slices) {
863+
return src.submdspan_mapping_impl(slices...);
864+
}
850865
};
851866
}
852867
}

0 commit comments

Comments
 (0)