diff --git a/CHANGELOG.md b/CHANGELOG.md index 6aef4f0e67..441673db92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ #### Improvements +- Improved the following new capability for function `snowflake.snowpark.functions.array_remove` it is now possible to use in python. - Disables sql simplification when sort is performed after limit. - Previously, `df.sort().limit()` and `df.limit().sort()` generates the same query with sort in front of limit. Now, `df.limit().sort()` will generate query that reads `df.limit().sort()`. - Improve performance of generated query for `df.limit().sort()`, because limit stops table scanning as soon as the number of records is satisfied. diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index f68dce8cdd..db98cac2c5 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -5377,11 +5377,22 @@ def array_remove(array: ColumnOrName, element: ColumnOrLiteral) -> Column: ------------- + >>> df.select(array_remove(array_remove(df.data, 1), "2").alias("objects")).show() + ------------- + |"OBJECTS" | + ------------- + |[ | + | 3.1 | + |] | + ------------- + + See Also: - `ARRAY `_ for more details on semi-structured arrays. """ a = _to_col_if_str(array, "array_remove") - return builtin("array_remove")(a, element) + e = lit(element).cast("VARIANT") if isinstance(element, str) else element + return builtin("array_remove")(a, e) def array_cat(array1: ColumnOrName, array2: ColumnOrName) -> Column: diff --git a/tests/integ/scala/test_function_suite.py b/tests/integ/scala/test_function_suite.py index bd667f2a8c..665eb9c80e 100644 --- a/tests/integ/scala/test_function_suite.py +++ b/tests/integ/scala/test_function_suite.py @@ -2829,6 +2829,82 @@ def test_array_append(session): reason="array_remove is not yet supported in local testing mode.", ) def test_array_remove(session): + actual = session.createDataFrame([([1, 2, 4, 4, 3],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, 4)) + Utils.check_answer( + actual, + [ + Row("[\n 1,\n 2,\n 3\n]"), + Row("[]"), + ], + ) + + actual = session.createDataFrame([(["a", "b", "c", "a", "a"],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, "a")) + Utils.check_answer( + actual, + [ + Row('[\n "b",\n "c"\n]'), + Row("[]"), + ], + ) + + actual = session.createDataFrame( + [(["apple", "banana", "apple", "orange"],), ([],)], ["data"] + ) + actual = actual.select(array_remove(actual.data, "apple")) + Utils.check_answer( + actual, + [ + Row('[\n "banana",\n "orange"\n]'), + Row("[]"), + ], + ) + + actual = session.createDataFrame([([1, "2", 3.1, 1, 3],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, 1)) + Utils.check_answer( + actual, + [ + Row('[\n "2",\n 3.1,\n 3\n]'), + Row("[]"), + ], + ) + + actual = session.createDataFrame([(["@", ";", "3.1", 1, 5 / 3],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, 1)) + Utils.check_answer( + actual, + [ + Row('[\n "@",\n ";",\n "3.1",\n 1.6666666666666667\n]'), + Row("[]"), + ], + ) + + actual = session.createDataFrame([([-1, -2, -4, -4, -3],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, 1)) + Utils.check_answer( + actual, + [ + Row("[\n -1,\n -2,\n -4,\n -4,\n -3\n]"), + Row("[]"), + ], + ) + + actual = session.createDataFrame([([4.4, 5.5, 1.1],), ([],)], ["data"]) + actual = actual.select(array_remove(actual.data, 5.5)) + Utils.check_answer( + actual, + [ + Row("[\n 4.4,\n 1.1\n]"), + Row("[]"), + ], + ) + + actual = TestData.array1(session).select( + array_remove(array_remove(col("arr1"), lit(1)), lit(8)) + ) + Utils.check_answer( [ Row("[\n 2,\n 3\n]"),