Skip to content

Commit

Permalink
Revert PR 930 (#970)
Browse files Browse the repository at this point in the history
* revert PR 930

* fix pkg_resources.get_distribution
  • Loading branch information
sfc-gh-aalam authored Jul 25, 2023
1 parent ce49dec commit 2aac082
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import datetime
import decimal
import importlib
import json
import logging
import os
Expand All @@ -19,7 +18,7 @@
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union

import cloudpickle
from packaging.requirements import Requirement
import pkg_resources

from snowflake.connector import ProgrammingError, SnowflakeConnection
from snowflake.connector.options import installed_pandas, pandas
Expand Down Expand Up @@ -811,7 +810,7 @@ def remove_package(self, package: str) -> None:
>>> len(session.get_packages())
0
"""
package_name = Requirement(package).name.lower().replace("_", "-")
package_name = pkg_resources.Requirement.parse(package).key
if package_name in self._packages:
self._packages.pop(package_name)
else:
Expand Down Expand Up @@ -892,18 +891,20 @@ def _resolve_packages(
package_name = MODULE_NAME_TO_PACKAGE_NAME_MAP.get(
package.__name__, package.__name__
)
package = f"{package_name}=={importlib.metadata.distribution(package_name).version}"
package = f"{package_name}=={pkg_resources.get_distribution(package_name).version}"
use_local_version = True
else:
package = package.strip().lower()
if package.startswith("#"):
continue
use_local_version = False
package_req = Requirement(package)
package_req = pkg_resources.Requirement.parse(package)
# get the standard package name if there is no underscore
# underscores are discouraged in package names, but are still used in Anaconda channel
# pkg_resources.Requirement.parse will convert all underscores to dashes

package_name = (
package
if not use_local_version and "_" in package
else package_req.name.lower().replace("_", "-")
package if not use_local_version and "_" in package else package_req.key
)
package_dict[package] = (package_name, use_local_version, package_req)

Expand All @@ -924,19 +925,21 @@ def _resolve_packages(
unsupported_packages: List[str] = []
for package, package_info in package_dict.items():
package_name, use_local_version, package_req = package_info
package_version_req = package_req.specifier
package_version_req = package_req.specs[0][1] if package_req.specs else None

if validate_package:
if package_name not in valid_packages or (
package_version_req
and not any(
package_version_req.contains(v)
for v in valid_packages[package_name]
)
and not any(v in package_req for v in valid_packages[package_name])
):
version_text = (
f"(version {package_version_req})"
if package_version_req is not None
else ""
)
if is_in_stored_procedure(): # pragma: no cover
raise RuntimeError(
f"Cannot add package {package_req} because it is not available in Snowflake "
f"Cannot add package {package_name}{version_text} because it is not available in Snowflake "
f"and it cannot be installed via pip as you are executing this code inside a stored "
f"procedure. You can find the directory of these packages and add it via "
f"session.add_import(). See details at "
Expand All @@ -947,15 +950,15 @@ def _resolve_packages(
and not self._is_anaconda_terms_acknowledged()
):
raise RuntimeError(
f"Cannot add package {package_req} because Anaconda terms must be accepted "
f"Cannot add package {package_name}{version_text} because Anaconda terms must be accepted "
"by ORGADMIN to use Anaconda 3rd party packages. Please follow the instructions at "
"https://docs.snowflake.com/en/developer-guide/udf/python/udf-python-packages.html#using-third-party-packages-from-anaconda."
)
unsupported_packages.append(package)
continue
elif not use_local_version:
try:
package_client_version = importlib.metadata.distribution(
package_client_version = pkg_resources.get_distribution(
package_name
).version
if package_client_version not in valid_packages[package_name]:
Expand All @@ -965,7 +968,7 @@ def _resolve_packages(
f"requirement '{package}'. Your UDF might not work when the package version "
f"is different between the server and your local environment."
)
except importlib.metadata.PackageNotFoundError:
except pkg_resources.DistributionNotFound:
_logger.warning(
f"Package '{package_name}' is not installed in the local environment. "
f"Your UDF might not work when the package is installed on the server "
Expand All @@ -987,7 +990,7 @@ def _resolve_packages(
else:
result_dict[package_name] = package

dependency_packages: Optional[List[Requirement]] = None
dependency_packages: Optional[List[pkg_resources.Requirement]] = None
if len(unsupported_packages) != 0:
_logger.warning(
f"The following packages are not available in Snowflake: {unsupported_packages}. They "
Expand Down Expand Up @@ -1037,7 +1040,7 @@ def _upload_unsupported_packages(
packages: List[str],
package_table: str,
force_push: bool = True,
) -> List[Requirement]:
) -> List[pkg_resources.Requirement]:
"""
Uploads a list of Pypi packages, which are unavailable in Snowflake, to session stage.
Expand Down

0 comments on commit 2aac082

Please sign in to comment.