|
46 | 46 | }
|
47 | 47 | #endif // EVENT_PROFILE
|
48 | 48 |
|
49 |
| -// From utilties/graph_utils.cuh |
50 |
| -// FIXME Revisit the barriers and fences and local storage with subgroups |
51 |
| -// FIXME revisit with SYCL group algorithms |
52 |
| -template <typename count_t, typename index_t, typename value_t> |
53 |
| -__inline__ value_t |
54 |
| -parallel_prefix_sum(cl::sycl::nd_item<2> const &tid_info, count_t n, |
55 |
| - cl::sycl::accessor<index_t, 1, cl::sycl::access::mode::read> ind, |
56 |
| - count_t ind_off, cl::sycl::accessor<value_t, 1, cl::sycl::access::mode::read> w, |
57 |
| - cl::sycl::accessor<value_t, 1, cl::sycl::access::mode::read_write, |
58 |
| - cl::sycl::access::target::local> |
59 |
| - shfl_temp) { |
60 |
| - count_t i, j, mn; |
61 |
| - value_t v, last; |
62 |
| - value_t sum = 0.0; |
63 |
| - bool valid; |
64 |
| - |
65 |
| - // Parallel prefix sum (using __shfl) |
66 |
| - mn = (((n + tid_info.get_local_range(1) - 1) / tid_info.get_local_range(1)) * |
67 |
| - tid_info.get_local_range(1)); // n in multiple of blockDim.x |
68 |
| - for (i = tid_info.get_local_id(1); i < mn; i += tid_info.get_local_range(1)) { |
69 |
| - // All threads (especially the last one) must always participate |
70 |
| - // in the shfl instruction, otherwise their sum will be undefined. |
71 |
| - // So, the loop stopping condition is based on multiple of n in loop increments, |
72 |
| - // so that all threads enter into the loop and inside we make sure we do not |
73 |
| - // read out of bounds memory checking for the actual size n. |
74 |
| - |
75 |
| - // check if the thread is valid |
76 |
| - valid = i < n; |
77 |
| - |
78 |
| - // Notice that the last thread is used to propagate the prefix sum. |
79 |
| - // For all the threads, in the first iteration the last is 0, in the following |
80 |
| - // iterations it is the value at the last thread of the previous iterations. |
81 |
| - |
82 |
| - // get the value of the last thread |
83 |
| - // FIXME: __shfl_sync |
84 |
| - // FIXME make sure everybody is here |
85 |
| - group_barrier(tid_info.get_group()); |
86 |
| - // write your current sum |
87 |
| - // This is a 2D block, use a linear ID |
88 |
| - shfl_temp[tid_info.get_local_linear_id()] = sum; |
89 |
| - // FIXME make sure everybody has read from the top thread in the same Y-dimensional subgroup |
90 |
| - group_barrier(tid_info.get_group()); |
91 |
| - last = shfl_temp[tid_info.get_local_range(1) - 1 + |
92 |
| - (tid_info.get_local_range(1) * tid_info.get_local_id(0))]; |
93 |
| - // Move forward |
94 |
| - // last = __shfl_sync(warp_full_mask(), sum, blockDim.x - 1, blockDim.x); |
95 |
| - |
96 |
| - // if you are valid read the value from memory, otherwise set your value to 0 |
97 |
| - sum = (valid) ? w[ind[ind_off + i]] : 0.0; |
98 |
| - |
99 |
| - // do prefix sum (of size warpSize=blockDim.x =< 32) |
100 |
| - for (j = 1; j < tid_info.get_local_range(1); j *= 2) { |
101 |
| - // FIXME: __shfl_up_warp |
102 |
| - // FIXME make sure everybody is here |
103 |
| - // Write your current sum |
104 |
| - group_barrier(tid_info.get_group()); |
105 |
| - shfl_temp[tid_info.get_local_linear_id()] = sum; |
106 |
| - // FIXME Force writes to finish |
107 |
| - // read from tid-j |
108 |
| - // Using the x-dimension local id for the conditional protects from overflows to other |
109 |
| - // Y-subgroups Using the local_linear_id for the read saves us having to offset by x_range * |
110 |
| - // y_id |
111 |
| - group_barrier(tid_info.get_group()); |
112 |
| - if (tid_info.get_local_id(1) >= j) v = shfl_temp[tid_info.get_local_linear_id() - j]; |
113 |
| - // FIXME Force reads to finish |
114 |
| - // v = __shfl_up_sync(warp_full_mask(), sum, j, blockDim.x); |
115 |
| - if (tid_info.get_local_id(1) >= j) sum += v; |
116 |
| - } |
117 |
| - // shift by last |
118 |
| - sum += last; |
119 |
| - // notice that no __threadfence or __syncthreads are needed in this implementation |
120 |
| - } |
121 |
| - // get the value of the last thread (to all threads) |
122 |
| - // FIXME: __shfl_sync |
123 |
| - // FIXME make sure everybody is here |
124 |
| - // write your current sum |
125 |
| - // This is a 2D block, use a linear ID |
126 |
| - group_barrier(tid_info.get_group()); |
127 |
| - shfl_temp[tid_info.get_local_linear_id()] = sum; |
128 |
| - // FIXME make sure everybody has read from the top thread in the same Y-dimensional group |
129 |
| - group_barrier(tid_info.get_group()); |
130 |
| - last = shfl_temp[tid_info.get_local_range(1) - 1 + |
131 |
| - (tid_info.get_local_range(1) * tid_info.get_local_id(0))]; |
132 |
| - // Move forward |
133 |
| - // last = __shfl_sync(warp_full_mask(), sum, blockDim.x - 1, blockDim.x); |
134 |
| - |
135 |
| - return last; |
136 |
| -} |
137 | 49 | namespace sygraph {
|
138 | 50 |
|
139 | 51 | /**
|
@@ -220,76 +132,6 @@ template <typename VT, typename ET, typename WT>
|
220 | 132 | void jaccard_list(GraphCSRView<VT, ET, WT> &graph, ET num_pairs, cl::sycl::buffer<VT> &first,
|
221 | 133 | cl::sycl::buffer<VT> &second, cl::sycl::buffer<WT> &result, cl::sycl::queue &q);
|
222 | 134 |
|
223 |
| -namespace detail { |
224 |
| -template <bool weighted, typename vertex_t, typename edge_t, typename weight_t> |
225 |
| -class Jaccard_RowSumKernel { |
226 |
| - vertex_t n; |
227 |
| - cl::sycl::accessor<edge_t, 1, cl::sycl::access::mode::read> csrPtr; |
228 |
| - cl::sycl::accessor<vertex_t, 1, cl::sycl::access::mode::read> csrInd; |
229 |
| - // FIXME, with std::conditional_t we should be able to simplify out some of the code paths in the |
230 |
| - // other weight-branching kernels |
231 |
| - #ifdef NEEDS_NULL_DEVICE_PTR |
232 |
| - std::conditional_t<weighted, cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read>, |
233 |
| - cl::sycl::device_ptr<std::nullptr_t>> |
234 |
| - v; |
235 |
| - #else |
236 |
| - std::conditional_t<weighted, cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read>, |
237 |
| - std::nullptr_t> |
238 |
| - v; |
239 |
| - #endif |
240 |
| - cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::discard_write> work; |
241 |
| - cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read_write, |
242 |
| - cl::sycl::access::target::local> |
243 |
| - shfl_temp; |
244 |
| - |
245 |
| -public: |
246 |
| - Jaccard_RowSumKernel<true>( |
247 |
| - vertex_t n, cl::sycl::accessor<edge_t, 1, cl::sycl::access::mode::read> csrPtr, |
248 |
| - cl::sycl::accessor<vertex_t, 1, cl::sycl::access::mode::read> csrInd, |
249 |
| - cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read> v, |
250 |
| - cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::discard_write> work, |
251 |
| - cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read_write, |
252 |
| - cl::sycl::access::target::local> |
253 |
| - shfl_temp) |
254 |
| - : n{n}, csrInd{csrInd}, csrPtr{csrPtr}, v{v}, work{work}, shfl_temp{shfl_temp} { |
255 |
| - } |
256 |
| - // When not using weights, we don't care about v |
257 |
| - Jaccard_RowSumKernel<false>( |
258 |
| - vertex_t n, cl::sycl::accessor<edge_t, 1, cl::sycl::access::mode::read> csrPtr, |
259 |
| - cl::sycl::accessor<vertex_t, 1, cl::sycl::access::mode::read> csrInd, |
260 |
| - cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::discard_write> work, |
261 |
| - cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read_write, |
262 |
| - cl::sycl::access::target::local> |
263 |
| - shfl_temp) |
264 |
| - : n{n}, csrInd{csrInd}, csrPtr{csrPtr}, work{work}, shfl_temp{shfl_temp} { |
265 |
| - } |
266 |
| - // Volume of neighboors (*weight_s) |
267 |
| - const void |
268 |
| - operator()(cl::sycl::nd_item<2> tid_info) const { |
269 |
| - vertex_t row; |
270 |
| - edge_t start, end, length; |
271 |
| - weight_t sum; |
272 |
| - |
273 |
| - vertex_t row_start = tid_info.get_global_id(0); |
274 |
| - vertex_t row_incr = tid_info.get_global_range(0); |
275 |
| - for (row = row_start; row < n; row += row_incr) { |
276 |
| - start = csrPtr[row]; |
277 |
| - end = csrPtr[row + 1]; |
278 |
| - length = end - start; |
279 |
| - |
280 |
| - // compute row sums |
281 |
| - // Must be if constexpr so it doesn't try to evaluate v when it's a nullptr_t |
282 |
| - if constexpr (weighted) { |
283 |
| - sum = parallel_prefix_sum(tid_info, length, csrInd, start, v, shfl_temp); |
284 |
| - if (tid_info.get_local_id(1) == 0) work[row] = sum; |
285 |
| - } else { |
286 |
| - work[row] = static_cast<weight_t>(length); |
287 |
| - } |
288 |
| - } |
289 |
| - } |
290 |
| -}; |
291 |
| - |
292 |
| -} // namespace detail |
293 | 135 | } // namespace sygraph
|
294 | 136 |
|
295 | 137 | #endif
|
0 commit comments