Skip to content

Commit 92a1297

Browse files
authored
MDSpan issues expose by Kokkos View refactor (#358)
* Add test mapping(other_mapping) where other_mapping has none-convertible extents Specifically the 1D layout_left_padded <-> layout_right_padded ctors didn't compile for cases where the new mapping has a static extent, but the source mapping has dynamic extent. * Fix some layout_padded conversions and compilation with CUDA * Use proper index_pair_like constraint for submdspan * Don't use get for std::complex slice specifier and test that complex<double> works like pair as slice specifier *barf* * Remove out of date comment
1 parent c2494ad commit 92a1297

File tree

6 files changed

+65
-8
lines changed

6 files changed

+65
-8
lines changed

include/experimental/__p2630_bits/submdspan_extents.hpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
#include <tuple>
20+
#include <complex>
2021

2122
#include "strided_slice.hpp"
2223
namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
@@ -52,6 +53,31 @@ template <class OffsetType, class ExtentType, class StrideType>
5253
struct is_strided_slice<
5354
strided_slice<OffsetType, ExtentType, StrideType>> : std::true_type {};
5455

56+
// Helper for identifying valid pair like things
57+
template <class T, class IndexType> struct index_pair_like : std::false_type {};
58+
59+
template <class IdxT1, class IdxT2, class IndexType>
60+
struct index_pair_like<std::pair<IdxT1, IdxT2>, IndexType> {
61+
static constexpr bool value = std::is_convertible_v<IdxT1, IndexType> &&
62+
std::is_convertible_v<IdxT2, IndexType>;
63+
};
64+
65+
template <class IdxT1, class IdxT2, class IndexType>
66+
struct index_pair_like<std::tuple<IdxT1, IdxT2>, IndexType> {
67+
static constexpr bool value = std::is_convertible_v<IdxT1, IndexType> &&
68+
std::is_convertible_v<IdxT2, IndexType>;
69+
};
70+
71+
template <class IdxT, class IndexType>
72+
struct index_pair_like<std::complex<IdxT>, IndexType> {
73+
static constexpr bool value = std::is_convertible_v<IdxT, IndexType>;
74+
};
75+
76+
template <class IdxT, class IndexType>
77+
struct index_pair_like<std::array<IdxT, 2>, IndexType> {
78+
static constexpr bool value = std::is_convertible_v<IdxT, IndexType>;
79+
};
80+
5581
// first_of(slice): getting begin of slice specifier range
5682
MDSPAN_TEMPLATE_REQUIRES(
5783
class Integral,
@@ -70,13 +96,19 @@ first_of(const ::MDSPAN_IMPL_STANDARD_NAMESPACE::full_extent_t &) {
7096

7197
MDSPAN_TEMPLATE_REQUIRES(
7298
class Slice,
73-
/* requires */(std::is_convertible_v<Slice, std::tuple<size_t, size_t>>)
99+
/* requires */(index_pair_like<Slice, size_t>::value)
74100
)
75101
MDSPAN_INLINE_FUNCTION
76102
constexpr auto first_of(const Slice &i) {
77103
return std::get<0>(i);
78104
}
79105

106+
template<class T>
107+
MDSPAN_INLINE_FUNCTION
108+
constexpr auto first_of(const std::complex<T> &i) {
109+
return i.real();
110+
}
111+
80112
template <class OffsetType, class ExtentType, class StrideType>
81113
MDSPAN_INLINE_FUNCTION
82114
constexpr OffsetType
@@ -100,14 +132,20 @@ constexpr Integral
100132

101133
MDSPAN_TEMPLATE_REQUIRES(
102134
size_t k, class Extents, class Slice,
103-
/* requires */(std::is_convertible_v<Slice, std::tuple<size_t, size_t>>)
135+
/* requires */(index_pair_like<Slice, size_t>::value)
104136
)
105137
MDSPAN_INLINE_FUNCTION
106138
constexpr auto last_of(std::integral_constant<size_t, k>, const Extents &,
107139
const Slice &i) {
108140
return std::get<1>(i);
109141
}
110142

143+
template<size_t k, class Extents, class T>
144+
MDSPAN_INLINE_FUNCTION
145+
constexpr auto last_of(std::integral_constant<size_t, k>, const Extents &, const std::complex<T> &i) {
146+
return i.imag();
147+
}
148+
111149
// Suppress spurious warning with NVCC about no return statement.
112150
// This is a known issue in NVCC and NVC++
113151
// Depending on the CUDA and GCC version we need both the builtin

include/experimental/__p2630_bits/submdspan_mapping.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ template<class SliceSpecifier, class IndexType>
9898
struct is_range_slice {
9999
constexpr static bool value =
100100
std::is_same_v<SliceSpecifier, full_extent_t> ||
101-
std::is_convertible_v<SliceSpecifier,
102-
std::tuple<IndexType, IndexType>>;
101+
index_pair_like<SliceSpecifier, IndexType>::value;
103102
};
104103

105104
template<class SliceSpecifier, class IndexType>

include/experimental/__p2642_bits/layout_padded.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ struct padded_extent {
9595
using static_array_type = typename static_array_type_for_padded_extent<
9696
padding_value, _Extents, _ExtentToPadIdx, _Extents::rank()>::type;
9797

98+
MDSPAN_INLINE_FUNCTION
9899
static constexpr auto static_value() { return static_array_type::static_value(0); }
99100

100101
MDSPAN_INLINE_FUNCTION
@@ -203,7 +204,7 @@ class layout_left_padded<PaddingValue>::mapping {
203204
}
204205

205206
public:
206-
#if !MDSPAN_HAS_CXX_20
207+
#if !MDSPAN_HAS_CXX_20 || defined(__NVCC__)
207208
MDSPAN_INLINE_FUNCTION_DEFAULTED
208209
constexpr mapping()
209210
: mapping(extents_type{})
@@ -347,7 +348,7 @@ class layout_left_padded<PaddingValue>::mapping {
347348
MDSPAN_INLINE_FUNCTION
348349
constexpr mapping(const _Mapping &other_mapping) noexcept
349350
: padded_stride(padded_stride_type::init_padding(
350-
other_mapping.extents(),
351+
static_cast<extents_type>(other_mapping.extents()),
351352
other_mapping.extents().extent(extent_to_pad_idx))),
352353
exts(other_mapping.extents()) {}
353354

@@ -566,7 +567,7 @@ class layout_right_padded<PaddingValue>::mapping {
566567
}
567568

568569
public:
569-
#if !MDSPAN_HAS_CXX_20
570+
#if !MDSPAN_HAS_CXX_20 || defined(__NVCC__)
570571
MDSPAN_INLINE_FUNCTION_DEFAULTED
571572
constexpr mapping()
572573
: mapping(extents_type{})
@@ -707,7 +708,7 @@ class layout_right_padded<PaddingValue>::mapping {
707708
MDSPAN_INLINE_FUNCTION
708709
constexpr mapping(const _Mapping &other_mapping) noexcept
709710
: padded_stride(padded_stride_type::init_padding(
710-
other_mapping.extents(),
711+
static_cast<extents_type>(other_mapping.extents()),
711712
other_mapping.extents().extent(extent_to_pad_idx))),
712713
exts(other_mapping.extents()) {}
713714

tests/test_layout_padded_left.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ TEST(LayoutLeftTests, construction)
296296
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
297297
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).extents()), (Kokkos::extents<std::size_t, 4, 7>()));
298298
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).padded_stride.value(0)), 4);
299+
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>(4))).extents()), (Kokkos::extents<std::size_t, 4, 7>()));
300+
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>(4))).padded_stride.value(0)), 4);
299301
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).extents()), (Kokkos::extents<std::size_t, 4, 7>()));
300302
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).padded_stride.value(0)), 4);
301303

@@ -311,6 +313,8 @@ TEST(LayoutLeftTests, construction)
311313
ASSERT_EQ(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t>>()).extents(), Kokkos::extents<std::size_t>());
312314
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).extents()), (Kokkos::extents<std::size_t, 3>()));
313315
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
316+
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t,1>(3))).extents()), (Kokkos::extents<std::size_t, 3>()));
317+
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t,1>(3))).padded_stride.value(0)), 0);
314318
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).extents()), (Kokkos::extents<std::size_t, 3>()));
315319
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).padded_stride.value(0)), 0);
316320

tests/test_layout_padded_right.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,17 @@ TEST(LayoutrightTests, construction)
304304
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
305305
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).extents()), (Kokkos::extents<std::size_t, 7, 5>()));
306306
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).padded_stride.value(0)), 8);
307+
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 5>>(Kokkos::extents<size_t, Kokkos::dynamic_extent, 5>(7))).extents()), (Kokkos::extents<std::size_t, 7, 5>()));
308+
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 5>>(Kokkos::extents<size_t, Kokkos::dynamic_extent, 5>(7))).padded_stride.value(0)), 8);
307309
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, Kokkos::dynamic_extent>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).extents()), (Kokkos::extents<std::size_t, 7, 5>()));
308310
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, Kokkos::dynamic_extent>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).padded_stride.value(0)), 8);
309311

310312
// Construct layout_right_padded mapping from layout_left_padded mapping
311313
ASSERT_EQ(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t>>()).extents(), Kokkos::extents<std::size_t>());
312314
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).extents()), (Kokkos::extents<std::size_t, 3>()));
313315
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
316+
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t, 1>(3))).extents()), (Kokkos::extents<std::size_t, 3>()));
317+
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t, 1>(3))).padded_stride.value(0)), 0);
314318
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).extents()), (Kokkos::extents<std::size_t, 3>()));
315319
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).padded_stride.value(0)), 0);
316320

tests/test_submdspan.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ using submdspan_test_types =
140140
// layout_right to layout_right Check Extents Preservation
141141
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t,10>, Kokkos::full_extent_t>
142142
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t,dyn>, std::pair<int,int>>
143+
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t,dyn>, std::complex<double>>
143144
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t>, int>
144145
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10,20>, args_t<10,20>, Kokkos::extents<size_t,10,20>, Kokkos::full_extent_t, Kokkos::full_extent_t>
145146
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10,20>, args_t<10,20>, Kokkos::extents<size_t,dyn,20>, std::pair<int,int>, Kokkos::full_extent_t>
@@ -274,6 +275,10 @@ struct TestSubMDSpan<
274275
return std::pair<int,int>(1,3);
275276
}
276277
MDSPAN_INLINE_FUNCTION
278+
static auto create_slice_arg(std::complex<double>) {
279+
return std::complex<double>{1.,3.};
280+
}
281+
MDSPAN_INLINE_FUNCTION
277282
static auto create_slice_arg(Kokkos::strided_slice<int,int,int>) {
278283
return Kokkos::strided_slice<int,int,int>{1,3,2};
279284
}
@@ -300,6 +305,12 @@ struct TestSubMDSpan<
300305
}
301306
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
302307
MDSPAN_INLINE_FUNCTION
308+
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>, std::complex<double> p, SliceArgs ... slices) {
309+
using idx_t = typename SubMDSpan::index_type;
310+
return (sub_mds.extent(sub_idx)==static_cast<idx_t>(p.imag()-p.real())) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,2>(), std::index_sequence<SubIdx...,1>(), slices...);
311+
}
312+
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
313+
MDSPAN_INLINE_FUNCTION
303314
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>,
304315
Kokkos::strided_slice<int,int,int> p, SliceArgs ... slices) {
305316
using idx_t = typename SubMDSpan::index_type;

0 commit comments

Comments
 (0)