Skip to content

Commit

Permalink
Add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-stan committed Jul 20, 2023
1 parent 5401f77 commit 24c909d
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 0 deletions.
110 changes: 110 additions & 0 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7003,6 +7003,116 @@ def pandas_udf(
)


def pandas_udtf(
handler: Optional[Callable] = None,
*,
output_schema: Union[StructType, List[str]],
input_types: Optional[List[DataType]] = None,
name: Optional[Union[str, Iterable[str]]] = None,
is_permanent: bool = False,
stage_location: Optional[str] = None,
imports: Optional[List[Union[str, Tuple[str, str]]]] = None,
packages: Optional[List[Union[str, ModuleType]]] = None,
replace: bool = False,
if_not_exists: bool = False,
session: Optional["snowflake.snowpark.session.Session"] = None,
parallel: int = 4,
statement_params: Optional[Dict[str, str]] = None,
strict: bool = False,
secure: bool = False,
) -> Union[UserDefinedTableFunction, functools.partial]:
"""Registers a Python class as a vectorized Python UDTF and returns the UDTF.
The arguments, return value and usage of this function are exactly the same as
:func:`udtf`, but this function can only be used for registering vectorized UDTFs.
See examples in :class:`~snowflake.snowpark.udtf.UDTFRegistration`.
See Also:
- :func:`udtf`
- :meth:`UDTFRegistration.register() <snowflake.snowpark.udf.UDTFRegistration.register>`
Example::
>>> from snowflake.snowpark.types import PandasSeriesType, PandasDataFrameType, IntegerType
>>> class multiply:
... def __init__(self):
... self.multiplier = 10
... def end_partition(self, df):
... df.columns = ['id', 'col1', 'col2']
... df.col1 = df.col1*self.multiplier
... df.col2 = df.col2*self.multiplier
... yield df
>>> multiply_udtf = pandas_udtf(
... multiply,
... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]),
... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])]
... )
>>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"])
>>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show()
-----------------------------
|"ID_" |"COL1_" |"COL2_" |
-----------------------------
|x |30 |359.0 |
|x |90 |205.0 |
-----------------------------
<BLANKLINE>
Example::
>>> @pandas_udtf(output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]), input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])])
... class _multiply:
... def __init__(self):
... self.multiplier = 10
... def end_partition(self, df):
... df.columns = ['id', 'col1', 'col2']
... df.col1 = df.col1*self.multiplier
... df.col2 = df.col2*self.multiplier
... yield df
>>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show()
-----------------------------
|"ID_" |"COL1_" |"COL2_" |
-----------------------------
|x |30 |359.0 |
|x |90 |205.0 |
-----------------------------
<BLANKLINE>
"""
session = session or snowflake.snowpark.session._get_active_session()
if handler is None:
return functools.partial(
session.udtf.register,
output_schema=output_schema,
input_types=input_types,
name=name,
is_permanent=is_permanent,
stage_location=stage_location,
imports=imports,
packages=packages,
replace=replace,
if_not_exists=if_not_exists,
parallel=parallel,
statement_params=statement_params,
strict=strict,
secure=secure,
)
else:
return session.udtf.register(
handler,
output_schema=output_schema,
input_types=input_types,
name=name,
is_permanent=is_permanent,
stage_location=stage_location,
imports=imports,
packages=packages,
replace=replace,
if_not_exists=if_not_exists,
parallel=parallel,
statement_params=statement_params,
strict=strict,
secure=secure,
)


def call_udf(
udf_name: str,
*args: ColumnOrLiteral,
Expand Down
88 changes: 88 additions & 0 deletions src/snowflake/snowpark/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,94 @@ class UDTFRegistration:
- :meth:`~snowflake.snowpark.Session.add_packages`
- :meth:`~snowflake.snowpark.Session.table_function`
- :meth:`~snowflake.snowpark.DataFrame.join_table_function`
Compared to the default row-by-row processing pattern of a normal UDTF, which sometimes is
inefficient, a vectorized UDTF allows vectorized operations on a dataframe, with the input as a
`Pandas DataFrame <https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html>`_. In a
vectorized UDTF, you can operate on a batches of rows by handling Pandas DataFrame or Pandas
Series. You can use :func:`~snowflake.snowpark.functions.udtf`, :meth:`register` or
:func:`~snowflake.snowpark.functions.pandas_udtf` to create a vectorized UDTF by providing
appropriate return and input types. If you would like to use :meth:`register_from_file` to
create a vectorized UDTF, you would need to explicitly mark the handler method as vectorized using
either the decorator `@vectorized(input=pandas.DataFrame)` or setting `<class>.end_partition._sf_vectorized_input = pandas.DataFrame`
Example 11
Creating a vectorized UDTF by specifying a `PandasDataFrameType` as `input_types` and a `PandasDataFrameType` with column names as `output_schema`.
>>> from snowflake.snowpark.types import PandasDataFrameType, IntegerType, StringType, FloatType
>>> class multiply:
... def __init__(self):
... self.multiplier = 10
... def end_partition(self, df):
... df.columns = ['id', 'col1', 'col2']
... df.col1 = df.col1*self.multiplier
... df.col2 = df.col2*self.multiplier
... yield df
>>> multiply_udtf = session.udtf.register(
... multiply,
... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]),
... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])]
... )
>>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"])
>>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show()
-----------------------------
|"ID_" |"COL1_" |"COL2_" |
-----------------------------
|x |30 |359.0 |
|x |90 |205.0 |
-----------------------------
<BLANKLINE>
Example 12
Creating a vectorized UDTF by specifying `PandasDataFrame` with nested types as type hints.
>>> from snowflake.snowpark.types import PandasDataFrame
>>> class multiply:
... def __init__(self):
... self.multiplier = 10
... def end_partition(self, df: PandasDataFrame[str, int, float]) -> PandasDataFrame[str, int, float]:
... df.columns = ['id', 'col1', 'col2']
... df.col1 = df.col1*self.multiplier
... df.col2 = df.col2*self.multiplier
... yield df
>>> multiply_udtf = session.udtf.register(
... multiply,
... output_schema=["id_", "col1_", "col2_"],
... )
>>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"])
>>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show()
-----------------------------
|"ID_" |"COL1_" |"COL2_" |
-----------------------------
|x |30 |359.0 |
|x |90 |205.0 |
-----------------------------
<BLANKLINE>
Example 13
Creating a vectorized UDTF by specifying a `pandas.DataFrame` as type hints and a `StructType` with type information and column names as `output_schema`.
>>> import pandas as pd
>>> from snowflake.snowpark.types import IntegerType, StringType, FloatType, StructType, StructField
>>> class multiply:
... def __init__(self):
... self.multiplier = 10
... def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
... df.columns = ['id', 'col1', 'col2']
... df.col1 = df.col1*self.multiplier
... df.col2 = df.col2*self.multiplier
... yield df
>>> multiply_udtf = session.udtf.register(
... multiply,
... output_schema=StructType([StructField("id_", StringType()), StructField("col1_", IntegerType()), StructField("col2_", FloatType())]),
... input_types=[StringType(), IntegerType(), FloatType()]
... )
>>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"])
>>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show()
-----------------------------
|"ID_" |"COL1_" |"COL2_" |
-----------------------------
|x |30 |359.0 |
|x |90 |205.0 |
-----------------------------
<BLANKLINE>
"""

def __init__(self, session: "snowflake.snowpark.Session") -> None:
Expand Down

0 comments on commit 24c909d

Please sign in to comment.