Skip to content

Commit

Permalink
feat: add awkward_ListArray_getitem_jagged_shrink kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ManasviGoyal committed Jun 4, 2024
1 parent 0e76fba commit c429c52
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 3 deletions.
1 change: 1 addition & 0 deletions dev/generate-kernel-signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"awkward_ListArray_getitem_jagged_descend",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_numvalid",
"awkward_ListArray_getitem_jagged_shrink",
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
Expand Down
1 change: 1 addition & 0 deletions dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,7 @@ def gencpuunittests(specdict):
"awkward_ListArray_getitem_jagged_descend",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_numvalid",
"awkward_ListArray_getitem_jagged_shrink",
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
Expand Down
1 change: 1 addition & 0 deletions src/awkward/_connect/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def fetch_template_specializations(kernel_dict):
"awkward_ListArray_getitem_jagged_carrylen",
"awkward_ListArray_getitem_jagged_descend",
"awkward_ListArray_getitem_jagged_numvalid",
"awkward_ListArray_getitem_jagged_shrink",
"awkward_ListArray_getitem_next_range",
"awkward_ListArray_getitem_next_range_carrylength",
"awkward_ListArray_min_range",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ awkward_ListArray_getitem_jagged_apply_b(
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_GET_LEN)
}
int64_t count = stop - start;
for (int64_t j = slicestart; j < slicestop; j++) {
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) {
int64_t index = (int64_t) sliceindex[j];
if (index < -count || index > count) {
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::IND_OUT_OF_RANGE)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

// BEGIN PYTHON
// def f(grid, block, args):
// (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, invocation_index, err_code) = args
// if length > 0 and length < int(slicestops[length - 1]):
// len_array = int(slicestops[length - 1])
// else:
// len_array = length
// scan_in_array_k = cupy.zeros(len_array, dtype=cupy.int64)
// scan_in_array_tosmalloffsets = cupy.zeros(length + 1, dtype=cupy.int64)
// scan_in_array_tolargeoffsets = cupy.zeros(length + 1, dtype=cupy.int64)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_shrink_a", tocarry.dtype, tosmalloffsets.dtype, tolargeoffsets.dtype, slicestarts.dtype, slicestops.dtype, missing.dtype]))(grid, block, (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, scan_in_array_k, scan_in_array_tosmalloffsets, scan_in_array_tolargeoffsets, invocation_index, err_code))
// scan_in_array_k = cupy.cumsum(scan_in_array_k)
// scan_in_array_tosmalloffsets = cupy.cumsum(scan_in_array_tosmalloffsets)
// scan_in_array_tolargeoffsets = cupy.cumsum(scan_in_array_tolargeoffsets)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_shrink_b", tocarry.dtype, tosmalloffsets.dtype, tolargeoffsets.dtype, slicestarts.dtype, slicestops.dtype, missing.dtype]))(grid, block, (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, scan_in_array_k, scan_in_array_tosmalloffsets, scan_in_array_tolargeoffsets, invocation_index, err_code))
// out["awkward_ListArray_getitem_jagged_shrink_a", {dtype_specializations}] = None
// out["awkward_ListArray_getitem_jagged_shrink_b", {dtype_specializations}] = None
// END PYTHON

template <typename T, typename C, typename U, typename V, typename W, typename X>
__global__ void
awkward_ListArray_getitem_jagged_shrink_a(
T* tocarry,
C* tosmalloffsets,
U* tolargeoffsets,
const V* slicestarts,
const W* slicestops,
int64_t length,
const X* missing,
int64_t* scan_in_array_k,
int64_t* scan_in_array_tosmalloffsets,
int64_t* scan_in_array_tolargeoffsets,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id < length) {
if (thread_id == 0) {
scan_in_array_tosmalloffsets[0] = slicestarts[0];
scan_in_array_tolargeoffsets[0] = slicestarts[0];
}
V slicestart = slicestarts[thread_id];
W slicestop = slicestops[thread_id];
if (slicestart != slicestop) {
C smallcount = 0;
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) {
if (missing[j] >= 0) {
smallcount++;
}
}
scan_in_array_k[thread_id + 1] = smallcount;
scan_in_array_tosmalloffsets[thread_id + 1] = smallcount;
}
scan_in_array_tolargeoffsets[thread_id + 1] = slicestop - slicestart;
}
}
}

template <typename T, typename C, typename U, typename V, typename W, typename X>
__global__ void
awkward_ListArray_getitem_jagged_shrink_b(
T* tocarry,
C* tosmalloffsets,
U* tolargeoffsets,
const V* slicestarts,
const W* slicestops,
int64_t length,
const X* missing,
int64_t* scan_in_array_k,
int64_t* scan_in_array_tosmalloffsets,
int64_t* scan_in_array_tolargeoffsets,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (length == 0) {
tosmalloffsets[0] = 0;
tolargeoffsets[0] = 0;
}
else {
tosmalloffsets[0] = slicestarts[0];
tolargeoffsets[0] = slicestarts[0];
}
if (thread_id < length) {
V slicestart = slicestarts[thread_id];
W slicestop = slicestops[thread_id];
int64_t k = scan_in_array_k[thread_id] - scan_in_array_k[0];
if (slicestart != slicestop) {
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) {
if (missing[j] >= 0) {
tocarry[k] = j;
k++;
}
}
tosmalloffsets[thread_id + 1] = scan_in_array_tosmalloffsets[thread_id + 1];
}
else {
tosmalloffsets[thread_id + 1] = scan_in_array_tosmalloffsets[thread_id];
}
tolargeoffsets[thread_id + 1] = scan_in_array_tolargeoffsets[thread_id + 1];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ awkward_ListArray_getitem_next_range_carrylength_a(
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t kMaxInt64 = 9223372036854775806; // 2**63 - 2: see below
const int64_t kSliceNone = kMaxInt64 + 1; // for Slice::none()
if (thread_id < lenstarts) {
int64_t length = fromstops[thread_id] - fromstarts[thread_id];
int64_t regular_start = start;
Expand Down

0 comments on commit c429c52

Please sign in to comment.