diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 2c143eaf74..9ca94fd049 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -7003,6 +7003,116 @@ def pandas_udf( ) +def pandas_udtf( + handler: Optional[Callable] = None, + *, + output_schema: Union[StructType, List[str]], + input_types: Optional[List[DataType]] = None, + name: Optional[Union[str, Iterable[str]]] = None, + is_permanent: bool = False, + stage_location: Optional[str] = None, + imports: Optional[List[Union[str, Tuple[str, str]]]] = None, + packages: Optional[List[Union[str, ModuleType]]] = None, + replace: bool = False, + if_not_exists: bool = False, + session: Optional["snowflake.snowpark.session.Session"] = None, + parallel: int = 4, + statement_params: Optional[Dict[str, str]] = None, + strict: bool = False, + secure: bool = False, +) -> Union[UserDefinedTableFunction, functools.partial]: + """Registers a Python class as a vectorized Python UDTF and returns the UDTF. + + The arguments, return value and usage of this function are exactly the same as + :func:`udtf`, but this function can only be used for registering vectorized UDTFs. + See examples in :class:`~snowflake.snowpark.udtf.UDTFRegistration`. + + See Also: + - :func:`udtf` + - :meth:`UDTFRegistration.register() ` + + Example:: + >>> from snowflake.snowpark.types import PandasSeriesType, PandasDataFrameType, IntegerType + >>> class multiply: + ... def __init__(self): + ... self.multiplier = 10 + ... def end_partition(self, df): + ... df.columns = ['id', 'col1', 'col2'] + ... df.col1 = df.col1*self.multiplier + ... df.col2 = df.col2*self.multiplier + ... yield df + >>> multiply_udtf = pandas_udtf( + ... multiply, + ... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]), + ... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])] + ... ) + >>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"]) + >>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show() + ----------------------------- + |"ID_" |"COL1_" |"COL2_" | + ----------------------------- + |x |30 |359.0 | + |x |90 |205.0 | + ----------------------------- + + + Example:: + + >>> @pandas_udtf(output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]), input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])]) + ... class _multiply: + ... def __init__(self): + ... self.multiplier = 10 + ... def end_partition(self, df): + ... df.columns = ['id', 'col1', 'col2'] + ... df.col1 = df.col1*self.multiplier + ... df.col2 = df.col2*self.multiplier + ... yield df + >>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show() + ----------------------------- + |"ID_" |"COL1_" |"COL2_" | + ----------------------------- + |x |30 |359.0 | + |x |90 |205.0 | + ----------------------------- + + """ + session = session or snowflake.snowpark.session._get_active_session() + if handler is None: + return functools.partial( + session.udtf.register, + output_schema=output_schema, + input_types=input_types, + name=name, + is_permanent=is_permanent, + stage_location=stage_location, + imports=imports, + packages=packages, + replace=replace, + if_not_exists=if_not_exists, + parallel=parallel, + statement_params=statement_params, + strict=strict, + secure=secure, + ) + else: + return session.udtf.register( + handler, + output_schema=output_schema, + input_types=input_types, + name=name, + is_permanent=is_permanent, + stage_location=stage_location, + imports=imports, + packages=packages, + replace=replace, + if_not_exists=if_not_exists, + parallel=parallel, + statement_params=statement_params, + strict=strict, + secure=secure, + ) + + def call_udf( udf_name: str, *args: ColumnOrLiteral, diff --git a/src/snowflake/snowpark/udtf.py b/src/snowflake/snowpark/udtf.py index 35266c261f..ff19ab8df5 100644 --- a/src/snowflake/snowpark/udtf.py +++ b/src/snowflake/snowpark/udtf.py @@ -297,6 +297,94 @@ class UDTFRegistration: - :meth:`~snowflake.snowpark.Session.add_packages` - :meth:`~snowflake.snowpark.Session.table_function` - :meth:`~snowflake.snowpark.DataFrame.join_table_function` + + Compared to the default row-by-row processing pattern of a normal UDTF, which sometimes is + inefficient, a vectorized UDTF allows vectorized operations on a dataframe, with the input as a + `Pandas DataFrame `_. In a + vectorized UDTF, you can operate on a batches of rows by handling Pandas DataFrame or Pandas + Series. You can use :func:`~snowflake.snowpark.functions.udtf`, :meth:`register` or + :func:`~snowflake.snowpark.functions.pandas_udtf` to create a vectorized UDTF by providing + appropriate return and input types. If you would like to use :meth:`register_from_file` to + create a vectorized UDTF, you would need to explicitly mark the handler method as vectorized using + either the decorator `@vectorized(input=pandas.DataFrame)` or setting `.end_partition._sf_vectorized_input = pandas.DataFrame` + + Example 11 + Creating a vectorized UDTF by specifying a `PandasDataFrameType` as `input_types` and a `PandasDataFrameType` with column names as `output_schema`. + >>> from snowflake.snowpark.types import PandasDataFrameType, IntegerType, StringType, FloatType + >>> class multiply: + ... def __init__(self): + ... self.multiplier = 10 + ... def end_partition(self, df): + ... df.columns = ['id', 'col1', 'col2'] + ... df.col1 = df.col1*self.multiplier + ... df.col2 = df.col2*self.multiplier + ... yield df + >>> multiply_udtf = session.udtf.register( + ... multiply, + ... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]), + ... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])] + ... ) + >>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"]) + >>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show() + ----------------------------- + |"ID_" |"COL1_" |"COL2_" | + ----------------------------- + |x |30 |359.0 | + |x |90 |205.0 | + ----------------------------- + + + Example 12 + Creating a vectorized UDTF by specifying `PandasDataFrame` with nested types as type hints. + >>> from snowflake.snowpark.types import PandasDataFrame + >>> class multiply: + ... def __init__(self): + ... self.multiplier = 10 + ... def end_partition(self, df: PandasDataFrame[str, int, float]) -> PandasDataFrame[str, int, float]: + ... df.columns = ['id', 'col1', 'col2'] + ... df.col1 = df.col1*self.multiplier + ... df.col2 = df.col2*self.multiplier + ... yield df + >>> multiply_udtf = session.udtf.register( + ... multiply, + ... output_schema=["id_", "col1_", "col2_"], + ... ) + >>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"]) + >>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show() + ----------------------------- + |"ID_" |"COL1_" |"COL2_" | + ----------------------------- + |x |30 |359.0 | + |x |90 |205.0 | + ----------------------------- + + + Example 13 + Creating a vectorized UDTF by specifying a `pandas.DataFrame` as type hints and a `StructType` with type information and column names as `output_schema`. + >>> import pandas as pd + >>> from snowflake.snowpark.types import IntegerType, StringType, FloatType, StructType, StructField + >>> class multiply: + ... def __init__(self): + ... self.multiplier = 10 + ... def end_partition(self, df: pd.DataFrame) -> pd.DataFrame: + ... df.columns = ['id', 'col1', 'col2'] + ... df.col1 = df.col1*self.multiplier + ... df.col2 = df.col2*self.multiplier + ... yield df + >>> multiply_udtf = session.udtf.register( + ... multiply, + ... output_schema=StructType([StructField("id_", StringType()), StructField("col1_", IntegerType()), StructField("col2_", FloatType())]), + ... input_types=[StringType(), IntegerType(), FloatType()] + ... ) + >>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"]) + >>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show() + ----------------------------- + |"ID_" |"COL1_" |"COL2_" | + ----------------------------- + |x |30 |359.0 | + |x |90 |205.0 | + ----------------------------- + """ def __init__(self, session: "snowflake.snowpark.Session") -> None: