diff --git a/datafusion/tests/test_functions.py b/datafusion/tests/test_functions.py index be2a2f1f..d0514f89 100644 --- a/datafusion/tests/test_functions.py +++ b/datafusion/tests/test_functions.py @@ -25,6 +25,8 @@ from datafusion import functions as f from datafusion import literal +np.seterr(invalid="ignore") + @pytest.fixture def df(): @@ -197,6 +199,68 @@ def test_math_functions(): ) +def test_array_functions(): + data = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]] + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays( + [np.array(data, dtype=object)], names=["arr"] + ) + df = ctx.create_dataframe([[batch]]) + + col = column("arr") + test_items = [ + [ + f.array_append(col, literal(99.0)), + lambda: [np.append(arr, 99.0) for arr in data], + ], + [ + f.array_concat(col, col), + lambda: [np.concatenate([arr, arr]) for arr in data], + ], + [ + f.array_cat(col, col), + lambda: [np.concatenate([arr, arr]) for arr in data], + ], + [ + f.array_dims(col), + lambda: [[len(r)] for r in data], + ], + [ + f.list_dims(col), + lambda: [[len(r)] for r in data], + ], + [ + f.array_element(col, literal(1)), + lambda: [r[0] for r in data], + ], + [ + f.array_extract(col, literal(1)), + lambda: [r[0] for r in data], + ], + [ + f.list_element(col, literal(1)), + lambda: [r[0] for r in data], + ], + [ + f.list_extract(col, literal(1)), + lambda: [r[0] for r in data], + ], + [ + f.array_length(col), + lambda: [len(r) for r in data], + ], + [ + f.list_length(col), + lambda: [len(r) for r in data], + ], + ] + + for stmt, py_expr in test_items: + query_result = df.select(stmt).collect()[0].column(0).tolist() + for a, b in zip(query_result, py_expr()): + np.testing.assert_array_almost_equal(a, b) + + def test_string_functions(df): df = df.select( f.ascii(column("a")), diff --git a/src/functions.rs b/src/functions.rs index d1f3e807..3dc5322a 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -357,6 +357,19 @@ scalar_function!(random, Random); scalar_function!(encode, Encode); scalar_function!(decode, Decode); +// Array Functions +scalar_function!(array_append, ArrayAppend); +scalar_function!(array_concat, ArrayConcat); +scalar_function!(array_cat, ArrayConcat); +scalar_function!(array_dims, ArrayDims); +scalar_function!(list_dims, ArrayDims); +scalar_function!(array_element, ArrayElement); +scalar_function!(array_extract, ArrayElement); +scalar_function!(list_element, ArrayElement); +scalar_function!(list_extract, ArrayElement); +scalar_function!(array_length, ArrayLength); +scalar_function!(list_length, ArrayLength); + aggregate_function!(approx_distinct, ApproxDistinct); aggregate_function!(approx_median, ApproxMedian); aggregate_function!(approx_percentile_cont, ApproxPercentileCont); @@ -546,5 +559,19 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { //Binary String Functions m.add_wrapped(wrap_pyfunction!(encode))?; m.add_wrapped(wrap_pyfunction!(decode))?; + + // Array Functions + m.add_wrapped(wrap_pyfunction!(array_append))?; + m.add_wrapped(wrap_pyfunction!(array_concat))?; + m.add_wrapped(wrap_pyfunction!(array_cat))?; + m.add_wrapped(wrap_pyfunction!(array_dims))?; + m.add_wrapped(wrap_pyfunction!(list_dims))?; + m.add_wrapped(wrap_pyfunction!(array_element))?; + m.add_wrapped(wrap_pyfunction!(array_extract))?; + m.add_wrapped(wrap_pyfunction!(list_element))?; + m.add_wrapped(wrap_pyfunction!(list_extract))?; + m.add_wrapped(wrap_pyfunction!(array_length))?; + m.add_wrapped(wrap_pyfunction!(list_length))?; + Ok(()) }