Skip to content

Commit

Permalink
SNOW-1662692: Add copy_grants option when registering SP/UDXF (#2393)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Oct 16, 2024
1 parent 8d50130 commit 2e33301
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

- Added support for 'Service' domain to `session.lineage.trace` API.
- Updated `Session` class to be thread-safe. This allows concurrent dataframe transformations, dataframe actions, UDF and store procedure registration, and concurrent file uploads.
- Added support for `copy_grants` parameter when registering UDxF and stored procedures.

#### New Features

Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,7 @@ def create_python_udf_or_sp(
statement_params: Optional[Dict[str, str]] = None,
comment: Optional[str] = None,
native_app_params: Optional[Dict[str, Any]] = None,
copy_grants: bool = False,
runtime_version: Optional[str] = None,
) -> None:
runtime_version = runtime_version or f"{sys.version_info[0]}.{sys.version_info[1]}"
Expand Down Expand Up @@ -1327,6 +1328,7 @@ def create_python_udf_or_sp(
create_query = f"""
CREATE{" OR REPLACE " if replace else ""}
{"" if is_permanent else "TEMPORARY"} {"SECURE" if secure else ""} {object_type.value.replace("_", " ")} {"IF NOT EXISTS" if if_not_exists else ""} {object_name}({sql_func_args})
{" COPY GRANTS " if copy_grants else ""}
{return_sql}
LANGUAGE PYTHON {strict_as_sql}
{mutability}
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/mock/_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _do_register_sp(
force_inline_code: bool = False,
comment: Optional[str] = None,
native_app_params: Optional[Dict[str, Any]] = None,
copy_grants: bool = False,
) -> StoredProcedure:

if is_permanent:
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/mock/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _do_register_udf(
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
native_app_params: Optional[Dict[str, Any]] = None,
copy_grants: bool = False,
) -> UserDefinedFunction:
if is_permanent:
self._session._conn.log_not_supported_error(
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def register(
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
comment: Optional[str] = None,
copy_grants: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
Expand Down Expand Up @@ -544,6 +545,8 @@ def register(
retrieve the secrets using secret API.
comment: Adds a comment for the created object. See
`COMMENT <https://docs.snowflake.com/en/sql-reference/sql/comment>`_
copy_grants: Specifies to retain the access privileges from the original function when a new function is
created using CREATE OR REPLACE PROCEDURE.
See Also:
- :func:`~snowflake.snowpark.functions.sproc`
Expand Down Expand Up @@ -578,6 +581,7 @@ def register(
external_access_integrations=external_access_integrations,
secrets=secrets,
comment=comment,
copy_grants=copy_grants,
statement_params=statement_params,
execute_as=execute_as,
api_call_source="StoredProcedureRegistration.register",
Expand Down Expand Up @@ -610,6 +614,7 @@ def register_from_file(
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
comment: Optional[str] = None,
copy_grants: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
Expand Down Expand Up @@ -694,6 +699,8 @@ def register_from_file(
retrieve the secrets using secret API.
comment: Adds a comment for the created object. See
`COMMENT <https://docs.snowflake.com/en/sql-reference/sql/comment>`_
copy_grants: Specifies to retain the access privileges from the original function when a new function is
created using CREATE OR REPLACE PROCEDURE.
Note::
The type hints can still be extracted from the source Python file if they
Expand Down Expand Up @@ -730,6 +737,7 @@ def register_from_file(
external_access_integrations=external_access_integrations,
secrets=secrets,
comment=comment,
copy_grants=copy_grants,
statement_params=statement_params,
execute_as=execute_as,
api_call_source="StoredProcedureRegistration.register_from_file",
Expand Down Expand Up @@ -764,6 +772,7 @@ def _do_register_sp(
force_inline_code: bool = False,
comment: Optional[str] = None,
native_app_params: Optional[Dict[str, Any]] = None,
copy_grants: bool = False,
) -> StoredProcedure:
(
udf_name,
Expand Down Expand Up @@ -874,6 +883,7 @@ def _do_register_sp(
statement_params=statement_params,
comment=comment,
native_app_params=native_app_params,
copy_grants=copy_grants,
runtime_version=runtime_version_from_requirement,
)
# an exception might happen during registering a stored procedure
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def register(
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
comment: Optional[str] = None,
copy_grants: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
Expand Down Expand Up @@ -411,6 +412,8 @@ def register(
retrieve the secrets using secret API.
comment: Adds a comment for the created object. See
`COMMENT <https://docs.snowflake.com/en/sql-reference/sql/comment>`_
copy_grants: Specifies to retain the access privileges from the original function when a new function is
created using CREATE OR REPLACE FUNCTION.
See Also:
- :func:`~snowflake.snowpark.functions.udaf`
Expand Down Expand Up @@ -455,6 +458,7 @@ def register(
secrets=secrets,
comment=comment,
native_app_params=native_app_params,
copy_grants=copy_grants,
)

def register_from_file(
Expand All @@ -474,6 +478,7 @@ def register_from_file(
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
comment: Optional[str] = None,
copy_grants: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
Expand Down Expand Up @@ -562,6 +567,8 @@ def register_from_file(
retrieve the secrets using secret API.
comment: Adds a comment for the created object. See
`COMMENT <https://docs.snowflake.com/en/sql-reference/sql/comment>`_
copy_grants: Specifies to retain the access privileges from the original function when a new function is
created using CREATE OR REPLACE FUNCTION.
Note::
The type hints can still be extracted from the local source Python file if they
Expand Down Expand Up @@ -609,6 +616,7 @@ def register_from_file(
is_permanent=is_permanent,
immutable=immutable,
comment=comment,
copy_grants=copy_grants,
)

def _do_register_udaf(
Expand All @@ -634,6 +642,7 @@ def _do_register_udaf(
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
immutable: bool = False,
copy_grants: bool = False,
) -> UserDefinedAggregateFunction:
# get the udaf name, return and input types
(
Expand Down Expand Up @@ -714,6 +723,7 @@ def _do_register_udaf(
statement_params=statement_params,
comment=comment,
native_app_params=native_app_params,
copy_grants=copy_grants,
runtime_version=runtime_version_from_requirement,
)
# an exception might happen during registering a udaf
Expand Down
11 changes: 11 additions & 0 deletions src/snowflake/snowpark/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ def register(
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
comment: Optional[str] = None,
copy_grants: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
Expand Down Expand Up @@ -593,6 +594,9 @@ def register(
immutable: Whether the UDF result is deterministic or not for the same input.
comment: Adds a comment for the created object. See
`COMMENT <https://docs.snowflake.com/en/sql-reference/sql/comment>`_
copy_grants: Specifies to retain the access privileges from the original function when a new function is created
using CREATE OR REPLACE FUNCTION.
See Also:
- :func:`~snowflake.snowpark.functions.udf`
- :meth:`register_from_file`
Expand Down Expand Up @@ -637,6 +641,7 @@ def register(
api_call_source="UDFRegistration.register"
+ ("[pandas_udf]" if _from_pandas else ""),
is_permanent=is_permanent,
copy_grants=copy_grants,
)

def register_from_file(
Expand All @@ -659,6 +664,7 @@ def register_from_file(
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
comment: Optional[str] = None,
copy_grants: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
Expand Down Expand Up @@ -750,6 +756,8 @@ def register_from_file(
immutable: Whether the UDF result is deterministic or not for the same input.
comment: Adds a comment for the created object. See
`COMMENT <https://docs.snowflake.com/en/sql-reference/sql/comment>`_
copy_grants: Specifies to retain the access privileges from the original function when a new function is created
using CREATE OR REPLACE FUNCTION.
Note::
The type hints can still be extracted from the local source Python file if they
Expand Down Expand Up @@ -792,6 +800,7 @@ def register_from_file(
api_call_source="UDFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
copy_grants=copy_grants,
)

def _do_register_udf(
Expand Down Expand Up @@ -821,6 +830,7 @@ def _do_register_udf(
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
copy_grants: bool = False,
) -> UserDefinedFunction:
# get the udf name, return and input types
(
Expand Down Expand Up @@ -909,6 +919,7 @@ def _do_register_udf(
statement_params=statement_params,
comment=comment,
native_app_params=native_app_params,
copy_grants=copy_grants,
runtime_version=runtime_version_from_requirement,
)
# an exception might happen during registering a udf
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def register(
immutable: bool = False,
max_batch_size: Optional[int] = None,
comment: Optional[str] = None,
copy_grants: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
**kwargs,
Expand Down Expand Up @@ -634,6 +635,8 @@ def register(
be ignored when registering a non-vectorized UDTF.
comment: Adds a comment for the created object. See
`COMMENT <https://docs.snowflake.com/en/sql-reference/sql/comment>`_
copy_grants: Specifies to retain the access privileges from the original function when a new function is created
using CREATE OR REPLACE FUNCTION.
See Also:
- :func:`~snowflake.snowpark.functions.udtf`
Expand Down Expand Up @@ -682,6 +685,7 @@ def register(
api_call_source="UDTFRegistration.register",
is_permanent=is_permanent,
native_app_params=native_app_params,
copy_grants=copy_grants,
)

def register_from_file(
Expand All @@ -705,6 +709,7 @@ def register_from_file(
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
comment: Optional[str] = None,
copy_grants: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
skip_upload_on_content_match: bool = False,
Expand Down Expand Up @@ -789,6 +794,8 @@ def register_from_file(
immutable: Whether the UDTF result is deterministic or not for the same input.
comment: Adds a comment for the created object. See
`COMMENT <https://docs.snowflake.com/en/sql-reference/sql/comment>`_
copy_grants: Specifies to retain the access privileges from the original function when a new function is created
using CREATE OR REPLACE FUNCTION.
Note::
The type hints can still be extracted from the local source Python file if they
Expand Down Expand Up @@ -838,6 +845,7 @@ def register_from_file(
api_call_source="UDTFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
copy_grants=copy_grants,
)

def _do_register_udtf(
Expand Down Expand Up @@ -866,6 +874,7 @@ def _do_register_udtf(
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
copy_grants: bool = False,
) -> UserDefinedTableFunction:

if isinstance(output_schema, StructType):
Expand Down Expand Up @@ -970,6 +979,7 @@ def _do_register_udtf(
statement_params=statement_params,
comment=comment,
native_app_params=native_app_params,
copy_grants=copy_grants,
runtime_version=runtime_version_from_requirement,
)
# an exception might happen during registering a udtf
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/test_udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,45 @@ def test_add_snowpark_package_to_sproc_packages_to_session():
assert result is None


@pytest.mark.parametrize("copy_grants", [True, False])
@pytest.mark.parametrize(
"object_type",
[
TempObjectType.FUNCTION,
TempObjectType.PROCEDURE,
TempObjectType.TABLE_FUNCTION,
TempObjectType.AGGREGATE_FUNCTION,
],
)
def test_copy_grant_for_udf_or_sp_registration(
mock_session, mock_server_connection, copy_grants, object_type
):
mock_session._conn = mock_server_connection
mock_session._runtime_version_from_requirement = None
with mock.patch.object(mock_session, "_run_query") as mock_run_query:
create_python_udf_or_sp(
session=mock_session,
func=lambda: None,
return_type=StringType(),
input_args=[],
opt_arg_defaults=[],
handler="",
object_type=object_type,
object_name="",
all_imports="",
all_packages="",
raw_imports=None,
is_permanent=True,
replace=False,
if_not_exists=False,
copy_grants=copy_grants,
)
if copy_grants:
mock_run_query.assert_called_once()
assert "COPY GRANTS" in mock_run_query.call_args[0][0]
pass


def test_create_python_udf_or_sp_with_none_session():
mock_callback = mock.MagicMock(return_value=False)

Expand Down

0 comments on commit 2e33301

Please sign in to comment.