Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-874787: Prevent future flakiness of packaging tests #967

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 38 additions & 36 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,20 +912,10 @@ def _resolve_packages(
if not self.get_current_database():
package_table = f"snowflake.{package_table}"

valid_packages = (
{
p[0]: json.loads(p[1])
for p in self.table(package_table)
.filter(
(col("language") == "python")
& (col("package_name").in_([v[0] for v in package_dict.values()]))
)
.group_by("package_name")
.agg(array_agg("version"))
._internal_collect_with_tag()
}
if validate_package and package_dict
else None
valid_packages = self._get_available_versions_for_packages(
package_names=[v[0] for v in package_dict.values()],
package_table_name=package_table,
validate_package=validate_package,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the code now sits in a utility function, it can be patched easily.

The goal here is to patch _get_available_versions_for_packages for certain package inputs (sktime, scikit-fuzzy) so that they are seen as custom packages forever (regardless of whether they are eventually added to Anaconda)

result_dict = (
Expand All @@ -938,10 +928,11 @@ def _resolve_packages(

if validate_package:
if package_name not in valid_packages or (
package_version_req and not any(
package_version_req.contains(v)
for v in valid_packages[package_name]
)
package_version_req
and not any(
package_version_req.contains(v)
for v in valid_packages[package_name]
)
):
if is_in_stored_procedure(): # pragma: no cover
raise RuntimeError(
Expand Down Expand Up @@ -1078,24 +1069,12 @@ def _upload_unsupported_packages(
downloaded_packages_dict = map_python_packages_to_files_and_folders(target)

# Fetch valid Snowflake Anaconda versions for all packages installed by pip (if present).
valid_downloaded_packages = {
p[0]: json.loads(p[1])
for p in self.table(package_table)
.filter(
(col("language") == "python")
& (
col("package_name").in_(
[
package.name
for package in downloaded_packages_dict.keys()
]
)
)
)
.group_by("package_name")
.agg(array_agg("version"))
._internal_collect_with_tag()
}
valid_downloaded_packages = self._get_available_versions_for_packages(
package_names=[
package.name for package in downloaded_packages_dict.keys()
],
package_table_name=package_table,
)

# Detect packages which use native code.
native_packages = detect_native_dependencies(
Expand Down Expand Up @@ -1148,6 +1127,29 @@ def _upload_unsupported_packages(
def _is_anaconda_terms_acknowledged(self) -> bool:
return self._run_query("select system$are_anaconda_terms_acknowledged()")[0][0]

def _get_available_versions_for_packages(
self,
package_names: List[str],
package_table_name: str,
validate_package: bool = True,
) -> Dict[str, List[str]]:
package_to_version_mapping = (
{
p[0]: json.loads(p[1])
for p in self.table(package_table_name)
.filter(
(col("language") == "python")
& (col("package_name").in_(package_names))
)
.group_by("package_name")
.agg(array_agg("version"))
._internal_collect_with_tag()
}
if validate_package and len(package_names) > 0
else None
)
return package_to_version_mapping

@property
def query_tag(self) -> Optional[str]:
"""
Expand Down
101 changes: 97 additions & 4 deletions tests/integ/test_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

import pytest

from snowflake.snowpark import Row
from snowflake.snowpark.functions import col, count_distinct, udf
from snowflake.snowpark import Row, Session
from snowflake.snowpark.functions import col, count_distinct, sproc, udf
from snowflake.snowpark.types import DateType
from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, TempObjectType, TestFiles, Utils

Expand Down Expand Up @@ -52,7 +52,34 @@ def clean_up(session):
session.clear_packages()
session.clear_imports()
session._runtime_version_from_requirement = None
yield


@pytest.fixture(autouse=True)
def get_available_versions_for_packages_patched(session):
# Save a reference to the original function
original_function = session._get_available_versions_for_packages
sentinel_version = "0.0.1"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sentinel version will be returned for sktime and scikit_fuzzy, if either of those packages are present.


with patch.object(session, "_get_available_versions_for_packages") as mock_function:

def side_effect(package_names, *args, **kwargs):
sktime_found = False
scikit_fuzzy_found = False
for name in package_names:
if name == "sktime":
sktime_found = True
elif name == "scikit-fuzzy":
scikit_fuzzy_found = True

result = original_function(package_names, *args, **kwargs)
if sktime_found:
result.update({"sktime": [sentinel_version]})
if scikit_fuzzy_found:
result.update({"scikit-fuzzy": [sentinel_version]})
return result

mock_function.side_effect = side_effect
yield


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -101,6 +128,27 @@ def ranged_yaml_file():
os.remove(file_path)


def test_patch_on_get_available_versions_for_packages(session):
"""
Assert that the utility function get_available_versions_for_packages() is patched for custom packages. This ensures
that if custom packages are eventually added to the Anaconda channel, the custom package tests will not fail.
"""
package_table = "information_schema.packages"
# TODO: Use the database from fully qualified UDF name
if not session.get_current_database():
package_table = f"snowflake.{package_table}"

packages = ["sktime", "scikit-fuzzy", "numpy", "pandas"]
returned = session._get_available_versions_for_packages(packages, package_table)
assert returned.keys() == set(packages)
for key in returned.keys():
assert len(returned[key]) > 0
assert returned["sktime"] == ["0.0.1"]
assert returned["scikit-fuzzy"] == ["0.0.1"]
assert returned["numpy"] != ["0.0.1"]
assert returned["pandas"] != ["0.0.1"]


@pytest.mark.skipif(
(not is_pandas_and_numpy_available) or IS_IN_STORED_PROC,
reason="numpy and pandas are required",
Expand Down Expand Up @@ -319,7 +367,7 @@ def test_add_packages_should_fail_if_dependency_package_already_added(session):
IS_IN_STORED_PROC,
reason="Subprocess calls are not allowed within stored procedures",
)
def test_add_requirements_unsupported(session, resources_path):
def test_add_requirements_unsupported_usable_by_udf(session, resources_path):
test_files = TestFiles(resources_path)

with patch.object(session, "_is_anaconda_terms_acknowledged", lambda: True):
Expand All @@ -343,6 +391,34 @@ def import_scikit_fuzzy() -> str:
Utils.check_answer(session.sql(f"select {udf_name}()"), [Row("0.4.2")])


@pytest.mark.skipif(
IS_IN_STORED_PROC,
reason="Subprocess calls are not allowed within stored procedures",
)
def test_add_requirements_unsupported_usable_by_sproc(session, resources_path):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is moved into test_packaging.py from test_stored_procedure.py file as it references the scikit-fuzzy package

test_files = TestFiles(resources_path)

with patch.object(session, "_is_anaconda_terms_acknowledged", lambda: True):
session.add_requirements(test_files.test_unsupported_requirements_file)
session.add_packages("snowflake-snowpark-python")
# Once scikit-fuzzy is supported, this test will break; change the test to a different unsupported module
assert set(session.get_packages().keys()) == {
"matplotlib",
"pyyaml",
"scipy",
"snowflake-snowpark-python",
"numpy",
}

@sproc
def run_scikit_fuzzy(_: Session) -> str:
import skfuzzy as fuzz

return fuzz.__version__

assert run_scikit_fuzzy(session) == "0.4.2"


@pytest.mark.skipif(
IS_IN_STORED_PROC or IS_WINDOWS,
reason="Subprocess calls are not allowed within stored procedures",
Expand Down Expand Up @@ -539,3 +615,20 @@ def plus_one_month(x):
Utils.check_answer(
df.select(plus_one_month_udf("a")).collect(), [Row(plus_one_month(d))]
)


def test_get_available_versions_for_packages(session):
"""
Assert that the utility function get_available_versions_for_packages() returns a list of versions available in Snowflake,
for some common packages.
"""
package_table = "information_schema.packages"
# TODO: Use the database from fully qualified UDF name
if not session.get_current_database():
Comment on lines +626 to +627
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to take care of the todo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the TODO is about, it was in the codebase prior to my changes.

package_table = f"snowflake.{package_table}"

packages = ["numpy", "pandas", "matplotlib"]
returned = session._get_available_versions_for_packages(packages, package_table)
assert returned.keys() == set(packages)
for key in returned.keys():
assert len(returned[key]) > 0
27 changes: 0 additions & 27 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,30 +1168,3 @@ def test_anonymous_stored_procedure(session):
)
assert add_sp._anonymous_sp_sql is not None
assert add_sp(1, 2) == 3


@pytest.mark.skipif(
IS_IN_STORED_PROC,
reason="Subprocess calls are not allowed within stored procedures",
)
def test_add_requirements_unsupported(session, resources_path):
test_files = TestFiles(resources_path)

with patch.object(session, "_is_anaconda_terms_acknowledged", lambda: True):
session.add_requirements(test_files.test_unsupported_requirements_file)
# Once scikit-fuzzy is supported, this test will break; change the test to a different unsupported module
assert set(session.get_packages().keys()) == {
"matplotlib",
"pyyaml",
"snowflake-snowpark-python",
"scipy",
"numpy",
}

@sproc
def run_scikit_fuzzy(_: Session) -> str:
import skfuzzy as fuzz

return fuzz.__version__

assert run_scikit_fuzzy(session) == "0.4.2"
Loading