Skip to content

Commit

Permalink
SNOW-1659276 add multithreading tests to sp env (#2479)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Oct 21, 2024
1 parent 051d566 commit f73980e
Showing 1 changed file with 49 additions and 23 deletions.
72 changes: 49 additions & 23 deletions tests/integ/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import tempfile
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Tuple # noqa: F401
from unittest.mock import patch
Expand All @@ -32,22 +33,34 @@

from snowflake.snowpark.functions import lit
from snowflake.snowpark.row import Row
from tests.utils import IS_IN_STORED_PROC, IS_LINUX, IS_WINDOWS, TestFiles, Utils
from tests.utils import (
IS_IN_STORED_PROC,
IS_IN_STORED_PROC_LOCALFS,
IS_LINUX,
IS_WINDOWS,
TestFiles,
Utils,
)


@pytest.fixture(scope="module")
def threadsafe_session(
db_parameters, sql_simplifier_enabled, cte_optimization_enabled, local_testing_mode
db_parameters,
session,
sql_simplifier_enabled,
local_testing_mode,
):
new_db_parameters = db_parameters.copy()
new_db_parameters["local_testing"] = local_testing_mode
new_db_parameters["session_parameters"] = {
_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True
}
with Session.builder.configs(new_db_parameters).create() as session:
session._sql_simplifier_enabled = sql_simplifier_enabled
session._cte_optimization_enabled = cte_optimization_enabled
if IS_IN_STORED_PROC:
yield session
else:
new_db_parameters = db_parameters.copy()
new_db_parameters["local_testing"] = local_testing_mode
new_db_parameters["session_parameters"] = {
_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True
}
with Session.builder.configs(new_db_parameters).create() as session:
session._sql_simplifier_enabled = sql_simplifier_enabled
yield session


@pytest.fixture(scope="function")
Expand All @@ -65,6 +78,13 @@ def threadsafe_temp_stage(threadsafe_session, resources_path, local_testing_mode
Utils.drop_stage(threadsafe_session, tmp_stage_name)


def test_threadsafe_session_uses_locks(threadsafe_session):
rlock_class = threading.RLock().__class__
assert isinstance(threadsafe_session._lock, rlock_class)
assert isinstance(threadsafe_session._temp_table_auto_cleaner.lock, rlock_class)
assert isinstance(threadsafe_session._conn._lock, rlock_class)


def test_concurrent_select_queries(threadsafe_session):
def run_select(session_, thread_id):
df = session_.sql(f"SELECT {thread_id} as A")
Expand Down Expand Up @@ -183,6 +203,7 @@ def test_action_ids_are_unique(threadsafe_session):
assert len(action_ids) == 10


@pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="Skip file IO tests in localfs")
@pytest.mark.parametrize("use_stream", [True, False])
def test_file_io(threadsafe_session, resources_path, threadsafe_temp_stage, use_stream):
stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}"
Expand Down Expand Up @@ -256,30 +277,28 @@ def put_and_get_file(upload_file_path, download_dir):
def test_concurrent_add_packages(threadsafe_session):
# this is a list of packages available in snowflake anaconda. If this
# test fails due to packages not being available, please update the list
existing_packages = threadsafe_session.get_packages()
package_list = {
"graphviz",
"cloudpickle",
"numpy",
"pandas",
"scipy",
"scikit-learn",
"matplotlib",
"pyyaml",
}

try:
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [
for package in package_list:
executor.submit(threadsafe_session.add_packages, package)
for package in package_list
]

for future in as_completed(futures):
future.result()

assert threadsafe_session.get_packages() == {
package: package for package in package_list
}
final_packages = threadsafe_session.get_packages()
for package in package_list:
assert package in final_packages
finally:
threadsafe_session.clear_packages()
for package in package_list:
if package not in existing_packages:
threadsafe_session.remove_package(package)


def test_concurrent_remove_package(threadsafe_session):
Expand Down Expand Up @@ -317,6 +336,7 @@ def remove_package(session_, package_name):
@pytest.mark.skipif(not is_dateutil_available, reason="dateutil is not available")
def test_concurrent_add_import(threadsafe_session, resources_path):
test_files = TestFiles(resources_path)
existing_imports = set(threadsafe_session.get_imports())
import_files = [
test_files.test_udf_py_file,
os.path.relpath(test_files.test_udf_py_file),
Expand All @@ -336,7 +356,7 @@ def test_concurrent_add_import(threadsafe_session, resources_path):

assert set(threadsafe_session.get_imports()) == {
os.path.abspath(file) for file in import_files
}
}.union(existing_imports)
finally:
threadsafe_session.clear_imports()

Expand Down Expand Up @@ -556,6 +576,9 @@ def finish(self):
reason="session.sql is not supported in local testing mode",
run=False,
)
@pytest.mark.skipif(
IS_IN_STORED_PROC, reason="SNOW-609328: support caplog in SP regression test"
)
def test_auto_temp_table_cleaner(threadsafe_session, caplog):
threadsafe_session._temp_table_auto_cleaner.ref_count_map.clear()
original_auto_clean_up_temp_table_enabled = (
Expand Down Expand Up @@ -633,6 +656,9 @@ def change_config_value(session_):
)


@pytest.mark.skipif(
IS_IN_STORED_PROC, reason="Cannot create new session inside stored proc"
)
@pytest.mark.parametrize("is_enabled", [True, False])
def test_num_cursors_created(db_parameters, is_enabled, local_testing_mode):
if is_enabled and local_testing_mode:
Expand Down

0 comments on commit f73980e

Please sign in to comment.