Skip to content

Commit faa6105

Browse files
committed
Add pitched allocation example
This is the example I wrote for the user question here: kokkos#117 Fixes kokkos#248.
1 parent b31a635 commit faa6105

File tree

3 files changed

+237
-0
lines changed

3 files changed

+237
-0
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ add_subdirectory(dot_product)
1616
add_subdirectory(tiled_layout)
1717
add_subdirectory(restrict_accessor)
1818
add_subdirectory(aligned_accessor)
19+
add_subdirectory(pitched_allocation)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
mdspan_add_example(pitched_allocation)
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
#include <experimental/mdspan>
2+
#include <cassert>
3+
#include <cstring>
4+
#include <cstdint>
5+
#include <memory>
6+
7+
// This example shows how to deal with "pitched" allocations. These
8+
// are multidimensional array allocations where the size of each
9+
// element might not necessarily evenly divide the number of bytes per
10+
// "row" of the contiguous dimension. The commented-out example below
11+
// uses cudaMallocPitch to allocate a 4 x 5 two-dimensional array of
12+
// T, where sizeof(T) is 12. Each row (the contiguous dimension) has
13+
// 64 bytes. The last 4 bytes of each row are padding that do not
14+
// participate in an element.
15+
16+
// void* ptr = nullptr;
17+
// size_t pitch = 0;
18+
//
19+
// constexpr size_t num_cols = 5;
20+
// constexpr size_t num_rows = 4;
21+
//
22+
// cudaMallocPitch(&ptr, &pitch, sizeof(T) * num_cols, num_rows);
23+
// extents<int, num_rows, num_cols> exts{};
24+
// layout_stride::mapping mapping{exts, std::array{pitch, sizeof(T)}};
25+
// mdspan m{reinterpret_cast<char*>(ptr), mapping, aligned_byte_accessor<T>{}};
26+
27+
namespace stdex = std::experimental;
28+
29+
// This is the element type. "tbs" stands for Twelve-Byte Struct.
30+
// In this example, the struct includes a mixture of float and int,
31+
// just to make aliasing more interesting.
32+
struct tbs {
33+
float f0;
34+
std::int32_t i;
35+
float f1;
36+
};
37+
38+
// Use of the proxy reference types is only required
39+
// if access to each element is not aligned.
40+
// That should not be the case here.
41+
42+
class const_tbs_proxy;
43+
class nonconst_tbs_proxy;
44+
45+
template<class T>
46+
class const_proxy {
47+
private:
48+
friend class const_tbs_proxy;
49+
constexpr const_proxy(const char* p) : p_(p) {}
50+
51+
public:
52+
constexpr operator T () const {
53+
// We can't do the commented-out reinterpret_cast
54+
// in Standard C++, because p_ might not have correct
55+
// alignment to point to a T.
56+
//
57+
//return *reinterpret_cast<const T*>(p_);
58+
59+
T f;
60+
std::memcpy(&f, p_, sizeof(T));
61+
return f;
62+
}
63+
private:
64+
const char* p_ = nullptr;
65+
};
66+
67+
template<class T>
68+
class nonconst_proxy {
69+
private:
70+
friend class nonconst_tbs_proxy;
71+
constexpr nonconst_proxy(char* p) : p_(p) {}
72+
73+
public:
74+
constexpr nonconst_proxy& operator=(const T& f) {
75+
std::memcpy(p_, &f, sizeof(T));
76+
return *this;
77+
}
78+
constexpr operator T () const {
79+
T f;
80+
std::memcpy(&f, p_, sizeof(T));
81+
return f;
82+
}
83+
private:
84+
char* p_ = nullptr;
85+
};
86+
87+
class nonconst_tbs_proxy {
88+
private:
89+
char* p_ = nullptr;
90+
91+
public:
92+
constexpr nonconst_tbs_proxy(char* p)
93+
: p_(p), f0(p), i(p + sizeof(float)), f1(p + sizeof(float) + sizeof(int))
94+
{}
95+
96+
constexpr nonconst_tbs_proxy& operator=(const tbs& s) {
97+
this->f0 = s.f0;
98+
this->i = s.i;
99+
this->f1 = s.f1;
100+
return *this;
101+
}
102+
103+
constexpr operator tbs() const {
104+
return {float(f0), std::int32_t(i), float(f1)};
105+
};
106+
107+
nonconst_proxy<float> f0;
108+
nonconst_proxy<std::int32_t> i;
109+
nonconst_proxy<float> f1;
110+
};
111+
112+
// tbs is a struct, so users want to access its fields
113+
// with the usual dot notation. The two proxy reference types,
114+
// const_tbs_proxy and nonconst_tbs_proxy, preserve this behavior.
115+
116+
class const_tbs_proxy {
117+
private:
118+
const char* p_ = nullptr;
119+
120+
public:
121+
constexpr const_tbs_proxy(const char* p)
122+
: p_(p), f0(p), i(p + sizeof(float)), f1(p + sizeof(float) + sizeof(int))
123+
{}
124+
125+
constexpr operator tbs() const {
126+
return {float(f0), std::int32_t(i), float(f1)};
127+
};
128+
129+
const_proxy<float> f0;
130+
const_proxy<std::int32_t> i;
131+
const_proxy<float> f1;
132+
};
133+
134+
135+
struct const_tbs_accessor {
136+
using offset_policy = const_tbs_accessor;
137+
138+
using data_handle_type = const char*;
139+
using element_type = const tbs;
140+
// In the const reference case, we can use
141+
// either const_tbs_proxy or tbs (a value).
142+
//using reference = const_tbs_proxy;
143+
using reference = tbs;
144+
145+
constexpr const_tbs_accessor() noexcept = default;
146+
147+
constexpr reference
148+
access(data_handle_type p, size_t i) const noexcept {
149+
//return {p + i * sizeof(tbs)}; // for const_tbs_proxy
150+
tbs t;
151+
std::memcpy(&t, p + i * sizeof(tbs), sizeof(tbs));
152+
return t;
153+
}
154+
155+
constexpr typename offset_policy::data_handle_type
156+
offset(data_handle_type p, size_t i) const noexcept {
157+
return p + i * sizeof(tbs);
158+
}
159+
};
160+
161+
struct nonconst_tbs_accessor {
162+
using offset_policy = nonconst_tbs_accessor;
163+
164+
using data_handle_type = char*;
165+
using element_type = tbs;
166+
using reference = nonconst_tbs_proxy;
167+
168+
constexpr nonconst_tbs_accessor() noexcept = default;
169+
170+
constexpr reference
171+
access(data_handle_type p, size_t i) const noexcept {
172+
return {p + i * sizeof(tbs)};
173+
}
174+
175+
constexpr typename offset_policy::data_handle_type
176+
offset(data_handle_type p, size_t i) const noexcept {
177+
return p + i * sizeof(tbs);
178+
}
179+
};
180+
181+
int main() {
182+
constexpr std::size_t num_elements = 5;
183+
184+
std::array<char, num_elements * sizeof(tbs)> data;
185+
auto* ptr = reinterpret_cast<tbs*>(data.data());
186+
187+
std::uninitialized_fill_n(ptr, num_elements, tbs{1.0, 2, 3.0});
188+
189+
for(std::size_t k = 0; k < num_elements; ++k) {
190+
assert(ptr[k].f0 == 1.0);
191+
assert(ptr[k].i == 2);
192+
assert(ptr[k].f1 == 3.0);
193+
}
194+
195+
const tbs* ptr_c = ptr;
196+
stdex::mdspan<const tbs, stdex::extents<int, num_elements>,
197+
stdex::layout_right, const_tbs_accessor> m{data.data()};
198+
for(std::size_t k = 0; k < num_elements; ++k) {
199+
assert(m[k].f0 == 1.0f);
200+
assert(m[k].i == 2);
201+
assert(m[k].f1 == 3.0f);
202+
}
203+
204+
stdex::mdspan<tbs, stdex::extents<int, num_elements>,
205+
stdex::layout_right, nonconst_tbs_accessor> m_nc{data.data()};
206+
for(std::size_t k = 0; k < num_elements; ++k) {
207+
m_nc[k].f0 = 4.0f;
208+
m_nc[k].i = 5;
209+
m_nc[k].f1 = 6.0f;
210+
}
211+
212+
for(std::size_t k = 0; k < num_elements; ++k) {
213+
// Be careful returning a proxy reference from a function via auto.
214+
auto m_k = m[k];
215+
assert(m_k.f0 == 4.0f);
216+
assert(m_k.i == 5);
217+
assert(m_k.f1 == 6.0f);
218+
}
219+
220+
for(std::size_t k = 0; k < num_elements; ++k) {
221+
auto m_nc_k = m_nc[k];
222+
m_nc_k.f0 = 7.0f;
223+
m_nc_k.i = 8;
224+
m_nc_k.f1 = 9.0f;
225+
}
226+
227+
for(std::size_t k = 0; k < num_elements; ++k) {
228+
auto m_k = m[k];
229+
assert(m_k.f0 == 7.0f);
230+
assert(m_k.i == 8);
231+
assert(m_k.f1 == 9.0f);
232+
}
233+
234+
return 0;
235+
}

0 commit comments

Comments
 (0)