diff --git a/CHANGELOG.md b/CHANGELOG.md index 04082df9c8..6ea6d01677 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,8 @@ - Fixed a bug that when processing time format, fractional second part is not handled properly. - Fixed a bug that caused DecimalType data to have incorrect precision in some cases. - Fixed a bug where referencing missing table or view raises confusing `IndexError`. +- Fixed a bug that mocked function `to_timestamp_ntz` can not handle None data. +- Fixed a bug that mocked UDFs handles output data of None improperly. #### Improvements diff --git a/src/snowflake/snowpark/mock/_functions.py b/src/snowflake/snowpark/mock/_functions.py index 251d827c19..350c5cbc40 100644 --- a/src/snowflake/snowpark/mock/_functions.py +++ b/src/snowflake/snowpark/mock/_functions.py @@ -80,9 +80,11 @@ def set_local_timezone(cls, tz: Optional[datetime.timezone] = None) -> None: cls.LOCAL_TZ = tz @classmethod - def to_local_timezone(cls, d: datetime.datetime) -> datetime.datetime: + def to_local_timezone( + cls, d: Optional[datetime.datetime] + ) -> Optional[datetime.datetime]: """Converts an input datetime to the local timezone.""" - return d.astimezone(tz=cls.LOCAL_TZ) + return d.astimezone(tz=cls.LOCAL_TZ) if d is not None else d @classmethod def replace_tz(cls, d: datetime.datetime) -> datetime.datetime: @@ -822,7 +824,9 @@ def mock_timestamp_ntz( result = _to_timestamp(column, fmt, try_cast, enforce_ltz=True) # Cast to NTZ by removing tz data if present return ColumnEmulator( - data=[x.replace(tzinfo=None) for x in result], + data=[ + try_convert(lambda x: x.replace(tzinfo=None), try_cast, x) for x in result + ], sf_type=ColumnType( TimestampType(TimestampTimeZone.NTZ), column.sf_type.nullable ), diff --git a/src/snowflake/snowpark/mock/_plan.py b/src/snowflake/snowpark/mock/_plan.py index 6af6e43274..f7ed9b40b1 100644 --- a/src/snowflake/snowpark/mock/_plan.py +++ b/src/snowflake/snowpark/mock/_plan.py @@ -536,10 +536,21 @@ def cleanup_imports(): child, input_data, analyzer, expr_to_alias ) - res = function_input.apply(lambda row: udf_handler(*row), axis=1) - res.sf_type = ColumnType(exp.datatype, exp.nullable) - res.name = quote_name( - f"{exp.udf_name}({', '.join(input_data.columns)})".upper() + # we do not use pd.apply here because pd.apply will auto infer dtype for the output column + # this will lead to NaN or None information loss, think about the following case of a udf definition: + # def udf(x): return numpy.sqrt(x) if x is not None else None + # calling udf(-1) and udf(None), pd.apply will infer the column dtype to be int which returns NaT for both + # however, we want NaT for the former case and None for the latter case. + # using dtype object + function execution does not have the limitation + # In the future maybe we could call fix_drift_between_column_sf_type_and_dtype in methods like set_sf_type. + # And these code would look like: + # res=input.apply(...) + # res.set_sf_type(ColumnType(exp.datatype, exp.nullable)) # fixes the drift and removes NaT + res = ColumnEmulator( + data=[udf_handler(*row) for _, row in function_input.iterrows()], + sf_type=ColumnType(exp.datatype, exp.nullable), + name=quote_name(f"{exp.udf_name}({', '.join(input_data.columns)})".upper()), + dtype=object, ) return res diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index f1c7d5d1b2..38ecc39f2b 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -2328,10 +2328,6 @@ def test_numpy_udf(session, func): ) -@pytest.mark.skipif( - "config.getoption('local_testing_mode', default=False)", - reason="SNOW-1370447: mock_timestamp_ntz raises error", -) @pytest.mark.skipif( not is_pandas_available, reason="pandas required for vectorized UDF" ) @@ -2392,6 +2388,57 @@ def func_tz_udf(x: Timestamp[TZ]) -> Timestamp[TZ]: ) +@pytest.mark.skipif( + not is_pandas_available, reason="pandas required for vectorized UDF" +) +def test_udf_return_none(session): + data = [ + [ + 1, + "a", + "a", + ], + [ + 2, + "b", + "b", + ], + [None, None, None], + ] + schema = StructType( + [ + StructField('"int"', IntegerType()), + StructField('"str"', StringType()), + StructField('"var"', VariantType()), + ] + ) + df = session.create_dataframe(data, schema=schema) + + def f(x): + return x if x is not None else None + + @udf + def func_int_udf(x: int) -> int: + return f(x) + + @udf + def func_str_udf(x: str) -> str: + return f(x) + + @udf + def func_var_udf(x: Variant) -> Variant: + return f(x) + + Utils.check_answer( + df.select( + func_int_udf('"int"'), + func_str_udf('"str"'), + func_var_udf('"var"'), + ), + [Row(1, "a", '"a"'), Row(2, "b", '"b"'), Row(None, None, None)], + ) + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="Vectorized UDTF is not supported in Local Testing",