From c429c5266aaeb3e792b50bd3433eb3235f1d922a Mon Sep 17 00:00:00 2001 From: ManasviGoyal Date: Tue, 4 Jun 2024 14:04:04 +0200 Subject: [PATCH] feat: add awkward_ListArray_getitem_jagged_shrink kernel --- dev/generate-kernel-signatures.py | 1 + dev/generate-tests.py | 1 + src/awkward/_connect/cuda/__init__.py | 1 + .../awkward_ListArray_getitem_jagged_apply.cu | 2 +- ...awkward_ListArray_getitem_jagged_shrink.cu | 105 ++++++++++++++++++ ...istArray_getitem_next_range_carrylength.cu | 2 - 6 files changed, 109 insertions(+), 3 deletions(-) create mode 100644 src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_jagged_shrink.cu diff --git a/dev/generate-kernel-signatures.py b/dev/generate-kernel-signatures.py index 6a61074505..67d1e69a9f 100644 --- a/dev/generate-kernel-signatures.py +++ b/dev/generate-kernel-signatures.py @@ -58,6 +58,7 @@ "awkward_ListArray_getitem_jagged_descend", "awkward_ListArray_getitem_jagged_expand", "awkward_ListArray_getitem_jagged_numvalid", + "awkward_ListArray_getitem_jagged_shrink", "awkward_ListArray_getitem_next_array_advanced", "awkward_ListArray_getitem_next_array", "awkward_ListArray_getitem_next_at", diff --git a/dev/generate-tests.py b/dev/generate-tests.py index 6672460c99..e434c2093e 100644 --- a/dev/generate-tests.py +++ b/dev/generate-tests.py @@ -843,6 +843,7 @@ def gencpuunittests(specdict): "awkward_ListArray_getitem_jagged_descend", "awkward_ListArray_getitem_jagged_expand", "awkward_ListArray_getitem_jagged_numvalid", + "awkward_ListArray_getitem_jagged_shrink", "awkward_ListArray_getitem_next_array_advanced", "awkward_ListArray_getitem_next_array", "awkward_ListArray_getitem_next_at", diff --git a/src/awkward/_connect/cuda/__init__.py b/src/awkward/_connect/cuda/__init__.py index a9a590fecc..098d1c4525 100644 --- a/src/awkward/_connect/cuda/__init__.py +++ b/src/awkward/_connect/cuda/__init__.py @@ -100,6 +100,7 @@ def fetch_template_specializations(kernel_dict): "awkward_ListArray_getitem_jagged_carrylen", "awkward_ListArray_getitem_jagged_descend", "awkward_ListArray_getitem_jagged_numvalid", + "awkward_ListArray_getitem_jagged_shrink", "awkward_ListArray_getitem_next_range", "awkward_ListArray_getitem_next_range_carrylength", "awkward_ListArray_min_range", diff --git a/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_jagged_apply.cu b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_jagged_apply.cu index 88a7398352..b3017113b4 100644 --- a/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_jagged_apply.cu +++ b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_jagged_apply.cu @@ -104,7 +104,7 @@ awkward_ListArray_getitem_jagged_apply_b( RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_GET_LEN) } int64_t count = stop - start; - for (int64_t j = slicestart; j < slicestop; j++) { + for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) { int64_t index = (int64_t) sliceindex[j]; if (index < -count || index > count) { RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::IND_OUT_OF_RANGE) diff --git a/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_jagged_shrink.cu b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_jagged_shrink.cu new file mode 100644 index 0000000000..477af8e65c --- /dev/null +++ b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_jagged_shrink.cu @@ -0,0 +1,105 @@ +// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE + +// BEGIN PYTHON +// def f(grid, block, args): +// (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, invocation_index, err_code) = args +// if length > 0 and length < int(slicestops[length - 1]): +// len_array = int(slicestops[length - 1]) +// else: +// len_array = length +// scan_in_array_k = cupy.zeros(len_array, dtype=cupy.int64) +// scan_in_array_tosmalloffsets = cupy.zeros(length + 1, dtype=cupy.int64) +// scan_in_array_tolargeoffsets = cupy.zeros(length + 1, dtype=cupy.int64) +// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_shrink_a", tocarry.dtype, tosmalloffsets.dtype, tolargeoffsets.dtype, slicestarts.dtype, slicestops.dtype, missing.dtype]))(grid, block, (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, scan_in_array_k, scan_in_array_tosmalloffsets, scan_in_array_tolargeoffsets, invocation_index, err_code)) +// scan_in_array_k = cupy.cumsum(scan_in_array_k) +// scan_in_array_tosmalloffsets = cupy.cumsum(scan_in_array_tosmalloffsets) +// scan_in_array_tolargeoffsets = cupy.cumsum(scan_in_array_tolargeoffsets) +// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_shrink_b", tocarry.dtype, tosmalloffsets.dtype, tolargeoffsets.dtype, slicestarts.dtype, slicestops.dtype, missing.dtype]))(grid, block, (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, scan_in_array_k, scan_in_array_tosmalloffsets, scan_in_array_tolargeoffsets, invocation_index, err_code)) +// out["awkward_ListArray_getitem_jagged_shrink_a", {dtype_specializations}] = None +// out["awkward_ListArray_getitem_jagged_shrink_b", {dtype_specializations}] = None +// END PYTHON + +template +__global__ void +awkward_ListArray_getitem_jagged_shrink_a( + T* tocarry, + C* tosmalloffsets, + U* tolargeoffsets, + const V* slicestarts, + const W* slicestops, + int64_t length, + const X* missing, + int64_t* scan_in_array_k, + int64_t* scan_in_array_tosmalloffsets, + int64_t* scan_in_array_tolargeoffsets, + uint64_t invocation_index, + uint64_t* err_code) { + if (err_code[0] == NO_ERROR) { + int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + if (thread_id < length) { + if (thread_id == 0) { + scan_in_array_tosmalloffsets[0] = slicestarts[0]; + scan_in_array_tolargeoffsets[0] = slicestarts[0]; + } + V slicestart = slicestarts[thread_id]; + W slicestop = slicestops[thread_id]; + if (slicestart != slicestop) { + C smallcount = 0; + for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) { + if (missing[j] >= 0) { + smallcount++; + } + } + scan_in_array_k[thread_id + 1] = smallcount; + scan_in_array_tosmalloffsets[thread_id + 1] = smallcount; + } + scan_in_array_tolargeoffsets[thread_id + 1] = slicestop - slicestart; + } + } +} + +template +__global__ void +awkward_ListArray_getitem_jagged_shrink_b( + T* tocarry, + C* tosmalloffsets, + U* tolargeoffsets, + const V* slicestarts, + const W* slicestops, + int64_t length, + const X* missing, + int64_t* scan_in_array_k, + int64_t* scan_in_array_tosmalloffsets, + int64_t* scan_in_array_tolargeoffsets, + uint64_t invocation_index, + uint64_t* err_code) { + if (err_code[0] == NO_ERROR) { + int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + if (length == 0) { + tosmalloffsets[0] = 0; + tolargeoffsets[0] = 0; + } + else { + tosmalloffsets[0] = slicestarts[0]; + tolargeoffsets[0] = slicestarts[0]; + } + if (thread_id < length) { + V slicestart = slicestarts[thread_id]; + W slicestop = slicestops[thread_id]; + int64_t k = scan_in_array_k[thread_id] - scan_in_array_k[0]; + if (slicestart != slicestop) { + for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) { + if (missing[j] >= 0) { + tocarry[k] = j; + k++; + } + } + tosmalloffsets[thread_id + 1] = scan_in_array_tosmalloffsets[thread_id + 1]; + } + else { + tosmalloffsets[thread_id + 1] = scan_in_array_tosmalloffsets[thread_id]; + } + tolargeoffsets[thread_id + 1] = scan_in_array_tolargeoffsets[thread_id + 1]; + } + } +} diff --git a/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_range_carrylength.cu b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_range_carrylength.cu index dde19b6e8b..a69aecb732 100644 --- a/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_range_carrylength.cu +++ b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_range_carrylength.cu @@ -26,8 +26,6 @@ awkward_ListArray_getitem_next_range_carrylength_a( uint64_t* err_code) { if (err_code[0] == NO_ERROR) { int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; - const int64_t kMaxInt64 = 9223372036854775806; // 2**63 - 2: see below - const int64_t kSliceNone = kMaxInt64 + 1; // for Slice::none() if (thread_id < lenstarts) { int64_t length = fromstops[thread_id] - fromstarts[thread_id]; int64_t regular_start = start;