From 32a507e307e6757164990e1a727b976780df75f9 Mon Sep 17 00:00:00 2001 From: Ianna Osborne Date: Mon, 7 Oct 2024 17:17:29 +0200 Subject: [PATCH] make sure it's a cupy zero dim array --- .../cuda_kernels/awkward_ListArray_getitem_next_at.cu | 7 +++++++ .../cuda_kernels/awkward_RegularArray_getitem_next_at.cu | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu index 7fc13f2ae8..1fa5333af5 100644 --- a/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu +++ b/src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_at.cu @@ -1,5 +1,12 @@ // BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE +// BEGIN PYTHON +// def f(grid, block, args): +// (tocarry, fromstarts, fromstops, lenstarts, at, invocation_index, err_code) = args +// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_next_at", tocarry.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tocarry, fromstarts, fromstops, lenstarts, cupy.array(at), invocation_index, err_code)) +// out["awkward_ListArray_getitem_next_at", {dtype_specializations}] = None +// END PYTHON + enum class LISTARRAY_GETITEM_NEXT_AT_ERRORS { IND_OUT_OF_RANGE, // message: "index out of range" }; diff --git a/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu b/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu index 1b8bd53b38..ddb8cd94ab 100644 --- a/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu +++ b/src/awkward/_connect/cuda/cuda_kernels/awkward_RegularArray_getitem_next_at.cu @@ -1,5 +1,13 @@ // BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE + +// BEGIN PYTHON +// def f(grid, block, args): +// (tocarry, at, length, size, invocation_index, err_code) = args +// cuda_kernel_templates.get_function(fetch_specialization(["awkward_RegularArray_getitem_next_at", tocarry.dtype]))(grid, block, (tocarry, cupy.array(at), length, size, invocation_index, err_code)) +// out["awkward_RegularArray_getitem_next_at", {dtype_specializations}] = None +// END PYTHON + enum class REGULARARRAY_GETITEM_NEXT_AT_ERRORS { IND_OUT_OF_RANGE // message: "index out of range" };