Skip to content

Commit

Permalink
apacheGH-36672: [Python][C++] Add support for vector function UDF (ap…
Browse files Browse the repository at this point in the history
…ache#36673)

### Rationale for this change
In Arrow compute, there are four main types of functions: Scalar, Vector, ScalarAggregate and HashAggregate.

Some of the previous work added support for Scalar, ScalarAggregate(apache#35515) and HashAggregate(apache#36252). I think it makes sense to add support for vector function as well to complete all non-decomposable UDF kernel support.

Internally, we plan to extend Acero to implement a "SegmentVectorNode" which would use this API to invoke vector on a segment by segment basis, which will allow to use constant memory to compute things like "rank the value across all rows per segment using a python UDF".

### What changes are included in this PR?
The change includes is very similar to the support for aggregate function, which includes code to register the vector UDF, and a kernel that invokes the vector UDF on given inputs.

### Are these changes tested?
Yes. Added new test.

### Are there any user-facing changes?
Yes. This adds an user-facing API to register the vector function. 

* Closes: apache#36672

Authored-by: Li Jin <ice.xelloss@gmail.com>
Signed-off-by: Li Jin <ice.xelloss@gmail.com>
  • Loading branch information
icexelloss authored Aug 8, 2023
1 parent 59f30f0 commit b1e85a6
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 17 deletions.
84 changes: 83 additions & 1 deletion python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1964,7 +1964,7 @@ class CumulativeOptions(_CumulativeOptions):
Parameters
----------
start : Scalar, default None
Starting value for the cumulative operation. If none is given,
Starting value for the cumulative operation. If none is given,
a default value depending on the operation and input type is used.
skip_nulls : bool, default False
When false, the first encountered null is propagated.
Expand Down Expand Up @@ -2707,6 +2707,11 @@ cdef get_register_aggregate_function():
reg.register_func = RegisterAggregateFunction
return reg

cdef get_register_vector_function():
cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
reg.register_func = RegisterVectorFunction
return reg


def register_scalar_function(func, function_name, function_doc, in_types, out_type,
func_registry=None):
Expand Down Expand Up @@ -2789,6 +2794,83 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty
out_type, func_registry)


def register_vector_function(func, function_name, function_doc, in_types, out_type,
func_registry=None):
"""
Register a user-defined vector function.
This API is EXPERIMENTAL.
A vector function is a function that executes vector
operations on arrays. Vector function is often used
when compute doesn't fit other more specific types of
functions (e.g., scalar and aggregate).
Parameters
----------
func : callable
A callable implementing the user-defined function.
The first argument is the context argument of type
UdfContext.
Then, it must take arguments equal to the number of
in_types defined. It must return an Array or Scalar
matching the out_type. It must return a Scalar if
all arguments are scalar, else it must return an Array.
To define a varargs function, pass a callable that takes
*args. The last in_type will be the type of all varargs
arguments.
function_name : str
Name of the function. There should only be one function
registered with this name in the function registry.
function_doc : dict
A dictionary object with keys "summary" (str),
and "description" (str).
in_types : Dict[str, DataType]
A dictionary mapping function argument names to
their respective DataType.
The argument names will be used to generate
documentation for the function. The number of
arguments specified here determines the function
arity.
out_type : DataType
Output type of the function.
func_registry : FunctionRegistry
Optional function registry to use instead of the default global one.
Examples
--------
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>>
>>> func_doc = {}
>>> func_doc["summary"] = "percent rank"
>>> func_doc["description"] = "compute percent rank"
>>>
>>> def list_flatten_udf(ctx, x):
... return pc.list_flatten(x)
>>>
>>> func_name = "list_flatten_udf"
>>> in_types = {"array": pa.list_(pa.int64())}
>>> out_type = pa.int64()
>>> pc.register_vector_function(list_flatten_udf, func_name, func_doc,
... in_types, out_type)
>>>
>>> answer = pc.call_function(func_name, [pa.array([[1, 2], [3, 4]])])
>>> answer
<pyarrow.lib.Int64Array object at ...>
[
1,
2,
3,
4
]
"""
return _register_user_defined_function(get_register_vector_function(),
func, function_name, function_doc, in_types,
out_type, func_registry)


def register_aggregate_function(func, function_name, function_doc, in_types, out_type,
func_registry=None):
"""
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
register_scalar_function,
register_tabular_function,
register_aggregate_function,
register_vector_function,
UdfContext,
# Expressions
Expression,
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2815,5 +2815,9 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil:
function[CallbackUdf] wrapper, const CUdfOptions& options,
CFunctionRegistry* registry)

CStatus RegisterVectorFunction(PyObject* function,
function[CallbackUdf] wrapper, const CUdfOptions& options,
CFunctionRegistry* registry)

CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction(
const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry)
41 changes: 25 additions & 16 deletions python/pyarrow/src/arrow/python/udf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,14 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
return out;
}

Status Resize(KernelContext* ctx, int64_t new_num_groups) {
Status Resize(KernelContext* ctx, int64_t new_num_groups) override {
// We only need to change num_groups in resize
// similar to other hash aggregate kernels
num_groups = new_num_groups;
return Status::OK();
}

Status Consume(KernelContext* ctx, const ExecSpan& batch) {
Status Consume(KernelContext* ctx, const ExecSpan& batch) override {
ARROW_ASSIGN_OR_RAISE(
std::shared_ptr<RecordBatch> rb,
batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool()));
Expand All @@ -316,7 +316,7 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
return Status::OK();
}
Status Merge(KernelContext* ctx, KernelState&& other_state,
const ArrayData& group_id_mapping) {
const ArrayData& group_id_mapping) override {
// This is similar to GroupedListImpl
auto& other = checked_cast<PythonUdfHashAggregatorImpl&>(other_state);
auto& other_values = other.values;
Expand All @@ -336,7 +336,7 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
return Status::OK();
}

Status Finalize(KernelContext* ctx, Datum* out) {
Status Finalize(KernelContext* ctx, Datum* out) override {
// Exclude the last column which is the group id
const int num_args = input_schema->num_fields() - 1;

Expand Down Expand Up @@ -484,24 +484,25 @@ Status PythonUdfExec(compute::KernelContext* ctx, const compute::ExecSpan& batch
return SafeCallIntoPython([&]() -> Status { return udf->Exec(ctx, batch, out); });
}

Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init,
UdfWrapperCallback wrapper, const UdfOptions& options,
template <class Function, class Kernel>
Status RegisterUdf(PyObject* function, compute::KernelInit kernel_init,
UdfWrapperCallback cb, const UdfOptions& options,
compute::FunctionRegistry* registry) {
if (!PyCallable_Check(user_function)) {
if (!PyCallable_Check(function)) {
return Status::TypeError("Expected a callable Python object.");
}
auto scalar_func = std::make_shared<compute::ScalarFunction>(
options.func_name, options.arity, options.func_doc);
Py_INCREF(user_function);
auto scalar_func =
std::make_shared<Function>(options.func_name, options.arity, options.func_doc);
Py_INCREF(function);
std::vector<compute::InputType> input_types;
for (const auto& in_dtype : options.input_types) {
input_types.emplace_back(in_dtype);
}
compute::OutputType output_type(options.output_type);
auto udf_data = std::make_shared<PythonUdf>(
std::make_shared<OwnedRefNoGIL>(user_function), wrapper,
std::make_shared<OwnedRefNoGIL>(function), cb,
TypeHolder::FromTypes(options.input_types), options.output_type);
compute::ScalarKernel kernel(
Kernel kernel(
compute::KernelSignature::Make(std::move(input_types), std::move(output_type),
options.arity.is_varargs),
PythonUdfExec, kernel_init);
Expand All @@ -522,9 +523,17 @@ Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init,
Status RegisterScalarFunction(PyObject* function, UdfWrapperCallback cb,
const UdfOptions& options,
compute::FunctionRegistry* registry) {
return RegisterUdf(function,
PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
options, registry);
return RegisterUdf<compute::ScalarFunction, compute::ScalarKernel>(
function, PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
options, registry);
}

Status RegisterVectorFunction(PyObject* function, UdfWrapperCallback cb,
const UdfOptions& options,
compute::FunctionRegistry* registry) {
return RegisterUdf<compute::VectorFunction, compute::VectorKernel>(
function, PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function)}, cb,
options, registry);
}

Status RegisterTabularFunction(PyObject* function, UdfWrapperCallback cb,
Expand All @@ -536,7 +545,7 @@ Status RegisterTabularFunction(PyObject* function, UdfWrapperCallback cb,
if (options.output_type->id() != Type::type::STRUCT) {
return Status::Invalid("tabular function with non-struct output");
}
return RegisterUdf(
return RegisterUdf<compute::ScalarFunction, compute::ScalarKernel>(
function, PythonTableUdfKernelInit{std::make_shared<OwnedRefNoGIL>(function), cb},
cb, options, registry);
}
Expand Down
5 changes: 5 additions & 0 deletions python/pyarrow/src/arrow/python/udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ Status ARROW_PYTHON_EXPORT RegisterAggregateFunction(
PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options,
compute::FunctionRegistry* registry = NULLPTR);

/// \brief register a Vector user-defined-function from Python
Status ARROW_PYTHON_EXPORT RegisterVectorFunction(
PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options,
compute::FunctionRegistry* registry = NULLPTR);

Result<std::shared_ptr<RecordBatchReader>> ARROW_PYTHON_EXPORT
CallTabularFunction(const std::string& func_name, const std::vector<Datum>& args,
compute::FunctionRegistry* registry = NULLPTR);
Expand Down
70 changes: 70 additions & 0 deletions python/pyarrow/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,44 @@ def raising_func(ctx):
return raising_func, func_name


@pytest.fixture(scope="session")
def unary_vector_func_fixture():
"""
Reigster a vector function
"""
def pct_rank(ctx, x):
# copy here to get around pandas 1.0 issue
return pa.array(x.to_pandas().copy().rank(pct=True))

func_name = "y=pct_rank(x)"
doc = empty_udf_doc
pc.register_vector_function(pct_rank, func_name, doc, {
'x': pa.float64()}, pa.float64())

return pct_rank, func_name


@pytest.fixture(scope="session")
def struct_vector_func_fixture():
"""
Reigster a vector function that returns a struct array
"""
def pivot(ctx, k, v, c):
df = pa.RecordBatch.from_arrays([k, v, c], names=['k', 'v', 'c']).to_pandas()
df_pivot = df.pivot(columns='c', values='v', index='k').reset_index()
return pa.RecordBatch.from_pandas(df_pivot).to_struct_array()

func_name = "y=pivot(x)"
doc = empty_udf_doc
pc.register_vector_function(
pivot, func_name, doc,
{'k': pa.int64(), 'v': pa.float64(), 'c': pa.utf8()},
pa.struct([('k', pa.int64()), ('v1', pa.float64()), ('v2', pa.float64())])
)

return pivot, func_name


def check_scalar_function(func_fixture,
inputs, *,
run_in_dataset=True,
Expand Down Expand Up @@ -797,3 +835,35 @@ def test_hash_agg_random(sum_agg_func_fixture):
[("value", "sum")]).rename_columns(['id', 'value_sum_udf'])

assert result.sort_by('id') == expected.sort_by('id')


@pytest.mark.pandas
def test_vector_basic(unary_vector_func_fixture):
arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
result = pc.call_function("y=pct_rank(x)", [arr])
expected = unary_vector_func_fixture[0](None, arr)
assert result == expected


@pytest.mark.pandas
def test_vector_empty(unary_vector_func_fixture):
arr = pa.array([1], pa.float64())
result = pc.call_function("y=pct_rank(x)", [arr])
expected = unary_vector_func_fixture[0](None, arr)
assert result == expected


@pytest.mark.pandas
def test_vector_struct(struct_vector_func_fixture):
k = pa.array(
[1, 1, 2, 2], pa.int64()
)
v = pa.array(
[1.0, 2.0, 3.0, 4.0], pa.float64()
)
c = pa.array(
['v1', 'v2', 'v1', 'v2']
)
result = pc.call_function("y=pivot(x)", [k, v, c])
expected = struct_vector_func_fixture[0](None, k, v, c)
assert result == expected

0 comments on commit b1e85a6

Please sign in to comment.