From 37fe60335c17e0594831bb14632c7391fd9561f7 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 23 Oct 2024 13:45:22 -0700 Subject: [PATCH] SNOW-1418533 handle dropping temp objects in post actions (#2405) --- CHANGELOG.md | 2 +- .../_internal/analyzer/snowflake_plan.py | 42 ++++++- .../_internal/compiler/plan_compiler.py | 55 ++++++++- .../snowpark/_internal/server_connection.py | 9 +- tests/integ/test_large_query_breakdown.py | 14 ++- tests/integ/test_multithreading.py | 114 +++++++++++++++++- tests/unit/test_dataframe.py | 1 + 7 files changed, 224 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d890ef532..7ce7fee72f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ #### Improvements -- Disables sql simplification when sort is performed after limit. +- Disables sql simplification when sort is performed after limit. - Previously, `df.sort().limit()` and `df.limit().sort()` generates the same query with sort in front of limit. Now, `df.limit().sort()` will generate query that reads `df.limit().sort()`. - Improve performance of generated query for `df.limit().sort()`, because limit stops table scanning as soon as the number of records is satisfied. diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index f1d1caef0a..4298c79903 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -646,7 +646,12 @@ def large_local_relation_plan( source_plan: Optional[LogicalPlan], schema_query: Optional[str], ) -> SnowflakePlan: - temp_table_name = random_name_for_temp_object(TempObjectType.TABLE) + thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled + temp_table_name = ( + f"temp_name_placeholder_{generate_random_alphanumeric()}" + if thread_safe_session_enabled + else random_name_for_temp_object(TempObjectType.TABLE) + ) attributes = [ Attribute(attr.name, attr.datatype, attr.nullable) for attr in output ] @@ -670,7 +675,13 @@ def large_local_relation_plan( else: schema_query = schema_query or schema_value_statement(attributes) queries = [ - Query(create_table_stmt, is_ddl_on_temp_object=True), + Query( + create_table_stmt, + is_ddl_on_temp_object=True, + temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE) + if thread_safe_session_enabled + else None, + ), BatchInsertQuery(insert_stmt, data), Query(select_stmt), ] @@ -1184,6 +1195,7 @@ def read_file( metadata_project: Optional[List[str]] = None, metadata_schema: Optional[List[Attribute]] = None, ): + thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled format_type_options, copy_options = get_copy_into_table_options(options) format_type_options = self._merge_file_format_options( format_type_options, options @@ -1214,7 +1226,9 @@ def read_file( queries: List[Query] = [] post_queries: List[Query] = [] format_name = self.session.get_fully_qualified_name_if_possible( - random_name_for_temp_object(TempObjectType.FILE_FORMAT) + f"temp_name_placeholder_{generate_random_alphanumeric()}" + if thread_safe_session_enabled + else random_name_for_temp_object(TempObjectType.FILE_FORMAT) ) queries.append( Query( @@ -1228,6 +1242,9 @@ def read_file( is_generated=True, ), is_ddl_on_temp_object=True, + temp_obj_name_placeholder=(format_name, TempObjectType.FILE_FORMAT) + if thread_safe_session_enabled + else None, ) ) post_queries.append( @@ -1285,7 +1302,9 @@ def read_file( ) temp_table_name = self.session.get_fully_qualified_name_if_possible( - random_name_for_temp_object(TempObjectType.TABLE) + f"temp_name_placeholder_{generate_random_alphanumeric()}" + if thread_safe_session_enabled + else random_name_for_temp_object(TempObjectType.TABLE) ) queries = [ Query( @@ -1298,6 +1317,9 @@ def read_file( is_generated=True, ), is_ddl_on_temp_object=True, + temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE) + if thread_safe_session_enabled + else None, ), Query( copy_into_table( @@ -1618,6 +1640,7 @@ def __init__( *, query_id_place_holder: Optional[str] = None, is_ddl_on_temp_object: bool = False, + temp_obj_name_placeholder: Optional[Tuple[str, TempObjectType]] = None, params: Optional[Sequence[Any]] = None, ) -> None: self.sql = sql @@ -1626,6 +1649,16 @@ def __init__( if query_id_place_holder else f"query_id_place_holder_{generate_random_alphanumeric()}" ) + # This is to handle the case when a snowflake plan is created in the following way + # in a multi-threaded environment: + # 1. Create a temp object + # 2. Use the temp object in a query + # 3. Drop the temp object + # When step 3 in thread A is executed before step 2 in thread B, the query in thread B will fail with + # temp object not found. To handle this, we replace temp object names with placeholders in the query + # and track the temp object placeholder name and temp object type here. During query execution, we replace + # the placeholders with the actual temp object names for the given execution. + self.temp_obj_name_placeholder = temp_obj_name_placeholder self.is_ddl_on_temp_object = is_ddl_on_temp_object self.params = params or [] @@ -1644,6 +1677,7 @@ def __eq__(self, other: "Query") -> bool: self.sql == other.sql and self.query_id_place_holder == other.query_id_place_holder and self.is_ddl_on_temp_object == other.is_ddl_on_temp_object + and self.temp_obj_name_placeholder == other.temp_obj_name_placeholder and self.params == other.params ) diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index aa0f65a45b..6636badfa3 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -29,6 +29,7 @@ plot_plan_if_enabled, ) from snowflake.snowpark._internal.telemetry import TelemetryField +from snowflake.snowpark._internal.utils import random_name_for_temp_object from snowflake.snowpark.mock._connection import MockServerConnection @@ -156,10 +157,60 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: plan_uuid=self._plan.uuid, compilation_stage_summary=summary_value, ) - return queries else: final_plan = self._plan - return { + queries = { PlanQueryType.QUERIES: final_plan.queries, PlanQueryType.POST_ACTIONS: final_plan.post_actions, } + + return self.replace_temp_obj_placeholders(queries) + + def replace_temp_obj_placeholders( + self, queries: Dict[PlanQueryType, List[Query]] + ) -> Dict[PlanQueryType, List[Query]]: + """ + When thread-safe session is enabled, we use temporary object name placeholders instead of a temporary name + when generating snowflake plan. We replace the temporary object name placeholders with actual temporary object + names here. This is done to prevent the following scenario: + + 1. A dataframe is created and resolved in main thread. + 2. The resolve plan contains queries that create and drop temp objects. + 3. If the plan with same temp object names is executed my multiple threads, the temp object names will conflict. + One thread can drop the object before another thread finished using it. + + To prevent this, we generate queries with temp object name placeholders and replace them with actual temp object + here. + """ + session = self._plan.session + if session._conn._thread_safe_session_enabled: + # This dictionary will store the mapping between placeholder name and actual temp object name. + placeholders = {} + # Final execution queries + execution_queries = {} + for query_type, query_list in queries.items(): + execution_queries[query_type] = [] + for query in query_list: + # If the query contains a temp object name placeholder, we generate a random + # name for the temp object and add it to the placeholders dictionary. + if query.temp_obj_name_placeholder: + ( + placeholder_name, + temp_obj_type, + ) = query.temp_obj_name_placeholder + placeholders[placeholder_name] = random_name_for_temp_object( + temp_obj_type + ) + + copied_query = copy.copy(query) + for placeholder_name, target_temp_name in placeholders.items(): + # Copy the original query and replace all the placeholder names with the + # actual temp object names. + copied_query.sql = copied_query.sql.replace( + placeholder_name, target_temp_name + ) + + execution_queries[query_type].append(copied_query) + return execution_queries + + return queries diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index ae7c635a88..fff270afa8 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -44,6 +44,7 @@ from snowflake.snowpark._internal.analyzer.snowflake_plan import ( BatchInsertQuery, PlanQueryType, + Query, SnowflakePlan, ) from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages @@ -459,6 +460,7 @@ def run_query( params: Optional[Sequence[Any]] = None, num_statements: Optional[int] = None, ignore_results: bool = False, + async_post_actions: Optional[List[Query]] = None, **kwargs, ) -> Union[Dict[str, Any], AsyncJob]: try: @@ -502,7 +504,7 @@ def run_query( query, async_job_plan.session, data_type, - async_job_plan.post_actions, + async_post_actions, log_on_exception, case_sensitive=case_sensitive, num_statements=num_statements, @@ -633,6 +635,7 @@ def get_result_set( kwargs["_statement_params"] = statement_params try: main_queries = plan_queries[PlanQueryType.QUERIES] + post_actions = plan_queries[PlanQueryType.POST_ACTIONS] placeholders = {} is_batch_insert = False for q in main_queries: @@ -666,6 +669,7 @@ def get_result_set( num_statements=len(main_queries), params=params, ignore_results=ignore_results, + async_post_actions=post_actions, **kwargs, ) @@ -695,6 +699,7 @@ def get_result_set( case_sensitive=case_sensitive, params=query.params, ignore_results=ignore_results, + async_post_actions=post_actions, **kwargs, ) placeholders[query.query_id_place_holder] = ( @@ -706,7 +711,7 @@ def get_result_set( finally: # delete created tmp object if block: - for action in plan_queries[PlanQueryType.POST_ACTIONS]: + for action in post_actions: self.run_query( action.sql, is_ddl_on_temp_object=action.is_ddl_on_temp_object, diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index 3c962dae01..1f906765b4 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -478,12 +478,20 @@ def test_optimization_skipped_with_no_active_db_or_schema( assert called_with_reason == f"no active {db_or_schema}" -def test_async_job_with_large_query_breakdown(session, large_query_df): +def test_async_job_with_large_query_breakdown(large_query_df): """Test large query breakdown gives same result for async and non-async jobs""" - with SqlCounter(query_count=2): + with SqlCounter(query_count=3): + # 1 for current transaction + # 1 for created temp table; main query submitted as multi-statement query + # 1 for post action job = large_query_df.collect(block=False) result = job.result() - assert result == large_query_df.collect() + with SqlCounter(query_count=4): + # 1 for current transaction + # 1 for created temp table + # 1 for main query + # 1 for post action + assert result == large_query_df.collect() assert len(large_query_df.queries["queries"]) == 2 assert large_query_df.queries["queries"][0].startswith( "CREATE SCOPED TEMPORARY TABLE" diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 610510b821..00fb2e1666 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -6,6 +6,7 @@ import hashlib import logging import os +import re import tempfile import threading from concurrent.futures import ThreadPoolExecutor, as_completed @@ -18,7 +19,14 @@ _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION, Session, ) -from snowflake.snowpark.types import IntegerType +from snowflake.snowpark.types import ( + DoubleType, + IntegerType, + LongType, + StringType, + StructField, + StructType, +) from tests.integ.test_temp_table_cleanup import wait_for_drop_table_sql_done try: @@ -656,6 +664,110 @@ def change_config_value(session_): ) +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="local testing does not execute sql queries", + run=False, +) +def test_temp_name_placeholder_for_sync(threadsafe_session): + from snowflake.snowpark._internal.analyzer import analyzer + + original_value = analyzer.ARRAY_BIND_THRESHOLD + + def process_data(df_, thread_id): + df_cleaned = df_.filter(df.A == thread_id) + return df_cleaned.collect() + + try: + analyzer.ARRAY_BIND_THRESHOLD = 4 + df = threadsafe_session.create_dataframe([[1, 2], [3, 4]], ["A", "B"]) + + with threadsafe_session.query_history() as history: + with ThreadPoolExecutor(max_workers=5) as executor: + for i in range(10): + executor.submit(process_data, df, i) + + queries_sent = [query.sql_text for query in history.queries] + unique_create_table_queries = set() + unique_drop_table_queries = set() + for query in queries_sent: + assert "temp_name_placeholder" not in query + if query.startswith("CREATE OR REPLACE"): + match = re.search(r"SNOWPARK_TEMP_TABLE_[\w]+", query) + assert match is not None, query + table_name = match.group() + unique_create_table_queries.add(table_name) + elif query.startswith("DROP TABLE"): + match = re.search(r"SNOWPARK_TEMP_TABLE_[\w]+", query) + assert match is not None, query + table_name = match.group() + unique_drop_table_queries.add(table_name) + assert len(unique_create_table_queries) == 10, queries_sent + assert len(unique_drop_table_queries) == 10, queries_sent + + finally: + analyzer.ARRAY_BIND_THRESHOLD = original_value + + +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="local testing does not execute sql queries", + run=False, +) +@pytest.mark.skipif(IS_IN_STORED_PROC_LOCALFS, reason="Skip file IO tests in localfs") +def test_temp_name_placeholder_for_async( + threadsafe_session, resources_path, threadsafe_temp_stage +): + stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}" + stage_with_prefix = f"@{threadsafe_temp_stage}/{stage_prefix}/" + test_files = TestFiles(resources_path) + threadsafe_session.file.put( + test_files.test_file_csv, stage_with_prefix, auto_compress=False + ) + filename = os.path.basename(test_files.test_file_csv) + + def process_data(df_, thread_id): + df_cleaned = df_.filter(df.A == thread_id) + job = df_cleaned.collect(block=False) + job.result() + + df = threadsafe_session.read.schema( + StructType( + [ + StructField("A", LongType()), + StructField("B", StringType()), + StructField("C", DoubleType()), + ] + ) + ).csv(f"{stage_with_prefix}/{filename}") + + with threadsafe_session.query_history() as history: + with ThreadPoolExecutor(max_workers=5) as executor: + for i in range(10): + executor.submit(process_data, df, i) + + queries_sent = [query.sql_text for query in history.queries] + + unique_create_file_format_queries = set() + unique_drop_file_format_queries = set() + for query in queries_sent: + assert "temp_name_placeholder" not in query + if query.startswith(" CREATE SCOPED TEMPORARY FILE FORMAT"): + match = re.search(r"SNOWPARK_TEMP_FILE_FORMAT_[\w]+", query) + assert match is not None, query + file_format_name = match.group() + unique_create_file_format_queries.add(file_format_name) + else: + assert query.startswith("DROP FILE FORMAT") + match = re.search(r"SNOWPARK_TEMP_FILE_FORMAT_[\w]+", query) + assert match is not None, query + file_format_name = match.group() + unique_drop_file_format_queries.add(file_format_name) + + assert len(unique_create_file_format_queries) == 10 + assert len(unique_drop_file_format_queries) == 10 + + @pytest.mark.skipif( IS_IN_STORED_PROC, reason="Cannot create new session inside stored proc" ) diff --git a/tests/unit/test_dataframe.py b/tests/unit/test_dataframe.py index e9a77564e8..0aaaff8824 100644 --- a/tests/unit/test_dataframe.py +++ b/tests/unit/test_dataframe.py @@ -118,6 +118,7 @@ def nop(name): fake_session._cte_optimization_enabled = False fake_session._query_compilation_stage_enabled = False fake_session._conn = mock.create_autospec(ServerConnection) + fake_session._conn._thread_safe_session_enabled = False fake_session._plan_builder = SnowflakePlanBuilder(fake_session) fake_session._analyzer = Analyzer(fake_session) fake_session._use_scoped_temp_objects = True