Skip to content

Commit

Permalink
feat: add slicing CUDA kernels (#3140)
Browse files Browse the repository at this point in the history
* feat add awkward_ListArray_getitem_jagged_apply kernel

* fix: remove print statements

* feat: add awkward_ListArray_getitem_jagged_shrink kernel

* test: cuda integration tests

* test: more slicing integration tests

* fix: ndarray error for cupy array shape

* fix: remove unused variable
  • Loading branch information
ManasviGoyal authored Jun 5, 2024
1 parent 0b9f6f4 commit 81085be
Show file tree
Hide file tree
Showing 9 changed files with 2,433 additions and 3 deletions.
2 changes: 2 additions & 0 deletions dev/generate-kernel-signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@
"awkward_missing_repeat",
"awkward_RegularArray_getitem_jagged_expand",
"awkward_ListArray_combinations_length",
"awkward_ListArray_getitem_jagged_apply",
"awkward_ListArray_getitem_jagged_carrylen",
"awkward_ListArray_getitem_jagged_descend",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_numvalid",
"awkward_ListArray_getitem_jagged_shrink",
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
Expand Down
2 changes: 2 additions & 0 deletions dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,10 +838,12 @@ def gencpuunittests(specdict):
"awkward_missing_repeat",
"awkward_RegularArray_getitem_jagged_expand",
"awkward_ListArray_combinations_length",
"awkward_ListArray_getitem_jagged_apply",
"awkward_ListArray_getitem_jagged_carrylen",
"awkward_ListArray_getitem_jagged_descend",
"awkward_ListArray_getitem_jagged_expand",
"awkward_ListArray_getitem_jagged_numvalid",
"awkward_ListArray_getitem_jagged_shrink",
"awkward_ListArray_getitem_next_array_advanced",
"awkward_ListArray_getitem_next_array",
"awkward_ListArray_getitem_next_at",
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_connect/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ def fetch_template_specializations(kernel_dict):
"awkward_ListArray_broadcast_tooffsets",
"awkward_ListArray_combinations_length",
"awkward_ListArray_compact_offsets",
"awkward_ListArray_getitem_jagged_apply",
"awkward_ListArray_getitem_jagged_carrylen",
"awkward_ListArray_getitem_jagged_descend",
"awkward_ListArray_getitem_jagged_numvalid",
"awkward_ListArray_getitem_jagged_shrink",
"awkward_ListArray_getitem_next_range",
"awkward_ListArray_getitem_next_range_carrylength",
"awkward_ListArray_min_range",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

// BEGIN PYTHON
// def f(grid, block, args):
// (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, invocation_index, err_code) = args
// scan_in_array = cupy.zeros(sliceouterlen + 1, dtype=cupy.int64)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_apply_a", tooffsets.dtype, tocarry.dtype, slicestarts.dtype, slicestops.dtype, sliceindex.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, scan_in_array, invocation_index, err_code))
// scan_in_array = cupy.cumsum(scan_in_array)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_apply_b", tooffsets.dtype, tocarry.dtype, slicestarts.dtype, slicestops.dtype, sliceindex.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, scan_in_array, invocation_index, err_code))
// out["awkward_ListArray_getitem_jagged_apply_a", {dtype_specializations}] = None
// out["awkward_ListArray_getitem_jagged_apply_b", {dtype_specializations}] = None
// END PYTHON

enum class LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS {
JAG_STOP_LT_START, // message: "jagged slice's stops[i] < starts[i]"
OFF_GET_CON, // message: "jagged slice's offsets extend beyond its content"
STOP_LT_START, // message: "stops[i] < starts[i]"
STOP_GET_LEN, // message: "stops[i] > len(content)"
IND_OUT_OF_RANGE, // message: "index out of range"
};

template <typename T, typename C, typename U, typename V, typename W, typename X, typename Y>
__global__ void
awkward_ListArray_getitem_jagged_apply_a(
T* tooffsets,
C* tocarry,
const U* slicestarts,
const V* slicestops,
int64_t sliceouterlen,
const W* sliceindex,
int64_t sliceinnerlen,
const X* fromstarts,
const Y* fromstops,
int64_t contentlen,
int64_t* scan_in_array,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
scan_in_array[0] = 0;

if (thread_id < sliceouterlen) {
U slicestart = slicestarts[thread_id];
V slicestop = slicestops[thread_id];

if (slicestart != slicestop) {
if (slicestop < slicestart) {
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::JAG_STOP_LT_START)
}
if (slicestop > sliceinnerlen) {
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::OFF_GET_CON)
}
int64_t start = (int64_t)fromstarts[thread_id];
int64_t stop = (int64_t)fromstops[thread_id];
if (stop < start) {
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_LT_START)
}
if (start != stop && stop > contentlen) {
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_GET_LEN)
}
scan_in_array[thread_id + 1] = slicestop - slicestart;
}
}
}
}

template <typename T, typename C, typename U, typename V, typename W, typename X, typename Y>
__global__ void
awkward_ListArray_getitem_jagged_apply_b(
T* tooffsets,
C* tocarry,
const U* slicestarts,
const V* slicestops,
int64_t sliceouterlen,
const W* sliceindex,
int64_t sliceinnerlen,
const X* fromstarts,
const Y* fromstops,
int64_t contentlen,
int64_t* scan_in_array,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < sliceouterlen) {
U slicestart = slicestarts[thread_id];
V slicestop = slicestops[thread_id];
tooffsets[thread_id] = (T)(scan_in_array[thread_id]);
if (slicestart != slicestop) {
if (slicestop < slicestart) {
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::JAG_STOP_LT_START)
}
if (slicestop > sliceinnerlen) {
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::OFF_GET_CON)
}
int64_t start = (int64_t)fromstarts[thread_id];
int64_t stop = (int64_t)fromstops[thread_id];
if (stop < start) {
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_LT_START)
}
if (start != stop && stop > contentlen) {
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_GET_LEN)
}
int64_t count = stop - start;
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) {
int64_t index = (int64_t) sliceindex[j];
if (index < -count || index > count) {
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::IND_OUT_OF_RANGE)
}
if (index < 0) {
index += count;
}
tocarry[scan_in_array[thread_id] + j - slicestart] = start + index;
}
}
}
tooffsets[sliceouterlen] = scan_in_array[sliceouterlen];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

// BEGIN PYTHON
// def f(grid, block, args):
// (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, invocation_index, err_code) = args
// if length > 0 and length < int(slicestops[length - 1]):
// len_array = int(slicestops[length - 1])
// else:
// len_array = length
// scan_in_array_k = cupy.zeros(len_array, dtype=cupy.int64)
// scan_in_array_tosmalloffsets = cupy.zeros(length + 1, dtype=cupy.int64)
// scan_in_array_tolargeoffsets = cupy.zeros(length + 1, dtype=cupy.int64)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_shrink_a", tocarry.dtype, tosmalloffsets.dtype, tolargeoffsets.dtype, slicestarts.dtype, slicestops.dtype, missing.dtype]))(grid, block, (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, scan_in_array_k, scan_in_array_tosmalloffsets, scan_in_array_tolargeoffsets, invocation_index, err_code))
// scan_in_array_k = cupy.cumsum(scan_in_array_k)
// scan_in_array_tosmalloffsets = cupy.cumsum(scan_in_array_tosmalloffsets)
// scan_in_array_tolargeoffsets = cupy.cumsum(scan_in_array_tolargeoffsets)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_shrink_b", tocarry.dtype, tosmalloffsets.dtype, tolargeoffsets.dtype, slicestarts.dtype, slicestops.dtype, missing.dtype]))(grid, block, (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, scan_in_array_k, scan_in_array_tosmalloffsets, scan_in_array_tolargeoffsets, invocation_index, err_code))
// out["awkward_ListArray_getitem_jagged_shrink_a", {dtype_specializations}] = None
// out["awkward_ListArray_getitem_jagged_shrink_b", {dtype_specializations}] = None
// END PYTHON

template <typename T, typename C, typename U, typename V, typename W, typename X>
__global__ void
awkward_ListArray_getitem_jagged_shrink_a(
T* tocarry,
C* tosmalloffsets,
U* tolargeoffsets,
const V* slicestarts,
const W* slicestops,
int64_t length,
const X* missing,
int64_t* scan_in_array_k,
int64_t* scan_in_array_tosmalloffsets,
int64_t* scan_in_array_tolargeoffsets,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id < length) {
if (thread_id == 0) {
scan_in_array_tosmalloffsets[0] = slicestarts[0];
scan_in_array_tolargeoffsets[0] = slicestarts[0];
}
V slicestart = slicestarts[thread_id];
W slicestop = slicestops[thread_id];
if (slicestart != slicestop) {
C smallcount = 0;
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) {
if (missing[j] >= 0) {
smallcount++;
}
}
scan_in_array_k[thread_id + 1] = smallcount;
scan_in_array_tosmalloffsets[thread_id + 1] = smallcount;
}
scan_in_array_tolargeoffsets[thread_id + 1] = slicestop - slicestart;
}
}
}

template <typename T, typename C, typename U, typename V, typename W, typename X>
__global__ void
awkward_ListArray_getitem_jagged_shrink_b(
T* tocarry,
C* tosmalloffsets,
U* tolargeoffsets,
const V* slicestarts,
const W* slicestops,
int64_t length,
const X* missing,
int64_t* scan_in_array_k,
int64_t* scan_in_array_tosmalloffsets,
int64_t* scan_in_array_tolargeoffsets,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (length == 0) {
tosmalloffsets[0] = 0;
tolargeoffsets[0] = 0;
}
else {
tosmalloffsets[0] = slicestarts[0];
tolargeoffsets[0] = slicestarts[0];
}
if (thread_id < length) {
V slicestart = slicestarts[thread_id];
W slicestop = slicestops[thread_id];
int64_t k = scan_in_array_k[thread_id] - scan_in_array_k[0];
if (slicestart != slicestop) {
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) {
if (missing[j] >= 0) {
tocarry[k] = j;
k++;
}
}
tosmalloffsets[thread_id + 1] = scan_in_array_tosmalloffsets[thread_id + 1];
}
else {
tosmalloffsets[thread_id + 1] = scan_in_array_tosmalloffsets[thread_id];
}
tolargeoffsets[thread_id + 1] = scan_in_array_tolargeoffsets[thread_id + 1];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ awkward_ListArray_getitem_next_range_carrylength_a(
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t kMaxInt64 = 9223372036854775806; // 2**63 - 2: see below
const int64_t kSliceNone = kMaxInt64 + 1; // for Slice::none()
if (thread_id < lenstarts) {
int64_t length = fromstops[thread_id] - fromstarts[thread_id];
int64_t regular_start = start;
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def _normalise_item_bool_to_int(item: Content, backend: Backend) -> Content:

# outindex fits into the lists; non-missing are sequential
outindex = ak.index.Index64(
item_backend.index_nplike.full(nextoffsets.data[-1], -1, dtype=np.int64)
item_backend.index_nplike.full(nextoffsets[-1], -1, dtype=np.int64)
)
outindex.data[~isnegative[expanded]] = item_backend.index_nplike.arange(
nextcontent.shape[0], dtype=np.int64
Expand Down
Loading

0 comments on commit 81085be

Please sign in to comment.