Skip to content

Commit e9ec98c

Browse files
committed
Moved RowSum operator() and dependencies back to jaccard.cpp
1 parent 3d911f6 commit e9ec98c

File tree

4 files changed

+165
-158
lines changed

4 files changed

+165
-158
lines changed

jaccard.cpp

+117
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,95 @@
4747
#define EMULATE_ATOMIC_ADD_DOUBLE
4848
#endif
4949

50+
// From utilties/graph_utils.cuh
51+
// FIXME Revisit the barriers and fences and local storage with subgroups
52+
// FIXME revisit with SYCL group algorithms
53+
template <typename count_t, typename index_t, typename value_t>
54+
__inline__ value_t
55+
parallel_prefix_sum(cl::sycl::nd_item<2> const &tid_info, count_t n,
56+
cl::sycl::accessor<index_t, 1, cl::sycl::access::mode::read> ind,
57+
count_t ind_off, cl::sycl::accessor<value_t, 1, cl::sycl::access::mode::read> w,
58+
cl::sycl::accessor<value_t, 1, cl::sycl::access::mode::read_write,
59+
cl::sycl::access::target::local>
60+
shfl_temp) {
61+
count_t i, j, mn;
62+
value_t v, last;
63+
value_t sum = 0.0;
64+
bool valid;
65+
66+
// Parallel prefix sum (using __shfl)
67+
mn = (((n + tid_info.get_local_range(1) - 1) / tid_info.get_local_range(1)) *
68+
tid_info.get_local_range(1)); // n in multiple of blockDim.x
69+
for (i = tid_info.get_local_id(1); i < mn; i += tid_info.get_local_range(1)) {
70+
// All threads (especially the last one) must always participate
71+
// in the shfl instruction, otherwise their sum will be undefined.
72+
// So, the loop stopping condition is based on multiple of n in loop increments,
73+
// so that all threads enter into the loop and inside we make sure we do not
74+
// read out of bounds memory checking for the actual size n.
75+
76+
// check if the thread is valid
77+
valid = i < n;
78+
79+
// Notice that the last thread is used to propagate the prefix sum.
80+
// For all the threads, in the first iteration the last is 0, in the following
81+
// iterations it is the value at the last thread of the previous iterations.
82+
83+
// get the value of the last thread
84+
// FIXME: __shfl_sync
85+
// FIXME make sure everybody is here
86+
group_barrier(tid_info.get_group());
87+
// write your current sum
88+
// This is a 2D block, use a linear ID
89+
shfl_temp[tid_info.get_local_linear_id()] = sum;
90+
// FIXME make sure everybody has read from the top thread in the same Y-dimensional subgroup
91+
group_barrier(tid_info.get_group());
92+
last = shfl_temp[tid_info.get_local_range(1) - 1 +
93+
(tid_info.get_local_range(1) * tid_info.get_local_id(0))];
94+
// Move forward
95+
// last = __shfl_sync(warp_full_mask(), sum, blockDim.x - 1, blockDim.x);
96+
97+
// if you are valid read the value from memory, otherwise set your value to 0
98+
sum = (valid) ? w[ind[ind_off + i]] : 0.0;
99+
100+
// do prefix sum (of size warpSize=blockDim.x =< 32)
101+
for (j = 1; j < tid_info.get_local_range(1); j *= 2) {
102+
// FIXME: __shfl_up_warp
103+
// FIXME make sure everybody is here
104+
// Write your current sum
105+
group_barrier(tid_info.get_group());
106+
shfl_temp[tid_info.get_local_linear_id()] = sum;
107+
// FIXME Force writes to finish
108+
// read from tid-j
109+
// Using the x-dimension local id for the conditional protects from overflows to other
110+
// Y-subgroups Using the local_linear_id for the read saves us having to offset by x_range *
111+
// y_id
112+
group_barrier(tid_info.get_group());
113+
if (tid_info.get_local_id(1) >= j) v = shfl_temp[tid_info.get_local_linear_id() - j];
114+
// FIXME Force reads to finish
115+
// v = __shfl_up_sync(warp_full_mask(), sum, j, blockDim.x);
116+
if (tid_info.get_local_id(1) >= j) sum += v;
117+
}
118+
// shift by last
119+
sum += last;
120+
// notice that no __threadfence or __syncthreads are needed in this implementation
121+
}
122+
// get the value of the last thread (to all threads)
123+
// FIXME: __shfl_sync
124+
// FIXME make sure everybody is here
125+
// write your current sum
126+
// This is a 2D block, use a linear ID
127+
group_barrier(tid_info.get_group());
128+
shfl_temp[tid_info.get_local_linear_id()] = sum;
129+
// FIXME make sure everybody has read from the top thread in the same Y-dimensional group
130+
group_barrier(tid_info.get_group());
131+
last = shfl_temp[tid_info.get_local_range(1) - 1 +
132+
(tid_info.get_local_range(1) * tid_info.get_local_id(0))];
133+
// Move forward
134+
// last = __shfl_sync(warp_full_mask(), sum, blockDim.x - 1, blockDim.x);
135+
136+
return last;
137+
}
138+
50139
// From RAFT at commit 048063dc08
51140
constexpr inline int warp_size() {
52141
return 32;
@@ -142,6 +231,34 @@ double myAtomicAdd(cl::sycl::atomic<uint64_t> &address, double val) {
142231

143232
namespace sygraph {
144233
namespace detail {
234+
// Volume of neighboors (*weight_s)
235+
template <bool weighted, typename vertex_t, typename edge_t, typename weight_t>
236+
// Must be marked external since main.cpp uses it
237+
extern SYCL_EXTERNAL const void
238+
Jaccard_RowSumKernel<weighted, vertex_t, edge_t, weight_t>::operator()(
239+
cl::sycl::nd_item<2> tid_info) const {
240+
vertex_t row;
241+
edge_t start, end, length;
242+
weight_t sum;
243+
244+
vertex_t row_start = tid_info.get_global_id(0);
245+
vertex_t row_incr = tid_info.get_global_range(0);
246+
for (row = row_start; row < n; row += row_incr) {
247+
start = csrPtr[row];
248+
end = csrPtr[row + 1];
249+
length = end - start;
250+
251+
// compute row sums
252+
// Must be if constexpr so it doesn't try to evaluate v when it's a nullptr_t
253+
if constexpr (weighted) {
254+
sum = parallel_prefix_sum(tid_info, length, csrInd, start, v, shfl_temp);
255+
if (tid_info.get_local_id(1) == 0) work[row] = sum;
256+
} else {
257+
work[row] = static_cast<weight_t>(length);
258+
}
259+
}
260+
}
261+
145262
// Volume of intersections (*weight_i) and cumulated volume of neighboors (*weight_s)
146263
template <bool weighted, typename vertex_t, typename edge_t, typename weight_t>
147264
const void Jaccard_IsKernel<weighted, vertex_t, edge_t, weight_t>::operator()(

jaccard.hpp

+47
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,53 @@ class FillKernel {
5252

5353
namespace sygraph {
5454
namespace detail {
55+
template <bool weighted, typename vertex_t, typename edge_t, typename weight_t>
56+
class Jaccard_RowSumKernel {
57+
vertex_t n;
58+
cl::sycl::accessor<edge_t, 1, cl::sycl::access::mode::read> csrPtr;
59+
cl::sycl::accessor<vertex_t, 1, cl::sycl::access::mode::read> csrInd;
60+
// FIXME, with std::conditional_t we should be able to simplify out some of the code paths in the
61+
// other weight-branching kernels
62+
#ifdef NEEDS_NULL_DEVICE_PTR
63+
std::conditional_t<weighted, cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read>,
64+
cl::sycl::device_ptr<std::nullptr_t>>
65+
v;
66+
#else
67+
std::conditional_t<weighted, cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read>,
68+
std::nullptr_t>
69+
v;
70+
#endif
71+
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::discard_write> work;
72+
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read_write,
73+
cl::sycl::access::target::local>
74+
shfl_temp;
75+
76+
public:
77+
Jaccard_RowSumKernel<true>(
78+
vertex_t n, cl::sycl::accessor<edge_t, 1, cl::sycl::access::mode::read> csrPtr,
79+
cl::sycl::accessor<vertex_t, 1, cl::sycl::access::mode::read> csrInd,
80+
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read> v,
81+
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::discard_write> work,
82+
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read_write,
83+
cl::sycl::access::target::local>
84+
shfl_temp)
85+
: n{n}, csrInd{csrInd}, csrPtr{csrPtr}, v{v}, work{work}, shfl_temp{shfl_temp} {
86+
}
87+
// When not using weights, we don't care about v
88+
Jaccard_RowSumKernel<false>(
89+
vertex_t n, cl::sycl::accessor<edge_t, 1, cl::sycl::access::mode::read> csrPtr,
90+
cl::sycl::accessor<vertex_t, 1, cl::sycl::access::mode::read> csrInd,
91+
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::discard_write> work,
92+
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read_write,
93+
cl::sycl::access::target::local>
94+
shfl_temp)
95+
: n{n}, csrInd{csrInd}, csrPtr{csrPtr}, work{work}, shfl_temp{shfl_temp} {
96+
}
97+
// Volume of neighboors (*weight_s)
98+
// Must be marked external since main.cpp uses it
99+
SYCL_EXTERNAL const void operator()(cl::sycl::nd_item<2> tid_info) const;
100+
};
101+
55102
// Volume of intersections (*weight_i) and cumulated volume of neighboors (*weight_s)
56103
template <bool weighted, typename vertex_t, typename edge_t, typename weight_t>
57104
class Jaccard_IsKernel {

main.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616

1717
#include "filetypes.hpp"
18+
#include "jaccard.hpp"
1819
#include "readMtxToCSR.hpp" //implicitly includes standalone_csr.hpp
1920
#include "standalone_algorithms.hpp"
2021
#include "standalone_csr.hpp"

standalone_algorithms.hpp

-158
Original file line numberDiff line numberDiff line change
@@ -46,94 +46,6 @@
4646
}
4747
#endif // EVENT_PROFILE
4848

49-
// From utilties/graph_utils.cuh
50-
// FIXME Revisit the barriers and fences and local storage with subgroups
51-
// FIXME revisit with SYCL group algorithms
52-
template <typename count_t, typename index_t, typename value_t>
53-
__inline__ value_t
54-
parallel_prefix_sum(cl::sycl::nd_item<2> const &tid_info, count_t n,
55-
cl::sycl::accessor<index_t, 1, cl::sycl::access::mode::read> ind,
56-
count_t ind_off, cl::sycl::accessor<value_t, 1, cl::sycl::access::mode::read> w,
57-
cl::sycl::accessor<value_t, 1, cl::sycl::access::mode::read_write,
58-
cl::sycl::access::target::local>
59-
shfl_temp) {
60-
count_t i, j, mn;
61-
value_t v, last;
62-
value_t sum = 0.0;
63-
bool valid;
64-
65-
// Parallel prefix sum (using __shfl)
66-
mn = (((n + tid_info.get_local_range(1) - 1) / tid_info.get_local_range(1)) *
67-
tid_info.get_local_range(1)); // n in multiple of blockDim.x
68-
for (i = tid_info.get_local_id(1); i < mn; i += tid_info.get_local_range(1)) {
69-
// All threads (especially the last one) must always participate
70-
// in the shfl instruction, otherwise their sum will be undefined.
71-
// So, the loop stopping condition is based on multiple of n in loop increments,
72-
// so that all threads enter into the loop and inside we make sure we do not
73-
// read out of bounds memory checking for the actual size n.
74-
75-
// check if the thread is valid
76-
valid = i < n;
77-
78-
// Notice that the last thread is used to propagate the prefix sum.
79-
// For all the threads, in the first iteration the last is 0, in the following
80-
// iterations it is the value at the last thread of the previous iterations.
81-
82-
// get the value of the last thread
83-
// FIXME: __shfl_sync
84-
// FIXME make sure everybody is here
85-
group_barrier(tid_info.get_group());
86-
// write your current sum
87-
// This is a 2D block, use a linear ID
88-
shfl_temp[tid_info.get_local_linear_id()] = sum;
89-
// FIXME make sure everybody has read from the top thread in the same Y-dimensional subgroup
90-
group_barrier(tid_info.get_group());
91-
last = shfl_temp[tid_info.get_local_range(1) - 1 +
92-
(tid_info.get_local_range(1) * tid_info.get_local_id(0))];
93-
// Move forward
94-
// last = __shfl_sync(warp_full_mask(), sum, blockDim.x - 1, blockDim.x);
95-
96-
// if you are valid read the value from memory, otherwise set your value to 0
97-
sum = (valid) ? w[ind[ind_off + i]] : 0.0;
98-
99-
// do prefix sum (of size warpSize=blockDim.x =< 32)
100-
for (j = 1; j < tid_info.get_local_range(1); j *= 2) {
101-
// FIXME: __shfl_up_warp
102-
// FIXME make sure everybody is here
103-
// Write your current sum
104-
group_barrier(tid_info.get_group());
105-
shfl_temp[tid_info.get_local_linear_id()] = sum;
106-
// FIXME Force writes to finish
107-
// read from tid-j
108-
// Using the x-dimension local id for the conditional protects from overflows to other
109-
// Y-subgroups Using the local_linear_id for the read saves us having to offset by x_range *
110-
// y_id
111-
group_barrier(tid_info.get_group());
112-
if (tid_info.get_local_id(1) >= j) v = shfl_temp[tid_info.get_local_linear_id() - j];
113-
// FIXME Force reads to finish
114-
// v = __shfl_up_sync(warp_full_mask(), sum, j, blockDim.x);
115-
if (tid_info.get_local_id(1) >= j) sum += v;
116-
}
117-
// shift by last
118-
sum += last;
119-
// notice that no __threadfence or __syncthreads are needed in this implementation
120-
}
121-
// get the value of the last thread (to all threads)
122-
// FIXME: __shfl_sync
123-
// FIXME make sure everybody is here
124-
// write your current sum
125-
// This is a 2D block, use a linear ID
126-
group_barrier(tid_info.get_group());
127-
shfl_temp[tid_info.get_local_linear_id()] = sum;
128-
// FIXME make sure everybody has read from the top thread in the same Y-dimensional group
129-
group_barrier(tid_info.get_group());
130-
last = shfl_temp[tid_info.get_local_range(1) - 1 +
131-
(tid_info.get_local_range(1) * tid_info.get_local_id(0))];
132-
// Move forward
133-
// last = __shfl_sync(warp_full_mask(), sum, blockDim.x - 1, blockDim.x);
134-
135-
return last;
136-
}
13749
namespace sygraph {
13850

13951
/**
@@ -220,76 +132,6 @@ template <typename VT, typename ET, typename WT>
220132
void jaccard_list(GraphCSRView<VT, ET, WT> &graph, ET num_pairs, cl::sycl::buffer<VT> &first,
221133
cl::sycl::buffer<VT> &second, cl::sycl::buffer<WT> &result, cl::sycl::queue &q);
222134

223-
namespace detail {
224-
template <bool weighted, typename vertex_t, typename edge_t, typename weight_t>
225-
class Jaccard_RowSumKernel {
226-
vertex_t n;
227-
cl::sycl::accessor<edge_t, 1, cl::sycl::access::mode::read> csrPtr;
228-
cl::sycl::accessor<vertex_t, 1, cl::sycl::access::mode::read> csrInd;
229-
// FIXME, with std::conditional_t we should be able to simplify out some of the code paths in the
230-
// other weight-branching kernels
231-
#ifdef NEEDS_NULL_DEVICE_PTR
232-
std::conditional_t<weighted, cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read>,
233-
cl::sycl::device_ptr<std::nullptr_t>>
234-
v;
235-
#else
236-
std::conditional_t<weighted, cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read>,
237-
std::nullptr_t>
238-
v;
239-
#endif
240-
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::discard_write> work;
241-
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read_write,
242-
cl::sycl::access::target::local>
243-
shfl_temp;
244-
245-
public:
246-
Jaccard_RowSumKernel<true>(
247-
vertex_t n, cl::sycl::accessor<edge_t, 1, cl::sycl::access::mode::read> csrPtr,
248-
cl::sycl::accessor<vertex_t, 1, cl::sycl::access::mode::read> csrInd,
249-
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read> v,
250-
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::discard_write> work,
251-
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read_write,
252-
cl::sycl::access::target::local>
253-
shfl_temp)
254-
: n{n}, csrInd{csrInd}, csrPtr{csrPtr}, v{v}, work{work}, shfl_temp{shfl_temp} {
255-
}
256-
// When not using weights, we don't care about v
257-
Jaccard_RowSumKernel<false>(
258-
vertex_t n, cl::sycl::accessor<edge_t, 1, cl::sycl::access::mode::read> csrPtr,
259-
cl::sycl::accessor<vertex_t, 1, cl::sycl::access::mode::read> csrInd,
260-
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::discard_write> work,
261-
cl::sycl::accessor<weight_t, 1, cl::sycl::access::mode::read_write,
262-
cl::sycl::access::target::local>
263-
shfl_temp)
264-
: n{n}, csrInd{csrInd}, csrPtr{csrPtr}, work{work}, shfl_temp{shfl_temp} {
265-
}
266-
// Volume of neighboors (*weight_s)
267-
const void
268-
operator()(cl::sycl::nd_item<2> tid_info) const {
269-
vertex_t row;
270-
edge_t start, end, length;
271-
weight_t sum;
272-
273-
vertex_t row_start = tid_info.get_global_id(0);
274-
vertex_t row_incr = tid_info.get_global_range(0);
275-
for (row = row_start; row < n; row += row_incr) {
276-
start = csrPtr[row];
277-
end = csrPtr[row + 1];
278-
length = end - start;
279-
280-
// compute row sums
281-
// Must be if constexpr so it doesn't try to evaluate v when it's a nullptr_t
282-
if constexpr (weighted) {
283-
sum = parallel_prefix_sum(tid_info, length, csrInd, start, v, shfl_temp);
284-
if (tid_info.get_local_id(1) == 0) work[row] = sum;
285-
} else {
286-
work[row] = static_cast<weight_t>(length);
287-
}
288-
}
289-
}
290-
};
291-
292-
} // namespace detail
293135
} // namespace sygraph
294136

295137
#endif

0 commit comments

Comments
 (0)