diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index ac7efeff41aba..bc3b9e8c558e0 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -1964,7 +1964,7 @@ class CumulativeOptions(_CumulativeOptions): Parameters ---------- start : Scalar, default None - Starting value for the cumulative operation. If none is given, + Starting value for the cumulative operation. If none is given, a default value depending on the operation and input type is used. skip_nulls : bool, default False When false, the first encountered null is propagated. @@ -2707,6 +2707,11 @@ cdef get_register_aggregate_function(): reg.register_func = RegisterAggregateFunction return reg +cdef get_register_vector_function(): + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) + reg.register_func = RegisterVectorFunction + return reg + def register_scalar_function(func, function_name, function_doc, in_types, out_type, func_registry=None): @@ -2789,6 +2794,83 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty out_type, func_registry) +def register_vector_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + """ + Register a user-defined vector function. + + This API is EXPERIMENTAL. + + A vector function is a function that executes vector + operations on arrays. Vector function is often used + when compute doesn't fit other more specific types of + functions (e.g., scalar and aggregate). + + Parameters + ---------- + func : callable + A callable implementing the user-defined function. + The first argument is the context argument of type + UdfContext. + Then, it must take arguments equal to the number of + in_types defined. It must return an Array or Scalar + matching the out_type. It must return a Scalar if + all arguments are scalar, else it must return an Array. + + To define a varargs function, pass a callable that takes + *args. The last in_type will be the type of all varargs + arguments. + function_name : str + Name of the function. There should only be one function + registered with this name in the function registry. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + The argument names will be used to generate + documentation for the function. The number of + arguments specified here determines the function + arity. + out_type : DataType + Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> + >>> func_doc = {} + >>> func_doc["summary"] = "percent rank" + >>> func_doc["description"] = "compute percent rank" + >>> + >>> def list_flatten_udf(ctx, x): + ... return pc.list_flatten(x) + >>> + >>> func_name = "list_flatten_udf" + >>> in_types = {"array": pa.list_(pa.int64())} + >>> out_type = pa.int64() + >>> pc.register_vector_function(list_flatten_udf, func_name, func_doc, + ... in_types, out_type) + >>> + >>> answer = pc.call_function(func_name, [pa.array([[1, 2], [3, 4]])]) + >>> answer + + [ + 1, + 2, + 3, + 4 + ] + """ + return _register_user_defined_function(get_register_vector_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + def register_aggregate_function(func, function_name, function_doc, in_types, out_type, func_registry=None): """ diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 0fefa18dd1136..7b8983cbb98d2 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -87,6 +87,7 @@ register_scalar_function, register_tabular_function, register_aggregate_function, + register_vector_function, UdfContext, # Expressions Expression, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index da46cdcb750d5..f4d6541fa724c 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2815,5 +2815,9 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: function[CallbackUdf] wrapper, const CUdfOptions& options, CFunctionRegistry* registry) + CStatus RegisterVectorFunction(PyObject* function, + function[CallbackUdf] wrapper, const CUdfOptions& options, + CFunctionRegistry* registry) + CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction( const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 435c89f596d48..f7761a9277f0e 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -292,14 +292,14 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { return out; } - Status Resize(KernelContext* ctx, int64_t new_num_groups) { + Status Resize(KernelContext* ctx, int64_t new_num_groups) override { // We only need to change num_groups in resize // similar to other hash aggregate kernels num_groups = new_num_groups; return Status::OK(); } - Status Consume(KernelContext* ctx, const ExecSpan& batch) { + Status Consume(KernelContext* ctx, const ExecSpan& batch) override { ARROW_ASSIGN_OR_RAISE( std::shared_ptr rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); @@ -316,7 +316,7 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { return Status::OK(); } Status Merge(KernelContext* ctx, KernelState&& other_state, - const ArrayData& group_id_mapping) { + const ArrayData& group_id_mapping) override { // This is similar to GroupedListImpl auto& other = checked_cast(other_state); auto& other_values = other.values; @@ -336,7 +336,7 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { return Status::OK(); } - Status Finalize(KernelContext* ctx, Datum* out) { + Status Finalize(KernelContext* ctx, Datum* out) override { // Exclude the last column which is the group id const int num_args = input_schema->num_fields() - 1; @@ -484,24 +484,25 @@ Status PythonUdfExec(compute::KernelContext* ctx, const compute::ExecSpan& batch return SafeCallIntoPython([&]() -> Status { return udf->Exec(ctx, batch, out); }); } -Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init, - UdfWrapperCallback wrapper, const UdfOptions& options, +template +Status RegisterUdf(PyObject* function, compute::KernelInit kernel_init, + UdfWrapperCallback cb, const UdfOptions& options, compute::FunctionRegistry* registry) { - if (!PyCallable_Check(user_function)) { + if (!PyCallable_Check(function)) { return Status::TypeError("Expected a callable Python object."); } - auto scalar_func = std::make_shared( - options.func_name, options.arity, options.func_doc); - Py_INCREF(user_function); + auto scalar_func = + std::make_shared(options.func_name, options.arity, options.func_doc); + Py_INCREF(function); std::vector input_types; for (const auto& in_dtype : options.input_types) { input_types.emplace_back(in_dtype); } compute::OutputType output_type(options.output_type); auto udf_data = std::make_shared( - std::make_shared(user_function), wrapper, + std::make_shared(function), cb, TypeHolder::FromTypes(options.input_types), options.output_type); - compute::ScalarKernel kernel( + Kernel kernel( compute::KernelSignature::Make(std::move(input_types), std::move(output_type), options.arity.is_varargs), PythonUdfExec, kernel_init); @@ -522,9 +523,17 @@ Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init, Status RegisterScalarFunction(PyObject* function, UdfWrapperCallback cb, const UdfOptions& options, compute::FunctionRegistry* registry) { - return RegisterUdf(function, - PythonUdfKernelInit{std::make_shared(function)}, cb, - options, registry); + return RegisterUdf( + function, PythonUdfKernelInit{std::make_shared(function)}, cb, + options, registry); +} + +Status RegisterVectorFunction(PyObject* function, UdfWrapperCallback cb, + const UdfOptions& options, + compute::FunctionRegistry* registry) { + return RegisterUdf( + function, PythonUdfKernelInit{std::make_shared(function)}, cb, + options, registry); } Status RegisterTabularFunction(PyObject* function, UdfWrapperCallback cb, @@ -536,7 +545,7 @@ Status RegisterTabularFunction(PyObject* function, UdfWrapperCallback cb, if (options.output_type->id() != Type::type::STRUCT) { return Status::Invalid("tabular function with non-struct output"); } - return RegisterUdf( + return RegisterUdf( function, PythonTableUdfKernelInit{std::make_shared(function), cb}, cb, options, registry); } diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 682cbb2ffe8d5..d8c4e430e53d4 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -67,6 +67,11 @@ Status ARROW_PYTHON_EXPORT RegisterAggregateFunction( PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); +/// \brief register a Vector user-defined-function from Python +Status ARROW_PYTHON_EXPORT RegisterVectorFunction( + PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options, + compute::FunctionRegistry* registry = NULLPTR); + Result> ARROW_PYTHON_EXPORT CallTabularFunction(const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry = NULLPTR); diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 5631e19455c06..62d1eb5bafd4f 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -299,6 +299,44 @@ def raising_func(ctx): return raising_func, func_name +@pytest.fixture(scope="session") +def unary_vector_func_fixture(): + """ + Reigster a vector function + """ + def pct_rank(ctx, x): + # copy here to get around pandas 1.0 issue + return pa.array(x.to_pandas().copy().rank(pct=True)) + + func_name = "y=pct_rank(x)" + doc = empty_udf_doc + pc.register_vector_function(pct_rank, func_name, doc, { + 'x': pa.float64()}, pa.float64()) + + return pct_rank, func_name + + +@pytest.fixture(scope="session") +def struct_vector_func_fixture(): + """ + Reigster a vector function that returns a struct array + """ + def pivot(ctx, k, v, c): + df = pa.RecordBatch.from_arrays([k, v, c], names=['k', 'v', 'c']).to_pandas() + df_pivot = df.pivot(columns='c', values='v', index='k').reset_index() + return pa.RecordBatch.from_pandas(df_pivot).to_struct_array() + + func_name = "y=pivot(x)" + doc = empty_udf_doc + pc.register_vector_function( + pivot, func_name, doc, + {'k': pa.int64(), 'v': pa.float64(), 'c': pa.utf8()}, + pa.struct([('k', pa.int64()), ('v1', pa.float64()), ('v2', pa.float64())]) + ) + + return pivot, func_name + + def check_scalar_function(func_fixture, inputs, *, run_in_dataset=True, @@ -797,3 +835,35 @@ def test_hash_agg_random(sum_agg_func_fixture): [("value", "sum")]).rename_columns(['id', 'value_sum_udf']) assert result.sort_by('id') == expected.sort_by('id') + + +@pytest.mark.pandas +def test_vector_basic(unary_vector_func_fixture): + arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) + result = pc.call_function("y=pct_rank(x)", [arr]) + expected = unary_vector_func_fixture[0](None, arr) + assert result == expected + + +@pytest.mark.pandas +def test_vector_empty(unary_vector_func_fixture): + arr = pa.array([1], pa.float64()) + result = pc.call_function("y=pct_rank(x)", [arr]) + expected = unary_vector_func_fixture[0](None, arr) + assert result == expected + + +@pytest.mark.pandas +def test_vector_struct(struct_vector_func_fixture): + k = pa.array( + [1, 1, 2, 2], pa.int64() + ) + v = pa.array( + [1.0, 2.0, 3.0, 4.0], pa.float64() + ) + c = pa.array( + ['v1', 'v2', 'v1', 'v2'] + ) + result = pc.call_function("y=pivot(x)", [k, v, c]) + expected = struct_vector_func_fixture[0](None, k, v, c) + assert result == expected