diff --git a/CHANGELOG.md b/CHANGELOG.md index c68e010adf..366e010803 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,17 @@ - Fixed a bug where `pd.to_numeric()` would leave `Timedelta` inputs as `Timedelta` instead of converting them to integers. - Fixed `loc` set when setting a single row, or multiple rows, of a DataFrame with a Series value. +### Snowpark Local Testing Updates + +#### Bug Fixes + +- Fixed a bug where nullable columns were annotated wrongly. +- Fixed a bug where the `date_add` and `date_sub` functions failed for `NULL` values. +- Fixed a bug where `equal_null` could fail inside a merge statement. +- Fixed a bug where `row_number` could fail inside a Window function. +- Fixed a bug where updates could fail when the source is the result of a join. + + ## 1.22.1 (2024-09-11) This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content. diff --git a/src/snowflake/snowpark/mock/_functions.py b/src/snowflake/snowpark/mock/_functions.py index 3842f6fda3..ef469cec91 100644 --- a/src/snowflake/snowpark/mock/_functions.py +++ b/src/snowflake/snowpark/mock/_functions.py @@ -1530,6 +1530,8 @@ def add_months(scalar, date, duration): def add_timedelta(unit, date, duration, scalar=1): + if date is None: + return date return date + datetime.timedelta(**{f"{unit}s": float(duration) * scalar}) diff --git a/src/snowflake/snowpark/mock/_plan.py b/src/snowflake/snowpark/mock/_plan.py index 5e8ae3f07a..d1fc060fb0 100644 --- a/src/snowflake/snowpark/mock/_plan.py +++ b/src/snowflake/snowpark/mock/_plan.py @@ -1293,7 +1293,7 @@ def outer_join(base_df): join_condition = calculate_expression( source_plan.join_expr, cartesian_product, analyzer, expr_to_alias ) - join_result = cartesian_product[join_condition] + join_result = cartesian_product[join_condition].reset_index(drop=True) join_result.sf_types = cartesian_product.sf_types # TODO [GA]: # ERROR_ON_NONDETERMINISTIC_MERGE is by default True, raise error if @@ -1715,7 +1715,7 @@ def calculate_expression( exp.datatype = StringType(len(exp.value)) res = ColumnEmulator( data=[exp.value for _ in range(len(input_data))], - sf_type=ColumnType(exp.datatype, False), + sf_type=ColumnType(exp.datatype, nullable=exp.value is None), dtype=object, ) res.index = input_data.index @@ -2005,8 +2005,10 @@ def _match_pattern(row) -> bool: # Process partition_by clause if window_spec.partition_spec: + # Remove duplicate keys while maintaining order + keys = list(dict.fromkeys([exp.name for exp in window_spec.partition_spec])) res = res.groupby( - [exp.name for exp in window_spec.partition_spec], + keys, sort=False, as_index=False, ) @@ -2033,10 +2035,14 @@ def _match_pattern(row) -> bool: indexer = EntireWindowIndexer() rolling = res.rolling(indexer) windows = [ordered.loc[w.index] for w in rolling] + # rolling can unpredictably change the index of the data + # apply a trivial function to materialize the final index + res_index = list(rolling.count().index) elif isinstance(window_spec.frame_spec.frame_type, RowFrame): indexer = RowFrameIndexer(frame_spec=window_spec.frame_spec) res = res.rolling(indexer) + res_index = list(res.count().index) windows = [w for w in res] elif isinstance(window_spec.frame_spec.frame_type, RangeFrame): @@ -2066,6 +2072,7 @@ def _match_pattern(row) -> bool: isinstance(lower, UnboundedPreceding), isinstance(upper, UnboundedFollowing), ) + # compute window function: if isinstance(window_function, (FunctionExpression,)): res_cols = [] diff --git a/tests/integ/scala/test_dataframe_writer_suite.py b/tests/integ/scala/test_dataframe_writer_suite.py index 275f32e955..3d6a6ff689 100644 --- a/tests/integ/scala/test_dataframe_writer_suite.py +++ b/tests/integ/scala/test_dataframe_writer_suite.py @@ -11,7 +11,7 @@ from snowflake.snowpark import Row from snowflake.snowpark._internal.utils import TempObjectType, parse_table_name from snowflake.snowpark.exceptions import SnowparkSQLException -from snowflake.snowpark.functions import col, parse_json +from snowflake.snowpark.functions import col, lit, parse_json from snowflake.snowpark.mock.exceptions import SnowparkLocalTestingException from snowflake.snowpark.types import ( DoubleType, @@ -89,6 +89,17 @@ def test_write_with_target_column_name_order(session, local_testing_mode): Utils.drop_table(session, special_table_name) +def test_snow_1668862_repro_save_null_data(session): + table_name = Utils.random_table_name() + test_data = session.create_dataframe([(1,), (2,)], ["A"]) + df = test_data.with_column("b", lit(None)) + try: + df.write.save_as_table(table_name=table_name, mode="truncate") + assert session.table(table_name).collect() == [Row(1, None), Row(2, None)] + finally: + Utils.drop_table(session, table_name) + + def test_write_truncate_with_less_columns(session): # test truncate mode saving dataframe with fewer columns than the target table but column name in the same order schema1 = StructType( diff --git a/tests/integ/scala/test_update_delete_merge_suite.py b/tests/integ/scala/test_update_delete_merge_suite.py index c0afd3356d..41d195a15d 100644 --- a/tests/integ/scala/test_update_delete_merge_suite.py +++ b/tests/integ/scala/test_update_delete_merge_suite.py @@ -8,6 +8,7 @@ import pytest +from snowflake.connector.options import installed_pandas, pandas as pd from snowflake.snowpark import ( DeleteResult, MergeResult, @@ -21,6 +22,7 @@ from snowflake.snowpark.exceptions import SnowparkTableException from snowflake.snowpark.functions import ( col, + lit, max as max_, mean, min as min_, @@ -668,3 +670,27 @@ def test_merge_multi_operation(session): ], ) assert target.sort(col("id")).collect() == [Row(1, "a")] + + +@pytest.mark.skipif( + not installed_pandas, + reason="Test requires pandas.", +) +def test_snow_1694649_repro_merge_with_equal_null(session): + # Force temp table + df1 = session.create_dataframe(pd.DataFrame({"A": [0, 1], "B": ["a", "b"]})) + df2 = session.create_dataframe(pd.DataFrame({"A": [0, 1], "B": ["a", "c"]})) + + df1.merge( + source=df2, + join_expr=df1["A"].equal_null(df2["A"]), + clauses=[ + when_matched( + ~(df1["A"].equal_null(df2["A"])) & (df1["B"].equal_null(df2["B"])) + ).update({"A": lit(3)}) + ], + ) + assert session.table(df1.table_name).order_by("A").collect() == [ + Row(0, "a"), + Row(1, "b"), + ] diff --git a/tests/integ/scala/test_window_spec_suite.py b/tests/integ/scala/test_window_spec_suite.py index 217100cca1..c2bb91360b 100644 --- a/tests/integ/scala/test_window_spec_suite.py +++ b/tests/integ/scala/test_window_spec_suite.py @@ -3,6 +3,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from datetime import date from decimal import Decimal import pytest @@ -274,6 +275,30 @@ def test_window_function_should_fail_if_order_by_clause_is_not_specified(session assert "requires ORDER BY in window specification" in str(ex_info) +def test_snow_1360263_repro(session): + data = [ + Row(id=1, row_date=date(2024, 1, 1), value=1), + Row(id=2, row_date=date(2024, 1, 1), value=1), + Row(id=1, row_date=date(2024, 1, 2), value=1), + Row(id=1, row_date=date(2024, 1, 2), value=100), + Row(id=2, row_date=date(2024, 1, 2), value=1), + ] + + test_data = session.create_dataframe(data) + + # partition over id and row_date and get the records with the largest values + window = Window.partition_by("id", "row_date").order_by(col("value").desc()) + df = test_data.with_column("row_num", row_number().over(window)).where( + col("row_num") == 1 + ) + assert df.order_by("ID", "ROW_DATE").collect() == [ + Row(1, date(2024, 1, 1), 1, 1), + Row(1, date(2024, 1, 2), 100, 1), + Row(2, date(2024, 1, 1), 1, 1), + Row(2, date(2024, 1, 2), 1, 1), + ] + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="corr is not yet supported in local testing mode.", diff --git a/tests/integ/test_function.py b/tests/integ/test_function.py index 070f57d45a..5797d11fbe 100644 --- a/tests/integ/test_function.py +++ b/tests/integ/test_function.py @@ -1816,17 +1816,27 @@ def test_date_operations_negative(session): def test_date_add_date_sub(session): df = session.createDataFrame( - [("2019-01-23"), ("2019-06-24"), ("2019-09-20")], ["date"] + [ + ("2019-01-23"), + ("2019-06-24"), + ("2019-09-20"), + (None), + ], + ["date"], ) df = df.withColumn("date", to_date("date")) - res = df.withColumn("date", date_add("date", 4)).collect() - assert res[0].DATE == datetime.date(2019, 1, 27) - assert res[1].DATE == datetime.date(2019, 6, 28) - assert res[2].DATE == datetime.date(2019, 9, 24) - res = df.withColumn("date", date_sub("date", 4)).collect() - assert res[0].DATE == datetime.date(2019, 1, 19) - assert res[1].DATE == datetime.date(2019, 6, 20) - assert res[2].DATE == datetime.date(2019, 9, 16) + assert df.withColumn("date", date_add("date", 4)).collect() == [ + Row(datetime.date(2019, 1, 27)), + Row(datetime.date(2019, 6, 28)), + Row(datetime.date(2019, 9, 24)), + Row(None), + ] + assert df.withColumn("date", date_sub("date", 4)).collect() == [ + Row(datetime.date(2019, 1, 19)), + Row(datetime.date(2019, 6, 20)), + Row(datetime.date(2019, 9, 16)), + Row(None), + ] @pytest.mark.skipif(