Skip to content

Commit

Permalink
SNOW-1418533 handle dropping temp objects in post actions (#2405)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Oct 23, 2024
1 parent 9dd7a23 commit 37fe603
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#### Improvements

- Disables sql simplification when sort is performed after limit.
- Disables sql simplification when sort is performed after limit.
- Previously, `df.sort().limit()` and `df.limit().sort()` generates the same query with sort in front of limit. Now, `df.limit().sort()` will generate query that reads `df.limit().sort()`.
- Improve performance of generated query for `df.limit().sort()`, because limit stops table scanning as soon as the number of records is satisfied.

Expand Down
42 changes: 38 additions & 4 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,12 @@ def large_local_relation_plan(
source_plan: Optional[LogicalPlan],
schema_query: Optional[str],
) -> SnowflakePlan:
temp_table_name = random_name_for_temp_object(TempObjectType.TABLE)
thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled
temp_table_name = (
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.TABLE)
)
attributes = [
Attribute(attr.name, attr.datatype, attr.nullable) for attr in output
]
Expand All @@ -670,7 +675,13 @@ def large_local_relation_plan(
else:
schema_query = schema_query or schema_value_statement(attributes)
queries = [
Query(create_table_stmt, is_ddl_on_temp_object=True),
Query(
create_table_stmt,
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE)
if thread_safe_session_enabled
else None,
),
BatchInsertQuery(insert_stmt, data),
Query(select_stmt),
]
Expand Down Expand Up @@ -1184,6 +1195,7 @@ def read_file(
metadata_project: Optional[List[str]] = None,
metadata_schema: Optional[List[Attribute]] = None,
):
thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled
format_type_options, copy_options = get_copy_into_table_options(options)
format_type_options = self._merge_file_format_options(
format_type_options, options
Expand Down Expand Up @@ -1214,7 +1226,9 @@ def read_file(
queries: List[Query] = []
post_queries: List[Query] = []
format_name = self.session.get_fully_qualified_name_if_possible(
random_name_for_temp_object(TempObjectType.FILE_FORMAT)
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.FILE_FORMAT)
)
queries.append(
Query(
Expand All @@ -1228,6 +1242,9 @@ def read_file(
is_generated=True,
),
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(format_name, TempObjectType.FILE_FORMAT)
if thread_safe_session_enabled
else None,
)
)
post_queries.append(
Expand Down Expand Up @@ -1285,7 +1302,9 @@ def read_file(
)

temp_table_name = self.session.get_fully_qualified_name_if_possible(
random_name_for_temp_object(TempObjectType.TABLE)
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.TABLE)
)
queries = [
Query(
Expand All @@ -1298,6 +1317,9 @@ def read_file(
is_generated=True,
),
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE)
if thread_safe_session_enabled
else None,
),
Query(
copy_into_table(
Expand Down Expand Up @@ -1618,6 +1640,7 @@ def __init__(
*,
query_id_place_holder: Optional[str] = None,
is_ddl_on_temp_object: bool = False,
temp_obj_name_placeholder: Optional[Tuple[str, TempObjectType]] = None,
params: Optional[Sequence[Any]] = None,
) -> None:
self.sql = sql
Expand All @@ -1626,6 +1649,16 @@ def __init__(
if query_id_place_holder
else f"query_id_place_holder_{generate_random_alphanumeric()}"
)
# This is to handle the case when a snowflake plan is created in the following way
# in a multi-threaded environment:
# 1. Create a temp object
# 2. Use the temp object in a query
# 3. Drop the temp object
# When step 3 in thread A is executed before step 2 in thread B, the query in thread B will fail with
# temp object not found. To handle this, we replace temp object names with placeholders in the query
# and track the temp object placeholder name and temp object type here. During query execution, we replace
# the placeholders with the actual temp object names for the given execution.
self.temp_obj_name_placeholder = temp_obj_name_placeholder
self.is_ddl_on_temp_object = is_ddl_on_temp_object
self.params = params or []

Expand All @@ -1644,6 +1677,7 @@ def __eq__(self, other: "Query") -> bool:
self.sql == other.sql
and self.query_id_place_holder == other.query_id_place_holder
and self.is_ddl_on_temp_object == other.is_ddl_on_temp_object
and self.temp_obj_name_placeholder == other.temp_obj_name_placeholder
and self.params == other.params
)

Expand Down
55 changes: 53 additions & 2 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
plot_plan_if_enabled,
)
from snowflake.snowpark._internal.telemetry import TelemetryField
from snowflake.snowpark._internal.utils import random_name_for_temp_object
from snowflake.snowpark.mock._connection import MockServerConnection


Expand Down Expand Up @@ -156,10 +157,60 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
plan_uuid=self._plan.uuid,
compilation_stage_summary=summary_value,
)
return queries
else:
final_plan = self._plan
return {
queries = {
PlanQueryType.QUERIES: final_plan.queries,
PlanQueryType.POST_ACTIONS: final_plan.post_actions,
}

return self.replace_temp_obj_placeholders(queries)

def replace_temp_obj_placeholders(
self, queries: Dict[PlanQueryType, List[Query]]
) -> Dict[PlanQueryType, List[Query]]:
"""
When thread-safe session is enabled, we use temporary object name placeholders instead of a temporary name
when generating snowflake plan. We replace the temporary object name placeholders with actual temporary object
names here. This is done to prevent the following scenario:
1. A dataframe is created and resolved in main thread.
2. The resolve plan contains queries that create and drop temp objects.
3. If the plan with same temp object names is executed my multiple threads, the temp object names will conflict.
One thread can drop the object before another thread finished using it.
To prevent this, we generate queries with temp object name placeholders and replace them with actual temp object
here.
"""
session = self._plan.session
if session._conn._thread_safe_session_enabled:
# This dictionary will store the mapping between placeholder name and actual temp object name.
placeholders = {}
# Final execution queries
execution_queries = {}
for query_type, query_list in queries.items():
execution_queries[query_type] = []
for query in query_list:
# If the query contains a temp object name placeholder, we generate a random
# name for the temp object and add it to the placeholders dictionary.
if query.temp_obj_name_placeholder:
(
placeholder_name,
temp_obj_type,
) = query.temp_obj_name_placeholder
placeholders[placeholder_name] = random_name_for_temp_object(
temp_obj_type
)

copied_query = copy.copy(query)
for placeholder_name, target_temp_name in placeholders.items():
# Copy the original query and replace all the placeholder names with the
# actual temp object names.
copied_query.sql = copied_query.sql.replace(
placeholder_name, target_temp_name
)

execution_queries[query_type].append(copied_query)
return execution_queries

return queries
9 changes: 7 additions & 2 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from snowflake.snowpark._internal.analyzer.snowflake_plan import (
BatchInsertQuery,
PlanQueryType,
Query,
SnowflakePlan,
)
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
Expand Down Expand Up @@ -459,6 +460,7 @@ def run_query(
params: Optional[Sequence[Any]] = None,
num_statements: Optional[int] = None,
ignore_results: bool = False,
async_post_actions: Optional[List[Query]] = None,
**kwargs,
) -> Union[Dict[str, Any], AsyncJob]:
try:
Expand Down Expand Up @@ -502,7 +504,7 @@ def run_query(
query,
async_job_plan.session,
data_type,
async_job_plan.post_actions,
async_post_actions,
log_on_exception,
case_sensitive=case_sensitive,
num_statements=num_statements,
Expand Down Expand Up @@ -633,6 +635,7 @@ def get_result_set(
kwargs["_statement_params"] = statement_params
try:
main_queries = plan_queries[PlanQueryType.QUERIES]
post_actions = plan_queries[PlanQueryType.POST_ACTIONS]
placeholders = {}
is_batch_insert = False
for q in main_queries:
Expand Down Expand Up @@ -666,6 +669,7 @@ def get_result_set(
num_statements=len(main_queries),
params=params,
ignore_results=ignore_results,
async_post_actions=post_actions,
**kwargs,
)

Expand Down Expand Up @@ -695,6 +699,7 @@ def get_result_set(
case_sensitive=case_sensitive,
params=query.params,
ignore_results=ignore_results,
async_post_actions=post_actions,
**kwargs,
)
placeholders[query.query_id_place_holder] = (
Expand All @@ -706,7 +711,7 @@ def get_result_set(
finally:
# delete created tmp object
if block:
for action in plan_queries[PlanQueryType.POST_ACTIONS]:
for action in post_actions:
self.run_query(
action.sql,
is_ddl_on_temp_object=action.is_ddl_on_temp_object,
Expand Down
14 changes: 11 additions & 3 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,12 +478,20 @@ def test_optimization_skipped_with_no_active_db_or_schema(
assert called_with_reason == f"no active {db_or_schema}"


def test_async_job_with_large_query_breakdown(session, large_query_df):
def test_async_job_with_large_query_breakdown(large_query_df):
"""Test large query breakdown gives same result for async and non-async jobs"""
with SqlCounter(query_count=2):
with SqlCounter(query_count=3):
# 1 for current transaction
# 1 for created temp table; main query submitted as multi-statement query
# 1 for post action
job = large_query_df.collect(block=False)
result = job.result()
assert result == large_query_df.collect()
with SqlCounter(query_count=4):
# 1 for current transaction
# 1 for created temp table
# 1 for main query
# 1 for post action
assert result == large_query_df.collect()
assert len(large_query_df.queries["queries"]) == 2
assert large_query_df.queries["queries"][0].startswith(
"CREATE SCOPED TEMPORARY TABLE"
Expand Down
Loading

0 comments on commit 37fe603

Please sign in to comment.