Skip to content

Commit

Permalink
SNOW-1370447: local testing fix bug in udf returning result containin…
Browse files Browse the repository at this point in the history
…g null value (#1615)
  • Loading branch information
sfc-gh-aling authored May 23, 2024
1 parent ece1310 commit 8cd50ca
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
- Fixed a bug that when processing time format, fractional second part is not handled properly.
- Fixed a bug that caused DecimalType data to have incorrect precision in some cases.
- Fixed a bug where referencing missing table or view raises confusing `IndexError`.
- Fixed a bug that mocked function `to_timestamp_ntz` can not handle None data.
- Fixed a bug that mocked UDFs handles output data of None improperly.

#### Improvements

Expand Down
10 changes: 7 additions & 3 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ def set_local_timezone(cls, tz: Optional[datetime.timezone] = None) -> None:
cls.LOCAL_TZ = tz

@classmethod
def to_local_timezone(cls, d: datetime.datetime) -> datetime.datetime:
def to_local_timezone(
cls, d: Optional[datetime.datetime]
) -> Optional[datetime.datetime]:
"""Converts an input datetime to the local timezone."""
return d.astimezone(tz=cls.LOCAL_TZ)
return d.astimezone(tz=cls.LOCAL_TZ) if d is not None else d

@classmethod
def replace_tz(cls, d: datetime.datetime) -> datetime.datetime:
Expand Down Expand Up @@ -822,7 +824,9 @@ def mock_timestamp_ntz(
result = _to_timestamp(column, fmt, try_cast, enforce_ltz=True)
# Cast to NTZ by removing tz data if present
return ColumnEmulator(
data=[x.replace(tzinfo=None) for x in result],
data=[
try_convert(lambda x: x.replace(tzinfo=None), try_cast, x) for x in result
],
sf_type=ColumnType(
TimestampType(TimestampTimeZone.NTZ), column.sf_type.nullable
),
Expand Down
19 changes: 15 additions & 4 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,21 @@ def cleanup_imports():
child, input_data, analyzer, expr_to_alias
)

res = function_input.apply(lambda row: udf_handler(*row), axis=1)
res.sf_type = ColumnType(exp.datatype, exp.nullable)
res.name = quote_name(
f"{exp.udf_name}({', '.join(input_data.columns)})".upper()
# we do not use pd.apply here because pd.apply will auto infer dtype for the output column
# this will lead to NaN or None information loss, think about the following case of a udf definition:
# def udf(x): return numpy.sqrt(x) if x is not None else None
# calling udf(-1) and udf(None), pd.apply will infer the column dtype to be int which returns NaT for both
# however, we want NaT for the former case and None for the latter case.
# using dtype object + function execution does not have the limitation
# In the future maybe we could call fix_drift_between_column_sf_type_and_dtype in methods like set_sf_type.
# And these code would look like:
# res=input.apply(...)
# res.set_sf_type(ColumnType(exp.datatype, exp.nullable)) # fixes the drift and removes NaT
res = ColumnEmulator(
data=[udf_handler(*row) for _, row in function_input.iterrows()],
sf_type=ColumnType(exp.datatype, exp.nullable),
name=quote_name(f"{exp.udf_name}({', '.join(input_data.columns)})".upper()),
dtype=object,
)

return res
Expand Down
55 changes: 51 additions & 4 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2328,10 +2328,6 @@ def test_numpy_udf(session, func):
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="SNOW-1370447: mock_timestamp_ntz raises error",
)
@pytest.mark.skipif(
not is_pandas_available, reason="pandas required for vectorized UDF"
)
Expand Down Expand Up @@ -2392,6 +2388,57 @@ def func_tz_udf(x: Timestamp[TZ]) -> Timestamp[TZ]:
)


@pytest.mark.skipif(
not is_pandas_available, reason="pandas required for vectorized UDF"
)
def test_udf_return_none(session):
data = [
[
1,
"a",
"a",
],
[
2,
"b",
"b",
],
[None, None, None],
]
schema = StructType(
[
StructField('"int"', IntegerType()),
StructField('"str"', StringType()),
StructField('"var"', VariantType()),
]
)
df = session.create_dataframe(data, schema=schema)

def f(x):
return x if x is not None else None

@udf
def func_int_udf(x: int) -> int:
return f(x)

@udf
def func_str_udf(x: str) -> str:
return f(x)

@udf
def func_var_udf(x: Variant) -> Variant:
return f(x)

Utils.check_answer(
df.select(
func_int_udf('"int"'),
func_str_udf('"str"'),
func_var_udf('"var"'),
),
[Row(1, "a", '"a"'), Row(2, "b", '"b"'), Row(None, None, None)],
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="Vectorized UDTF is not supported in Local Testing",
Expand Down

0 comments on commit 8cd50ca

Please sign in to comment.