diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index b427ef8aa4..610510b821 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -7,6 +7,7 @@ import logging import os import tempfile +import threading from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Tuple # noqa: F401 from unittest.mock import patch @@ -32,22 +33,34 @@ from snowflake.snowpark.functions import lit from snowflake.snowpark.row import Row -from tests.utils import IS_IN_STORED_PROC, IS_LINUX, IS_WINDOWS, TestFiles, Utils +from tests.utils import ( + IS_IN_STORED_PROC, + IS_IN_STORED_PROC_LOCALFS, + IS_LINUX, + IS_WINDOWS, + TestFiles, + Utils, +) @pytest.fixture(scope="module") def threadsafe_session( - db_parameters, sql_simplifier_enabled, cte_optimization_enabled, local_testing_mode + db_parameters, + session, + sql_simplifier_enabled, + local_testing_mode, ): - new_db_parameters = db_parameters.copy() - new_db_parameters["local_testing"] = local_testing_mode - new_db_parameters["session_parameters"] = { - _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True - } - with Session.builder.configs(new_db_parameters).create() as session: - session._sql_simplifier_enabled = sql_simplifier_enabled - session._cte_optimization_enabled = cte_optimization_enabled + if IS_IN_STORED_PROC: yield session + else: + new_db_parameters = db_parameters.copy() + new_db_parameters["local_testing"] = local_testing_mode + new_db_parameters["session_parameters"] = { + _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True + } + with Session.builder.configs(new_db_parameters).create() as session: + session._sql_simplifier_enabled = sql_simplifier_enabled + yield session @pytest.fixture(scope="function") @@ -65,6 +78,13 @@ def threadsafe_temp_stage(threadsafe_session, resources_path, local_testing_mode Utils.drop_stage(threadsafe_session, tmp_stage_name) +def test_threadsafe_session_uses_locks(threadsafe_session): + rlock_class = threading.RLock().__class__ + assert isinstance(threadsafe_session._lock, rlock_class) + assert isinstance(threadsafe_session._temp_table_auto_cleaner.lock, rlock_class) + assert isinstance(threadsafe_session._conn._lock, rlock_class) + + def test_concurrent_select_queries(threadsafe_session): def run_select(session_, thread_id): df = session_.sql(f"SELECT {thread_id} as A") @@ -183,6 +203,7 @@ def test_action_ids_are_unique(threadsafe_session): assert len(action_ids) == 10 +@pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="Skip file IO tests in localfs") @pytest.mark.parametrize("use_stream", [True, False]) def test_file_io(threadsafe_session, resources_path, threadsafe_temp_stage, use_stream): stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}" @@ -256,30 +277,28 @@ def put_and_get_file(upload_file_path, download_dir): def test_concurrent_add_packages(threadsafe_session): # this is a list of packages available in snowflake anaconda. If this # test fails due to packages not being available, please update the list + existing_packages = threadsafe_session.get_packages() package_list = { - "graphviz", + "cloudpickle", "numpy", "pandas", "scipy", "scikit-learn", - "matplotlib", + "pyyaml", } try: with ThreadPoolExecutor(max_workers=10) as executor: - futures = [ + for package in package_list: executor.submit(threadsafe_session.add_packages, package) - for package in package_list - ] - for future in as_completed(futures): - future.result() - - assert threadsafe_session.get_packages() == { - package: package for package in package_list - } + final_packages = threadsafe_session.get_packages() + for package in package_list: + assert package in final_packages finally: - threadsafe_session.clear_packages() + for package in package_list: + if package not in existing_packages: + threadsafe_session.remove_package(package) def test_concurrent_remove_package(threadsafe_session): @@ -317,6 +336,7 @@ def remove_package(session_, package_name): @pytest.mark.skipif(not is_dateutil_available, reason="dateutil is not available") def test_concurrent_add_import(threadsafe_session, resources_path): test_files = TestFiles(resources_path) + existing_imports = set(threadsafe_session.get_imports()) import_files = [ test_files.test_udf_py_file, os.path.relpath(test_files.test_udf_py_file), @@ -336,7 +356,7 @@ def test_concurrent_add_import(threadsafe_session, resources_path): assert set(threadsafe_session.get_imports()) == { os.path.abspath(file) for file in import_files - } + }.union(existing_imports) finally: threadsafe_session.clear_imports() @@ -556,6 +576,9 @@ def finish(self): reason="session.sql is not supported in local testing mode", run=False, ) +@pytest.mark.skipif( + IS_IN_STORED_PROC, reason="SNOW-609328: support caplog in SP regression test" +) def test_auto_temp_table_cleaner(threadsafe_session, caplog): threadsafe_session._temp_table_auto_cleaner.ref_count_map.clear() original_auto_clean_up_temp_table_enabled = ( @@ -633,6 +656,9 @@ def change_config_value(session_): ) +@pytest.mark.skipif( + IS_IN_STORED_PROC, reason="Cannot create new session inside stored proc" +) @pytest.mark.parametrize("is_enabled", [True, False]) def test_num_cursors_created(db_parameters, is_enabled, local_testing_mode): if is_enabled and local_testing_mode: