@@ -224,38 +224,58 @@ struct TestSubMDSpan<
224
224
return Kokkos::full_extent;
225
225
}
226
226
227
- template <class SrcExtents , class SubExtents , class ... SliceArgs>
227
+ template <class SrcMDSpan , class SubMDSpan , size_t ... SrcIdx, size_t ... SubIdx , class ... SliceArgs>
228
228
MDSPAN_INLINE_FUNCTION
229
- static bool match_expected_extents (int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext , int , SliceArgs ... slices) {
230
- return match_expected_extents (++src_idx, sub_idx, src_ext, sub_ext , slices...);
229
+ static bool check_submdspan_match (int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...> , int , SliceArgs ... slices) {
230
+ return check_submdspan_match (++src_idx, sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx..., 2 >(), std::index_sequence<SubIdx...>() , slices...);
231
231
}
232
- template <class SrcExtents , class SubExtents , class ... SliceArgs>
232
+ template <class SrcMDSpan , class SubMDSpan , size_t ... SrcIdx, size_t ... SubIdx , class ... SliceArgs>
233
233
MDSPAN_INLINE_FUNCTION
234
- static bool match_expected_extents (int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext , std::pair<int ,int > p, SliceArgs ... slices) {
235
- using idx_t = typename SubExtents ::index_type;
236
- return (sub_ext .extent (sub_idx)==static_cast <idx_t >(p.second -p.first )) && match_expected_extents (++src_idx, ++sub_idx, src_ext, sub_ext , slices...);
234
+ static bool check_submdspan_match (int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...> , std::pair<int ,int > p, SliceArgs ... slices) {
235
+ using idx_t = typename SubMDSpan ::index_type;
236
+ return (sub_mds .extent (sub_idx)==static_cast <idx_t >(p.second -p.first )) && check_submdspan_match (++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx..., 2 >(), std::index_sequence<SubIdx..., 1 >() , slices...);
237
237
}
238
- template <class SrcExtents , class SubExtents , class ... SliceArgs>
238
+ template <class SrcMDSpan , class SubMDSpan , size_t ... SrcIdx, size_t ... SubIdx , class ... SliceArgs>
239
239
MDSPAN_INLINE_FUNCTION
240
- static bool match_expected_extents (int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext ,
240
+ static bool check_submdspan_match (int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...> ,
241
241
Kokkos::strided_slice<int ,int ,int > p, SliceArgs ... slices) {
242
- using idx_t = typename SubExtents ::index_type;
243
- return (sub_ext .extent (sub_idx)==static_cast <idx_t >((p.extent +p.stride -1 )/p.stride )) && match_expected_extents (++src_idx, ++sub_idx, src_ext, sub_ext , slices...);
242
+ using idx_t = typename SubMDSpan ::index_type;
243
+ return (sub_mds .extent (sub_idx)==static_cast <idx_t >((p.extent +p.stride -1 )/p.stride )) && check_submdspan_match (++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx..., 3 >(), std::index_sequence<SubIdx..., 1 >() , slices...);
244
244
}
245
- template <class SrcExtents , class SubExtents , class ... SliceArgs>
245
+ template <class SrcMDSpan , class SubMDSpan , size_t ... SrcIdx, size_t ... SubIdx , class ... SliceArgs>
246
246
MDSPAN_INLINE_FUNCTION
247
- static bool match_expected_extents (int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext ,
247
+ static bool check_submdspan_match (int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...> ,
248
248
Kokkos::strided_slice<int ,std::integral_constant<int , 0 >,std::integral_constant<int ,0 >>, SliceArgs ... slices) {
249
- return (sub_ext .extent (sub_idx)==0 ) && match_expected_extents (++src_idx, ++sub_idx, src_ext, sub_ext , slices...);
249
+ return (sub_mds .extent (sub_idx)==0 ) && check_submdspan_match (++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx..., 1 >(), std::index_sequence<SubIdx..., 0 >() , slices...);
250
250
}
251
- template <class SrcExtents , class SubExtents , class ... SliceArgs>
251
+ template <class SrcMDSpan , class SubMDSpan , size_t ... SrcIdx, size_t ... SubIdx , class ... SliceArgs>
252
252
MDSPAN_INLINE_FUNCTION
253
- static bool match_expected_extents (int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext , Kokkos::full_extent_t , SliceArgs ... slices) {
254
- return (sub_ext .extent (sub_idx)==src_ext .extent (src_idx)) && match_expected_extents (++src_idx, ++sub_idx, src_ext, sub_ext , slices...);
253
+ static bool check_submdspan_match (int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...> , Kokkos::full_extent_t , SliceArgs ... slices) {
254
+ return (sub_mds .extent (sub_idx)==src_mds .extent (src_idx)) && check_submdspan_match (++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx..., 1 >(), std::index_sequence<SubIdx..., 1 >() , slices...);
255
255
}
256
- template <class SrcExtents , class SubExtents >
256
+ template <class SrcMDSpan , class SubMDSpan , size_t ... SrcIdx, size_t ... SubIdx >
257
257
MDSPAN_INLINE_FUNCTION
258
- static bool match_expected_extents (int , int , SrcExtents, SubExtents) { return true ; }
258
+ static bool check_submdspan_match (int , int , SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>) {
259
+ #if MDSPAN_USE_BRACKET_OPERATOR
260
+ if constexpr (SrcMDSpan::rank () == 0 ) {
261
+ return (&src_mds[]==&sub_mds[]);
262
+ } else if constexpr (SubMDSpan::rank () == 0 ) {
263
+ return (&src_mds[SrcIdx...]==&sub_mds[]);
264
+ } else {
265
+ if (sub_mds.size () == 0 ) return true ;
266
+ return (&src_mds[SrcIdx...]==&sub_mds[SubIdx...]);
267
+ }
268
+ #else
269
+ if constexpr (SrcMDSpan::rank () == 0 ) {
270
+ return (&src_mds ()==&sub_mds ());
271
+ } else if constexpr (SubMDSpan::rank () == 0 ) {
272
+ return (&src_mds (SrcIdx...)==&sub_mds ());
273
+ } else {
274
+ if (sub_mds.size () == 0 ) return true ;
275
+ return (&src_mds (SrcIdx...)==&sub_mds (SubIdx...));
276
+ }
277
+ #endif
278
+ }
259
279
260
280
static void run () {
261
281
typename mds_org_t ::mapping_type map (typename mds_org_t::extents_type (ConstrArgs...));
@@ -265,7 +285,7 @@ struct TestSubMDSpan<
265
285
266
286
dispatch ([=] _MDSPAN_HOST_DEVICE () {
267
287
auto sub = Kokkos::submdspan (src, create_slice_arg (SubArgs ())...);
268
- bool match = match_expected_extents (0 , 0 , src. extents (), sub. extents (), create_slice_arg (SubArgs ())...);
288
+ bool match = check_submdspan_match (0 , 0 , src, sub, std::index_sequence<> (), std::index_sequence<> (), create_slice_arg (SubArgs ())...);
269
289
result[0 ] = match?1 :0 ;
270
290
});
271
291
EXPECT_EQ (result[0 ], 1 );
0 commit comments