diff --git a/src/ekat/kokkos/ekat_subview_utils.hpp b/src/ekat/kokkos/ekat_subview_utils.hpp index 582949d6..16af0d49 100644 --- a/src/ekat/kokkos/ekat_subview_utils.hpp +++ b/src/ekat/kokkos/ekat_subview_utils.hpp @@ -244,7 +244,6 @@ subview_1(const ViewLR& v, // Since we are keeping the first dimension, the stride is unchanged. auto vm = tmp.impl_map(); vm.m_impl_offset.m_stride = v.impl_map().stride_0(); - auto test = Unmanaged>(v.impl_track(),vm); return Unmanaged>( v.impl_track(),vm); } @@ -267,7 +266,6 @@ subview_1(const ViewLR& v, // Since we are keeping the first dimension, the stride is unchanged. auto vm = tmp.impl_map(); vm.m_impl_offset.m_stride = v.impl_map().stride_0(); - auto test = Unmanaged>(v.impl_track(),vm); return Unmanaged>( v.impl_track(),vm); } @@ -316,6 +314,165 @@ subview_1(const ViewLR& v, v.impl_track(),vm); } +// ================ Multi-sliced Subviews ======================= // +// e.g., instead of a single-entry slice like v(:, 42, :), we slice over a range +// of values, as in v(:, 27:42, :) +// Note that this obtains entries for which in dimesion 2 is in the +// range [27, 42) == {v(i, j, k), where 27 <= j < 42} +// Note also that this slicing means that the subview has the same rank +// as the source view + +// --- Rank1 multi-slice --- // +template +KOKKOS_INLINE_FUNCTION +Unmanaged> +subview(const ViewLR& v, + const Kokkos::pair &kp0, + const int idim = 0) { + assert(v.data() != nullptr); + assert(idim == 0); + assert(kp0.first >= 0 && kp0.first < kp0.second); + return Unmanaged>(Kokkos::subview(v, kp0)); +} + +// --- Rank2 multi-slice --- // +template +KOKKOS_INLINE_FUNCTION +Unmanaged> +subview(const ViewLR& v, + const Kokkos::pair &kp0, + const int idim) { + assert(v.data() != nullptr); + assert(idim >= 0 && idim <= v.rank); + assert(kp0.first >= 0 && kp0.first < kp0.second + && kp0.second < v.extent_int(idim)); + if (idim == 0) { + return Unmanaged>(Kokkos::subview(v, kp0, Kokkos::ALL)); + } else { + assert(idim == 1); + return Unmanaged>(Kokkos::subview(v, Kokkos::ALL, kp0)); + } +} + +// --- Rank3 multi-slice --- // +template +KOKKOS_INLINE_FUNCTION +Unmanaged> +subview(const ViewLR& v, + const Kokkos::pair &kp0, + const int idim) { + assert(v.data() != nullptr); + assert(idim >= 0 && idim <= v.rank); + assert(kp0.first >= 0 && kp0.first < kp0.second + && kp0.second < v.extent_int(idim)); + if (idim == 0) { + return Unmanaged>( + Kokkos::subview(v, kp0, Kokkos::ALL, Kokkos::ALL)); + } else if (idim == 1) { + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, kp0, Kokkos::ALL)); + } else { + assert(idim == 2); + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, kp0)); + } +} + +// --- Rank4 multi-slice --- // +template +KOKKOS_INLINE_FUNCTION +Unmanaged> +subview(const ViewLR& v, + const Kokkos::pair &kp0, + const int idim) { + assert(v.data() != nullptr); + assert(idim >= 0 && idim <= v.rank); + assert(kp0.first >= 0 && kp0.first < kp0.second + && kp0.second < v.extent_int(idim)); + if (idim == 0) { + return Unmanaged>( + Kokkos::subview(v, kp0, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL)); + } else if (idim == 1) { + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, kp0, Kokkos::ALL, Kokkos::ALL)); + } else if (idim == 2) { + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, kp0, Kokkos::ALL)); + } else { + assert(idim == 3); + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL, kp0)); + } +} + +// --- Rank5 multi-slice --- // +template +KOKKOS_INLINE_FUNCTION +Unmanaged> +subview(const ViewLR& v, + const Kokkos::pair &kp0, + const int idim) { + assert(v.data() != nullptr); + assert(idim >= 0 && idim <= v.rank); + assert(kp0.first >= 0 && kp0.first < kp0.second + && kp0.second < v.extent_int(idim)); + if (idim == 0) { + return Unmanaged>( + Kokkos::subview(v, kp0, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL)); + } else if (idim == 1) { + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, kp0, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL)); + } else if (idim == 2) { + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, kp0, Kokkos::ALL, Kokkos::ALL)); + } else if (idim == 3) { + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL, kp0, Kokkos::ALL)); + } else { + assert(idim == 4); + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL, kp0)); + } +} + +// --- Rank6 multi-slice --- // +template +KOKKOS_INLINE_FUNCTION +Unmanaged> +subview(const ViewLR& v, + const Kokkos::pair &kp0, + const int idim) { + assert(v.data() != nullptr); + assert(idim >= 0 && idim <= v.rank); + assert(kp0.first >= 0 && kp0.first < kp0.second + && kp0.second < v.extent_int(idim)); + if (idim == 0) { + return Unmanaged>( + Kokkos::subview(v, kp0, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL, + Kokkos::ALL, Kokkos::ALL)); + } else if (idim == 1) { + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, kp0, Kokkos::ALL, Kokkos::ALL, + Kokkos::ALL, Kokkos::ALL)); + } else if (idim == 2) { + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, kp0, Kokkos::ALL, + Kokkos::ALL, Kokkos::ALL)); + } else if (idim == 3) { + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL, kp0, + Kokkos::ALL, Kokkos::ALL)); + } else if (idim == 4) { + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL, + kp0, Kokkos::ALL)); + } else { + assert(idim == 5); + return Unmanaged>( + Kokkos::subview(v, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL, + Kokkos::ALL, kp0)); + } +} } // namespace ekat #endif // EKAT_SUBVIEW_UTILS_HPP diff --git a/tests/kokkos/kokkos_utils_tests.cpp b/tests/kokkos/kokkos_utils_tests.cpp index 1378adbb..bcc4cc1d 100644 --- a/tests/kokkos/kokkos_utils_tests.cpp +++ b/tests/kokkos/kokkos_utils_tests.cpp @@ -544,4 +544,230 @@ TEST_CASE("subviews") { } } +TEST_CASE("multi-slice subviews") { + using kt = ekat::KokkosTypes; + + const int i0 = 5; + const int i1 = 4; + const int i2 = 3; + const int i3 = 2; + const int i4 = 1; + const int idx0[6] = {0, 3, 1, 0, 1, 0}; + const int idx1[6] = {3, 5, 4, 2, 2, 1}; + + auto p0 = Kokkos::make_pair(idx0[0], idx1[0]); + auto p1 = Kokkos::make_pair(idx0[1], idx1[1]); + auto p2 = Kokkos::make_pair(idx0[2], idx1[2]); + auto p3 = Kokkos::make_pair(idx0[3], idx1[3]); + auto p4 = Kokkos::make_pair(idx0[4], idx1[4]); + auto p5 = Kokkos::make_pair(idx0[5], idx1[5]); + + // Create input view + kt::view_ND v6("v6", 7, 6, 5, 4, 3, 2); + const int s = v6.size(); + Kokkos::parallel_for( + kt::RangePolicy(0, s), KOKKOS_LAMBDA(int i) { *(v6.data() + i) = i; }); + + auto v5 = ekat::subview(v6, i0); + auto v4 = ekat::subview(v6, i0, i1); + auto v3 = ekat::subview(v6, i0, i1, i2); + auto v2 = ekat::subview(v6, i0, i1, i2, i3); + auto v1 = ekat::subview(v6, i0, i1, i2, i3, i4); + + SECTION("subview_major") { + // Subviews of v6 + auto v6_5 = ekat::subview(v6, p5, 5); + auto v6_4 = ekat::subview(v6, p4, 4); + auto v6_3 = ekat::subview(v6, p3, 3); + auto v6_2 = ekat::subview(v6, p2, 2); + auto v6_1 = ekat::subview(v6, p1, 1); + auto v6_0 = ekat::subview(v6, p0, 0); + + // Subviews of v5 + auto v5_4 = ekat::subview(v5, p5, 4); + auto v5_3 = ekat::subview(v5, p4, 3); + auto v5_2 = ekat::subview(v5, p3, 2); + auto v5_1 = ekat::subview(v5, p2, 1); + auto v5_0 = ekat::subview(v5, p1, 0); + + // Subviews of v4 + auto v4_3 = ekat::subview(v4, p5, 3); + auto v4_2 = ekat::subview(v4, p4, 2); + auto v4_1 = ekat::subview(v4, p3, 1); + auto v4_0 = ekat::subview(v4, p2, 0); + + // Subviews of v3 + auto v3_2 = ekat::subview(v3, p5, 2); + auto v3_1 = ekat::subview(v3, p4, 1); + auto v3_0 = ekat::subview(v3, p3, 0); + + // Subviews of v2 + auto v2_1 = ekat::subview(v2, p5, 1); + auto v2_0 = ekat::subview(v2, p4, 0); + + // Subviews of v1 + auto v1_0 = ekat::subview(v1, p5); + + // Compare with original views and count diffs + Kokkos::View diffs(""); + Kokkos::deep_copy(diffs, 0); + Kokkos::parallel_for( + kt::RangePolicy(0, 1), KOKKOS_LAMBDA(int) { + int i1, i2, j1, j2, k1, k2, l1, l2, m1, m2, n1, n2; + + auto testv6(v6_0); + auto testv5(v5_0); + auto testv4(v4_0); + auto testv3(v3_0); + auto testv2(v2_0); + auto testv1(v1_0); + + int& ndiffs = diffs(); + for (int ens = 0; ens < 6; ens++) { + i1 = (ens == 0) ? idx0[0] : 0; + i2 = (ens == 0) ? idx1[0] : 7; + if (ens == 0) + testv6 = v6_0; + j1 = (ens == 1) ? idx0[1] : 0; + j2 = (ens == 1) ? idx1[1] : 6; + if (ens == 1) + testv6 = v6_1; + k1 = (ens == 2) ? idx0[2] : 0; + k2 = (ens == 2) ? idx1[2] : 5; + if (ens == 2) + testv6 = v6_2; + l1 = (ens == 3) ? idx0[3] : 0; + l2 = (ens == 3) ? idx1[3] : 4; + if (ens == 3) + testv6 = v6_3; + m1 = (ens == 4) ? idx0[4] : 0; + m2 = (ens == 4) ? idx1[4] : 3; + if (ens == 4) + testv6 = v6_4; + n1 = (ens == 5) ? idx0[5] : 0; + n2 = (ens == 5) ? idx1[5] : 2; + if (ens == 5) + testv6 = v6_5; + for (int n = n1; n < n2; n++) + for (int m = m1; m < m2; m++) + for (int l = l1; l < l2; l++) + for (int k = k1; k < k2; k++) + for (int j = j1; j < j2; j++) + for (int i = i1; i < i2; i++) { + if (v6(i, j, k, l, m, n) != testv6(i - i1, j - j1, + k - k1, l - l1, + m - m1, n - n1)) + ++ndiffs; + } + } + for (int ens = 1; ens < 6; ens++) { + j1 = (ens == 1) ? idx0[1] : 0; + j2 = (ens == 1) ? idx1[1] : 6; + if (ens == 1) + testv5 = v5_0; + k1 = (ens == 2) ? idx0[2] : 0; + k2 = (ens == 2) ? idx1[2] : 5; + if (ens == 2) + testv5 = v5_1; + l1 = (ens == 3) ? idx0[3] : 0; + l2 = (ens == 3) ? idx1[3] : 4; + if (ens == 3) + testv5 = v5_2; + m1 = (ens == 4) ? idx0[4] : 0; + m2 = (ens == 4) ? idx1[4] : 3; + if (ens == 4) + testv5 = v5_3; + n1 = (ens == 5) ? idx0[5] : 0; + n2 = (ens == 5) ? idx1[5] : 2; + if (ens == 5) + testv5 = v5_4; + for (int j = j1; j < j2; j++) + for (int k = k1; k < k2; k++) + for (int l = l1; l < l2; l++) + for (int m = m1; m < m2; m++) + for (int n = n1; n < n2; n++) { + if (v5(j, k, l, m, n) != + testv5(j - j1, k - k1, l - l1, m - m1, n - n1)) + ++ndiffs; + } + } + for (int ens = 2; ens < 6; ens++) { + k1 = (ens == 2) ? idx0[2] : 0; + k2 = (ens == 2) ? idx1[2] : 5; + if (ens == 2) + testv4 = v4_0; + l1 = (ens == 3) ? idx0[3] : 0; + l2 = (ens == 3) ? idx1[3] : 4; + if (ens == 3) + testv4 = v4_1; + m1 = (ens == 4) ? idx0[4] : 0; + m2 = (ens == 4) ? idx1[4] : 3; + if (ens == 4) + testv4 = v4_2; + n1 = (ens == 5) ? idx0[5] : 0; + n2 = (ens == 5) ? idx1[5] : 2; + if (ens == 5) + testv4 = v4_3; + for (int k = k1; k < k2; k++) + for (int l = l1; l < l2; l++) + for (int m = m1; m < m2; m++) + for (int n = n1; n < n2; n++) { + if (v4(k, l, m, n) != + testv4(k - k1, l - l1, m - m1, n - n1)) + ++ndiffs; + } + } + for (int ens = 3; ens < 6; ens++) { + l1 = (ens == 3) ? idx0[3] : 0; + l2 = (ens == 3) ? idx1[3] : 4; + if (ens == 3) + testv3 = v3_0; + m1 = (ens == 4) ? idx0[4] : 0; + m2 = (ens == 4) ? idx1[4] : 3; + if (ens == 4) + testv3 = v3_1; + n1 = (ens == 5) ? idx0[5] : 0; + n2 = (ens == 5) ? idx1[5] : 2; + if (ens == 5) + testv3 = v3_2; + for (int l = l1; l < l2; l++) + for (int m = m1; m < m2; m++) + for (int n = n1; n < n2; n++) { + if (v3(l, m, n) != testv3(l - l1, m - m1, n - n1)) + ++ndiffs; + } + } + for (int ens = 4; ens < 6; ens++) { + m1 = (ens == 4) ? idx0[4] : 0; + m2 = (ens == 4) ? idx1[4] : 3; + if (ens == 4) + testv2 = v2_0; + n1 = (ens == 5) ? idx0[5] : 0; + n2 = (ens == 5) ? idx1[5] : 2; + if (ens == 5) + testv2 = v2_1; + for (int m = m1; m < m2; m++) + for (int n = n1; n < n2; n++) { + if (v2(m, n) != testv2(m - m1, n - n1)) + ++ndiffs; + } + } + n1 = idx0[5]; + n2 = idx1[5]; + testv1 = v1_0; + for (int n = n1; n < n2; n++) { + if (v1(n) != testv1(n - n1)) + ++ndiffs; + } + // Make sure that our diffs counting strategy works + // by checking that two entries that should be different + // are indeed different. + if (v1_0(0) != v6(i0, i1, i2, i3, i4, 1)) + ++ndiffs; + }); + auto diffs_h = Kokkos::create_mirror_view(diffs); + Kokkos::deep_copy(diffs_h, diffs); + REQUIRE(diffs_h() == 1); + } +} } // anonymous namespace