Skip to content

Commit 4d54ab0

Browse files
authored
Improve submdspan testing (#342)
This now tests that the elements of the submdspan point to the correct element of the src. That caught a mistake in the layout_foo test layout. Disable bracket operator for icpx due to compiler crash just in the test config. Note: for icpx 2024 we could turn it on again.
1 parent 46f8270 commit 4d54ab0

File tree

3 files changed

+42
-21
lines changed

3 files changed

+42
-21
lines changed

.github/workflows/cmake.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ jobs:
2626
# To get new URL, look here:
2727
# https://www.intel.com/content/www/us/en/developer/articles/tool/oneapi-standalone-components.html#inpage-nav-6-undefined
2828
compiler_url: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/ebf5d9aa-17a7-46a4-b5df-ace004227c0e/l_dpcpp-cpp-compiler_p_2023.2.1.8_offline.sh
29+
cxx_flags_extra: "-DMDSPAN_USE_BRACKET_OPERATOR=0"
2930
- enable_benchmark: ON
3031
- stdcxx: 14
3132
enable_benchmark: OFF

tests/foo_customizations.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ class layout_foo::mapping {
193193
template<class Indx0, class Indx1>
194194
MDSPAN_INLINE_FUNCTION
195195
constexpr index_type operator()(Indx0 idx0, Indx1 idx1) const noexcept {
196-
return static_cast<index_type>(idx0 * __extents.extent(0) + idx1);
196+
return static_cast<index_type>(idx0 * __extents.extent(1) + idx1);
197197
}
198198

199199
MDSPAN_INLINE_FUNCTION static constexpr bool is_always_unique() noexcept { return true; }

tests/test_submdspan.cpp

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -224,38 +224,58 @@ struct TestSubMDSpan<
224224
return Kokkos::full_extent;
225225
}
226226

227-
template<class SrcExtents, class SubExtents, class ... SliceArgs>
227+
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
228228
MDSPAN_INLINE_FUNCTION
229-
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext, int, SliceArgs ... slices) {
230-
return match_expected_extents(++src_idx, sub_idx, src_ext, sub_ext, slices...);
229+
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>, int, SliceArgs ... slices) {
230+
return check_submdspan_match(++src_idx, sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,2>(), std::index_sequence<SubIdx...>(), slices...);
231231
}
232-
template<class SrcExtents, class SubExtents, class ... SliceArgs>
232+
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
233233
MDSPAN_INLINE_FUNCTION
234-
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext, std::pair<int,int> p, SliceArgs ... slices) {
235-
using idx_t = typename SubExtents::index_type;
236-
return (sub_ext.extent(sub_idx)==static_cast<idx_t>(p.second-p.first)) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
234+
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>, std::pair<int,int> p, SliceArgs ... slices) {
235+
using idx_t = typename SubMDSpan::index_type;
236+
return (sub_mds.extent(sub_idx)==static_cast<idx_t>(p.second-p.first)) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,2>(), std::index_sequence<SubIdx...,1>(), slices...);
237237
}
238-
template<class SrcExtents, class SubExtents, class ... SliceArgs>
238+
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
239239
MDSPAN_INLINE_FUNCTION
240-
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext,
240+
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>,
241241
Kokkos::strided_slice<int,int,int> p, SliceArgs ... slices) {
242-
using idx_t = typename SubExtents::index_type;
243-
return (sub_ext.extent(sub_idx)==static_cast<idx_t>((p.extent+p.stride-1)/p.stride)) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
242+
using idx_t = typename SubMDSpan::index_type;
243+
return (sub_mds.extent(sub_idx)==static_cast<idx_t>((p.extent+p.stride-1)/p.stride)) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,3>(), std::index_sequence<SubIdx...,1>(), slices...);
244244
}
245-
template<class SrcExtents, class SubExtents, class ... SliceArgs>
245+
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
246246
MDSPAN_INLINE_FUNCTION
247-
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext,
247+
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>,
248248
Kokkos::strided_slice<int,std::integral_constant<int, 0>,std::integral_constant<int,0>>, SliceArgs ... slices) {
249-
return (sub_ext.extent(sub_idx)==0) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
249+
return (sub_mds.extent(sub_idx)==0) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,1>(), std::index_sequence<SubIdx...,0>(), slices...);
250250
}
251-
template<class SrcExtents, class SubExtents, class ... SliceArgs>
251+
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
252252
MDSPAN_INLINE_FUNCTION
253-
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext, Kokkos::full_extent_t, SliceArgs ... slices) {
254-
return (sub_ext.extent(sub_idx)==src_ext.extent(src_idx)) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
253+
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>, Kokkos::full_extent_t, SliceArgs ... slices) {
254+
return (sub_mds.extent(sub_idx)==src_mds.extent(src_idx)) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,1>(), std::index_sequence<SubIdx...,1>(), slices...);
255255
}
256-
template<class SrcExtents, class SubExtents>
256+
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx>
257257
MDSPAN_INLINE_FUNCTION
258-
static bool match_expected_extents(int, int, SrcExtents, SubExtents) { return true; }
258+
static bool check_submdspan_match(int, int, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>) {
259+
#if MDSPAN_USE_BRACKET_OPERATOR
260+
if constexpr (SrcMDSpan::rank() == 0) {
261+
return (&src_mds[]==&sub_mds[]);
262+
} else if constexpr (SubMDSpan::rank() == 0) {
263+
return (&src_mds[SrcIdx...]==&sub_mds[]);
264+
} else {
265+
if(sub_mds.size() == 0) return true;
266+
return (&src_mds[SrcIdx...]==&sub_mds[SubIdx...]);
267+
}
268+
#else
269+
if constexpr (SrcMDSpan::rank() == 0) {
270+
return (&src_mds()==&sub_mds());
271+
} else if constexpr (SubMDSpan::rank() == 0) {
272+
return (&src_mds(SrcIdx...)==&sub_mds());
273+
} else {
274+
if(sub_mds.size() == 0) return true;
275+
return (&src_mds(SrcIdx...)==&sub_mds(SubIdx...));
276+
}
277+
#endif
278+
}
259279

260280
static void run() {
261281
typename mds_org_t::mapping_type map(typename mds_org_t::extents_type(ConstrArgs...));
@@ -265,7 +285,7 @@ struct TestSubMDSpan<
265285

266286
dispatch([=] _MDSPAN_HOST_DEVICE () {
267287
auto sub = Kokkos::submdspan(src, create_slice_arg(SubArgs())...);
268-
bool match = match_expected_extents(0, 0, src.extents(), sub.extents(), create_slice_arg(SubArgs())...);
288+
bool match = check_submdspan_match(0, 0, src, sub, std::index_sequence<>(), std::index_sequence<>(), create_slice_arg(SubArgs())...);
269289
result[0] = match?1:0;
270290
});
271291
EXPECT_EQ(result[0], 1);

0 commit comments

Comments
 (0)