From 62f9dd6d48f4a708936ab1d00f459f1e11bcfb3b Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Tue, 7 May 2024 17:37:21 -0700 Subject: [PATCH] [FEAT] Allow returning of pyarrow arrays from UDFs (#2252) See relevant thread: https://linen.getdaft.io/t/18793130/hey-team-loving-daft-so-far-i-noticed-a-discrepencey-between#14aad685-8634-4670-9bda-b9bb74fc8a92 Co-authored-by: Jay Chia --- daft/udf.py | 9 +++++++++ tests/expressions/test_udf.py | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/daft/udf.py b/daft/udf.py index 987855f2cf..3ba51b8a63 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -16,8 +16,15 @@ except ImportError: _NUMPY_AVAILABLE = False +_PYARROW_AVAILABLE = True +try: + import pyarrow as pa +except ImportError: + _PYARROW_AVAILABLE = False + if TYPE_CHECKING: import numpy as np + import pyarrow as pa UserProvidedPythonFunction = Callable[..., Union[Series, "np.ndarray", list]] @@ -114,6 +121,8 @@ def __call__(self, evaluated_expressions: list[Series]) -> PySeries: return Series.from_pylist(result, name=name, pyobj="allow").cast(self.udf.return_dtype)._series elif _NUMPY_AVAILABLE and isinstance(result, np.ndarray): return Series.from_numpy(result, name=name).cast(self.udf.return_dtype)._series + elif _PYARROW_AVAILABLE and isinstance(result, (pa.Array, pa.ChunkedArray)): + return Series.from_arrow(result, name=name).cast(self.udf.return_dtype)._series else: raise NotImplementedError(f"Return type not supported for UDF: {type(result)}") diff --git a/tests/expressions/test_udf.py b/tests/expressions/test_udf.py index bfbc6bcdf0..0b704eab3f 100644 --- a/tests/expressions/test_udf.py +++ b/tests/expressions/test_udf.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import pyarrow as pa import pytest from daft import col @@ -235,3 +236,14 @@ def add_cols_elementwise(*args, multiplier: float): result = table.eval_expression_list([expr]) assert result.to_pydict() == {"a": [6, 12, 18]} + + +def test_udf_return_pyarrow(): + table = MicroPartition.from_pydict({"a": [1, 2, 3]}) + + @udf(return_dtype=DataType.int64()) + def add_1(data): + return pa.compute.add(data.to_arrow(), 1) + + result = table.eval_expression_list([add_1(col("a"))]) + assert result.to_pydict() == {"a": [2, 3, 4]}