Skip to content

Commit

Permalink
Allow process to have return type hint of None
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-stan committed Jul 25, 2023
1 parent 2aac082 commit ab43f24
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 12 deletions.
35 changes: 24 additions & 11 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from snowflake.snowpark._internal.analyzer.datatype_mapper import to_sql
from snowflake.snowpark._internal.telemetry import TelemetryField
from snowflake.snowpark._internal.type_utils import (
NoneType,
convert_sp_to_sf_type,
infer_type,
python_type_str_to_object,
Expand Down Expand Up @@ -105,6 +106,16 @@ def get_python_types_dict_for_udaf(
return python_types_dict


def get_python_types_dict_for_udtf(
process: Dict[str, Any], end_partition: Dict[str, Any]
) -> Dict[str, Any]:
# Prefer input types from process and return types from end_partition
python_types_dict = {**end_partition, **process}
if "return" in end_partition:
python_types_dict["return"] = end_partition["return"]
return python_types_dict


def extract_return_type_from_udtf_type_hints(
return_type_hint, output_schema, func_name
) -> Union[StructType, "PandasDataFrameType", None]:
Expand Down Expand Up @@ -169,9 +180,11 @@ def extract_return_type_from_udtf_type_hints(
return PandasDataFrameType(
[]
) # placeholder, indicating the return type is pandas DataFrame
elif return_type_hint is NoneType:
return None
else:
raise ValueError(
f"The return type hint for a UDTF handler must be a collection type or a PandasDataFrame. {return_type_hint} is used."
f"The return type hint for a UDTF handler must be a collection type or None or a PandasDataFrame. {return_type_hint} is used."
)


Expand Down Expand Up @@ -205,18 +218,20 @@ def get_types_from_type_hints(
raise AttributeError(
f"Neither `{TABLE_FUNCTION_PROCESS_METHOD}` nor `{TABLE_FUNCTION_END_PARTITION_METHOD}` is defined for class {func}"
)
python_types_dict = {}
process_types_dict = {}
end_partition_types_dict = {}
# PROCESS and END_PARTITION have the same return type but input types might be different, favor PROCESS's types if both methods are present
if hasattr(func, TABLE_FUNCTION_PROCESS_METHOD):
python_types_dict = get_type_hints(
process_types_dict = get_type_hints(
getattr(func, TABLE_FUNCTION_PROCESS_METHOD)
)
if not python_types_dict and hasattr(
func, TABLE_FUNCTION_END_PARTITION_METHOD
):
python_types_dict = get_type_hints(
if hasattr(func, TABLE_FUNCTION_END_PARTITION_METHOD):
end_partition_types_dict = get_type_hints(
getattr(func, TABLE_FUNCTION_END_PARTITION_METHOD)
)
python_types_dict = get_python_types_dict_for_udtf(
process_types_dict, end_partition_types_dict
)
else:
python_types_dict = get_type_hints(func)
except TypeError:
Expand Down Expand Up @@ -251,10 +266,8 @@ def get_types_from_type_hints(
raise ValueError(
f"Neither {func_name}.{TABLE_FUNCTION_PROCESS_METHOD} or {func_name}.{TABLE_FUNCTION_END_PARTITION_METHOD} could be found from {filename}"
)
python_types_dict = (
process_types_dict
if process_types_dict is not None
else end_partition_types_dict
python_types_dict = get_python_types_dict_for_udtf(
process_types_dict or {}, end_partition_types_dict or {}
)
elif object_type in (TempObjectType.FUNCTION, TempObjectType.PROCEDURE):
python_types_dict = retrieve_func_type_hints_from_source(
Expand Down
38 changes: 38 additions & 0 deletions tests/integ/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,3 +547,41 @@ def end_partition(self, df: pandas.DataFrame) -> pandas.DataFrame:
)

assert_vectorized_udtf_result(session.table(vectorized_udtf_test_table), my_udtf)


@pytest.mark.parametrize("from_file", [True, False])
def test_register_udtf_where_process_returns_None(session, resources_path, from_file):
test_files = TestFiles(resources_path)
schema = [
"int_",
]

if from_file:
my_udtf = session.udtf.register_from_file(
test_files.test_udtf_py_file,
"ProcessReturnsNone",
output_schema=schema,
)
assert isinstance(my_udtf.handler, tuple)
else:

class ProcessReturnsNone:
def process(self, a: int, b: int, c: int) -> None:
pass

def end_partition(self) -> Iterable[Tuple[int]]:
yield (1,)

my_udtf = udtf(
ProcessReturnsNone,
output_schema=schema,
)

df = session.table_function(
my_udtf(
lit(1),
lit(2),
lit(3),
)
)
Utils.check_answer(df, [Row(INT_=1)])
10 changes: 9 additions & 1 deletion tests/resources/test_udtf_dir/test_udtf_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import decimal
from typing import List, Tuple
from typing import Iterable, List, Tuple


class MyUDTFWithTypeHints:
Expand Down Expand Up @@ -54,3 +54,11 @@ class GeneratorUDTF:
def process(self, n):
for i in range(n):
yield (i,)


class ProcessReturnsNone:
def process(self, a: int, b: int, c: int) -> None:
pass

def end_partition(self) -> Iterable[Tuple[int]]:
yield (1,)

0 comments on commit ab43f24

Please sign in to comment.