Skip to content

Commit

Permalink
Add array functions (#560)
Browse files Browse the repository at this point in the history
* Add array_has, array_has_all and array_has_any

* Add array_position, array_indexof, list_position and list_indexof

* Add array_to_string, array_join, list_to_string and list_join

* Add array_ndims and list_ndims

* Add array_push_back, list_append and list_push_back

* Add array_prepend, array_push_front, list_prepend and list_push_front

* Add array_pop_back and array_pop_front

* Add array_positions and list_positions

* Add array_remove, list_remove, array_remove_n, list_remove_n, array_remove_all and list_remove_all

* Add array_repeat

* Add array_replace, list_replace, array_replace_n, list_replace_n, array_replace_all, list_replace_all

* Add array_slice and list_slice
  • Loading branch information
ongchi authored Feb 12, 2024
1 parent 476ca22 commit 697ca2c
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 3 deletions.
211 changes: 208 additions & 3 deletions datafusion/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,19 +200,62 @@ def test_math_functions():


def test_array_functions():
data = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]
data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays(
[np.array(data, dtype=object)], names=["arr"]
)
df = ctx.create_dataframe([[batch]])

def py_indexof(arr, v):
try:
return arr.index(v) + 1
except ValueError:
return np.nan

def py_arr_remove(arr, v, n=None):
new_arr = arr[:]
found = 0
while found != n:
try:
new_arr.remove(v)
found += 1
except ValueError:
break

return new_arr

def py_arr_replace(arr, from_, to, n=None):
new_arr = arr[:]
found = 0
while found != n:
try:
idx = new_arr.index(from_)
new_arr[idx] = to
found += 1
except ValueError:
break

return new_arr

col = column("arr")
test_items = [
[
f.array_append(col, literal(99.0)),
lambda: [np.append(arr, 99.0) for arr in data],
],
[
f.array_push_back(col, literal(99.0)),
lambda: [np.append(arr, 99.0) for arr in data],
],
[
f.list_append(col, literal(99.0)),
lambda: [np.append(arr, 99.0) for arr in data],
],
[
f.list_push_back(col, literal(99.0)),
lambda: [np.append(arr, 99.0) for arr in data],
],
[
f.array_concat(col, col),
lambda: [np.concatenate([arr, arr]) for arr in data],
Expand Down Expand Up @@ -253,12 +296,174 @@ def test_array_functions():
f.list_length(col),
lambda: [len(r) for r in data],
],
[
f.array_has(col, literal(1.0)),
lambda: [1.0 in r for r in data],
],
[
f.array_has_all(
col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
),
lambda: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
],
[
f.array_has_any(
col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
),
lambda: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
],
[
f.array_position(col, literal(1.0)),
lambda: [py_indexof(r, 1.0) for r in data],
],
[
f.array_indexof(col, literal(1.0)),
lambda: [py_indexof(r, 1.0) for r in data],
],
[
f.list_position(col, literal(1.0)),
lambda: [py_indexof(r, 1.0) for r in data],
],
[
f.list_indexof(col, literal(1.0)),
lambda: [py_indexof(r, 1.0) for r in data],
],
[
f.array_positions(col, literal(1.0)),
lambda: [
[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data
],
],
[
f.list_positions(col, literal(1.0)),
lambda: [
[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data
],
],
[
f.array_ndims(col),
lambda: [np.array(r).ndim for r in data],
],
[
f.list_ndims(col),
lambda: [np.array(r).ndim for r in data],
],
[
f.array_prepend(literal(99.0), col),
lambda: [np.insert(arr, 0, 99.0) for arr in data],
],
[
f.array_push_front(literal(99.0), col),
lambda: [np.insert(arr, 0, 99.0) for arr in data],
],
[
f.list_prepend(literal(99.0), col),
lambda: [np.insert(arr, 0, 99.0) for arr in data],
],
[
f.list_push_front(literal(99.0), col),
lambda: [np.insert(arr, 0, 99.0) for arr in data],
],
[
f.array_pop_back(col),
lambda: [arr[:-1] for arr in data],
],
[
f.array_pop_front(col),
lambda: [arr[1:] for arr in data],
],
[
f.array_remove(col, literal(3.0)),
lambda: [py_arr_remove(arr, 3.0, 1) for arr in data],
],
[
f.list_remove(col, literal(3.0)),
lambda: [py_arr_remove(arr, 3.0, 1) for arr in data],
],
[
f.array_remove_n(col, literal(3.0), literal(2)),
lambda: [py_arr_remove(arr, 3.0, 2) for arr in data],
],
[
f.list_remove_n(col, literal(3.0), literal(2)),
lambda: [py_arr_remove(arr, 3.0, 2) for arr in data],
],
[
f.array_remove_all(col, literal(3.0)),
lambda: [py_arr_remove(arr, 3.0) for arr in data],
],
[
f.list_remove_all(col, literal(3.0)),
lambda: [py_arr_remove(arr, 3.0) for arr in data],
],
[
f.array_repeat(col, literal(2)),
lambda: [[arr] * 2 for arr in data],
],
[
f.array_replace(col, literal(3.0), literal(4.0)),
lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
],
[
f.list_replace(col, literal(3.0), literal(4.0)),
lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
],
[
f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)),
lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
],
[
f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)),
lambda: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data],
],
[
f.array_replace_all(col, literal(3.0), literal(4.0)),
lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
],
[
f.list_replace_all(col, literal(3.0), literal(4.0)),
lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
],
[
f.array_slice(col, literal(2), literal(4)),
lambda: [arr[1:4] for arr in data],
],
[
f.list_slice(col, literal(-1), literal(2)),
lambda: [arr[-1:2] for arr in data],
],
]

for stmt, py_expr in test_items:
query_result = df.select(stmt).collect()[0].column(0).tolist()
query_result = df.select(stmt).collect()[0].column(0)
for a, b in zip(query_result, py_expr()):
np.testing.assert_array_almost_equal(
np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
)

obj_test_items = [
[
f.array_to_string(col, literal(",")),
lambda: [",".join([str(int(v)) for v in r]) for r in data],
],
[
f.array_join(col, literal(",")),
lambda: [",".join([str(int(v)) for v in r]) for r in data],
],
[
f.list_to_string(col, literal(",")),
lambda: [",".join([str(int(v)) for v in r]) for r in data],
],
[
f.list_join(col, literal(",")),
lambda: [",".join([str(int(v)) for v in r]) for r in data],
],
]

for stmt, py_expr in obj_test_items:
query_result = np.array(df.select(stmt).collect()[0].column(0))
for a, b in zip(query_result, py_expr()):
np.testing.assert_array_almost_equal(a, b)
assert a == b


def test_string_functions(df):
Expand Down
78 changes: 78 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ scalar_function!(decode, Decode);

// Array Functions
scalar_function!(array_append, ArrayAppend);
scalar_function!(array_push_back, ArrayAppend);
scalar_function!(list_append, ArrayAppend);
scalar_function!(list_push_back, ArrayAppend);
scalar_function!(array_concat, ArrayConcat);
scalar_function!(array_cat, ArrayConcat);
scalar_function!(array_dims, ArrayDims);
Expand All @@ -370,6 +373,42 @@ scalar_function!(list_element, ArrayElement);
scalar_function!(list_extract, ArrayElement);
scalar_function!(array_length, ArrayLength);
scalar_function!(list_length, ArrayLength);
scalar_function!(array_has, ArrayHas);
scalar_function!(array_has_all, ArrayHasAll);
scalar_function!(array_has_any, ArrayHasAny);
scalar_function!(array_position, ArrayPosition);
scalar_function!(array_indexof, ArrayPosition);
scalar_function!(list_position, ArrayPosition);
scalar_function!(list_indexof, ArrayPosition);
scalar_function!(array_positions, ArrayPositions);
scalar_function!(list_positions, ArrayPositions);
scalar_function!(array_to_string, ArrayToString);
scalar_function!(array_join, ArrayToString);
scalar_function!(list_to_string, ArrayToString);
scalar_function!(list_join, ArrayToString);
scalar_function!(array_ndims, ArrayNdims);
scalar_function!(list_ndims, ArrayNdims);
scalar_function!(array_prepend, ArrayPrepend);
scalar_function!(array_push_front, ArrayPrepend);
scalar_function!(list_prepend, ArrayPrepend);
scalar_function!(list_push_front, ArrayPrepend);
scalar_function!(array_pop_back, ArrayPopBack);
scalar_function!(array_pop_front, ArrayPopFront);
scalar_function!(array_remove, ArrayRemove);
scalar_function!(list_remove, ArrayRemove);
scalar_function!(array_remove_n, ArrayRemoveN);
scalar_function!(list_remove_n, ArrayRemoveN);
scalar_function!(array_remove_all, ArrayRemoveAll);
scalar_function!(list_remove_all, ArrayRemoveAll);
scalar_function!(array_repeat, ArrayRepeat);
scalar_function!(array_replace, ArrayReplace);
scalar_function!(list_replace, ArrayReplace);
scalar_function!(array_replace_n, ArrayReplaceN);
scalar_function!(list_replace_n, ArrayReplaceN);
scalar_function!(array_replace_all, ArrayReplaceAll);
scalar_function!(list_replace_all, ArrayReplaceAll);
scalar_function!(array_slice, ArraySlice);
scalar_function!(list_slice, ArraySlice);

aggregate_function!(approx_distinct, ApproxDistinct);
aggregate_function!(approx_median, ApproxMedian);
Expand Down Expand Up @@ -563,6 +602,9 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {

// Array Functions
m.add_wrapped(wrap_pyfunction!(array_append))?;
m.add_wrapped(wrap_pyfunction!(array_push_back))?;
m.add_wrapped(wrap_pyfunction!(list_append))?;
m.add_wrapped(wrap_pyfunction!(list_push_back))?;
m.add_wrapped(wrap_pyfunction!(array_concat))?;
m.add_wrapped(wrap_pyfunction!(array_cat))?;
m.add_wrapped(wrap_pyfunction!(array_dims))?;
Expand All @@ -573,6 +615,42 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(list_extract))?;
m.add_wrapped(wrap_pyfunction!(array_length))?;
m.add_wrapped(wrap_pyfunction!(list_length))?;
m.add_wrapped(wrap_pyfunction!(array_has))?;
m.add_wrapped(wrap_pyfunction!(array_has_all))?;
m.add_wrapped(wrap_pyfunction!(array_has_any))?;
m.add_wrapped(wrap_pyfunction!(array_position))?;
m.add_wrapped(wrap_pyfunction!(array_indexof))?;
m.add_wrapped(wrap_pyfunction!(list_position))?;
m.add_wrapped(wrap_pyfunction!(list_indexof))?;
m.add_wrapped(wrap_pyfunction!(array_positions))?;
m.add_wrapped(wrap_pyfunction!(list_positions))?;
m.add_wrapped(wrap_pyfunction!(array_to_string))?;
m.add_wrapped(wrap_pyfunction!(array_join))?;
m.add_wrapped(wrap_pyfunction!(list_to_string))?;
m.add_wrapped(wrap_pyfunction!(list_join))?;
m.add_wrapped(wrap_pyfunction!(array_ndims))?;
m.add_wrapped(wrap_pyfunction!(list_ndims))?;
m.add_wrapped(wrap_pyfunction!(array_prepend))?;
m.add_wrapped(wrap_pyfunction!(array_push_front))?;
m.add_wrapped(wrap_pyfunction!(list_prepend))?;
m.add_wrapped(wrap_pyfunction!(list_push_front))?;
m.add_wrapped(wrap_pyfunction!(array_pop_back))?;
m.add_wrapped(wrap_pyfunction!(array_pop_front))?;
m.add_wrapped(wrap_pyfunction!(array_remove))?;
m.add_wrapped(wrap_pyfunction!(list_remove))?;
m.add_wrapped(wrap_pyfunction!(array_remove_n))?;
m.add_wrapped(wrap_pyfunction!(list_remove_n))?;
m.add_wrapped(wrap_pyfunction!(array_remove_all))?;
m.add_wrapped(wrap_pyfunction!(list_remove_all))?;
m.add_wrapped(wrap_pyfunction!(array_repeat))?;
m.add_wrapped(wrap_pyfunction!(array_replace))?;
m.add_wrapped(wrap_pyfunction!(list_replace))?;
m.add_wrapped(wrap_pyfunction!(array_replace_n))?;
m.add_wrapped(wrap_pyfunction!(list_replace_n))?;
m.add_wrapped(wrap_pyfunction!(array_replace_all))?;
m.add_wrapped(wrap_pyfunction!(list_replace_all))?;
m.add_wrapped(wrap_pyfunction!(array_slice))?;
m.add_wrapped(wrap_pyfunction!(list_slice))?;

Ok(())
}

0 comments on commit 697ca2c

Please sign in to comment.