diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py index 12c5390e5b..eee666f155 100644 --- a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py @@ -5,7 +5,7 @@ import hashlib import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Tuple, Union +from typing import TYPE_CHECKING, Optional, Set, Union from snowflake.snowpark._internal.analyzer.analyzer_utils import ( SPACE, @@ -24,11 +24,9 @@ TreeNode = Union[SnowflakePlan, Selectable] -def find_duplicate_subtrees( - root: "TreeNode", -) -> Tuple[Set["TreeNode"], Dict["TreeNode", Set["TreeNode"]]]: +def find_duplicate_subtrees(root: "TreeNode") -> Set[str]: """ - Returns a set containing all duplicate subtrees in query plan tree. + Returns a set of TreeNode encoded_id that indicates all duplicate subtrees in query plan tree. The root of a duplicate subtree is defined as a duplicate node, if - it appears more than once in the tree, AND - one of its parent is unique (only appear once) in the tree, OR @@ -49,8 +47,8 @@ def find_duplicate_subtrees( This function is used to only include nodes that should be converted to CTEs. """ - node_count_map = defaultdict(int) - node_parents_map = defaultdict(set) + id_count_map = defaultdict(int) + id_parents_map = defaultdict(set) def traverse(root: "TreeNode") -> None: """ @@ -60,32 +58,39 @@ def traverse(root: "TreeNode") -> None: while len(current_level) > 0: next_level = [] for node in current_level: - node_count_map[node] += 1 + id_count_map[node.encoded_node_id_with_query] += 1 for child in node.children_plan_nodes: - node_parents_map[child].add(node) + id_parents_map[child.encoded_node_id_with_query].add( + node.encoded_node_id_with_query + ) next_level.append(child) current_level = next_level - def is_duplicate_subtree(node: "TreeNode") -> bool: - is_duplicate_node = node_count_map[node] > 1 + def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: + is_duplicate_node = id_count_map[encoded_node_id_with_query] > 1 if is_duplicate_node: is_any_parent_unique_node = any( - node_count_map[n] == 1 for n in node_parents_map[node] + id_count_map[id] == 1 + for id in id_parents_map[encoded_node_id_with_query] ) if is_any_parent_unique_node: return True else: - has_multi_parents = len(node_parents_map[node]) > 1 + has_multi_parents = len(id_parents_map[encoded_node_id_with_query]) > 1 if has_multi_parents: return True return False traverse(root) - duplicated_node = {node for node in node_count_map if is_duplicate_subtree(node)} - return duplicated_node, node_parents_map + duplicated_node = { + encoded_node_id_with_query + for encoded_node_id_with_query in id_count_map + if is_duplicate_subtree(encoded_node_id_with_query) + } + return duplicated_node -def create_cte_query(root: "TreeNode", duplicate_plan_set: Set["TreeNode"]) -> str: +def create_cte_query(root: "TreeNode", duplicated_node_ids: Set[str]) -> str: from snowflake.snowpark._internal.analyzer.select_statement import Selectable plan_to_query_map = {} @@ -110,32 +115,41 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: while stack2: node = stack2.pop() - if node in plan_to_query_map: + if node.encoded_node_id_with_query in plan_to_query_map: continue if not node.children_plan_nodes or not node.placeholder_query: - plan_to_query_map[node] = ( + plan_to_query_map[node.encoded_node_id_with_query] = ( node.sql_query if isinstance(node, Selectable) else node.queries[-1].sql ) else: - plan_to_query_map[node] = node.placeholder_query + plan_to_query_map[ + node.encoded_node_id_with_query + ] = node.placeholder_query for child in node.children_plan_nodes: # replace the placeholder (id) with child query - plan_to_query_map[node] = plan_to_query_map[node].replace( - child._id, plan_to_query_map[child] + plan_to_query_map[ + node.encoded_node_id_with_query + ] = plan_to_query_map[node.encoded_node_id_with_query].replace( + child.encoded_query_id, + plan_to_query_map[child.encoded_node_id_with_query], ) # duplicate subtrees will be converted CTEs - if node in duplicate_plan_set: + if node.encoded_node_id_with_query in duplicated_node_ids: # when a subquery is converted a CTE to with clause, # it will be replaced by `SELECT * from TEMP_TABLE` in the original query table_name = random_name_for_temp_object(TempObjectType.CTE) select_stmt = project_statement([], table_name) - duplicate_plan_to_table_name_map[node] = table_name - duplicate_plan_to_cte_map[node] = plan_to_query_map[node] - plan_to_query_map[node] = select_stmt + duplicate_plan_to_table_name_map[ + node.encoded_node_id_with_query + ] = table_name + duplicate_plan_to_cte_map[ + node.encoded_node_id_with_query + ] = plan_to_query_map[node.encoded_node_id_with_query] + plan_to_query_map[node.encoded_node_id_with_query] = select_stmt build_plan_to_query_map_in_post_order(root) @@ -144,16 +158,53 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: list(duplicate_plan_to_cte_map.values()), list(duplicate_plan_to_table_name_map.values()), ) - final_query = with_stmt + SPACE + plan_to_query_map[root] + final_query = with_stmt + SPACE + plan_to_query_map[root.encoded_node_id_with_query] return final_query -def encode_id( - query: str, query_params: Optional[Sequence[Any]] = None -) -> Optional[str]: +def encoded_query_id(node) -> Optional[str]: + """ + Encode the query and its query parameter into an id using sha256. + + + Returns: + If encode succeed, return the first 10 encoded value. + Otherwise, return None + """ + from snowflake.snowpark._internal.analyzer.select_statement import SelectSQL + from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan + + if isinstance(node, SnowflakePlan): + query = node.queries[-1].sql + query_params = node.queries[-1].params + elif isinstance(node, SelectSQL): + # For SelectSql, The original SQL is used to encode its ID, + # which might be a non-select SQL. + query = node.original_sql + query_params = node.query_params + else: + query = node.sql_query + query_params = node.query_params + string = f"{query}#{query_params}" if query_params else query try: return hashlib.sha256(string.encode()).hexdigest()[:10] except Exception as ex: logging.warning(f"Encode SnowflakePlan ID failed: {ex}") return None + + +def encode_node_id_with_query(node: "TreeNode") -> str: + """ + Encode a for the given TreeNode. + + If query and query parameters can be encoded successfully using sha256, + return the encoded query id + node_type_name. + Otherwise, return the original node id. + """ + query_id = encoded_query_id(node) + if query_id is not None: + node_type_name = type(node).__name__ + return f"{query_id}_{node_type_name}" + else: + return str(id(node)) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 4d154c9ae0..e421c23218 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -22,7 +22,10 @@ ) import snowflake.snowpark._internal.utils -from snowflake.snowpark._internal.analyzer.cte_utils import encode_id +from snowflake.snowpark._internal.analyzer.cte_utils import ( + encode_node_id_with_query, + encoded_query_id, +) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, PlanState, @@ -239,17 +242,6 @@ def __init__( self._api_calls = api_calls.copy() if api_calls is not None else None self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None - def __eq__(self, other: "Selectable") -> bool: - if not isinstance(other, Selectable): - return False - if self._id is not None and other._id is not None: - return type(self) is type(other) and self._id == other._id - else: - return super().__eq__(other) - - def __hash__(self) -> int: - return hash(self._id) if self._id else super().__hash__() - @property @abstractmethod def sql_query(self) -> str: @@ -263,9 +255,20 @@ def placeholder_query(self) -> Optional[str]: pass @cached_property - def _id(self) -> Optional[str]: - """Returns the id of this Selectable logical plan.""" - return encode_id(self.sql_query, self.query_params) + def encoded_node_id_with_query(self) -> str: + """ + Returns an encoded node id of this Selectable logical plan. + + Note that the encoding algorithm uses queries as content, and returns the same id for + two selectable node with same queries. This is currently used by repeated subquery + elimination to detect two nodes with same query, please use it with careful. + """ + return encode_node_id_with_query(self) + + @cached_property + def encoded_query_id(self) -> Optional[str]: + """Returns an encoded id of the queries for this Selectable logical plan.""" + return encoded_query_id(self) @property @abstractmethod @@ -506,14 +509,6 @@ def sql_query(self) -> str: def placeholder_query(self) -> Optional[str]: return None - @property - def _id(self) -> Optional[str]: - """ - Returns the id of this SelectSQL logical plan. The original SQL is used to encode its ID, - which might be a non-select SQL. - """ - return encode_id(self.original_sql, self.query_params) - @property def query_params(self) -> Optional[Sequence[Any]]: return self._query_param @@ -591,9 +586,9 @@ def sql_query(self) -> str: def placeholder_query(self) -> Optional[str]: return self._snowflake_plan.placeholder_query - @property - def _id(self) -> Optional[str]: - return self._snowflake_plan._id + @cached_property + def encoded_query_id(self) -> Optional[str]: + return self._snowflake_plan.encoded_query_id @property def schema_query(self) -> Optional[str]: @@ -793,9 +788,9 @@ def sql_query(self) -> str: if ( self.analyzer.session._cte_optimization_enabled and (not self.analyzer.session._query_compilation_stage_enabled) - and self.from_._id + and self.from_.encoded_query_id ): - placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_._id}{analyzer_utils.RIGHT_PARENTHESIS}" + placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_query_id}{analyzer_utils.RIGHT_PARENTHESIS}" self._sql_query = self.placeholder_query.replace(placeholder, from_clause) else: where_clause = ( @@ -825,7 +820,7 @@ def sql_query(self) -> str: def placeholder_query(self) -> str: if self._placeholder_query: return self._placeholder_query - from_clause = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_._id}{analyzer_utils.RIGHT_PARENTHESIS}" + from_clause = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_query_id}{analyzer_utils.RIGHT_PARENTHESIS}" if not self.has_clause and not self.projection: self._placeholder_query = from_clause return self._placeholder_query @@ -1429,9 +1424,9 @@ def sql_query(self) -> str: @property def placeholder_query(self) -> Optional[str]: if not self._placeholder_query: - sql = f"({self.set_operands[0].selectable._id})" + sql = f"({self.set_operands[0].selectable.encoded_query_id})" for i in range(1, len(self.set_operands)): - sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable._id})" + sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable.encoded_query_id})" self._placeholder_query = sql return self._placeholder_query diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index c7729d991f..652ab71f74 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -85,7 +85,8 @@ ) from snowflake.snowpark._internal.analyzer.cte_utils import ( create_cte_query, - encode_id, + encode_node_id_with_query, + encoded_query_id, find_duplicate_subtrees, ) from snowflake.snowpark._internal.analyzer.expression import Attribute @@ -256,9 +257,13 @@ def __init__( # It is used for optimization, by replacing a subquery with a CTE self.placeholder_query = placeholder_query # encode an id for CTE optimization. This is generated based on the main - # query and the associated query parameters. We use this id for equality comparison - # to determine if two plans are the same. - self._id = encode_id(queries[-1].sql, queries[-1].params) + # query, query parameters and the node type. We use this id for equality + # comparison to determine if two plans are the same. + self.encoded_node_id_with_query = encode_node_id_with_query(self) + # encode id for the main query and query parameters, this is currently only used + # by the create_cte_query process. + # TODO (SNOW-1541096) remove this filed along removing the old cte implementation + self.encoded_query_id = encoded_query_id(self) self.referenced_ctes: Set[WithQueryBlock] = ( referenced_ctes.copy() if referenced_ctes else set() ) @@ -267,17 +272,6 @@ def __init__( # to UUID track queries that are generated from the same plan. self._uuid = str(uuid.uuid4()) - def __eq__(self, other: "SnowflakePlan") -> bool: - if not isinstance(other, SnowflakePlan): - return False - if self._id is not None and other._id is not None: - return isinstance(other, SnowflakePlan) and self._id == other._id - else: - return super().__eq__(other) - - def __hash__(self) -> int: - return hash(self._id) if self._id else super().__hash__() - @property def uuid(self) -> str: return self._uuid @@ -349,7 +343,7 @@ def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan": return self # if there is no duplicate node, no optimization will be performed - duplicate_plan_set, _ = find_duplicate_subtrees(self) + duplicate_plan_set = find_duplicate_subtrees(self) if not duplicate_plan_set: return self @@ -418,7 +412,7 @@ def output_dict(self) -> Dict[str, Any]: @cached_property def num_duplicate_nodes(self) -> int: - duplicated_nodes, _ = find_duplicate_subtrees(self) + duplicated_nodes = find_duplicate_subtrees(self) return len(duplicated_nodes) @cached_property @@ -590,8 +584,9 @@ def build( new_schema_query = schema_query or sql_generator(child.schema_query) placeholder_query = ( - sql_generator(select_child._id) - if self.session._cte_optimization_enabled and select_child._id is not None + sql_generator(select_child.encoded_query_id) + if self.session._cte_optimization_enabled + and select_child.encoded_query_id is not None else None ) @@ -629,10 +624,10 @@ def build_binary( schema_query = sql_generator(left_schema_query, right_schema_query) placeholder_query = ( - sql_generator(select_left._id, select_right._id) + sql_generator(select_left.encoded_query_id, select_right.encoded_query_id) if self.session._cte_optimization_enabled - and select_left._id is not None - and select_right._id is not None + and select_left.encoded_query_id is not None + and select_right.encoded_query_id is not None else None ) diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py index d979b8a93a..afb9626673 100644 --- a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -2,9 +2,11 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from collections import defaultdict from typing import Dict, List, Optional, Set from snowflake.snowpark._internal.analyzer.cte_utils import find_duplicate_subtrees +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( LogicalPlan, WithQueryBlock, @@ -87,10 +89,10 @@ def apply(self) -> RepeatedSubqueryEliminationResult: logical_plan = self._query_generator.resolve(logical_plan) # apply the CTE optimization on the resolved plan - duplicated_nodes, node_parents_map = find_duplicate_subtrees(logical_plan) - if len(duplicated_nodes) > 0: + duplicated_node_ids = find_duplicate_subtrees(logical_plan) + if len(duplicated_node_ids) > 0: deduplicated_plan = self._replace_duplicate_node_with_cte( - logical_plan, duplicated_nodes, node_parents_map + logical_plan, duplicated_node_ids ) final_logical_plans.append(deduplicated_plan) else: @@ -104,8 +106,7 @@ def apply(self) -> RepeatedSubqueryEliminationResult: def _replace_duplicate_node_with_cte( self, root: TreeNode, - duplicated_nodes: Set[TreeNode], - node_parents_map: Dict[TreeNode, Set[TreeNode]], + duplicated_node_ids: Set[str], ) -> LogicalPlan: """ Replace all duplicated nodes with a WithQueryBlock (CTE node), to enable @@ -117,17 +118,21 @@ def _replace_duplicate_node_with_cte( This function uses an iterative approach to avoid hitting Python's maximum recursion depth limit. """ + node_parents_map: Dict[TreeNode, Set[TreeNode]] = defaultdict(set) stack1, stack2 = [root], [] while stack1: node = stack1.pop() stack2.append(node) for child in reversed(node.children_plan_nodes): + node_parents_map[child].add(node) stack1.append(child) - # tack node that is already visited to avoid repeated operation on the same node + # track node that is already visited to avoid repeated operation on the same node visited_nodes: Set[TreeNode] = set() updated_nodes: Set[TreeNode] = set() + # track the resolved WithQueryBlock node has been created for each duplicated node + resolved_with_block_map: Dict[str, SnowflakePlan] = {} def _update_parents( node: TreeNode, @@ -151,18 +156,28 @@ def _update_parents( # if the node is a duplicated node and deduplication is not done for the node, # start the deduplication transformation use CTE - if node in duplicated_nodes: - # create a WithQueryBlock node - with_block = WithQueryBlock( - name=random_name_for_temp_object(TempObjectType.CTE), child=node - ) - with_block._is_valid_for_replacement = True - - resolved_with_block = self._query_generator.resolve(with_block) + if node.encoded_node_id_with_query in duplicated_node_ids: + if node.encoded_node_id_with_query in resolved_with_block_map: + # if the corresponding CTE block has been created, use the existing + # one. + resolved_with_block = resolved_with_block_map[ + node.encoded_node_id_with_query + ] + else: + # create a WithQueryBlock node + with_block = WithQueryBlock( + name=random_name_for_temp_object(TempObjectType.CTE), child=node + ) + with_block._is_valid_for_replacement = True + + resolved_with_block = self._query_generator.resolve(with_block) + resolved_with_block_map[ + node.encoded_node_id_with_query + ] = resolved_with_block + self._total_number_ctes += 1 _update_parents( node, should_replace_child=True, new_child=resolved_with_block ) - self._total_number_ctes += 1 elif node in updated_nodes: # if the node is updated, make sure all nodes up to parent is updated _update_parents(node, should_replace_child=False) diff --git a/tests/integ/test_deepcopy.py b/tests/integ/test_deepcopy.py index e0fa5fe74e..a07aa49950 100644 --- a/tests/integ/test_deepcopy.py +++ b/tests/integ/test_deepcopy.py @@ -401,19 +401,8 @@ def test_deepcopy_no_duplicate(session, generator): copied_plan = copy.deepcopy(final_df._plan) check_copied_plan(copied_plan, final_df._plan) - # we will traverse the plan to assert that the tuple (plan._id, type(plan)) have a unique id(plan) - # note that two nodes with same plan._id can have different id(plan) since SnowflakePlan inherits plan._id - # from its source selectable. - # plan._id: calculated by hashing the query sql and query params of plan - # type(plan): the type of the plan - # id(plan): the memory address of the plan object - # If the same plan._id has multiple id(plan), it means the deepcopy is duplicating source nodes - def traverse_plan(plan, plan_id_map): - plan_id = plan._id - plan_type = type(plan) plan_memo = id(plan) - identifier_tuple = (plan_id, plan_type) local_deepcopy_memo = {} first_deepcopy = copy.deepcopy(plan, local_deepcopy_memo) @@ -421,11 +410,6 @@ def traverse_plan(plan, plan_id_map): assert plan_memo in local_deepcopy_memo assert first_deepcopy is second_deepcopy - if identifier_tuple not in plan_id_map: - plan_id_map[identifier_tuple] = plan_memo - else: - assert plan_id_map[identifier_tuple] == plan_memo - for child in plan.children_plan_nodes: traverse_plan(child, plan_id_map) diff --git a/tests/unit/compiler/test_replace_child_and_update_node.py b/tests/unit/compiler/test_replace_child_and_update_node.py index 05098165a1..76c3a0fec4 100644 --- a/tests/unit/compiler/test_replace_child_and_update_node.py +++ b/tests/unit/compiler/test_replace_child_and_update_node.py @@ -441,7 +441,7 @@ def test_select_statement( assert_precondition(plan, new_plan, mock_query_generator, using_deep_copy=True) plan = copy.deepcopy(plan) - replace_child(plan, from_, new_plan, mock_query_generator) + replace_child(plan, plan.children_plan_nodes[0], new_plan, mock_query_generator) assert len(plan.children_plan_nodes) == 1 new_replaced_plan = plan.children_plan_nodes[0] @@ -566,11 +566,11 @@ def test_set_statement( assert_precondition(plan, new_plan, mock_analyzer, using_deep_copy=True) plan = copy.deepcopy(plan) - replace_child(plan, selectable1, new_plan, mock_query_generator) + replace_child(plan, plan.children_plan_nodes[0], new_plan, mock_query_generator) assert len(plan.children_plan_nodes) == 2 new_replaced_plan = plan.children_plan_nodes[0] - assert isinstance(new_replaced_plan, SelectSnowflakePlan) - assert plan.children_plan_nodes[1] == selectable2 + assert isinstance(plan.children_plan_nodes[0], SelectSnowflakePlan) + assert isinstance(plan.children_plan_nodes[1], SelectSQL) mocked_snowflake_plan = mock_snowflake_plan() verify_snowflake_plan(new_replaced_plan.snowflake_plan, mocked_snowflake_plan) diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index bab9d5680b..77024d1272 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -13,7 +13,7 @@ def test_case1(): nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] for i, node in enumerate(nodes): - node._id = i + node.encoded_node_id_with_query = i node.source_plan = None nodes[0].children_plan_nodes = [nodes[1], nodes[3]] nodes[1].children_plan_nodes = [nodes[2], nodes[2]] @@ -30,7 +30,7 @@ def test_case1(): def test_case2(): nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] for i, node in enumerate(nodes): - node._id = i + node.encoded_node_id_with_query = i node.source_plan = None nodes[0].children_plan_nodes = [nodes[1], nodes[3]] nodes[1].children_plan_nodes = [nodes[2], nodes[2]] @@ -47,5 +47,5 @@ def test_case2(): @pytest.mark.parametrize("test_case", [test_case1(), test_case2()]) def test_find_duplicate_subtrees(test_case): plan, expected_duplicate_subtree_ids = test_case - duplicate_subtrees, _ = find_duplicate_subtrees(plan) - assert {node._id for node in duplicate_subtrees} == expected_duplicate_subtree_ids + duplicate_subtrees_ids = find_duplicate_subtrees(plan) + assert duplicate_subtrees_ids == expected_duplicate_subtree_ids