diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 5d8a867db5..a236546aae 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -5,7 +5,6 @@ import datetime import decimal -import importlib import json import logging import os @@ -19,7 +18,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union import cloudpickle -from packaging.requirements import Requirement +import pkg_resources from snowflake.connector import ProgrammingError, SnowflakeConnection from snowflake.connector.options import installed_pandas, pandas @@ -811,7 +810,7 @@ def remove_package(self, package: str) -> None: >>> len(session.get_packages()) 0 """ - package_name = Requirement(package).name.lower().replace("_", "-") + package_name = pkg_resources.Requirement.parse(package).key if package_name in self._packages: self._packages.pop(package_name) else: @@ -892,18 +891,20 @@ def _resolve_packages( package_name = MODULE_NAME_TO_PACKAGE_NAME_MAP.get( package.__name__, package.__name__ ) - package = f"{package_name}=={importlib.metadata.distribution(package_name).version}" + package = f"{package_name}=={pkg_resources.get_distribution(package_name).version}" use_local_version = True else: package = package.strip().lower() if package.startswith("#"): continue use_local_version = False - package_req = Requirement(package) + package_req = pkg_resources.Requirement.parse(package) + # get the standard package name if there is no underscore + # underscores are discouraged in package names, but are still used in Anaconda channel + # pkg_resources.Requirement.parse will convert all underscores to dashes + package_name = ( - package - if not use_local_version and "_" in package - else package_req.name.lower().replace("_", "-") + package if not use_local_version and "_" in package else package_req.key ) package_dict[package] = (package_name, use_local_version, package_req) @@ -924,19 +925,21 @@ def _resolve_packages( unsupported_packages: List[str] = [] for package, package_info in package_dict.items(): package_name, use_local_version, package_req = package_info - package_version_req = package_req.specifier + package_version_req = package_req.specs[0][1] if package_req.specs else None if validate_package: if package_name not in valid_packages or ( package_version_req - and not any( - package_version_req.contains(v) - for v in valid_packages[package_name] - ) + and not any(v in package_req for v in valid_packages[package_name]) ): + version_text = ( + f"(version {package_version_req})" + if package_version_req is not None + else "" + ) if is_in_stored_procedure(): # pragma: no cover raise RuntimeError( - f"Cannot add package {package_req} because it is not available in Snowflake " + f"Cannot add package {package_name}{version_text} because it is not available in Snowflake " f"and it cannot be installed via pip as you are executing this code inside a stored " f"procedure. You can find the directory of these packages and add it via " f"session.add_import(). See details at " @@ -947,7 +950,7 @@ def _resolve_packages( and not self._is_anaconda_terms_acknowledged() ): raise RuntimeError( - f"Cannot add package {package_req} because Anaconda terms must be accepted " + f"Cannot add package {package_name}{version_text} because Anaconda terms must be accepted " "by ORGADMIN to use Anaconda 3rd party packages. Please follow the instructions at " "https://docs.snowflake.com/en/developer-guide/udf/python/udf-python-packages.html#using-third-party-packages-from-anaconda." ) @@ -955,7 +958,7 @@ def _resolve_packages( continue elif not use_local_version: try: - package_client_version = importlib.metadata.distribution( + package_client_version = pkg_resources.get_distribution( package_name ).version if package_client_version not in valid_packages[package_name]: @@ -965,7 +968,7 @@ def _resolve_packages( f"requirement '{package}'. Your UDF might not work when the package version " f"is different between the server and your local environment." ) - except importlib.metadata.PackageNotFoundError: + except pkg_resources.DistributionNotFound: _logger.warning( f"Package '{package_name}' is not installed in the local environment. " f"Your UDF might not work when the package is installed on the server " @@ -987,7 +990,7 @@ def _resolve_packages( else: result_dict[package_name] = package - dependency_packages: Optional[List[Requirement]] = None + dependency_packages: Optional[List[pkg_resources.Requirement]] = None if len(unsupported_packages) != 0: _logger.warning( f"The following packages are not available in Snowflake: {unsupported_packages}. They " @@ -1037,7 +1040,7 @@ def _upload_unsupported_packages( packages: List[str], package_table: str, force_push: bool = True, - ) -> List[Requirement]: + ) -> List[pkg_resources.Requirement]: """ Uploads a list of Pypi packages, which are unavailable in Snowflake, to session stage.