diff --git a/CHANGELOG.md b/CHANGELOG.md index 2afebd9520..113617e20d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Added support for 'Service' domain to `session.lineage.trace` API. - Updated `Session` class to be thread-safe. This allows concurrent dataframe transformations, dataframe actions, UDF and store procedure registration, and concurrent file uploads. +- Added support for `copy_grants` parameter when registering UDxF and stored procedures. #### New Features diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index e5a437925d..76db9b8566 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1236,6 +1236,7 @@ def create_python_udf_or_sp( statement_params: Optional[Dict[str, str]] = None, comment: Optional[str] = None, native_app_params: Optional[Dict[str, Any]] = None, + copy_grants: bool = False, runtime_version: Optional[str] = None, ) -> None: runtime_version = runtime_version or f"{sys.version_info[0]}.{sys.version_info[1]}" @@ -1327,6 +1328,7 @@ def create_python_udf_or_sp( create_query = f""" CREATE{" OR REPLACE " if replace else ""} {"" if is_permanent else "TEMPORARY"} {"SECURE" if secure else ""} {object_type.value.replace("_", " ")} {"IF NOT EXISTS" if if_not_exists else ""} {object_name}({sql_func_args}) +{" COPY GRANTS " if copy_grants else ""} {return_sql} LANGUAGE PYTHON {strict_as_sql} {mutability} diff --git a/src/snowflake/snowpark/mock/_stored_procedure.py b/src/snowflake/snowpark/mock/_stored_procedure.py index 14abec358c..0d45ef018a 100644 --- a/src/snowflake/snowpark/mock/_stored_procedure.py +++ b/src/snowflake/snowpark/mock/_stored_procedure.py @@ -212,6 +212,7 @@ def _do_register_sp( force_inline_code: bool = False, comment: Optional[str] = None, native_app_params: Optional[Dict[str, Any]] = None, + copy_grants: bool = False, ) -> StoredProcedure: if is_permanent: diff --git a/src/snowflake/snowpark/mock/_udf.py b/src/snowflake/snowpark/mock/_udf.py index a7a17d9a03..c0427b528e 100644 --- a/src/snowflake/snowpark/mock/_udf.py +++ b/src/snowflake/snowpark/mock/_udf.py @@ -110,6 +110,7 @@ def _do_register_udf( skip_upload_on_content_match: bool = False, is_permanent: bool = False, native_app_params: Optional[Dict[str, Any]] = None, + copy_grants: bool = False, ) -> UserDefinedFunction: if is_permanent: self._session._conn.log_not_supported_error( diff --git a/src/snowflake/snowpark/stored_procedure.py b/src/snowflake/snowpark/stored_procedure.py index c2fcb02419..70e4fe623e 100644 --- a/src/snowflake/snowpark/stored_procedure.py +++ b/src/snowflake/snowpark/stored_procedure.py @@ -469,6 +469,7 @@ def register( external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, comment: Optional[str] = None, + copy_grants: bool = False, *, statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = True, @@ -544,6 +545,8 @@ def register( retrieve the secrets using secret API. comment: Adds a comment for the created object. See `COMMENT `_ + copy_grants: Specifies to retain the access privileges from the original function when a new function is + created using CREATE OR REPLACE PROCEDURE. See Also: - :func:`~snowflake.snowpark.functions.sproc` @@ -578,6 +581,7 @@ def register( external_access_integrations=external_access_integrations, secrets=secrets, comment=comment, + copy_grants=copy_grants, statement_params=statement_params, execute_as=execute_as, api_call_source="StoredProcedureRegistration.register", @@ -610,6 +614,7 @@ def register_from_file( external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, comment: Optional[str] = None, + copy_grants: bool = False, *, statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = True, @@ -694,6 +699,8 @@ def register_from_file( retrieve the secrets using secret API. comment: Adds a comment for the created object. See `COMMENT `_ + copy_grants: Specifies to retain the access privileges from the original function when a new function is + created using CREATE OR REPLACE PROCEDURE. Note:: The type hints can still be extracted from the source Python file if they @@ -730,6 +737,7 @@ def register_from_file( external_access_integrations=external_access_integrations, secrets=secrets, comment=comment, + copy_grants=copy_grants, statement_params=statement_params, execute_as=execute_as, api_call_source="StoredProcedureRegistration.register_from_file", @@ -764,6 +772,7 @@ def _do_register_sp( force_inline_code: bool = False, comment: Optional[str] = None, native_app_params: Optional[Dict[str, Any]] = None, + copy_grants: bool = False, ) -> StoredProcedure: ( udf_name, @@ -874,6 +883,7 @@ def _do_register_sp( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + copy_grants=copy_grants, runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a stored procedure diff --git a/src/snowflake/snowpark/udaf.py b/src/snowflake/snowpark/udaf.py index 889b5b6291..e4c184e5f7 100644 --- a/src/snowflake/snowpark/udaf.py +++ b/src/snowflake/snowpark/udaf.py @@ -332,6 +332,7 @@ def register( external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, comment: Optional[str] = None, + copy_grants: bool = False, *, statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = True, @@ -411,6 +412,8 @@ def register( retrieve the secrets using secret API. comment: Adds a comment for the created object. See `COMMENT `_ + copy_grants: Specifies to retain the access privileges from the original function when a new function is + created using CREATE OR REPLACE FUNCTION. See Also: - :func:`~snowflake.snowpark.functions.udaf` @@ -455,6 +458,7 @@ def register( secrets=secrets, comment=comment, native_app_params=native_app_params, + copy_grants=copy_grants, ) def register_from_file( @@ -474,6 +478,7 @@ def register_from_file( external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, comment: Optional[str] = None, + copy_grants: bool = False, *, statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = True, @@ -562,6 +567,8 @@ def register_from_file( retrieve the secrets using secret API. comment: Adds a comment for the created object. See `COMMENT `_ + copy_grants: Specifies to retain the access privileges from the original function when a new function is + created using CREATE OR REPLACE FUNCTION. Note:: The type hints can still be extracted from the local source Python file if they @@ -609,6 +616,7 @@ def register_from_file( is_permanent=is_permanent, immutable=immutable, comment=comment, + copy_grants=copy_grants, ) def _do_register_udaf( @@ -634,6 +642,7 @@ def _do_register_udaf( skip_upload_on_content_match: bool = False, is_permanent: bool = False, immutable: bool = False, + copy_grants: bool = False, ) -> UserDefinedAggregateFunction: # get the udaf name, return and input types ( @@ -714,6 +723,7 @@ def _do_register_udaf( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + copy_grants=copy_grants, runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a udaf diff --git a/src/snowflake/snowpark/udf.py b/src/snowflake/snowpark/udf.py index b71a40263f..b52d99c7da 100644 --- a/src/snowflake/snowpark/udf.py +++ b/src/snowflake/snowpark/udf.py @@ -505,6 +505,7 @@ def register( secrets: Optional[Dict[str, str]] = None, immutable: bool = False, comment: Optional[str] = None, + copy_grants: bool = False, *, statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = True, @@ -593,6 +594,9 @@ def register( immutable: Whether the UDF result is deterministic or not for the same input. comment: Adds a comment for the created object. See `COMMENT `_ + copy_grants: Specifies to retain the access privileges from the original function when a new function is created + using CREATE OR REPLACE FUNCTION. + See Also: - :func:`~snowflake.snowpark.functions.udf` - :meth:`register_from_file` @@ -637,6 +641,7 @@ def register( api_call_source="UDFRegistration.register" + ("[pandas_udf]" if _from_pandas else ""), is_permanent=is_permanent, + copy_grants=copy_grants, ) def register_from_file( @@ -659,6 +664,7 @@ def register_from_file( secrets: Optional[Dict[str, str]] = None, immutable: bool = False, comment: Optional[str] = None, + copy_grants: bool = False, *, statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = True, @@ -750,6 +756,8 @@ def register_from_file( immutable: Whether the UDF result is deterministic or not for the same input. comment: Adds a comment for the created object. See `COMMENT `_ + copy_grants: Specifies to retain the access privileges from the original function when a new function is created + using CREATE OR REPLACE FUNCTION. Note:: The type hints can still be extracted from the local source Python file if they @@ -792,6 +800,7 @@ def register_from_file( api_call_source="UDFRegistration.register_from_file", skip_upload_on_content_match=skip_upload_on_content_match, is_permanent=is_permanent, + copy_grants=copy_grants, ) def _do_register_udf( @@ -821,6 +830,7 @@ def _do_register_udf( api_call_source: str, skip_upload_on_content_match: bool = False, is_permanent: bool = False, + copy_grants: bool = False, ) -> UserDefinedFunction: # get the udf name, return and input types ( @@ -909,6 +919,7 @@ def _do_register_udf( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + copy_grants=copy_grants, runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a udf diff --git a/src/snowflake/snowpark/udtf.py b/src/snowflake/snowpark/udtf.py index 856007cdb6..f0d59eae1b 100644 --- a/src/snowflake/snowpark/udtf.py +++ b/src/snowflake/snowpark/udtf.py @@ -553,6 +553,7 @@ def register( immutable: bool = False, max_batch_size: Optional[int] = None, comment: Optional[str] = None, + copy_grants: bool = False, *, statement_params: Optional[Dict[str, str]] = None, **kwargs, @@ -634,6 +635,8 @@ def register( be ignored when registering a non-vectorized UDTF. comment: Adds a comment for the created object. See `COMMENT `_ + copy_grants: Specifies to retain the access privileges from the original function when a new function is created + using CREATE OR REPLACE FUNCTION. See Also: - :func:`~snowflake.snowpark.functions.udtf` @@ -682,6 +685,7 @@ def register( api_call_source="UDTFRegistration.register", is_permanent=is_permanent, native_app_params=native_app_params, + copy_grants=copy_grants, ) def register_from_file( @@ -705,6 +709,7 @@ def register_from_file( secrets: Optional[Dict[str, str]] = None, immutable: bool = False, comment: Optional[str] = None, + copy_grants: bool = False, *, statement_params: Optional[Dict[str, str]] = None, skip_upload_on_content_match: bool = False, @@ -789,6 +794,8 @@ def register_from_file( immutable: Whether the UDTF result is deterministic or not for the same input. comment: Adds a comment for the created object. See `COMMENT `_ + copy_grants: Specifies to retain the access privileges from the original function when a new function is created + using CREATE OR REPLACE FUNCTION. Note:: The type hints can still be extracted from the local source Python file if they @@ -838,6 +845,7 @@ def register_from_file( api_call_source="UDTFRegistration.register_from_file", skip_upload_on_content_match=skip_upload_on_content_match, is_permanent=is_permanent, + copy_grants=copy_grants, ) def _do_register_udtf( @@ -866,6 +874,7 @@ def _do_register_udtf( api_call_source: str, skip_upload_on_content_match: bool = False, is_permanent: bool = False, + copy_grants: bool = False, ) -> UserDefinedTableFunction: if isinstance(output_schema, StructType): @@ -970,6 +979,7 @@ def _do_register_udtf( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + copy_grants=copy_grants, runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a udtf diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index 09e389d0c2..50a94b579a 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -266,6 +266,45 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): assert result is None +@pytest.mark.parametrize("copy_grants", [True, False]) +@pytest.mark.parametrize( + "object_type", + [ + TempObjectType.FUNCTION, + TempObjectType.PROCEDURE, + TempObjectType.TABLE_FUNCTION, + TempObjectType.AGGREGATE_FUNCTION, + ], +) +def test_copy_grant_for_udf_or_sp_registration( + mock_session, mock_server_connection, copy_grants, object_type +): + mock_session._conn = mock_server_connection + mock_session._runtime_version_from_requirement = None + with mock.patch.object(mock_session, "_run_query") as mock_run_query: + create_python_udf_or_sp( + session=mock_session, + func=lambda: None, + return_type=StringType(), + input_args=[], + opt_arg_defaults=[], + handler="", + object_type=object_type, + object_name="", + all_imports="", + all_packages="", + raw_imports=None, + is_permanent=True, + replace=False, + if_not_exists=False, + copy_grants=copy_grants, + ) + if copy_grants: + mock_run_query.assert_called_once() + assert "COPY GRANTS" in mock_run_query.call_args[0][0] + pass + + def test_create_python_udf_or_sp_with_none_session(): mock_callback = mock.MagicMock(return_value=False)