Skip to content

Commit

Permalink
SNOW-878135 is_permanent behavior fix (#989)
Browse files Browse the repository at this point in the history
* SNOW-878135: Prioritize is_permanent over non-None stage_name for permanent sprocs/udfs

* changelog updates

* fix test

* address comments

* fix lint

* use is_permanent in helper function; update changelog

* add tests
  • Loading branch information
sfc-gh-aalam authored Aug 8, 2023
1 parent 2830b9a commit 10af52d
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 9 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
- `array_flatten`
- Added support for replicating your local Python environment on Snowflake via `Session.replicate_local_environment`.

### Behavior Changes

- When creating stored procedures, UDFs, UDTFs, UDAFs with parameter `is_permanent=False` will now create temporary objects even when `stage_name` is provided. The default value of `is_permanent` is `False` which is why if this value is not explicitly set to `True` for permanent objects, users will notice a change in behavior.

## 1.6.1 (2023-08-02)

### New Features
Expand Down
14 changes: 10 additions & 4 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,11 @@ def check_register_args(
raise ValueError(
f"stage_location must be specified for permanent {get_error_message_abbr(object_type)}"
)
else:
if stage_location:
logger.warn(
"is_permanent is False therefore stage_location will be ignored"
)

if parallel < 1 or parallel > 99:
raise ValueError(
Expand Down Expand Up @@ -800,10 +805,11 @@ def resolve_imports_and_packages(
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = False,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
) -> Tuple[str, str, str, str, str, bool]:
upload_stage = (
unwrap_stage_location_single_quote(stage_location)
if stage_location
if stage_location and is_permanent
else session.get_session_stage()
)

Expand Down Expand Up @@ -947,7 +953,7 @@ def create_python_udf_or_sp(
object_name: str,
all_imports: str,
all_packages: str,
is_temporary: bool,
is_permanent: bool,
replace: bool,
if_not_exists: bool,
inline_python_code: Optional[str] = None,
Expand Down Expand Up @@ -1011,7 +1017,7 @@ def create_python_udf_or_sp(

create_query = f"""
CREATE{" OR REPLACE " if replace else ""}
{"TEMPORARY" if is_temporary else ""} {"SECURE" if secure else ""} {object_type.value.replace("_", " ")} {"IF NOT EXISTS" if if_not_exists else ""} {object_name}({sql_func_args})
{"" if is_permanent else "TEMPORARY"} {"SECURE" if secure else ""} {object_type.value.replace("_", " ")} {"IF NOT EXISTS" if if_not_exists else ""} {object_name}({sql_func_args})
{return_sql}
LANGUAGE PYTHON {strict_as_sql}
RUNTIME_VERSION={runtime_version}
Expand All @@ -1022,7 +1028,7 @@ def create_python_udf_or_sp(
HANDLER='{handler}'{execute_as_sql}
{inline_python_code_in_sql}
"""
session._run_query(create_query, is_ddl_on_temp_object=is_temporary)
session._run_query(create_query, is_ddl_on_temp_object=not is_permanent)

# fire telemetry after _run_query is successful
api_call_source = api_call_source or "_internal.create_python_udf_or_sp"
Expand Down
6 changes: 5 additions & 1 deletion src/snowflake/snowpark/stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ def register(
api_call_source="StoredProcedureRegistration.register",
source_code_display=source_code_display,
anonymous=kwargs.get("anonymous", False),
is_permanent=is_permanent,
)

def register_from_file(
Expand Down Expand Up @@ -687,6 +688,7 @@ def register_from_file(
api_call_source="StoredProcedureRegistration.register_from_file",
source_code_display=source_code_display,
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

def _do_register_sp(
Expand All @@ -709,6 +711,7 @@ def _do_register_sp(
anonymous: bool = False,
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
) -> StoredProcedure:
Expand Down Expand Up @@ -756,6 +759,7 @@ def _do_register_sp(
statement_params=statement_params,
source_code_display=source_code_display,
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

if not custom_python_runtime_version_allowed:
Expand Down Expand Up @@ -790,7 +794,7 @@ def _do_register_sp(
object_name=udf_name,
all_imports=all_imports,
all_packages=all_packages,
is_temporary=stage_location is None,
is_permanent=is_permanent,
replace=replace,
if_not_exists=if_not_exists,
inline_python_code=code,
Expand Down
6 changes: 5 additions & 1 deletion src/snowflake/snowpark/udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def register(
statement_params=statement_params,
source_code_display=source_code_display,
api_call_source="UDAFRegistration.register",
is_permanent=is_permanent,
)

def register_from_file(
Expand Down Expand Up @@ -553,6 +554,7 @@ def register_from_file(
source_code_display=source_code_display,
api_call_source="UDAFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

def _do_register_udaf(
Expand All @@ -572,6 +574,7 @@ def _do_register_udaf(
source_code_display: bool = True,
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
) -> UserDefinedAggregateFunction:
# get the udaf name, return and input types
(udaf_name, _, _, return_type, input_types,) = process_registration_inputs(
Expand Down Expand Up @@ -608,6 +611,7 @@ def _do_register_udaf(
statement_params=statement_params,
source_code_display=source_code_display,
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

if not custom_python_runtime_version_allowed:
Expand All @@ -626,7 +630,7 @@ def _do_register_udaf(
object_name=udaf_name,
all_imports=all_imports,
all_packages=all_packages,
is_temporary=stage_location is None,
is_permanent=is_permanent,
replace=replace,
if_not_exists=if_not_exists,
inline_python_code=code,
Expand Down
6 changes: 5 additions & 1 deletion src/snowflake/snowpark/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def register(
source_code_display=source_code_display,
api_call_source="UDFRegistration.register"
+ ("[pandas_udf]" if _from_pandas else ""),
is_permanent=is_permanent,
)

def register_from_file(
Expand Down Expand Up @@ -763,6 +764,7 @@ def register_from_file(
source_code_display=source_code_display,
api_call_source="UDFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

def _do_register_udf(
Expand All @@ -788,6 +790,7 @@ def _do_register_udf(
source_code_display: bool = True,
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
) -> UserDefinedFunction:
# get the udf name, return and input types
(
Expand Down Expand Up @@ -836,6 +839,7 @@ def _do_register_udf(
statement_params=statement_params,
source_code_display=source_code_display,
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

if not custom_python_runtime_version_allowed:
Expand All @@ -854,7 +858,7 @@ def _do_register_udf(
object_name=udf_name,
all_imports=all_imports,
all_packages=all_packages,
is_temporary=stage_location is None,
is_permanent=is_permanent,
replace=replace,
if_not_exists=if_not_exists,
inline_python_code=code,
Expand Down
6 changes: 5 additions & 1 deletion src/snowflake/snowpark/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ def register(
secrets=secrets,
statement_params=statement_params,
api_call_source="UDTFRegistration.register",
is_permanent=is_permanent,
)

def register_from_file(
Expand Down Expand Up @@ -657,6 +658,7 @@ def register_from_file(
statement_params=statement_params,
api_call_source="UDTFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

def _do_register_udtf(
Expand All @@ -679,6 +681,7 @@ def _do_register_udtf(
statement_params: Optional[Dict[str, str]] = None,
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
) -> UserDefinedTableFunction:

if isinstance(output_schema, StructType):
Expand Down Expand Up @@ -742,6 +745,7 @@ def _do_register_udtf(
is_dataframe_input,
statement_params=statement_params,
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

if not custom_python_runtime_version_allowed:
Expand All @@ -760,7 +764,7 @@ def _do_register_udtf(
object_name=udtf_name,
all_imports=all_imports,
all_packages=all_packages,
is_temporary=stage_location is None,
is_permanent=is_permanent,
replace=replace,
if_not_exists=if_not_exists,
inline_python_code=code,
Expand Down
35 changes: 35 additions & 0 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def test_call_named_stored_procedure(session, temp_schema, db_parameters):
stage_location=unwrap_stage_location_single_quote(
tmp_stage_name_in_temp_schema
),
is_permanent=True,
)
assert new_session.call(full_sp_name, 13, 19) == 13 + 19
# oen result in the temp schema
Expand Down Expand Up @@ -575,6 +576,40 @@ def test_permanent_sp(session, db_parameters):
Utils.drop_stage(session, stage_name)


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_permanent_sp_negative(session, db_parameters, caplog):
stage_name = Utils.random_stage_name()
sp_name = Utils.random_name_for_temp_object(TempObjectType.PROCEDURE)
with Session.builder.configs(db_parameters).create() as new_session:
new_session.sql_simplifier_enabled = session.sql_simplifier_enabled
new_session.add_packages("snowflake-snowpark-python")
try:
with caplog.at_level(logging.WARN):
sproc(
lambda session_, x, y: session_.sql(f"SELECT {x} + {y}").collect()[
0
][0],
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
name=sp_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)
assert (
"is_permanent is False therefore stage_location will be ignored"
in caplog.text
)

with pytest.raises(
SnowparkSQLException, match=f"Unknown function {sp_name}"
):
session.call(sp_name, 1, 2)
assert new_session.call(sp_name, 8, 9) == 17
finally:
new_session._run_query(f"drop function if exists {sp_name}(int, int)")


@pytest.mark.skipif(not is_pandas_available, reason="Requires pandas")
def test_sp_negative(session):
def f(_, x):
Expand Down
60 changes: 59 additions & 1 deletion tests/integ/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@

import datetime
import decimal
import logging
from typing import Any, Dict, List

import pytest

from snowflake.snowpark import Row
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.functions import udaf
from snowflake.snowpark.session import Session
from snowflake.snowpark.types import IntegerType, Variant
from tests.utils import TestFiles, Utils
from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils

pytestmark = pytest.mark.udf

Expand Down Expand Up @@ -407,6 +411,60 @@ def test_register_udaf_from_file_with_type_hints(session, resources_path):
Utils.check_answer(df.group_by("a").agg(sum_udaf("b")), [Row(1, 7), Row(2, 11)])


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_permanent_udaf_negative(session, db_parameters, caplog):
stage_name = Utils.random_stage_name()
udaf_name = Utils.random_name_for_temp_object(TempObjectType.AGGREGATE_FUNCTION)
df1 = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b")

class PythonSumUDAFHandler:
def __init__(self) -> None:
self._sum = 0

@property
def aggregate_state(self):
return self._sum

def accumulate(self, input_value):
self._sum += input_value

def merge(self, other_sum):
self._sum += other_sum

def finish(self):
return self._sum

with Session.builder.configs(db_parameters).create() as new_session:
new_session.sql_simplifier_enabled = session.sql_simplifier_enabled
df2 = new_session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df(
"a", "b"
)
try:
with caplog.at_level(logging.WARN):
sum_udaf = udaf(
PythonSumUDAFHandler,
return_type=IntegerType(),
input_types=[IntegerType()],
name=udaf_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)
assert (
"is_permanent is False therefore stage_location will be ignored"
in caplog.text
)

with pytest.raises(
SnowparkSQLException, match=f"Unknown function {udaf_name}"
):
df1.agg(sum_udaf("a")).collect()

Utils.check_answer(df2.agg(sum_udaf("a")), [Row(6)])
finally:
new_session._run_query(f"drop function if exists {udaf_name}(int)")


def test_udaf_negative(session):
with pytest.raises(TypeError, match="Invalid handler: expecting a class type"):
session.udaf.register(1)
Expand Down
Loading

0 comments on commit 10af52d

Please sign in to comment.