Skip to content

Commit

Permalink
SNOW-1360263: [Local Testing] Tests for several issues. (#2410)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jrose authored Oct 10, 2024
1 parent 6555b0c commit 43d37a3
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 13 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@
- Fixed a bug where `pd.to_numeric()` would leave `Timedelta` inputs as `Timedelta` instead of converting them to integers.
- Fixed `loc` set when setting a single row, or multiple rows, of a DataFrame with a Series value.

### Snowpark Local Testing Updates

#### Bug Fixes

- Fixed a bug where nullable columns were annotated wrongly.
- Fixed a bug where the `date_add` and `date_sub` functions failed for `NULL` values.
- Fixed a bug where `equal_null` could fail inside a merge statement.
- Fixed a bug where `row_number` could fail inside a Window function.
- Fixed a bug where updates could fail when the source is the result of a join.


## 1.22.1 (2024-09-11)
This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content.

Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,8 @@ def add_months(scalar, date, duration):


def add_timedelta(unit, date, duration, scalar=1):
if date is None:
return date
return date + datetime.timedelta(**{f"{unit}s": float(duration) * scalar})


Expand Down
13 changes: 10 additions & 3 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,7 @@ def outer_join(base_df):
join_condition = calculate_expression(
source_plan.join_expr, cartesian_product, analyzer, expr_to_alias
)
join_result = cartesian_product[join_condition]
join_result = cartesian_product[join_condition].reset_index(drop=True)
join_result.sf_types = cartesian_product.sf_types

# TODO [GA]: # ERROR_ON_NONDETERMINISTIC_MERGE is by default True, raise error if
Expand Down Expand Up @@ -1715,7 +1715,7 @@ def calculate_expression(
exp.datatype = StringType(len(exp.value))
res = ColumnEmulator(
data=[exp.value for _ in range(len(input_data))],
sf_type=ColumnType(exp.datatype, False),
sf_type=ColumnType(exp.datatype, nullable=exp.value is None),
dtype=object,
)
res.index = input_data.index
Expand Down Expand Up @@ -2005,8 +2005,10 @@ def _match_pattern(row) -> bool:

# Process partition_by clause
if window_spec.partition_spec:
# Remove duplicate keys while maintaining order
keys = list(dict.fromkeys([exp.name for exp in window_spec.partition_spec]))
res = res.groupby(
[exp.name for exp in window_spec.partition_spec],
keys,
sort=False,
as_index=False,
)
Expand All @@ -2033,10 +2035,14 @@ def _match_pattern(row) -> bool:
indexer = EntireWindowIndexer()
rolling = res.rolling(indexer)
windows = [ordered.loc[w.index] for w in rolling]
# rolling can unpredictably change the index of the data
# apply a trivial function to materialize the final index
res_index = list(rolling.count().index)

elif isinstance(window_spec.frame_spec.frame_type, RowFrame):
indexer = RowFrameIndexer(frame_spec=window_spec.frame_spec)
res = res.rolling(indexer)
res_index = list(res.count().index)
windows = [w for w in res]

elif isinstance(window_spec.frame_spec.frame_type, RangeFrame):
Expand Down Expand Up @@ -2066,6 +2072,7 @@ def _match_pattern(row) -> bool:
isinstance(lower, UnboundedPreceding),
isinstance(upper, UnboundedFollowing),
)

# compute window function:
if isinstance(window_function, (FunctionExpression,)):
res_cols = []
Expand Down
13 changes: 12 additions & 1 deletion tests/integ/scala/test_dataframe_writer_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from snowflake.snowpark import Row
from snowflake.snowpark._internal.utils import TempObjectType, parse_table_name
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.functions import col, parse_json
from snowflake.snowpark.functions import col, lit, parse_json
from snowflake.snowpark.mock.exceptions import SnowparkLocalTestingException
from snowflake.snowpark.types import (
DoubleType,
Expand Down Expand Up @@ -89,6 +89,17 @@ def test_write_with_target_column_name_order(session, local_testing_mode):
Utils.drop_table(session, special_table_name)


def test_snow_1668862_repro_save_null_data(session):
table_name = Utils.random_table_name()
test_data = session.create_dataframe([(1,), (2,)], ["A"])
df = test_data.with_column("b", lit(None))
try:
df.write.save_as_table(table_name=table_name, mode="truncate")
assert session.table(table_name).collect() == [Row(1, None), Row(2, None)]
finally:
Utils.drop_table(session, table_name)


def test_write_truncate_with_less_columns(session):
# test truncate mode saving dataframe with fewer columns than the target table but column name in the same order
schema1 = StructType(
Expand Down
26 changes: 26 additions & 0 deletions tests/integ/scala/test_update_delete_merge_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pytest

from snowflake.connector.options import installed_pandas, pandas as pd
from snowflake.snowpark import (
DeleteResult,
MergeResult,
Expand All @@ -21,6 +22,7 @@
from snowflake.snowpark.exceptions import SnowparkTableException
from snowflake.snowpark.functions import (
col,
lit,
max as max_,
mean,
min as min_,
Expand Down Expand Up @@ -668,3 +670,27 @@ def test_merge_multi_operation(session):
],
)
assert target.sort(col("id")).collect() == [Row(1, "a")]


@pytest.mark.skipif(
not installed_pandas,
reason="Test requires pandas.",
)
def test_snow_1694649_repro_merge_with_equal_null(session):
# Force temp table
df1 = session.create_dataframe(pd.DataFrame({"A": [0, 1], "B": ["a", "b"]}))
df2 = session.create_dataframe(pd.DataFrame({"A": [0, 1], "B": ["a", "c"]}))

df1.merge(
source=df2,
join_expr=df1["A"].equal_null(df2["A"]),
clauses=[
when_matched(
~(df1["A"].equal_null(df2["A"])) & (df1["B"].equal_null(df2["B"]))
).update({"A": lit(3)})
],
)
assert session.table(df1.table_name).order_by("A").collect() == [
Row(0, "a"),
Row(1, "b"),
]
25 changes: 25 additions & 0 deletions tests/integ/scala/test_window_spec_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from datetime import date
from decimal import Decimal

import pytest
Expand Down Expand Up @@ -274,6 +275,30 @@ def test_window_function_should_fail_if_order_by_clause_is_not_specified(session
assert "requires ORDER BY in window specification" in str(ex_info)


def test_snow_1360263_repro(session):
data = [
Row(id=1, row_date=date(2024, 1, 1), value=1),
Row(id=2, row_date=date(2024, 1, 1), value=1),
Row(id=1, row_date=date(2024, 1, 2), value=1),
Row(id=1, row_date=date(2024, 1, 2), value=100),
Row(id=2, row_date=date(2024, 1, 2), value=1),
]

test_data = session.create_dataframe(data)

# partition over id and row_date and get the records with the largest values
window = Window.partition_by("id", "row_date").order_by(col("value").desc())
df = test_data.with_column("row_num", row_number().over(window)).where(
col("row_num") == 1
)
assert df.order_by("ID", "ROW_DATE").collect() == [
Row(1, date(2024, 1, 1), 1, 1),
Row(1, date(2024, 1, 2), 100, 1),
Row(2, date(2024, 1, 1), 1, 1),
Row(2, date(2024, 1, 2), 1, 1),
]


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="corr is not yet supported in local testing mode.",
Expand Down
28 changes: 19 additions & 9 deletions tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1816,17 +1816,27 @@ def test_date_operations_negative(session):

def test_date_add_date_sub(session):
df = session.createDataFrame(
[("2019-01-23"), ("2019-06-24"), ("2019-09-20")], ["date"]
[
("2019-01-23"),
("2019-06-24"),
("2019-09-20"),
(None),
],
["date"],
)
df = df.withColumn("date", to_date("date"))
res = df.withColumn("date", date_add("date", 4)).collect()
assert res[0].DATE == datetime.date(2019, 1, 27)
assert res[1].DATE == datetime.date(2019, 6, 28)
assert res[2].DATE == datetime.date(2019, 9, 24)
res = df.withColumn("date", date_sub("date", 4)).collect()
assert res[0].DATE == datetime.date(2019, 1, 19)
assert res[1].DATE == datetime.date(2019, 6, 20)
assert res[2].DATE == datetime.date(2019, 9, 16)
assert df.withColumn("date", date_add("date", 4)).collect() == [
Row(datetime.date(2019, 1, 27)),
Row(datetime.date(2019, 6, 28)),
Row(datetime.date(2019, 9, 24)),
Row(None),
]
assert df.withColumn("date", date_sub("date", 4)).collect() == [
Row(datetime.date(2019, 1, 19)),
Row(datetime.date(2019, 6, 20)),
Row(datetime.date(2019, 9, 16)),
Row(None),
]


@pytest.mark.skipif(
Expand Down

0 comments on commit 43d37a3

Please sign in to comment.