Skip to content

Commit

Permalink
feat: add reduce sum, min and max kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
ManasviGoyal committed Jun 6, 2024
1 parent 02c03bc commit 34fc82b
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 27 deletions.
6 changes: 3 additions & 3 deletions kernel-test-data.json
Original file line number Diff line number Diff line change
Expand Up @@ -23510,7 +23510,7 @@
},
{
"name": "awkward_reduce_max",
"status": false,
"status": true,
"tests": [
{
"error": false,
Expand Down Expand Up @@ -24173,7 +24173,7 @@
},
{
"name": "awkward_reduce_sum",
"status": false,
"status": true,
"tests": [
{
"error": false,
Expand Down Expand Up @@ -25342,7 +25342,7 @@
},
{
"name": "awkward_reduce_min",
"status": false,
"status": true,
"tests": [
{
"error": false,
Expand Down
71 changes: 65 additions & 6 deletions src/awkward/_connect/cuda/cuda_kernels/awkward_reduce_max.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@
// BEGIN PYTHON
// def f(grid, block, args):
// (toptr, fromptr, parents, lenparents, outlength, identity, invocation_index, err_code) = args
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_max_a", toptr.dtype, fromptr.dtype, parents.dtype]))(grid, block, (toptr, fromptr, parents, lenparents, outlength, identity, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_max_b", toptr.dtype, fromptr.dtype, parents.dtype]))(grid, block, (toptr, fromptr, parents, lenparents, outlength, identity, invocation_index, err_code))
// if block[0] > 0:
// segment = math.floor((outlength + block[0] - 1) / block[0])
// grid_size = math.floor((lenparents + block[0] - 1) / block[0])
// else:
// segment = 0
// grid_size = 1
// partial = cupy.full(outlength * grid_size, identity, dtype=toptr.dtype)
// temp = cupy.zeros(lenparents, dtype=toptr.dtype)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_max_a", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, identity, partial, temp, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_max_b", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, identity, partial, temp, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_max_c", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((segment,), block, (toptr, fromptr, parents, lenparents, outlength, identity, partial, temp, invocation_index, err_code))
// out["awkward_reduce_max_a", {dtype_specializations}] = None
// out["awkward_reduce_max_b", {dtype_specializations}] = None
// out["awkward_reduce_max_c", {dtype_specializations}] = None
// END PYTHON

template <typename T, typename C, typename U>
Expand All @@ -18,10 +28,13 @@ awkward_reduce_max_a(
int64_t lenparents,
int64_t outlength,
T identity,
T* partial,
T* temp,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < outlength) {
toptr[thread_id] = identity;
}
Expand All @@ -37,15 +50,61 @@ awkward_reduce_max_b(
int64_t lenparents,
int64_t outlength,
T identity,
T* partial,
T* temp,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx = threadIdx.x;
int64_t thread_id = blockIdx.x * blockDim.x + idx;

if (thread_id < lenparents) {
temp[idx] = fromptr[thread_id];
}
__syncthreads();

for (int64_t stride = 1; stride < blockDim.x; stride *= 2) {
T val = identity;
if (idx >= stride && thread_id < lenparents && parents[thread_id] == parents[thread_id - stride]) {
val = temp[idx - stride];
}
__syncthreads();
temp[idx] = val > temp[idx] ? val : temp[idx];
__syncthreads();
}

if (thread_id < lenparents) {
C x = fromptr[thread_id];
toptr[parents[thread_id]] =
(x > toptr[parents[thread_id]] ? x : toptr[parents[thread_id]]);
int64_t parent = parents[thread_id];
if (idx == blockDim.x - 1 || thread_id == lenparents - 1 || parents[thread_id] != parents[thread_id + 1]) {
partial[blockIdx.x * outlength + parent] = temp[idx];
}
}
}
}

template <typename T, typename C, typename U>
__global__ void
awkward_reduce_max_c(
T* toptr,
const C* fromptr,
const U* parents,
int64_t lenparents,
int64_t outlength,
T identity,
T* partial,
T* temp,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < outlength) {
T maximum = identity;
int64_t blocks = (lenparents + blockDim.x - 1) / blockDim.x;
for (int64_t i = 0; i < blocks; ++i) {
maximum = maximum > partial[i * outlength + thread_id] ? maximum : partial[i * outlength + thread_id];
}
toptr[thread_id] = maximum;
}
}
}
71 changes: 65 additions & 6 deletions src/awkward/_connect/cuda/cuda_kernels/awkward_reduce_min.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@
// BEGIN PYTHON
// def f(grid, block, args):
// (toptr, fromptr, parents, lenparents, outlength, identity, invocation_index, err_code) = args
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_min_a", toptr.dtype, fromptr.dtype, parents.dtype]))(grid, block, (toptr, fromptr, parents, lenparents, outlength, identity, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_min_b", toptr.dtype, fromptr.dtype, parents.dtype]))(grid, block, (toptr, fromptr, parents, lenparents, outlength, identity, invocation_index, err_code))
// if block[0] > 0:
// segment = math.floor((outlength + block[0] - 1) / block[0])
// grid_size = math.floor((lenparents + block[0] - 1) / block[0])
// else:
// segment = 0
// grid_size = 1
// partial = cupy.full(outlength * grid_size, identity, dtype=toptr.dtype)
// temp = cupy.zeros(lenparents, dtype=toptr.dtype)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_min_a", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, identity, partial, temp, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_min_b", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, identity, partial, temp, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_min_c", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((segment,), block, (toptr, fromptr, parents, lenparents, outlength, identity, partial, temp, invocation_index, err_code))
// out["awkward_reduce_min_a", {dtype_specializations}] = None
// out["awkward_reduce_min_b", {dtype_specializations}] = None
// out["awkward_reduce_min_c", {dtype_specializations}] = None
// END PYTHON

template <typename T, typename C, typename U>
Expand All @@ -18,10 +28,13 @@ awkward_reduce_min_a(
int64_t lenparents,
int64_t outlength,
T identity,
T* partial,
T* temp,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < outlength) {
toptr[thread_id] = identity;
}
Expand All @@ -37,15 +50,61 @@ awkward_reduce_min_b(
int64_t lenparents,
int64_t outlength,
T identity,
T* partial,
T* temp,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx = threadIdx.x;
int64_t thread_id = blockIdx.x * blockDim.x + idx;

if (thread_id < lenparents) {
temp[idx] = fromptr[thread_id];
}
__syncthreads();

for (int64_t stride = 1; stride < blockDim.x; stride *= 2) {
T val = identity;
if (idx >= stride && thread_id < lenparents && parents[thread_id] == parents[thread_id - stride]) {
val = temp[idx - stride];
}
__syncthreads();
temp[idx] = val < temp[idx] ? val : temp[idx];
__syncthreads();
}

if (thread_id < lenparents) {
C x = fromptr[thread_id];
toptr[parents[thread_id]] =
(x < toptr[parents[thread_id]] ? x : toptr[parents[thread_id]]);
int64_t parent = parents[thread_id];
if (idx == blockDim.x - 1 || thread_id == lenparents - 1 || parents[thread_id] != parents[thread_id + 1]) {
partial[blockIdx.x * outlength + parent] = temp[idx];
}
}
}
}

template <typename T, typename C, typename U>
__global__ void
awkward_reduce_min_c(
T* toptr,
const C* fromptr,
const U* parents,
int64_t lenparents,
int64_t outlength,
T identity,
T* partial,
T* temp,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < outlength) {
T minimum = identity;
int64_t blocks = (lenparents + blockDim.x - 1) / blockDim.x;
for (int64_t i = 0; i < blocks; ++i) {
minimum = minimum < partial[i * outlength + thread_id] ? minimum : partial[i * outlength + thread_id];
}
toptr[thread_id] = minimum;
}
}
}
57 changes: 45 additions & 12 deletions src/awkward/_connect/cuda/cuda_kernels/awkward_reduce_sum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
// BEGIN PYTHON
// def f(grid, block, args):
// (toptr, fromptr, parents, lenparents, outlength, invocation_index, err_code) = args
// atomicAdd_toptr = cupy.array(toptr, dtype=cupy.uint64)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_sum_a", toptr.dtype, fromptr.dtype, parents.dtype]))(grid, block, (toptr, fromptr, parents, lenparents, outlength, atomicAdd_toptr, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_sum_b", toptr.dtype, fromptr.dtype, parents.dtype]))(grid, block, (toptr, fromptr, parents, lenparents, outlength, atomicAdd_toptr, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_sum_c", toptr.dtype, fromptr.dtype, parents.dtype]))(grid, block, (toptr, fromptr, parents, lenparents, outlength, atomicAdd_toptr, invocation_index, err_code))
// if block[0] > 0:
// segment = math.floor((outlength + block[0] - 1) / block[0])
// grid_size = math.floor((lenparents + block[0] - 1) / block[0])
// else:
// segment = 0
// grid_size = 1
// partial = cupy.zeros(outlength * grid_size, dtype=toptr.dtype)
// temp = cupy.zeros(lenparents, dtype=toptr.dtype)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_sum_a", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, partial, temp, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_sum_b", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, partial, temp, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_sum_c", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((segment,), block, (toptr, fromptr, parents, lenparents, outlength, partial, temp, invocation_index, err_code))
// out["awkward_reduce_sum_a", {dtype_specializations}] = None
// out["awkward_reduce_sum_b", {dtype_specializations}] = None
// out["awkward_reduce_sum_c", {dtype_specializations}] = None
Expand All @@ -20,14 +27,15 @@ awkward_reduce_sum_a(
const U* parents,
int64_t lenparents,
int64_t outlength,
uint64_t* atomicAdd_toptr,
T* partial,
T* temp,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < outlength) {
atomicAdd_toptr[thread_id] = 0;
toptr[thread_id] = 0;
}
}
}
Expand All @@ -40,15 +48,34 @@ awkward_reduce_sum_b(
const U* parents,
int64_t lenparents,
int64_t outlength,
uint64_t* atomicAdd_toptr,
T* partial,
T* temp,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx = threadIdx.x;
int64_t thread_id = blockIdx.x * blockDim.x + idx;

if (thread_id < lenparents) {
temp[idx] = fromptr[thread_id];
}
__syncthreads();

for (int64_t stride = 1; stride < blockDim.x; stride *= 2) {
T val = 0;
if (idx >= stride && thread_id < lenparents && parents[thread_id] == parents[thread_id - stride]) {
val = temp[idx - stride];
}
__syncthreads();
temp[idx] += val;
__syncthreads();
}

if (thread_id < lenparents) {
atomicAdd(atomicAdd_toptr + parents[thread_id],
(uint64_t)fromptr[thread_id]);
int64_t parent = parents[thread_id];
if (idx == blockDim.x - 1 || thread_id == lenparents - 1 || parents[thread_id] != parents[thread_id + 1]) {
partial[blockIdx.x * outlength + parent] = temp[idx];
}
}
}
}
Expand All @@ -61,14 +88,20 @@ awkward_reduce_sum_c(
const U* parents,
int64_t lenparents,
int64_t outlength,
uint64_t* atomicAdd_toptr,
T* partial,
T* temp,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < outlength) {
toptr[thread_id] = (T)atomicAdd_toptr[thread_id];
T sum = 0;
int64_t blocks = (lenparents + blockDim.x - 1) / blockDim.x;
for (int64_t i = 0; i < blocks; ++i) {
sum += partial[i * outlength + thread_id];
}
toptr[thread_id] = sum;
}
}
}

0 comments on commit 34fc82b

Please sign in to comment.