Skip to content

[SNOW-1844465] Avoid creating a CTE out of simple select start on top of a select entity #2713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 46 additions & 21 deletions src/snowflake/snowpark/_internal/compiler/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,41 @@ def find_duplicate_subtrees(

This function is used to only include nodes that should be converted to CTEs.
"""
id_count_map = defaultdict(int)
id_node_map = defaultdict(list)
id_parents_map = defaultdict(set)
id_complexity_map = defaultdict(int)

from snowflake.snowpark._internal.analyzer.select_statement import (
Selectable,
SelectStatement,
SelectableEntity,
SelectSnowflakePlan,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan

def is_simple_select_entity(node: "TreeNode") -> bool:
"""
Check if the current node is a simple select on top of a SelectEntity, for example:
select * from TABLE. This check only works with selectable when sql simplifier is enabled.
"""
if isinstance(node, SelectableEntity):
return True
if (
isinstance(node, SelectStatement)
and (node.projection is None)
and isinstance(node.from_, SelectableEntity)
):
return True
if (
isinstance(node, SnowflakePlan)
and (node.source_plan is not None)
and isinstance(node.source_plan, (SnowflakePlan, Selectable))
):
return is_simple_select_entity(node.source_plan)

if isinstance(node, SelectSnowflakePlan):
return is_simple_select_entity(node.snowflake_plan)

return False

def traverse(root: "TreeNode") -> None:
"""
Expand All @@ -57,15 +89,7 @@ def traverse(root: "TreeNode") -> None:
while len(current_level) > 0:
next_level = []
for node in current_level:
id_count_map[node.encoded_node_id_with_query] += 1
if propagate_complexity_hist and (
node.encoded_node_id_with_query not in id_complexity_map
):
# if propagate_complexity_hist is true, and the complexity score is not
# recorded for the current node id, record the complexity
id_complexity_map[
node.encoded_node_id_with_query
] = get_complexity_score(node)
id_node_map[node.encoded_node_id_with_query].append(node)
for child in node.children_plan_nodes:
id_parents_map[child.encoded_node_id_with_query].add(
node.encoded_node_id_with_query
Expand All @@ -77,13 +101,15 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:
# when a sql query is a select statement, its encoded_node_id_with_query
# contains _, which is used to separate the query id and node type name.
is_valid_candidate = "_" in encoded_node_id_with_query
if not is_valid_candidate:
if not is_valid_candidate or is_simple_select_entity(
id_node_map[encoded_node_id_with_query][0]
):
return False

is_duplicate_node = id_count_map[encoded_node_id_with_query] > 1
is_duplicate_node = len(id_node_map[encoded_node_id_with_query]) > 1
if is_duplicate_node:
is_any_parent_unique_node = any(
id_count_map[id] == 1
len(id_node_map[id]) == 1
for id in id_parents_map[encoded_node_id_with_query]
)
if is_any_parent_unique_node:
Expand All @@ -97,15 +123,15 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:
traverse(root)
duplicated_node_ids = {
encoded_node_id_with_query
for encoded_node_id_with_query in id_count_map
for encoded_node_id_with_query in id_node_map
if is_duplicate_subtree(encoded_node_id_with_query)
}

if propagate_complexity_hist:
return (
duplicated_node_ids,
get_duplicated_node_complexity_distribution(
duplicated_node_ids, id_complexity_map, id_count_map
duplicated_node_ids, id_node_map
),
)
else:
Expand All @@ -114,8 +140,7 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:

def get_duplicated_node_complexity_distribution(
duplicated_node_id_set: Set[str],
id_complexity_map: Dict[str, int],
id_count_map: Dict[str, int],
id_node_map: Dict[str, List["TreeNode"]],
) -> List[int]:
"""
Calculate the complexity distribution for the detected repeated node. The complexity are categorized as following:
Expand All @@ -131,8 +156,8 @@ def get_duplicated_node_complexity_distribution(
"""
node_complexity_dist = [0] * 7
for node_id in duplicated_node_id_set:
complexity_score = id_complexity_map[node_id]
repeated_count = id_count_map[node_id]
complexity_score = get_complexity_score(id_node_map[node_id][0])
repeated_count = len(id_node_map[node_id])
if complexity_score <= 10000:
node_complexity_dist[0] += repeated_count
elif 10000 < complexity_score <= 100000:
Expand All @@ -151,7 +176,7 @@ def get_duplicated_node_complexity_distribution(
return node_complexity_dist


def encode_query_id(node) -> Optional[str]:
def encode_query_id(node: "TreeNode") -> Optional[str]:
"""
Encode the query and its query parameter into an id using sha256.

Expand Down
24 changes: 22 additions & 2 deletions tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,13 +700,14 @@ def test_table(session):
check_result(
session,
df_result,
expect_cte_optimized=True,
expect_cte_optimized=False if session.sql_simplifier_enabled else True,
query_count=1,
describe_count=0,
union_count=1,
join_count=0,
)
assert count_number_of_ctes(df_result.queries["queries"][-1]) == 1
if not session.sql_simplifier_enabled:
assert count_number_of_ctes(df_result.queries["queries"][-1]) == 1


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1005,6 +1006,25 @@ def test_time_series_aggregation_grouping(session, enable_sql_simplifier):
session.sql_simplifier_enabled = original_sql_simplifier_enabled


def test_table_select_cte(session):
table_name = random_name_for_temp_object(TempObjectType.TABLE)
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
df.write.save_as_table(table_name, table_type="temp")
df = session.table(table_name)
df_result = df.with_column("add_one", col("a") + 1).union(
df.with_column("add_two", col("a") + 2)
)
check_result(
session,
df_result,
expect_cte_optimized=False if session.sql_simplifier_enabled else True,
query_count=1,
describe_count=0,
union_count=1,
join_count=0,
)


@pytest.mark.skipif(
IS_IN_STORED_PROC, reason="SNOW-609328: support caplog in SP regression test"
)
Expand Down
Loading