Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-878135 is_permanent behavior fix #989

Merged
merged 8 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
- `array_flatten`
- Added support for replicating your local Python environment on Snowflake via `Session.replicate_local_environment`.

### Behavior Changes

- When creating stored procedures, UDFs, UDTFs, UDAFs with parameter `is_permanent=False` will now create temporary objects even when `stage_name` is provided. Earlier `is_permanent=False` with non-None `stage_name` try to create permanent objects at the given stage location.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to call out that the default behavior is changing because the default value of is_permanent is None. So the behavior of a UDF that only has stage_name specified will see their UDF becomes temporary.


## 1.6.1 (2023-08-02)

### New Features
Expand Down
5 changes: 5 additions & 0 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,11 @@ def check_register_args(
raise ValueError(
f"stage_location must be specified for permanent {get_error_message_abbr(object_type)}"
)
else:
if stage_location:
logger.warn(
"is_permanent is False therefore stage_location will be ignored"
)

if parallel < 1 or parallel > 99:
raise ValueError(
Expand Down
5 changes: 4 additions & 1 deletion src/snowflake/snowpark/stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ def register(
api_call_source="StoredProcedureRegistration.register",
source_code_display=source_code_display,
anonymous=kwargs.get("anonymous", False),
is_permanent=is_permanent,
)

def register_from_file(
Expand Down Expand Up @@ -687,6 +688,7 @@ def register_from_file(
api_call_source="StoredProcedureRegistration.register_from_file",
source_code_display=source_code_display,
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

def _do_register_sp(
Expand All @@ -709,6 +711,7 @@ def _do_register_sp(
anonymous: bool = False,
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
) -> StoredProcedure:
Expand Down Expand Up @@ -790,7 +793,7 @@ def _do_register_sp(
object_name=udf_name,
all_imports=all_imports,
all_packages=all_packages,
is_temporary=stage_location is None,
is_temporary=not is_permanent,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it makes sense to rename is_temporary to is_permanent? This way we only need to do negation once instead of in all four classes.

We probably should have made the public param is_temporary at the beginning so we don't need this awkward flip lol. But now all is done 🤷

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could rename is_temporary -> is_permanent. public API already has is_permanent so we shouldn't change that.

replace=replace,
if_not_exists=if_not_exists,
inline_python_code=code,
Expand Down
5 changes: 4 additions & 1 deletion src/snowflake/snowpark/udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def register(
statement_params=statement_params,
source_code_display=source_code_display,
api_call_source="UDAFRegistration.register",
is_permanent=is_permanent,
)

def register_from_file(
Expand Down Expand Up @@ -553,6 +554,7 @@ def register_from_file(
source_code_display=source_code_display,
api_call_source="UDAFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

def _do_register_udaf(
Expand All @@ -572,6 +574,7 @@ def _do_register_udaf(
source_code_display: bool = True,
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
) -> UserDefinedAggregateFunction:
# get the udaf name, return and input types
(udaf_name, _, _, return_type, input_types,) = process_registration_inputs(
Expand Down Expand Up @@ -626,7 +629,7 @@ def _do_register_udaf(
object_name=udaf_name,
all_imports=all_imports,
all_packages=all_packages,
is_temporary=stage_location is None,
is_temporary=not is_permanent,
replace=replace,
if_not_exists=if_not_exists,
inline_python_code=code,
Expand Down
5 changes: 4 additions & 1 deletion src/snowflake/snowpark/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def register(
source_code_display=source_code_display,
api_call_source="UDFRegistration.register"
+ ("[pandas_udf]" if _from_pandas else ""),
is_permanent=is_permanent,
)

def register_from_file(
Expand Down Expand Up @@ -763,6 +764,7 @@ def register_from_file(
source_code_display=source_code_display,
api_call_source="UDFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

def _do_register_udf(
Expand All @@ -788,6 +790,7 @@ def _do_register_udf(
source_code_display: bool = True,
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
) -> UserDefinedFunction:
# get the udf name, return and input types
(
Expand Down Expand Up @@ -854,7 +857,7 @@ def _do_register_udf(
object_name=udf_name,
all_imports=all_imports,
all_packages=all_packages,
is_temporary=stage_location is None,
is_temporary=not is_permanent,
replace=replace,
if_not_exists=if_not_exists,
inline_python_code=code,
Expand Down
5 changes: 4 additions & 1 deletion src/snowflake/snowpark/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ def register(
secrets=secrets,
statement_params=statement_params,
api_call_source="UDTFRegistration.register",
is_permanent=is_permanent,
)

def register_from_file(
Expand Down Expand Up @@ -657,6 +658,7 @@ def register_from_file(
statement_params=statement_params,
api_call_source="UDTFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
)

def _do_register_udtf(
Expand All @@ -679,6 +681,7 @@ def _do_register_udtf(
statement_params: Optional[Dict[str, str]] = None,
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
) -> UserDefinedTableFunction:

if isinstance(output_schema, StructType):
Expand Down Expand Up @@ -760,7 +763,7 @@ def _do_register_udtf(
object_name=udtf_name,
all_imports=all_imports,
all_packages=all_packages,
is_temporary=stage_location is None,
is_temporary=not is_permanent,
replace=replace,
if_not_exists=if_not_exists,
inline_python_code=code,
Expand Down
35 changes: 35 additions & 0 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def test_call_named_stored_procedure(session, temp_schema, db_parameters):
stage_location=unwrap_stage_location_single_quote(
tmp_stage_name_in_temp_schema
),
is_permanent=True,
)
assert new_session.call(full_sp_name, 13, 19) == 13 + 19
# oen result in the temp schema
Expand Down Expand Up @@ -575,6 +576,40 @@ def test_permanent_sp(session, db_parameters):
Utils.drop_stage(session, stage_name)


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_permanent_sp_negative(session, db_parameters, caplog):
stage_name = Utils.random_stage_name()
sp_name = Utils.random_name_for_temp_object(TempObjectType.PROCEDURE)
with Session.builder.configs(db_parameters).create() as new_session:
new_session.sql_simplifier_enabled = session.sql_simplifier_enabled
new_session.add_packages("snowflake-snowpark-python")
try:
with caplog.at_level(logging.WARN):
sproc(
lambda session_, x, y: session_.sql(f"SELECT {x} + {y}").collect()[
0
][0],
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
name=sp_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)
assert (
"is_permanent is False therefore stage_location will be ignored"
in caplog.text
)

with pytest.raises(
SnowparkSQLException, match=f"Unknown function {sp_name}"
):
session.call(sp_name, 1, 2)
assert new_session.call(sp_name, 8, 9) == 17
finally:
new_session._run_query(f"drop function if exists {sp_name}(int, int)")


@pytest.mark.skipif(not is_pandas_available, reason="Requires pandas")
def test_sp_negative(session):
def f(_, x):
Expand Down
35 changes: 35 additions & 0 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def test_call_named_udf(session, temp_schema, db_parameters):
stage_location=unwrap_stage_location_single_quote(
tmp_stage_name_in_temp_schema
),
is_permanent=True,
)
Utils.check_answer(
new_session.sql(f"select {full_udf_name}(13, 19)").collect(), [Row(13 + 19)]
Expand Down Expand Up @@ -967,6 +968,40 @@ def test_permanent_udf(session, db_parameters):
Utils.drop_stage(session, stage_name)


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_permanent_udf_negative(session, db_parameters, caplog):
stage_name = Utils.random_stage_name()
udf_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION)
with Session.builder.configs(db_parameters).create() as new_session:
new_session.sql_simplifier_enabled = session.sql_simplifier_enabled
try:
with caplog.at_level(logging.WARN):
udf(
lambda x, y: x + y,
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
name=udf_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)
assert (
"is_permanent is False therefore stage_location will be ignored"
in caplog.text
)

with pytest.raises(
SnowparkSQLException, match=f"Unknown function {udf_name}"
):
session.sql(f"select {udf_name}(8, 9)").collect()

Utils.check_answer(
new_session.sql(f"select {udf_name}(8, 9)").collect(), [Row(17)]
)
finally:
new_session._run_query(f"drop function if exists {udf_name}(int, int)")


def test_udf_negative(session):
def f(x):
return x
Expand Down
Loading