Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNOW-1731783] Refactor node query comparison for Repeated subquery elimination #2437

Merged
merged 9 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 80 additions & 29 deletions src/snowflake/snowpark/_internal/analyzer/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hashlib
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Tuple, Union
from typing import TYPE_CHECKING, Optional, Set, Union

from snowflake.snowpark._internal.analyzer.analyzer_utils import (
SPACE,
Expand All @@ -24,11 +24,9 @@
TreeNode = Union[SnowflakePlan, Selectable]


def find_duplicate_subtrees(
root: "TreeNode",
) -> Tuple[Set["TreeNode"], Dict["TreeNode", Set["TreeNode"]]]:
def find_duplicate_subtrees(root: "TreeNode") -> Set[str]:
"""
Returns a set containing all duplicate subtrees in query plan tree.
Returns a set of TreeNode encoded_id that indicates all duplicate subtrees in query plan tree.
The root of a duplicate subtree is defined as a duplicate node, if
- it appears more than once in the tree, AND
- one of its parent is unique (only appear once) in the tree, OR
Expand All @@ -49,8 +47,8 @@ def find_duplicate_subtrees(

This function is used to only include nodes that should be converted to CTEs.
"""
node_count_map = defaultdict(int)
node_parents_map = defaultdict(set)
id_count_map = defaultdict(int)
id_parents_map = defaultdict(set)

def traverse(root: "TreeNode") -> None:
"""
Expand All @@ -60,32 +58,39 @@ def traverse(root: "TreeNode") -> None:
while len(current_level) > 0:
next_level = []
for node in current_level:
node_count_map[node] += 1
id_count_map[node.encoded_node_id_with_query] += 1
for child in node.children_plan_nodes:
node_parents_map[child].add(node)
id_parents_map[child.encoded_node_id_with_query].add(
node.encoded_node_id_with_query
)
next_level.append(child)
current_level = next_level

def is_duplicate_subtree(node: "TreeNode") -> bool:
is_duplicate_node = node_count_map[node] > 1
def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:
is_duplicate_node = id_count_map[encoded_node_id_with_query] > 1
if is_duplicate_node:
is_any_parent_unique_node = any(
node_count_map[n] == 1 for n in node_parents_map[node]
id_count_map[id] == 1
for id in id_parents_map[encoded_node_id_with_query]
)
if is_any_parent_unique_node:
return True
else:
has_multi_parents = len(node_parents_map[node]) > 1
has_multi_parents = len(id_parents_map[encoded_node_id_with_query]) > 1
if has_multi_parents:
return True
return False

traverse(root)
duplicated_node = {node for node in node_count_map if is_duplicate_subtree(node)}
return duplicated_node, node_parents_map
duplicated_node = {
encoded_node_id_with_query
for encoded_node_id_with_query in id_count_map
if is_duplicate_subtree(encoded_node_id_with_query)
}
return duplicated_node


def create_cte_query(root: "TreeNode", duplicate_plan_set: Set["TreeNode"]) -> str:
def create_cte_query(root: "TreeNode", duplicated_node_ids: Set[str]) -> str:
from snowflake.snowpark._internal.analyzer.select_statement import Selectable

plan_to_query_map = {}
Expand All @@ -110,32 +115,41 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None:

while stack2:
node = stack2.pop()
if node in plan_to_query_map:
if node.encoded_node_id_with_query in plan_to_query_map:
continue

if not node.children_plan_nodes or not node.placeholder_query:
plan_to_query_map[node] = (
plan_to_query_map[node.encoded_node_id_with_query] = (
node.sql_query
if isinstance(node, Selectable)
else node.queries[-1].sql
)
else:
plan_to_query_map[node] = node.placeholder_query
plan_to_query_map[
node.encoded_node_id_with_query
] = node.placeholder_query
for child in node.children_plan_nodes:
# replace the placeholder (id) with child query
plan_to_query_map[node] = plan_to_query_map[node].replace(
child._id, plan_to_query_map[child]
plan_to_query_map[
node.encoded_node_id_with_query
] = plan_to_query_map[node.encoded_node_id_with_query].replace(
child.encoded_query_id,
plan_to_query_map[child.encoded_node_id_with_query],
)

# duplicate subtrees will be converted CTEs
if node in duplicate_plan_set:
if node.encoded_node_id_with_query in duplicated_node_ids:
# when a subquery is converted a CTE to with clause,
# it will be replaced by `SELECT * from TEMP_TABLE` in the original query
table_name = random_name_for_temp_object(TempObjectType.CTE)
select_stmt = project_statement([], table_name)
duplicate_plan_to_table_name_map[node] = table_name
duplicate_plan_to_cte_map[node] = plan_to_query_map[node]
plan_to_query_map[node] = select_stmt
duplicate_plan_to_table_name_map[
node.encoded_node_id_with_query
] = table_name
duplicate_plan_to_cte_map[
node.encoded_node_id_with_query
] = plan_to_query_map[node.encoded_node_id_with_query]
plan_to_query_map[node.encoded_node_id_with_query] = select_stmt

build_plan_to_query_map_in_post_order(root)

Expand All @@ -144,16 +158,53 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None:
list(duplicate_plan_to_cte_map.values()),
list(duplicate_plan_to_table_name_map.values()),
)
final_query = with_stmt + SPACE + plan_to_query_map[root]
final_query = with_stmt + SPACE + plan_to_query_map[root.encoded_node_id_with_query]
return final_query


def encode_id(
query: str, query_params: Optional[Sequence[Any]] = None
) -> Optional[str]:
def encoded_query_id(node) -> Optional[str]:
"""
Encode the query and its query parameter into an id using sha256.


Returns:
If encode succeed, return the first 10 encoded value.
Otherwise, return None
"""
from snowflake.snowpark._internal.analyzer.select_statement import SelectSQL
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan

if isinstance(node, SnowflakePlan):
query = node.queries[-1].sql
query_params = node.queries[-1].params
elif isinstance(node, SelectSQL):
# For SelectSql, The original SQL is used to encode its ID,
# which might be a non-select SQL.
query = node.original_sql
query_params = node.query_params
else:
query = node.sql_query
query_params = node.query_params

string = f"{query}#{query_params}" if query_params else query
try:
return hashlib.sha256(string.encode()).hexdigest()[:10]
except Exception as ex:
logging.warning(f"Encode SnowflakePlan ID failed: {ex}")
return None


def encode_node_id_with_query(node: "TreeNode") -> str:
"""
Encode a for the given TreeNode.

If query and query parameters can be encoded successfully using sha256,
return the encoded query id + node_type_name.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need node_type_name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is following the previous equivalence check, we require the node type to be the same, for example, a snowflake plan node and selectstatement node with the same query is not counted as the same

Otherwise, return the original node id.
"""
query_id = encoded_query_id(node)
if query_id is not None:
node_type_name = type(node).__name__
return f"{query_id}_{node_type_name}"
else:
return str(id(node))
57 changes: 26 additions & 31 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
)

import snowflake.snowpark._internal.utils
from snowflake.snowpark._internal.analyzer.cte_utils import encode_id
from snowflake.snowpark._internal.analyzer.cte_utils import (
encode_node_id_with_query,
encoded_query_id,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
PlanState,
Expand Down Expand Up @@ -239,17 +242,6 @@ def __init__(
self._api_calls = api_calls.copy() if api_calls is not None else None
self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None

def __eq__(self, other: "Selectable") -> bool:
if not isinstance(other, Selectable):
return False
if self._id is not None and other._id is not None:
return type(self) is type(other) and self._id == other._id
else:
return super().__eq__(other)

def __hash__(self) -> int:
return hash(self._id) if self._id else super().__hash__()

@property
@abstractmethod
def sql_query(self) -> str:
Expand All @@ -263,9 +255,20 @@ def placeholder_query(self) -> Optional[str]:
pass

@cached_property
def _id(self) -> Optional[str]:
"""Returns the id of this Selectable logical plan."""
return encode_id(self.sql_query, self.query_params)
def encoded_node_id_with_query(self) -> str:
"""
Returns an encoded node id of this Selectable logical plan.

Note that the encoding algorithm uses queries as content, and returns the same id for
two selectable node with same queries. This is currently used by repeated subquery
elimination to detect two nodes with same query, please use it with careful.
"""
return encode_node_id_with_query(self)

@cached_property
def encoded_query_id(self) -> Optional[str]:
"""Returns an encoded id of the queries for this Selectable logical plan."""
return encoded_query_id(self)

@property
@abstractmethod
Expand Down Expand Up @@ -506,14 +509,6 @@ def sql_query(self) -> str:
def placeholder_query(self) -> Optional[str]:
return None

@property
def _id(self) -> Optional[str]:
"""
Returns the id of this SelectSQL logical plan. The original SQL is used to encode its ID,
which might be a non-select SQL.
"""
return encode_id(self.original_sql, self.query_params)

@property
def query_params(self) -> Optional[Sequence[Any]]:
return self._query_param
Expand Down Expand Up @@ -591,9 +586,9 @@ def sql_query(self) -> str:
def placeholder_query(self) -> Optional[str]:
return self._snowflake_plan.placeholder_query

@property
def _id(self) -> Optional[str]:
return self._snowflake_plan._id
@cached_property
def encoded_query_id(self) -> Optional[str]:
return self._snowflake_plan.encoded_query_id

@property
def schema_query(self) -> Optional[str]:
Expand Down Expand Up @@ -793,9 +788,9 @@ def sql_query(self) -> str:
if (
self.analyzer.session._cte_optimization_enabled
and (not self.analyzer.session._query_compilation_stage_enabled)
and self.from_._id
and self.from_.encoded_query_id
):
placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_._id}{analyzer_utils.RIGHT_PARENTHESIS}"
placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_query_id}{analyzer_utils.RIGHT_PARENTHESIS}"
self._sql_query = self.placeholder_query.replace(placeholder, from_clause)
else:
where_clause = (
Expand Down Expand Up @@ -825,7 +820,7 @@ def sql_query(self) -> str:
def placeholder_query(self) -> str:
if self._placeholder_query:
return self._placeholder_query
from_clause = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_._id}{analyzer_utils.RIGHT_PARENTHESIS}"
from_clause = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_query_id}{analyzer_utils.RIGHT_PARENTHESIS}"
if not self.has_clause and not self.projection:
self._placeholder_query = from_clause
return self._placeholder_query
Expand Down Expand Up @@ -1429,9 +1424,9 @@ def sql_query(self) -> str:
@property
def placeholder_query(self) -> Optional[str]:
if not self._placeholder_query:
sql = f"({self.set_operands[0].selectable._id})"
sql = f"({self.set_operands[0].selectable.encoded_query_id})"
for i in range(1, len(self.set_operands)):
sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable._id})"
sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable.encoded_query_id})"
self._placeholder_query = sql
return self._placeholder_query

Expand Down
39 changes: 17 additions & 22 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@
)
from snowflake.snowpark._internal.analyzer.cte_utils import (
create_cte_query,
encode_id,
encode_node_id_with_query,
encoded_query_id,
find_duplicate_subtrees,
)
from snowflake.snowpark._internal.analyzer.expression import Attribute
Expand Down Expand Up @@ -256,9 +257,13 @@ def __init__(
# It is used for optimization, by replacing a subquery with a CTE
self.placeholder_query = placeholder_query
# encode an id for CTE optimization. This is generated based on the main
# query and the associated query parameters. We use this id for equality comparison
# to determine if two plans are the same.
self._id = encode_id(queries[-1].sql, queries[-1].params)
# query, query parameters and the node type. We use this id for equality
# comparison to determine if two plans are the same.
self.encoded_node_id_with_query = encode_node_id_with_query(self)
# encode id for the main query and query parameters, this is currently only used
# by the create_cte_query process.
# TODO (SNOW-1541096) remove this filed along removing the old cte implementation
self.encoded_query_id = encoded_query_id(self)
self.referenced_ctes: Set[WithQueryBlock] = (
referenced_ctes.copy() if referenced_ctes else set()
)
Expand All @@ -267,17 +272,6 @@ def __init__(
# to UUID track queries that are generated from the same plan.
self._uuid = str(uuid.uuid4())

def __eq__(self, other: "SnowflakePlan") -> bool:
if not isinstance(other, SnowflakePlan):
return False
if self._id is not None and other._id is not None:
return isinstance(other, SnowflakePlan) and self._id == other._id
else:
return super().__eq__(other)

def __hash__(self) -> int:
return hash(self._id) if self._id else super().__hash__()

@property
def uuid(self) -> str:
return self._uuid
Expand Down Expand Up @@ -349,7 +343,7 @@ def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan":
return self

# if there is no duplicate node, no optimization will be performed
duplicate_plan_set, _ = find_duplicate_subtrees(self)
duplicate_plan_set = find_duplicate_subtrees(self)
if not duplicate_plan_set:
return self

Expand Down Expand Up @@ -418,7 +412,7 @@ def output_dict(self) -> Dict[str, Any]:

@cached_property
def num_duplicate_nodes(self) -> int:
duplicated_nodes, _ = find_duplicate_subtrees(self)
duplicated_nodes = find_duplicate_subtrees(self)
return len(duplicated_nodes)

@cached_property
Expand Down Expand Up @@ -590,8 +584,9 @@ def build(
new_schema_query = schema_query or sql_generator(child.schema_query)

placeholder_query = (
sql_generator(select_child._id)
if self.session._cte_optimization_enabled and select_child._id is not None
sql_generator(select_child.encoded_query_id)
if self.session._cte_optimization_enabled
and select_child.encoded_query_id is not None
else None
)

Expand Down Expand Up @@ -629,10 +624,10 @@ def build_binary(
schema_query = sql_generator(left_schema_query, right_schema_query)

placeholder_query = (
sql_generator(select_left._id, select_right._id)
sql_generator(select_left.encoded_query_id, select_right.encoded_query_id)
if self.session._cte_optimization_enabled
and select_left._id is not None
and select_right._id is not None
and select_left.encoded_query_id is not None
and select_right.encoded_query_id is not None
else None
)

Expand Down
Loading
Loading