diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e85bfb12d..f4c14bda3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ - `array_flatten` - Added support for replicating your local Python environment on Snowflake via `Session.replicate_local_environment`. +### Behavior Changes + +- When creating stored procedures, UDFs, UDTFs, UDAFs with parameter `is_permanent=False` will now create temporary objects even when `stage_name` is provided. The default value of `is_permanent` is `False` which is why if this value is not explicitly set to `True` for permanent objects, users will notice a change in behavior. + ## 1.6.1 (2023-08-02) ### New Features diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 95028182c9..449c908a7c 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -341,6 +341,11 @@ def check_register_args( raise ValueError( f"stage_location must be specified for permanent {get_error_message_abbr(object_type)}" ) + else: + if stage_location: + logger.warn( + "is_permanent is False therefore stage_location will be ignored" + ) if parallel < 1 or parallel > 99: raise ValueError( @@ -800,10 +805,11 @@ def resolve_imports_and_packages( statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = False, skip_upload_on_content_match: bool = False, + is_permanent: bool = False, ) -> Tuple[str, str, str, str, str, bool]: upload_stage = ( unwrap_stage_location_single_quote(stage_location) - if stage_location + if stage_location and is_permanent else session.get_session_stage() ) @@ -947,7 +953,7 @@ def create_python_udf_or_sp( object_name: str, all_imports: str, all_packages: str, - is_temporary: bool, + is_permanent: bool, replace: bool, if_not_exists: bool, inline_python_code: Optional[str] = None, @@ -1011,7 +1017,7 @@ def create_python_udf_or_sp( create_query = f""" CREATE{" OR REPLACE " if replace else ""} -{"TEMPORARY" if is_temporary else ""} {"SECURE" if secure else ""} {object_type.value.replace("_", " ")} {"IF NOT EXISTS" if if_not_exists else ""} {object_name}({sql_func_args}) +{"" if is_permanent else "TEMPORARY"} {"SECURE" if secure else ""} {object_type.value.replace("_", " ")} {"IF NOT EXISTS" if if_not_exists else ""} {object_name}({sql_func_args}) {return_sql} LANGUAGE PYTHON {strict_as_sql} RUNTIME_VERSION={runtime_version} @@ -1022,7 +1028,7 @@ def create_python_udf_or_sp( HANDLER='{handler}'{execute_as_sql} {inline_python_code_in_sql} """ - session._run_query(create_query, is_ddl_on_temp_object=is_temporary) + session._run_query(create_query, is_ddl_on_temp_object=not is_permanent) # fire telemetry after _run_query is successful api_call_source = api_call_source or "_internal.create_python_udf_or_sp" diff --git a/src/snowflake/snowpark/stored_procedure.py b/src/snowflake/snowpark/stored_procedure.py index 9c60703a0b..7c27d3ad02 100644 --- a/src/snowflake/snowpark/stored_procedure.py +++ b/src/snowflake/snowpark/stored_procedure.py @@ -548,6 +548,7 @@ def register( api_call_source="StoredProcedureRegistration.register", source_code_display=source_code_display, anonymous=kwargs.get("anonymous", False), + is_permanent=is_permanent, ) def register_from_file( @@ -687,6 +688,7 @@ def register_from_file( api_call_source="StoredProcedureRegistration.register_from_file", source_code_display=source_code_display, skip_upload_on_content_match=skip_upload_on_content_match, + is_permanent=is_permanent, ) def _do_register_sp( @@ -709,6 +711,7 @@ def _do_register_sp( anonymous: bool = False, api_call_source: str, skip_upload_on_content_match: bool = False, + is_permanent: bool = False, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, ) -> StoredProcedure: @@ -756,6 +759,7 @@ def _do_register_sp( statement_params=statement_params, source_code_display=source_code_display, skip_upload_on_content_match=skip_upload_on_content_match, + is_permanent=is_permanent, ) if not custom_python_runtime_version_allowed: @@ -790,7 +794,7 @@ def _do_register_sp( object_name=udf_name, all_imports=all_imports, all_packages=all_packages, - is_temporary=stage_location is None, + is_permanent=is_permanent, replace=replace, if_not_exists=if_not_exists, inline_python_code=code, diff --git a/src/snowflake/snowpark/udaf.py b/src/snowflake/snowpark/udaf.py index e47dc04d3d..a59dff93f5 100644 --- a/src/snowflake/snowpark/udaf.py +++ b/src/snowflake/snowpark/udaf.py @@ -424,6 +424,7 @@ def register( statement_params=statement_params, source_code_display=source_code_display, api_call_source="UDAFRegistration.register", + is_permanent=is_permanent, ) def register_from_file( @@ -553,6 +554,7 @@ def register_from_file( source_code_display=source_code_display, api_call_source="UDAFRegistration.register_from_file", skip_upload_on_content_match=skip_upload_on_content_match, + is_permanent=is_permanent, ) def _do_register_udaf( @@ -572,6 +574,7 @@ def _do_register_udaf( source_code_display: bool = True, api_call_source: str, skip_upload_on_content_match: bool = False, + is_permanent: bool = False, ) -> UserDefinedAggregateFunction: # get the udaf name, return and input types (udaf_name, _, _, return_type, input_types,) = process_registration_inputs( @@ -608,6 +611,7 @@ def _do_register_udaf( statement_params=statement_params, source_code_display=source_code_display, skip_upload_on_content_match=skip_upload_on_content_match, + is_permanent=is_permanent, ) if not custom_python_runtime_version_allowed: @@ -626,7 +630,7 @@ def _do_register_udaf( object_name=udaf_name, all_imports=all_imports, all_packages=all_packages, - is_temporary=stage_location is None, + is_permanent=is_permanent, replace=replace, if_not_exists=if_not_exists, inline_python_code=code, diff --git a/src/snowflake/snowpark/udf.py b/src/snowflake/snowpark/udf.py index 7667cb15f0..8ae5395101 100644 --- a/src/snowflake/snowpark/udf.py +++ b/src/snowflake/snowpark/udf.py @@ -619,6 +619,7 @@ def register( source_code_display=source_code_display, api_call_source="UDFRegistration.register" + ("[pandas_udf]" if _from_pandas else ""), + is_permanent=is_permanent, ) def register_from_file( @@ -763,6 +764,7 @@ def register_from_file( source_code_display=source_code_display, api_call_source="UDFRegistration.register_from_file", skip_upload_on_content_match=skip_upload_on_content_match, + is_permanent=is_permanent, ) def _do_register_udf( @@ -788,6 +790,7 @@ def _do_register_udf( source_code_display: bool = True, api_call_source: str, skip_upload_on_content_match: bool = False, + is_permanent: bool = False, ) -> UserDefinedFunction: # get the udf name, return and input types ( @@ -836,6 +839,7 @@ def _do_register_udf( statement_params=statement_params, source_code_display=source_code_display, skip_upload_on_content_match=skip_upload_on_content_match, + is_permanent=is_permanent, ) if not custom_python_runtime_version_allowed: @@ -854,7 +858,7 @@ def _do_register_udf( object_name=udf_name, all_imports=all_imports, all_packages=all_packages, - is_temporary=stage_location is None, + is_permanent=is_permanent, replace=replace, if_not_exists=if_not_exists, inline_python_code=code, diff --git a/src/snowflake/snowpark/udtf.py b/src/snowflake/snowpark/udtf.py index 12368d2d26..b4a7e38d79 100644 --- a/src/snowflake/snowpark/udtf.py +++ b/src/snowflake/snowpark/udtf.py @@ -521,6 +521,7 @@ def register( secrets=secrets, statement_params=statement_params, api_call_source="UDTFRegistration.register", + is_permanent=is_permanent, ) def register_from_file( @@ -657,6 +658,7 @@ def register_from_file( statement_params=statement_params, api_call_source="UDTFRegistration.register_from_file", skip_upload_on_content_match=skip_upload_on_content_match, + is_permanent=is_permanent, ) def _do_register_udtf( @@ -679,6 +681,7 @@ def _do_register_udtf( statement_params: Optional[Dict[str, str]] = None, api_call_source: str, skip_upload_on_content_match: bool = False, + is_permanent: bool = False, ) -> UserDefinedTableFunction: if isinstance(output_schema, StructType): @@ -742,6 +745,7 @@ def _do_register_udtf( is_dataframe_input, statement_params=statement_params, skip_upload_on_content_match=skip_upload_on_content_match, + is_permanent=is_permanent, ) if not custom_python_runtime_version_allowed: @@ -760,7 +764,7 @@ def _do_register_udtf( object_name=udtf_name, all_imports=all_imports, all_packages=all_packages, - is_temporary=stage_location is None, + is_permanent=is_permanent, replace=replace, if_not_exists=if_not_exists, inline_python_code=code, diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index c771cb139e..3523cdd581 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -189,6 +189,7 @@ def test_call_named_stored_procedure(session, temp_schema, db_parameters): stage_location=unwrap_stage_location_single_quote( tmp_stage_name_in_temp_schema ), + is_permanent=True, ) assert new_session.call(full_sp_name, 13, 19) == 13 + 19 # oen result in the temp schema @@ -575,6 +576,40 @@ def test_permanent_sp(session, db_parameters): Utils.drop_stage(session, stage_name) +@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") +def test_permanent_sp_negative(session, db_parameters, caplog): + stage_name = Utils.random_stage_name() + sp_name = Utils.random_name_for_temp_object(TempObjectType.PROCEDURE) + with Session.builder.configs(db_parameters).create() as new_session: + new_session.sql_simplifier_enabled = session.sql_simplifier_enabled + new_session.add_packages("snowflake-snowpark-python") + try: + with caplog.at_level(logging.WARN): + sproc( + lambda session_, x, y: session_.sql(f"SELECT {x} + {y}").collect()[ + 0 + ][0], + return_type=IntegerType(), + input_types=[IntegerType(), IntegerType()], + name=sp_name, + is_permanent=False, + stage_location=stage_name, + session=new_session, + ) + assert ( + "is_permanent is False therefore stage_location will be ignored" + in caplog.text + ) + + with pytest.raises( + SnowparkSQLException, match=f"Unknown function {sp_name}" + ): + session.call(sp_name, 1, 2) + assert new_session.call(sp_name, 8, 9) == 17 + finally: + new_session._run_query(f"drop function if exists {sp_name}(int, int)") + + @pytest.mark.skipif(not is_pandas_available, reason="Requires pandas") def test_sp_negative(session): def f(_, x): diff --git a/tests/integ/test_udaf.py b/tests/integ/test_udaf.py index 7d6a1a36ed..bdcf540d90 100644 --- a/tests/integ/test_udaf.py +++ b/tests/integ/test_udaf.py @@ -4,14 +4,18 @@ import datetime import decimal +import logging from typing import Any, Dict, List import pytest from snowflake.snowpark import Row +from snowflake.snowpark._internal.utils import TempObjectType +from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.functions import udaf +from snowflake.snowpark.session import Session from snowflake.snowpark.types import IntegerType, Variant -from tests.utils import TestFiles, Utils +from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils pytestmark = pytest.mark.udf @@ -407,6 +411,60 @@ def test_register_udaf_from_file_with_type_hints(session, resources_path): Utils.check_answer(df.group_by("a").agg(sum_udaf("b")), [Row(1, 7), Row(2, 11)]) +@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") +def test_permanent_udaf_negative(session, db_parameters, caplog): + stage_name = Utils.random_stage_name() + udaf_name = Utils.random_name_for_temp_object(TempObjectType.AGGREGATE_FUNCTION) + df1 = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b") + + class PythonSumUDAFHandler: + def __init__(self) -> None: + self._sum = 0 + + @property + def aggregate_state(self): + return self._sum + + def accumulate(self, input_value): + self._sum += input_value + + def merge(self, other_sum): + self._sum += other_sum + + def finish(self): + return self._sum + + with Session.builder.configs(db_parameters).create() as new_session: + new_session.sql_simplifier_enabled = session.sql_simplifier_enabled + df2 = new_session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df( + "a", "b" + ) + try: + with caplog.at_level(logging.WARN): + sum_udaf = udaf( + PythonSumUDAFHandler, + return_type=IntegerType(), + input_types=[IntegerType()], + name=udaf_name, + is_permanent=False, + stage_location=stage_name, + session=new_session, + ) + assert ( + "is_permanent is False therefore stage_location will be ignored" + in caplog.text + ) + + with pytest.raises( + SnowparkSQLException, match=f"Unknown function {udaf_name}" + ): + df1.agg(sum_udaf("a")).collect() + + Utils.check_answer(df2.agg(sum_udaf("a")), [Row(6)]) + finally: + new_session._run_query(f"drop function if exists {udaf_name}(int)") + + def test_udaf_negative(session): with pytest.raises(TypeError, match="Invalid handler: expecting a class type"): session.udaf.register(1) diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index 44bec38380..fa2a7244c5 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -206,6 +206,7 @@ def test_call_named_udf(session, temp_schema, db_parameters): stage_location=unwrap_stage_location_single_quote( tmp_stage_name_in_temp_schema ), + is_permanent=True, ) Utils.check_answer( new_session.sql(f"select {full_udf_name}(13, 19)").collect(), [Row(13 + 19)] @@ -967,6 +968,40 @@ def test_permanent_udf(session, db_parameters): Utils.drop_stage(session, stage_name) +@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") +def test_permanent_udf_negative(session, db_parameters, caplog): + stage_name = Utils.random_stage_name() + udf_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION) + with Session.builder.configs(db_parameters).create() as new_session: + new_session.sql_simplifier_enabled = session.sql_simplifier_enabled + try: + with caplog.at_level(logging.WARN): + udf( + lambda x, y: x + y, + return_type=IntegerType(), + input_types=[IntegerType(), IntegerType()], + name=udf_name, + is_permanent=False, + stage_location=stage_name, + session=new_session, + ) + assert ( + "is_permanent is False therefore stage_location will be ignored" + in caplog.text + ) + + with pytest.raises( + SnowparkSQLException, match=f"Unknown function {udf_name}" + ): + session.sql(f"select {udf_name}(8, 9)").collect() + + Utils.check_answer( + new_session.sql(f"select {udf_name}(8, 9)").collect(), [Row(17)] + ) + finally: + new_session._run_query(f"drop function if exists {udf_name}(int, int)") + + def test_udf_negative(session): def f(x): return x diff --git a/tests/integ/test_udtf.py b/tests/integ/test_udtf.py index 57265895b8..8e24799e3c 100644 --- a/tests/integ/test_udtf.py +++ b/tests/integ/test_udtf.py @@ -3,14 +3,17 @@ # import decimal +import logging import sys from typing import Tuple import pytest from snowflake.snowpark import Row, Table +from snowflake.snowpark._internal.utils import TempObjectType from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.functions import lit, udtf +from snowflake.snowpark.session import Session from snowflake.snowpark.types import ( BinaryType, BooleanType, @@ -278,6 +281,46 @@ def process( assert "SECURE" in session.sql(ddl_sql).collect()[0][0] +@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") +def test_permanent_udtf_negative(session, db_parameters, caplog): + stage_name = Utils.random_stage_name() + udtf_name = Utils.random_name_for_temp_object(TempObjectType.TABLE_FUNCTION) + + class UDTFEcho: + def process( + self, + num: int, + ) -> Iterable[Tuple[int]]: + return [(num,)] + + with Session.builder.configs(db_parameters).create() as new_session: + new_session.sql_simplifier_enabled = session.sql_simplifier_enabled + try: + with caplog.at_level(logging.WARN): + echo_udtf = udtf( + UDTFEcho, + output_schema=StructType([StructField("A", IntegerType())]), + input_types=[IntegerType()], + name=udtf_name, + is_permanent=False, + stage_location=stage_name, + session=new_session, + ) + assert ( + "is_permanent is False therefore stage_location will be ignored" + in caplog.text + ) + + with pytest.raises( + SnowparkSQLException, match=f"Unknown table function {udtf_name}" + ): + session.table_function(echo_udtf(lit(1))).collect() + + Utils.check_answer(new_session.table_function(echo_udtf(lit(1))), [Row(1)]) + finally: + new_session._run_query(f"drop function if exists {udtf_name}(int)") + + @pytest.mark.xfail(reason="SNOW-757054 flaky test", strict=False) @pytest.mark.skipif( IS_IN_STORED_PROC, reason="Named temporary udf is not supported in stored proc"