Skip to content

Commit

Permalink
Fix lnino/array_remove (#2372)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-lninobrijaldo authored Oct 25, 2024
1 parent 8383698 commit aaf2cdc
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#### Improvements

- Improved the following new capability for function `snowflake.snowpark.functions.array_remove` it is now possible to use in python.
- Disables sql simplification when sort is performed after limit.
- Previously, `df.sort().limit()` and `df.limit().sort()` generates the same query with sort in front of limit. Now, `df.limit().sort()` will generate query that reads `df.limit().sort()`.
- Improve performance of generated query for `df.limit().sort()`, because limit stops table scanning as soon as the number of records is satisfied.
Expand Down
13 changes: 12 additions & 1 deletion src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5377,11 +5377,22 @@ def array_remove(array: ColumnOrName, element: ColumnOrLiteral) -> Column:
-------------
<BLANKLINE>
>>> df.select(array_remove(array_remove(df.data, 1), "2").alias("objects")).show()
-------------
|"OBJECTS" |
-------------
|[ |
| 3.1 |
|] |
-------------
<BLANKLINE>
See Also:
- `ARRAY <https://docs.snowflake.com/en/sql-reference/data-types-semistructured#label-data-type-array>`_ for more details on semi-structured arrays.
"""
a = _to_col_if_str(array, "array_remove")
return builtin("array_remove")(a, element)
e = lit(element).cast("VARIANT") if isinstance(element, str) else element
return builtin("array_remove")(a, e)


def array_cat(array1: ColumnOrName, array2: ColumnOrName) -> Column:
Expand Down
76 changes: 76 additions & 0 deletions tests/integ/scala/test_function_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2829,6 +2829,82 @@ def test_array_append(session):
reason="array_remove is not yet supported in local testing mode.",
)
def test_array_remove(session):
actual = session.createDataFrame([([1, 2, 4, 4, 3],), ([],)], ["data"])
actual = actual.select(array_remove(actual.data, 4))
Utils.check_answer(
actual,
[
Row("[\n 1,\n 2,\n 3\n]"),
Row("[]"),
],
)

actual = session.createDataFrame([(["a", "b", "c", "a", "a"],), ([],)], ["data"])
actual = actual.select(array_remove(actual.data, "a"))
Utils.check_answer(
actual,
[
Row('[\n "b",\n "c"\n]'),
Row("[]"),
],
)

actual = session.createDataFrame(
[(["apple", "banana", "apple", "orange"],), ([],)], ["data"]
)
actual = actual.select(array_remove(actual.data, "apple"))
Utils.check_answer(
actual,
[
Row('[\n "banana",\n "orange"\n]'),
Row("[]"),
],
)

actual = session.createDataFrame([([1, "2", 3.1, 1, 3],), ([],)], ["data"])
actual = actual.select(array_remove(actual.data, 1))
Utils.check_answer(
actual,
[
Row('[\n "2",\n 3.1,\n 3\n]'),
Row("[]"),
],
)

actual = session.createDataFrame([(["@", ";", "3.1", 1, 5 / 3],), ([],)], ["data"])
actual = actual.select(array_remove(actual.data, 1))
Utils.check_answer(
actual,
[
Row('[\n "@",\n ";",\n "3.1",\n 1.6666666666666667\n]'),
Row("[]"),
],
)

actual = session.createDataFrame([([-1, -2, -4, -4, -3],), ([],)], ["data"])
actual = actual.select(array_remove(actual.data, 1))
Utils.check_answer(
actual,
[
Row("[\n -1,\n -2,\n -4,\n -4,\n -3\n]"),
Row("[]"),
],
)

actual = session.createDataFrame([([4.4, 5.5, 1.1],), ([],)], ["data"])
actual = actual.select(array_remove(actual.data, 5.5))
Utils.check_answer(
actual,
[
Row("[\n 4.4,\n 1.1\n]"),
Row("[]"),
],
)

actual = TestData.array1(session).select(
array_remove(array_remove(col("arr1"), lit(1)), lit(8))
)

Utils.check_answer(
[
Row("[\n 2,\n 3\n]"),
Expand Down

0 comments on commit aaf2cdc

Please sign in to comment.