Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1659276 add multithreading tests to sp env #2479

Merged
72 changes: 49 additions & 23 deletions tests/integ/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import tempfile
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Tuple # noqa: F401
from unittest.mock import patch
Expand All @@ -32,22 +33,34 @@

from snowflake.snowpark.functions import lit
from snowflake.snowpark.row import Row
from tests.utils import IS_IN_STORED_PROC, IS_LINUX, IS_WINDOWS, TestFiles, Utils
from tests.utils import (
IS_IN_STORED_PROC,
IS_IN_STORED_PROC_LOCALFS,
IS_LINUX,
IS_WINDOWS,
TestFiles,
Utils,
)


@pytest.fixture(scope="module")
def threadsafe_session(
db_parameters, sql_simplifier_enabled, cte_optimization_enabled, local_testing_mode
db_parameters,
session,
sql_simplifier_enabled,
local_testing_mode,
):
new_db_parameters = db_parameters.copy()
new_db_parameters["local_testing"] = local_testing_mode
new_db_parameters["session_parameters"] = {
_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True
}
with Session.builder.configs(new_db_parameters).create() as session:
session._sql_simplifier_enabled = sql_simplifier_enabled
session._cte_optimization_enabled = cte_optimization_enabled
if IS_IN_STORED_PROC:
yield session
else:
new_db_parameters = db_parameters.copy()
new_db_parameters["local_testing"] = local_testing_mode
new_db_parameters["session_parameters"] = {
_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True
}
with Session.builder.configs(new_db_parameters).create() as session:
session._sql_simplifier_enabled = sql_simplifier_enabled
yield session


@pytest.fixture(scope="function")
Expand All @@ -65,6 +78,13 @@ def threadsafe_temp_stage(threadsafe_session, resources_path, local_testing_mode
Utils.drop_stage(threadsafe_session, tmp_stage_name)


def test_threadsafe_session_uses_locks(threadsafe_session):
rlock_class = threading.RLock().__class__
assert isinstance(threadsafe_session._lock, rlock_class)
assert isinstance(threadsafe_session._temp_table_auto_cleaner.lock, rlock_class)
assert isinstance(threadsafe_session._conn._lock, rlock_class)


def test_concurrent_select_queries(threadsafe_session):
def run_select(session_, thread_id):
df = session_.sql(f"SELECT {thread_id} as A")
Expand Down Expand Up @@ -183,6 +203,7 @@ def test_action_ids_are_unique(threadsafe_session):
assert len(action_ids) == 10


@pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="Skip file IO tests in localfs")
@pytest.mark.parametrize("use_stream", [True, False])
def test_file_io(threadsafe_session, resources_path, threadsafe_temp_stage, use_stream):
stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}"
Expand Down Expand Up @@ -256,30 +277,28 @@ def put_and_get_file(upload_file_path, download_dir):
def test_concurrent_add_packages(threadsafe_session):
# this is a list of packages available in snowflake anaconda. If this
# test fails due to packages not being available, please update the list
existing_packages = threadsafe_session.get_packages()
package_list = {
"graphviz",
"cloudpickle",
"numpy",
"pandas",
"scipy",
"scikit-learn",
"matplotlib",
"pyyaml",
}

try:
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [
for package in package_list:
executor.submit(threadsafe_session.add_packages, package)
for package in package_list
]

for future in as_completed(futures):
future.result()
sfc-gh-aling marked this conversation as resolved.
Show resolved Hide resolved

assert threadsafe_session.get_packages() == {
package: package for package in package_list
}
final_packages = threadsafe_session.get_packages()
for package in package_list:
assert package in final_packages
finally:
threadsafe_session.clear_packages()
for package in package_list:
if package not in existing_packages:
threadsafe_session.remove_package(package)


def test_concurrent_remove_package(threadsafe_session):
Expand Down Expand Up @@ -317,6 +336,7 @@ def remove_package(session_, package_name):
@pytest.mark.skipif(not is_dateutil_available, reason="dateutil is not available")
def test_concurrent_add_import(threadsafe_session, resources_path):
test_files = TestFiles(resources_path)
existing_imports = set(threadsafe_session.get_imports())
import_files = [
test_files.test_udf_py_file,
os.path.relpath(test_files.test_udf_py_file),
Expand All @@ -336,7 +356,7 @@ def test_concurrent_add_import(threadsafe_session, resources_path):

assert set(threadsafe_session.get_imports()) == {
os.path.abspath(file) for file in import_files
}
}.union(existing_imports)
finally:
threadsafe_session.clear_imports()

Expand Down Expand Up @@ -556,6 +576,9 @@ def finish(self):
reason="session.sql is not supported in local testing mode",
run=False,
)
@pytest.mark.skipif(
IS_IN_STORED_PROC, reason="SNOW-609328: support caplog in SP regression test"
)
def test_auto_temp_table_cleaner(threadsafe_session, caplog):
threadsafe_session._temp_table_auto_cleaner.ref_count_map.clear()
original_auto_clean_up_temp_table_enabled = (
Expand Down Expand Up @@ -633,6 +656,9 @@ def change_config_value(session_):
)


@pytest.mark.skipif(
IS_IN_STORED_PROC, reason="Cannot create new session inside stored proc"
)
@pytest.mark.parametrize("is_enabled", [True, False])
def test_num_cursors_created(db_parameters, is_enabled, local_testing_mode):
if is_enabled and local_testing_mode:
Expand Down
Loading