From 944d93b9c6700706110324bbf13401940107dd5f Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 18 Oct 2024 11:35:12 -0700 Subject: [PATCH 1/9] yield regular session inside sp --- tests/integ/test_multithreading.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index b427ef8aa4..863f080b5f 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -37,17 +37,24 @@ @pytest.fixture(scope="module") def threadsafe_session( - db_parameters, sql_simplifier_enabled, cte_optimization_enabled, local_testing_mode + db_parameters, + session, + sql_simplifier_enabled, + cte_optimization_enabled, + local_testing_mode, ): - new_db_parameters = db_parameters.copy() - new_db_parameters["local_testing"] = local_testing_mode - new_db_parameters["session_parameters"] = { - _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True - } - with Session.builder.configs(new_db_parameters).create() as session: - session._sql_simplifier_enabled = sql_simplifier_enabled - session._cte_optimization_enabled = cte_optimization_enabled + if IS_IN_STORED_PROC: yield session + else: + new_db_parameters = db_parameters.copy() + new_db_parameters["local_testing"] = local_testing_mode + new_db_parameters["session_parameters"] = { + _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True + } + with Session.builder.configs(new_db_parameters).create() as session: + session._sql_simplifier_enabled = sql_simplifier_enabled + session._cte_optimization_enabled = cte_optimization_enabled + yield session @pytest.fixture(scope="function") From 6cd7dd0cd05a2791e80995b9d26471bc25dce38f Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 18 Oct 2024 11:47:58 -0700 Subject: [PATCH 2/9] add tests to ensure we use real locks --- tests/integ/test_multithreading.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 863f080b5f..4e137482fc 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -7,6 +7,7 @@ import logging import os import tempfile +import threading from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Tuple # noqa: F401 from unittest.mock import patch @@ -72,6 +73,12 @@ def threadsafe_temp_stage(threadsafe_session, resources_path, local_testing_mode Utils.drop_stage(threadsafe_session, tmp_stage_name) +def test_threadsafe_session_uses_locks(threadsafe_session): + assert isinstance(threadsafe_session._lock, threading.RLock) + assert isinstance(threadsafe_session._temp_table_auto_cleaner.lock, threading.RLock) + assert isinstance(threadsafe_session._conn._lock, threading.RLock) + + def test_concurrent_select_queries(threadsafe_session): def run_select(session_, thread_id): df = session_.sql(f"SELECT {thread_id} as A") From 60ab5eeba91b1a9be96b84ef15471002361f6c34 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 18 Oct 2024 11:50:23 -0700 Subject: [PATCH 3/9] ignore tests that may not run in SP --- tests/integ/test_multithreading.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 4e137482fc..12cecb43af 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -647,6 +647,9 @@ def change_config_value(session_): ) +@pytest.mark.skipif( + IS_IN_STORED_PROC, reason="Cannot create new session inside stored proc" +) @pytest.mark.parametrize("is_enabled", [True, False]) def test_num_cursors_created(db_parameters, is_enabled, local_testing_mode): if is_enabled and local_testing_mode: From 6dcaabc56e69747af451d152dfd2d295186bc75e Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 18 Oct 2024 12:54:43 -0700 Subject: [PATCH 4/9] read correct class --- tests/integ/test_multithreading.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 12cecb43af..a60cf8f2fe 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -74,9 +74,10 @@ def threadsafe_temp_stage(threadsafe_session, resources_path, local_testing_mode def test_threadsafe_session_uses_locks(threadsafe_session): - assert isinstance(threadsafe_session._lock, threading.RLock) - assert isinstance(threadsafe_session._temp_table_auto_cleaner.lock, threading.RLock) - assert isinstance(threadsafe_session._conn._lock, threading.RLock) + rlock_class = threading.RLock().__class__ + assert isinstance(threadsafe_session._lock, rlock_class) + assert isinstance(threadsafe_session._temp_table_auto_cleaner.lock, rlock_class) + assert isinstance(threadsafe_session._conn._lock, rlock_class) def test_concurrent_select_queries(threadsafe_session): From 0f31abeabf70f7d19668ef3f2f8f57bc2bf526e0 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 18 Oct 2024 16:14:07 -0700 Subject: [PATCH 5/9] simplify threadsafe session creation --- tests/integ/test_multithreading.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index a60cf8f2fe..a3df9dba7d 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -41,7 +41,6 @@ def threadsafe_session( db_parameters, session, sql_simplifier_enabled, - cte_optimization_enabled, local_testing_mode, ): if IS_IN_STORED_PROC: @@ -54,7 +53,6 @@ def threadsafe_session( } with Session.builder.configs(new_db_parameters).create() as session: session._sql_simplifier_enabled = sql_simplifier_enabled - session._cte_optimization_enabled = cte_optimization_enabled yield session From 344741f68dca580e162328cdfadd853b7826be50 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 18 Oct 2024 20:28:25 -0700 Subject: [PATCH 6/9] adjustment for SP environment --- tests/integ/test_multithreading.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index a3df9dba7d..083459747d 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -275,7 +275,7 @@ def test_concurrent_add_packages(threadsafe_session): "pandas", "scipy", "scikit-learn", - "matplotlib", + "pyyaml", } try: @@ -330,6 +330,7 @@ def remove_package(session_, package_name): @pytest.mark.skipif(not is_dateutil_available, reason="dateutil is not available") def test_concurrent_add_import(threadsafe_session, resources_path): test_files = TestFiles(resources_path) + existing_imports = set(threadsafe_session.get_imports()) import_files = [ test_files.test_udf_py_file, os.path.relpath(test_files.test_udf_py_file), @@ -349,7 +350,7 @@ def test_concurrent_add_import(threadsafe_session, resources_path): assert set(threadsafe_session.get_imports()) == { os.path.abspath(file) for file in import_files - } + }.union(existing_imports) finally: threadsafe_session.clear_imports() @@ -569,6 +570,9 @@ def finish(self): reason="session.sql is not supported in local testing mode", run=False, ) +@pytest.mark.skipif( + IS_IN_STORED_PROC, reason="SNOW-609328: support caplog in SP regression test" +) def test_auto_temp_table_cleaner(threadsafe_session, caplog): threadsafe_session._temp_table_auto_cleaner.ref_count_map.clear() original_auto_clean_up_temp_table_enabled = ( From 19a70badbba16cb850bbf1bf3de93f3875bc634f Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Sat, 19 Oct 2024 09:34:40 -0700 Subject: [PATCH 7/9] adjustment for SP environment --- tests/integ/test_multithreading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 083459747d..99fc017386 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -270,7 +270,7 @@ def test_concurrent_add_packages(threadsafe_session): # this is a list of packages available in snowflake anaconda. If this # test fails due to packages not being available, please update the list package_list = { - "graphviz", + "cloudpickle", "numpy", "pandas", "scipy", From af4167da232d97d14d4d2b9355dc94a06c77a714 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Sat, 19 Oct 2024 14:53:50 -0700 Subject: [PATCH 8/9] adjustment for SP environment --- tests/integ/test_multithreading.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 99fc017386..64d961f176 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -269,6 +269,7 @@ def put_and_get_file(upload_file_path, download_dir): def test_concurrent_add_packages(threadsafe_session): # this is a list of packages available in snowflake anaconda. If this # test fails due to packages not being available, please update the list + existing_packages = threadsafe_session.get_packages() package_list = { "cloudpickle", "numpy", @@ -280,19 +281,16 @@ def test_concurrent_add_packages(threadsafe_session): try: with ThreadPoolExecutor(max_workers=10) as executor: - futures = [ + for package in package_list: executor.submit(threadsafe_session.add_packages, package) - for package in package_list - ] - - for future in as_completed(futures): - future.result() - assert threadsafe_session.get_packages() == { - package: package for package in package_list - } + final_packages = threadsafe_session.get_packages() + for package in package_list: + assert package in final_packages finally: - threadsafe_session.clear_packages() + for package in package_list: + if package not in existing_packages: + threadsafe_session.remove_package(package) def test_concurrent_remove_package(threadsafe_session): From c21d1232abdb78898fc05e4bcf7777e225243900 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Sat, 19 Oct 2024 17:32:54 -0700 Subject: [PATCH 9/9] adjustment for SP environment --- tests/integ/test_multithreading.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 64d961f176..610510b821 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -33,7 +33,14 @@ from snowflake.snowpark.functions import lit from snowflake.snowpark.row import Row -from tests.utils import IS_IN_STORED_PROC, IS_LINUX, IS_WINDOWS, TestFiles, Utils +from tests.utils import ( + IS_IN_STORED_PROC, + IS_IN_STORED_PROC_LOCALFS, + IS_LINUX, + IS_WINDOWS, + TestFiles, + Utils, +) @pytest.fixture(scope="module") @@ -196,6 +203,7 @@ def test_action_ids_are_unique(threadsafe_session): assert len(action_ids) == 10 +@pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="Skip file IO tests in localfs") @pytest.mark.parametrize("use_stream", [True, False]) def test_file_io(threadsafe_session, resources_path, threadsafe_temp_stage, use_stream): stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}"