From fe755448748e9218252611e75dd814b397997132 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Thu, 10 Oct 2024 17:44:13 -0700 Subject: [PATCH 1/9] refactor --- .../snowpark/_internal/analyzer/cte_utils.py | 50 +++++++++---------- .../_internal/analyzer/select_statement.py | 21 +++----- .../_internal/analyzer/snowflake_plan.py | 17 ++----- .../compiler/repeated_subquery_elimination.py | 45 ++++++++++------- tests/integ/test_cte.py | 8 ++- 5 files changed, 68 insertions(+), 73 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py index 12c5390e5b..aca304c581 100644 --- a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py @@ -4,6 +4,7 @@ import hashlib import logging +import uuid from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Tuple, Union @@ -24,9 +25,7 @@ TreeNode = Union[SnowflakePlan, Selectable] -def find_duplicate_subtrees( - root: "TreeNode", -) -> Tuple[Set["TreeNode"], Dict["TreeNode", Set["TreeNode"]]]: +def find_duplicate_subtrees(root: "TreeNode") -> Set[str]: """ Returns a set containing all duplicate subtrees in query plan tree. The root of a duplicate subtree is defined as a duplicate node, if @@ -49,8 +48,9 @@ def find_duplicate_subtrees( This function is used to only include nodes that should be converted to CTEs. """ - node_count_map = defaultdict(int) - node_parents_map = defaultdict(set) + id_count_map = defaultdict(int) + id_parents_map = defaultdict(set) + # node_parents_map = defaultdict(set) def traverse(root: "TreeNode") -> None: """ @@ -60,32 +60,32 @@ def traverse(root: "TreeNode") -> None: while len(current_level) > 0: next_level = [] for node in current_level: - node_count_map[node] += 1 + id_count_map[node.encoded_id] += 1 for child in node.children_plan_nodes: - node_parents_map[child].add(node) + id_parents_map[child.encoded_id].add(node) next_level.append(child) current_level = next_level - def is_duplicate_subtree(node: "TreeNode") -> bool: - is_duplicate_node = node_count_map[node] > 1 + def is_duplicate_subtree(node_id: str) -> bool: + is_duplicate_node = id_count_map[node_id] > 1 if is_duplicate_node: is_any_parent_unique_node = any( - node_count_map[n] == 1 for n in node_parents_map[node] + id_count_map[node.encoded_id] == 1 for node in id_parents_map[node_id] ) if is_any_parent_unique_node: return True else: - has_multi_parents = len(node_parents_map[node]) > 1 + has_multi_parents = len(id_parents_map[node_id]) > 1 if has_multi_parents: return True return False traverse(root) - duplicated_node = {node for node in node_count_map if is_duplicate_subtree(node)} - return duplicated_node, node_parents_map + duplicated_node = {node_id for node_id in id_count_map if is_duplicate_subtree(node_id)} + return duplicated_node -def create_cte_query(root: "TreeNode", duplicate_plan_set: Set["TreeNode"]) -> str: +def create_cte_query(root: "TreeNode", duplicated_node_ids: Set[str]) -> str: from snowflake.snowpark._internal.analyzer.select_statement import Selectable plan_to_query_map = {} @@ -110,32 +110,32 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: while stack2: node = stack2.pop() - if node in plan_to_query_map: + if node.encoded_id in plan_to_query_map: continue if not node.children_plan_nodes or not node.placeholder_query: - plan_to_query_map[node] = ( + plan_to_query_map[node.encoded_id] = ( node.sql_query if isinstance(node, Selectable) else node.queries[-1].sql ) else: - plan_to_query_map[node] = node.placeholder_query + plan_to_query_map[node.encoded_id] = node.placeholder_query for child in node.children_plan_nodes: # replace the placeholder (id) with child query - plan_to_query_map[node] = plan_to_query_map[node].replace( - child._id, plan_to_query_map[child] + plan_to_query_map[node.encoded_id] = plan_to_query_map[node.encoded_id].replace( + child.encoded_id, plan_to_query_map[child.encoded_id] ) # duplicate subtrees will be converted CTEs - if node in duplicate_plan_set: + if node.encoded_id in duplicated_node_ids: # when a subquery is converted a CTE to with clause, # it will be replaced by `SELECT * from TEMP_TABLE` in the original query table_name = random_name_for_temp_object(TempObjectType.CTE) select_stmt = project_statement([], table_name) - duplicate_plan_to_table_name_map[node] = table_name - duplicate_plan_to_cte_map[node] = plan_to_query_map[node] - plan_to_query_map[node] = select_stmt + duplicate_plan_to_table_name_map[node.encoded_id] = table_name + duplicate_plan_to_cte_map[node.encoded_id] = plan_to_query_map[node.encoded_id] + plan_to_query_map[node.encoded_id] = select_stmt build_plan_to_query_map_in_post_order(root) @@ -150,10 +150,10 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: def encode_id( query: str, query_params: Optional[Sequence[Any]] = None -) -> Optional[str]: +) -> str: string = f"{query}#{query_params}" if query_params else query try: return hashlib.sha256(string.encode()).hexdigest()[:10] except Exception as ex: logging.warning(f"Encode SnowflakePlan ID failed: {ex}") - return None + return str(uuid.uuid4()) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 4d154c9ae0..d77111252a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -5,9 +5,9 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict +from functools import cached_property from copy import copy, deepcopy from enum import Enum -from functools import cached_property from typing import ( TYPE_CHECKING, AbstractSet, @@ -239,16 +239,7 @@ def __init__( self._api_calls = api_calls.copy() if api_calls is not None else None self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None - def __eq__(self, other: "Selectable") -> bool: - if not isinstance(other, Selectable): - return False - if self._id is not None and other._id is not None: - return type(self) is type(other) and self._id == other._id - else: - return super().__eq__(other) - - def __hash__(self) -> int: - return hash(self._id) if self._id else super().__hash__() + # self.encoded_id = encode_id(self.sql_query, self.query_params) @property @abstractmethod @@ -263,7 +254,7 @@ def placeholder_query(self) -> Optional[str]: pass @cached_property - def _id(self) -> Optional[str]: + def encoded_id(self) -> Optional[str]: """Returns the id of this Selectable logical plan.""" return encode_id(self.sql_query, self.query_params) @@ -592,8 +583,8 @@ def placeholder_query(self) -> Optional[str]: return self._snowflake_plan.placeholder_query @property - def _id(self) -> Optional[str]: - return self._snowflake_plan._id + def encoded_id(self) -> Optional[str]: + return self._snowflake_plan.encoded_id @property def schema_query(self) -> Optional[str]: @@ -825,7 +816,7 @@ def sql_query(self) -> str: def placeholder_query(self) -> str: if self._placeholder_query: return self._placeholder_query - from_clause = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_._id}{analyzer_utils.RIGHT_PARENTHESIS}" + from_clause = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_id}{analyzer_utils.RIGHT_PARENTHESIS}" if not self.has_clause and not self.projection: self._placeholder_query = from_clause return self._placeholder_query diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index c7729d991f..e22d64fc69 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -258,7 +258,7 @@ def __init__( # encode an id for CTE optimization. This is generated based on the main # query and the associated query parameters. We use this id for equality comparison # to determine if two plans are the same. - self._id = encode_id(queries[-1].sql, queries[-1].params) + self.encoded_id = encode_id(queries[-1].sql, queries[-1].params) self.referenced_ctes: Set[WithQueryBlock] = ( referenced_ctes.copy() if referenced_ctes else set() ) @@ -267,17 +267,6 @@ def __init__( # to UUID track queries that are generated from the same plan. self._uuid = str(uuid.uuid4()) - def __eq__(self, other: "SnowflakePlan") -> bool: - if not isinstance(other, SnowflakePlan): - return False - if self._id is not None and other._id is not None: - return isinstance(other, SnowflakePlan) and self._id == other._id - else: - return super().__eq__(other) - - def __hash__(self) -> int: - return hash(self._id) if self._id else super().__hash__() - @property def uuid(self) -> str: return self._uuid @@ -590,8 +579,8 @@ def build( new_schema_query = schema_query or sql_generator(child.schema_query) placeholder_query = ( - sql_generator(select_child._id) - if self.session._cte_optimization_enabled and select_child._id is not None + sql_generator(select_child.encoded_id) + if self.session._cte_optimization_enabled and select_child.encoded_id is not None else None ) diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py index d979b8a93a..b593225c50 100644 --- a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -3,8 +3,10 @@ # from typing import Dict, List, Optional, Set +from collections import defaultdict from snowflake.snowpark._internal.analyzer.cte_utils import find_duplicate_subtrees +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( LogicalPlan, WithQueryBlock, @@ -87,10 +89,10 @@ def apply(self) -> RepeatedSubqueryEliminationResult: logical_plan = self._query_generator.resolve(logical_plan) # apply the CTE optimization on the resolved plan - duplicated_nodes, node_parents_map = find_duplicate_subtrees(logical_plan) - if len(duplicated_nodes) > 0: + duplicated_node_ids = find_duplicate_subtrees(logical_plan) + if len(duplicated_node_ids) > 0: deduplicated_plan = self._replace_duplicate_node_with_cte( - logical_plan, duplicated_nodes, node_parents_map + logical_plan, duplicated_node_ids ) final_logical_plans.append(deduplicated_plan) else: @@ -104,8 +106,8 @@ def apply(self) -> RepeatedSubqueryEliminationResult: def _replace_duplicate_node_with_cte( self, root: TreeNode, - duplicated_nodes: Set[TreeNode], - node_parents_map: Dict[TreeNode, Set[TreeNode]], + duplicated_node_ids: Set[str], + # node_parents_map: Dict[TreeNode, Set[TreeNode]], ) -> LogicalPlan: """ Replace all duplicated nodes with a WithQueryBlock (CTE node), to enable @@ -117,17 +119,20 @@ def _replace_duplicate_node_with_cte( This function uses an iterative approach to avoid hitting Python's maximum recursion depth limit. """ + node_parents_map: Dict[TreeNode, Set[TreeNode]] = defaultdict(set) stack1, stack2 = [root], [] while stack1: node = stack1.pop() stack2.append(node) for child in reversed(node.children_plan_nodes): + node_parents_map[child].add(node) stack1.append(child) # tack node that is already visited to avoid repeated operation on the same node - visited_nodes: Set[TreeNode] = set() + # visited_nodes: Set[TreeNode] = set() updated_nodes: Set[TreeNode] = set() + resolved_with_block_map: Dict[str, SnowflakePlan] = {} def _update_parents( node: TreeNode, @@ -146,27 +151,31 @@ def _update_parents( while stack2: node = stack2.pop() - if node in visited_nodes: - continue + # if node in visited_nodes: + # continue # if the node is a duplicated node and deduplication is not done for the node, # start the deduplication transformation use CTE - if node in duplicated_nodes: - # create a WithQueryBlock node - with_block = WithQueryBlock( - name=random_name_for_temp_object(TempObjectType.CTE), child=node - ) - with_block._is_valid_for_replacement = True - - resolved_with_block = self._query_generator.resolve(with_block) + if node.encoded_id in duplicated_node_ids: + if node.encoded_id in resolved_with_block_map: + resolved_with_block = resolved_with_block_map[node.encoded_id] + else: + # create a WithQueryBlock node + with_block = WithQueryBlock( + name=random_name_for_temp_object(TempObjectType.CTE), child=node + ) + with_block._is_valid_for_replacement = True + + resolved_with_block = self._query_generator.resolve(with_block) + resolved_with_block_map[node.encoded_id] = resolved_with_block + self._total_number_ctes += 1 _update_parents( node, should_replace_child=True, new_child=resolved_with_block ) - self._total_number_ctes += 1 elif node in updated_nodes: # if the node is updated, make sure all nodes up to parent is updated _update_parents(node, should_replace_child=False) - visited_nodes.add(node) + # visited_nodes.add(node) return root diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 591b8c458f..d36b11246e 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -42,7 +42,9 @@ WITH = "WITH" -paramList = [False, True] +# paramList = [False, True] + +paramList = [True] @pytest.fixture(params=paramList, autouse=True) @@ -155,6 +157,9 @@ def count_number_of_ctes(query): def test_unary(session, action): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) df_action = action(df) + df_res = df_action.union_all(df_action) + print(df_res.queries) + """ check_result( session, df_action, @@ -173,6 +178,7 @@ def test_unary(session, action): union_count=1, join_count=0, ) + """ @pytest.mark.parametrize("type, action", binary_operations) From 85d20a8ed5901409ab516a3d67650f6e3196b0a8 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Thu, 10 Oct 2024 20:06:10 -0700 Subject: [PATCH 2/9] refactor --- .../snowpark/_internal/analyzer/cte_utils.py | 26 +++++++++++-------- .../_internal/analyzer/select_statement.py | 24 ++++++++--------- .../_internal/analyzer/snowflake_plan.py | 15 ++++++----- .../compiler/repeated_subquery_elimination.py | 10 +++---- tests/integ/test_cte.py | 6 +---- 5 files changed, 42 insertions(+), 39 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py index aca304c581..8af449d319 100644 --- a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py @@ -6,7 +6,7 @@ import logging import uuid from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Sequence, Set, Union from snowflake.snowpark._internal.analyzer.analyzer_utils import ( SPACE, @@ -62,7 +62,7 @@ def traverse(root: "TreeNode") -> None: for node in current_level: id_count_map[node.encoded_id] += 1 for child in node.children_plan_nodes: - id_parents_map[child.encoded_id].add(node) + id_parents_map[child.encoded_id].add(node.encoded_id) next_level.append(child) current_level = next_level @@ -70,7 +70,7 @@ def is_duplicate_subtree(node_id: str) -> bool: is_duplicate_node = id_count_map[node_id] > 1 if is_duplicate_node: is_any_parent_unique_node = any( - id_count_map[node.encoded_id] == 1 for node in id_parents_map[node_id] + id_count_map[id] == 1 for id in id_parents_map[node_id] ) if is_any_parent_unique_node: return True @@ -81,7 +81,9 @@ def is_duplicate_subtree(node_id: str) -> bool: return False traverse(root) - duplicated_node = {node_id for node_id in id_count_map if is_duplicate_subtree(node_id)} + duplicated_node = { + node_id for node_id in id_count_map if is_duplicate_subtree(node_id) + } return duplicated_node @@ -123,9 +125,9 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: plan_to_query_map[node.encoded_id] = node.placeholder_query for child in node.children_plan_nodes: # replace the placeholder (id) with child query - plan_to_query_map[node.encoded_id] = plan_to_query_map[node.encoded_id].replace( - child.encoded_id, plan_to_query_map[child.encoded_id] - ) + plan_to_query_map[node.encoded_id] = plan_to_query_map[ + node.encoded_id + ].replace(child.encoded_id, plan_to_query_map[child.encoded_id]) # duplicate subtrees will be converted CTEs if node.encoded_id in duplicated_node_ids: @@ -134,7 +136,9 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: table_name = random_name_for_temp_object(TempObjectType.CTE) select_stmt = project_statement([], table_name) duplicate_plan_to_table_name_map[node.encoded_id] = table_name - duplicate_plan_to_cte_map[node.encoded_id] = plan_to_query_map[node.encoded_id] + duplicate_plan_to_cte_map[node.encoded_id] = plan_to_query_map[ + node.encoded_id + ] plan_to_query_map[node.encoded_id] = select_stmt build_plan_to_query_map_in_post_order(root) @@ -144,16 +148,16 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: list(duplicate_plan_to_cte_map.values()), list(duplicate_plan_to_table_name_map.values()), ) - final_query = with_stmt + SPACE + plan_to_query_map[root] + final_query = with_stmt + SPACE + plan_to_query_map[root.encoded_id] return final_query def encode_id( - query: str, query_params: Optional[Sequence[Any]] = None + node_type_name: str, query: str, query_params: Optional[Sequence[Any]] = None ) -> str: string = f"{query}#{query_params}" if query_params else query try: - return hashlib.sha256(string.encode()).hexdigest()[:10] + return hashlib.sha256(string.encode()).hexdigest()[:10] + node_type_name except Exception as ex: logging.warning(f"Encode SnowflakePlan ID failed: {ex}") return str(uuid.uuid4()) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index d77111252a..55056012fe 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -5,9 +5,9 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from functools import cached_property from copy import copy, deepcopy from enum import Enum +from functools import cached_property from typing import ( TYPE_CHECKING, AbstractSet, @@ -254,9 +254,9 @@ def placeholder_query(self) -> Optional[str]: pass @cached_property - def encoded_id(self) -> Optional[str]: + def encoded_id(self) -> str: """Returns the id of this Selectable logical plan.""" - return encode_id(self.sql_query, self.query_params) + return encode_id(type(self).__name__, self.sql_query, self.query_params) @property @abstractmethod @@ -497,13 +497,13 @@ def sql_query(self) -> str: def placeholder_query(self) -> Optional[str]: return None - @property - def _id(self) -> Optional[str]: + @cached_property + def encoded_id(self) -> str: """ Returns the id of this SelectSQL logical plan. The original SQL is used to encode its ID, which might be a non-select SQL. """ - return encode_id(self.original_sql, self.query_params) + return encode_id(type(self).__name__, self.original_sql, self.query_params) @property def query_params(self) -> Optional[Sequence[Any]]: @@ -582,8 +582,8 @@ def sql_query(self) -> str: def placeholder_query(self) -> Optional[str]: return self._snowflake_plan.placeholder_query - @property - def encoded_id(self) -> Optional[str]: + @cached_property + def encoded_id(self) -> str: return self._snowflake_plan.encoded_id @property @@ -784,9 +784,9 @@ def sql_query(self) -> str: if ( self.analyzer.session._cte_optimization_enabled and (not self.analyzer.session._query_compilation_stage_enabled) - and self.from_._id + and self.from_.encoded_id ): - placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_._id}{analyzer_utils.RIGHT_PARENTHESIS}" + placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_id}{analyzer_utils.RIGHT_PARENTHESIS}" self._sql_query = self.placeholder_query.replace(placeholder, from_clause) else: where_clause = ( @@ -1420,9 +1420,9 @@ def sql_query(self) -> str: @property def placeholder_query(self) -> Optional[str]: if not self._placeholder_query: - sql = f"({self.set_operands[0].selectable._id})" + sql = f"({self.set_operands[0].selectable.encoded_id})" for i in range(1, len(self.set_operands)): - sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable._id})" + sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable.encoded_id})" self._placeholder_query = sql return self._placeholder_query diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index e22d64fc69..afe0973a8a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -258,7 +258,9 @@ def __init__( # encode an id for CTE optimization. This is generated based on the main # query and the associated query parameters. We use this id for equality comparison # to determine if two plans are the same. - self.encoded_id = encode_id(queries[-1].sql, queries[-1].params) + self.encoded_id = encode_id( + type(self).__name__, queries[-1].sql, queries[-1].params + ) self.referenced_ctes: Set[WithQueryBlock] = ( referenced_ctes.copy() if referenced_ctes else set() ) @@ -338,7 +340,7 @@ def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan": return self # if there is no duplicate node, no optimization will be performed - duplicate_plan_set, _ = find_duplicate_subtrees(self) + duplicate_plan_set = find_duplicate_subtrees(self) if not duplicate_plan_set: return self @@ -580,7 +582,8 @@ def build( placeholder_query = ( sql_generator(select_child.encoded_id) - if self.session._cte_optimization_enabled and select_child.encoded_id is not None + if self.session._cte_optimization_enabled + and select_child.encoded_id is not None else None ) @@ -618,10 +621,10 @@ def build_binary( schema_query = sql_generator(left_schema_query, right_schema_query) placeholder_query = ( - sql_generator(select_left._id, select_right._id) + sql_generator(select_left.encoded_id, select_right.encoded_id) if self.session._cte_optimization_enabled - and select_left._id is not None - and select_right._id is not None + and select_left.encoded_id is not None + and select_right.encoded_id is not None else None ) diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py index b593225c50..d0cfeaea31 100644 --- a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -2,8 +2,8 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import Dict, List, Optional, Set from collections import defaultdict +from typing import Dict, List, Optional, Set from snowflake.snowpark._internal.analyzer.cte_utils import find_duplicate_subtrees from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan @@ -130,7 +130,7 @@ def _replace_duplicate_node_with_cte( stack1.append(child) # tack node that is already visited to avoid repeated operation on the same node - # visited_nodes: Set[TreeNode] = set() + visited_nodes: Set[TreeNode] = set() updated_nodes: Set[TreeNode] = set() resolved_with_block_map: Dict[str, SnowflakePlan] = {} @@ -151,8 +151,8 @@ def _update_parents( while stack2: node = stack2.pop() - # if node in visited_nodes: - # continue + if node in visited_nodes: + continue # if the node is a duplicated node and deduplication is not done for the node, # start the deduplication transformation use CTE @@ -176,6 +176,6 @@ def _update_parents( # if the node is updated, make sure all nodes up to parent is updated _update_parents(node, should_replace_child=False) - # visited_nodes.add(node) + visited_nodes.add(node) return root diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index d36b11246e..794d67e4f7 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -44,7 +44,7 @@ # paramList = [False, True] -paramList = [True] +paramList = [False] @pytest.fixture(params=paramList, autouse=True) @@ -157,9 +157,6 @@ def count_number_of_ctes(query): def test_unary(session, action): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) df_action = action(df) - df_res = df_action.union_all(df_action) - print(df_res.queries) - """ check_result( session, df_action, @@ -178,7 +175,6 @@ def test_unary(session, action): union_count=1, join_count=0, ) - """ @pytest.mark.parametrize("type, action", binary_operations) From 4dc41364a844fd52e35f6809bbda589b863f64b3 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Fri, 11 Oct 2024 10:27:49 -0700 Subject: [PATCH 3/9] fix error --- tests/integ/test_cte.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 794d67e4f7..08dbb6e923 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -146,17 +146,21 @@ def count_number_of_ctes(query): "action", [ lambda x: x.select("a", "b").select("b"), - lambda x: x.filter(col("a") == 1).select("b"), - lambda x: x.select("a").filter(col("a") == 1), - lambda x: x.select_expr("sum(a) as a").with_column("b", seq1()), - lambda x: x.drop("b").sort("a", ascending=False), - lambda x: x.rename(col("a"), "new_a").limit(1), - lambda x: x.to_df("a1", "b1").alias("L"), + # lambda x: x.filter(col("a") == 1).select("b"), + # lambda x: x.select("a").filter(col("a") == 1), + # lambda x: x.select_expr("sum(a) as a").with_column("b", seq1()), + # lambda x: x.drop("b").sort("a", ascending=False), + # lambda x: x.rename(col("a"), "new_a").limit(1), + # lambda x: x.to_df("a1", "b1").alias("L"), ], ) def test_unary(session, action): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) df_action = action(df) + df_res = df_action.union_all(df_action) + df_res.count() + # print(df_res.queries) + """ check_result( session, df_action, @@ -175,6 +179,7 @@ def test_unary(session, action): union_count=1, join_count=0, ) + """ @pytest.mark.parametrize("type, action", binary_operations) From 4dbd23fcabedc7807603749d64ca2f1fc7e81ffe Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Fri, 11 Oct 2024 13:42:08 -0700 Subject: [PATCH 4/9] fix error --- .../snowpark/_internal/analyzer/cte_utils.py | 15 ++++++++- .../_internal/analyzer/select_statement.py | 33 ++++++++++++++----- .../_internal/analyzer/snowflake_plan.py | 12 ++++--- tests/integ/test_cte.py | 23 ++++++------- 4 files changed, 56 insertions(+), 27 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py index 8af449d319..6b149d84e9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py @@ -127,7 +127,9 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: # replace the placeholder (id) with child query plan_to_query_map[node.encoded_id] = plan_to_query_map[ node.encoded_id - ].replace(child.encoded_id, plan_to_query_map[child.encoded_id]) + ].replace( + child.encoded_query_id, plan_to_query_map[child.encoded_id] + ) # duplicate subtrees will be converted CTEs if node.encoded_id in duplicated_node_ids: @@ -152,6 +154,17 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: return final_query +def encoded_query_id( + query: str, query_params: Optional[Sequence[Any]] = None +) -> Optional[str]: + string = f"{query}#{query_params}" if query_params else query + try: + return hashlib.sha256(string.encode()).hexdigest()[:10] + except Exception as ex: + logging.warning(f"Encode SnowflakePlan ID failed: {ex}") + return None + + def encode_id( node_type_name: str, query: str, query_params: Optional[Sequence[Any]] = None ) -> str: diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 55056012fe..58b2140e67 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -22,7 +22,7 @@ ) import snowflake.snowpark._internal.utils -from snowflake.snowpark._internal.analyzer.cte_utils import encode_id +from snowflake.snowpark._internal.analyzer.cte_utils import encode_id, encoded_query_id from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, PlanState, @@ -258,6 +258,11 @@ def encoded_id(self) -> str: """Returns the id of this Selectable logical plan.""" return encode_id(type(self).__name__, self.sql_query, self.query_params) + @cached_property + def encoded_query_id(self) -> str: + """Returns the id of this Selectable logical plan.""" + return encoded_query_id(self.sql_query, self.query_params) + @property @abstractmethod def query_params(self) -> Optional[Sequence[Any]]: @@ -505,6 +510,14 @@ def encoded_id(self) -> str: """ return encode_id(type(self).__name__, self.original_sql, self.query_params) + @cached_property + def encoded_query_id(self) -> str: + """ + Returns the id of this SelectSQL logical plan. The original SQL is used to encode its ID, + which might be a non-select SQL. + """ + return encoded_query_id(self.original_sql, self.query_params) + @property def query_params(self) -> Optional[Sequence[Any]]: return self._query_param @@ -582,9 +595,13 @@ def sql_query(self) -> str: def placeholder_query(self) -> Optional[str]: return self._snowflake_plan.placeholder_query + # @cached_property + # def encoded_id(self) -> str: + # return self._snowflake_plan.encoded_id + @cached_property - def encoded_id(self) -> str: - return self._snowflake_plan.encoded_id + def encoded_query_id(self) -> str: + return self._snowflake_plan.encoded_query_id @property def schema_query(self) -> Optional[str]: @@ -784,9 +801,9 @@ def sql_query(self) -> str: if ( self.analyzer.session._cte_optimization_enabled and (not self.analyzer.session._query_compilation_stage_enabled) - and self.from_.encoded_id + and self.from_.encoded_query_id ): - placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_id}{analyzer_utils.RIGHT_PARENTHESIS}" + placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_query_id}{analyzer_utils.RIGHT_PARENTHESIS}" self._sql_query = self.placeholder_query.replace(placeholder, from_clause) else: where_clause = ( @@ -816,7 +833,7 @@ def sql_query(self) -> str: def placeholder_query(self) -> str: if self._placeholder_query: return self._placeholder_query - from_clause = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_id}{analyzer_utils.RIGHT_PARENTHESIS}" + from_clause = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_.encoded_query_id}{analyzer_utils.RIGHT_PARENTHESIS}" if not self.has_clause and not self.projection: self._placeholder_query = from_clause return self._placeholder_query @@ -1420,9 +1437,9 @@ def sql_query(self) -> str: @property def placeholder_query(self) -> Optional[str]: if not self._placeholder_query: - sql = f"({self.set_operands[0].selectable.encoded_id})" + sql = f"({self.set_operands[0].selectable.encoded_query_id})" for i in range(1, len(self.set_operands)): - sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable.encoded_id})" + sql = f"{sql}{self.set_operands[i].operator}({self.set_operands[i].selectable.encoded_query_id})" self._placeholder_query = sql return self._placeholder_query diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index afe0973a8a..cf67c88ad8 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -86,6 +86,7 @@ from snowflake.snowpark._internal.analyzer.cte_utils import ( create_cte_query, encode_id, + encoded_query_id, find_duplicate_subtrees, ) from snowflake.snowpark._internal.analyzer.expression import Attribute @@ -261,6 +262,7 @@ def __init__( self.encoded_id = encode_id( type(self).__name__, queries[-1].sql, queries[-1].params ) + self.encoded_query_id = encoded_query_id(queries[-1].sql, queries[-1].params) self.referenced_ctes: Set[WithQueryBlock] = ( referenced_ctes.copy() if referenced_ctes else set() ) @@ -581,9 +583,9 @@ def build( new_schema_query = schema_query or sql_generator(child.schema_query) placeholder_query = ( - sql_generator(select_child.encoded_id) + sql_generator(select_child.encoded_query_id) if self.session._cte_optimization_enabled - and select_child.encoded_id is not None + and select_child.encoded_query_id is not None else None ) @@ -621,10 +623,10 @@ def build_binary( schema_query = sql_generator(left_schema_query, right_schema_query) placeholder_query = ( - sql_generator(select_left.encoded_id, select_right.encoded_id) + sql_generator(select_left.encoded_query_id, select_right.encoded_query_id) if self.session._cte_optimization_enabled - and select_left.encoded_id is not None - and select_right.encoded_id is not None + and select_left.encoded_query_id is not None + and select_right.encoded_query_id is not None else None ) diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 08dbb6e923..0006303359 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -42,9 +42,7 @@ WITH = "WITH" -# paramList = [False, True] - -paramList = [False] +paramList = [False, True] @pytest.fixture(params=paramList, autouse=True) @@ -146,21 +144,20 @@ def count_number_of_ctes(query): "action", [ lambda x: x.select("a", "b").select("b"), - # lambda x: x.filter(col("a") == 1).select("b"), - # lambda x: x.select("a").filter(col("a") == 1), - # lambda x: x.select_expr("sum(a) as a").with_column("b", seq1()), - # lambda x: x.drop("b").sort("a", ascending=False), - # lambda x: x.rename(col("a"), "new_a").limit(1), - # lambda x: x.to_df("a1", "b1").alias("L"), + lambda x: x.filter(col("a") == 1).select("b"), + lambda x: x.select("a").filter(col("a") == 1), + lambda x: x.select_expr("sum(a) as a").with_column("b", seq1()), + lambda x: x.drop("b").sort("a", ascending=False), + lambda x: x.rename(col("a"), "new_a").limit(1), + lambda x: x.to_df("a1", "b1").alias("L"), ], ) def test_unary(session, action): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) df_action = action(df) - df_res = df_action.union_all(df_action) - df_res.count() + # df_res = df_action.union_all(df_action) + # df_res.count() # print(df_res.queries) - """ check_result( session, df_action, @@ -179,7 +176,6 @@ def test_unary(session, action): union_count=1, join_count=0, ) - """ @pytest.mark.parametrize("type, action", binary_operations) @@ -812,6 +808,7 @@ def test_join_table_function(session): ) df1 = df.join_table_function("split_to_table", df["addresses"], lit(" ")) df_result = df1.join(df1.select("name", "addresses"), rsuffix="_y") + df_result.collect() check_result( session, df_result, From de4bba0a4ae646377622fc6d0e3ed513ba9a6c4d Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Fri, 11 Oct 2024 14:36:39 -0700 Subject: [PATCH 5/9] updat code --- .../snowpark/_internal/analyzer/cte_utils.py | 27 ++++++++++++++----- .../_internal/analyzer/select_statement.py | 6 +---- .../_internal/analyzer/snowflake_plan.py | 7 +++-- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py index 6b149d84e9..ca3adfc076 100644 --- a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py @@ -27,7 +27,7 @@ def find_duplicate_subtrees(root: "TreeNode") -> Set[str]: """ - Returns a set containing all duplicate subtrees in query plan tree. + Returns a set of TreeNode encoded_id that indicates all duplicate subtrees in query plan tree. The root of a duplicate subtree is defined as a duplicate node, if - it appears more than once in the tree, AND - one of its parent is unique (only appear once) in the tree, OR @@ -50,7 +50,6 @@ def find_duplicate_subtrees(root: "TreeNode") -> Set[str]: """ id_count_map = defaultdict(int) id_parents_map = defaultdict(set) - # node_parents_map = defaultdict(set) def traverse(root: "TreeNode") -> None: """ @@ -157,6 +156,14 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: def encoded_query_id( query: str, query_params: Optional[Sequence[Any]] = None ) -> Optional[str]: + """ + Encode the query and its query parameter into an id using sha256. + + + Returns: + If encode succeed, return the first 10 encoded value. + Otherwise, return None + """ string = f"{query}#{query_params}" if query_params else query try: return hashlib.sha256(string.encode()).hexdigest()[:10] @@ -168,9 +175,15 @@ def encoded_query_id( def encode_id( node_type_name: str, query: str, query_params: Optional[Sequence[Any]] = None ) -> str: - string = f"{query}#{query_params}" if query_params else query - try: - return hashlib.sha256(string.encode()).hexdigest()[:10] + node_type_name - except Exception as ex: - logging.warning(f"Encode SnowflakePlan ID failed: {ex}") + """ + Encode given query, query parameters and the node type into an id. + + If query and query parameters can be encoded successfully using sha256, + return the encoded query id + node_type_name. + Otherwise, generate a uuid. + """ + query_id = encoded_query_id(query, query_params) + if query_id is not None: + return query_id + node_type_name + else: return str(uuid.uuid4()) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 58b2140e67..498c2ac0ad 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -260,7 +260,7 @@ def encoded_id(self) -> str: @cached_property def encoded_query_id(self) -> str: - """Returns the id of this Selectable logical plan.""" + """Returns the id of the queries for this Selectable logical plan.""" return encoded_query_id(self.sql_query, self.query_params) @property @@ -595,10 +595,6 @@ def sql_query(self) -> str: def placeholder_query(self) -> Optional[str]: return self._snowflake_plan.placeholder_query - # @cached_property - # def encoded_id(self) -> str: - # return self._snowflake_plan.encoded_id - @cached_property def encoded_query_id(self) -> str: return self._snowflake_plan.encoded_query_id diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index cf67c88ad8..eb0c1c25d0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -257,11 +257,14 @@ def __init__( # It is used for optimization, by replacing a subquery with a CTE self.placeholder_query = placeholder_query # encode an id for CTE optimization. This is generated based on the main - # query and the associated query parameters. We use this id for equality comparison - # to determine if two plans are the same. + # query, query parameters and the node type. We use this id for equality + # comparison to determine if two plans are the same. self.encoded_id = encode_id( type(self).__name__, queries[-1].sql, queries[-1].params ) + # encode id for the main query and query parameters, this is currently only used + # by the create_cte_query process. + # TODO (SNOW-1541096) remove this filed along removing the old cte implementation self.encoded_query_id = encoded_query_id(queries[-1].sql, queries[-1].params) self.referenced_ctes: Set[WithQueryBlock] = ( referenced_ctes.copy() if referenced_ctes else set() From cfea33d0978db151c36b182fa7b151e9cdbafd06 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Fri, 11 Oct 2024 15:18:34 -0700 Subject: [PATCH 6/9] add comment --- .../snowpark/_internal/analyzer/select_statement.py | 2 -- .../_internal/compiler/repeated_subquery_elimination.py | 6 ++++-- tests/integ/test_cte.py | 4 ---- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 498c2ac0ad..c5c63d7a9a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -239,8 +239,6 @@ def __init__( self._api_calls = api_calls.copy() if api_calls is not None else None self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None - # self.encoded_id = encode_id(self.sql_query, self.query_params) - @property @abstractmethod def sql_query(self) -> str: diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py index d0cfeaea31..4cdec32a53 100644 --- a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -107,7 +107,6 @@ def _replace_duplicate_node_with_cte( self, root: TreeNode, duplicated_node_ids: Set[str], - # node_parents_map: Dict[TreeNode, Set[TreeNode]], ) -> LogicalPlan: """ Replace all duplicated nodes with a WithQueryBlock (CTE node), to enable @@ -129,9 +128,10 @@ def _replace_duplicate_node_with_cte( node_parents_map[child].add(node) stack1.append(child) - # tack node that is already visited to avoid repeated operation on the same node + # track node that is already visited to avoid repeated operation on the same node visited_nodes: Set[TreeNode] = set() updated_nodes: Set[TreeNode] = set() + # track the resolved WithQueryBlock node has been created for each duplicated node resolved_with_block_map: Dict[str, SnowflakePlan] = {} def _update_parents( @@ -158,6 +158,8 @@ def _update_parents( # start the deduplication transformation use CTE if node.encoded_id in duplicated_node_ids: if node.encoded_id in resolved_with_block_map: + # if the corresponding CTE block has been created, use the existing + # one. resolved_with_block = resolved_with_block_map[node.encoded_id] else: # create a WithQueryBlock node diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 0006303359..591b8c458f 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -155,9 +155,6 @@ def count_number_of_ctes(query): def test_unary(session, action): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) df_action = action(df) - # df_res = df_action.union_all(df_action) - # df_res.count() - # print(df_res.queries) check_result( session, df_action, @@ -808,7 +805,6 @@ def test_join_table_function(session): ) df1 = df.join_table_function("split_to_table", df["addresses"], lit(" ")) df_result = df1.join(df1.select("name", "addresses"), rsuffix="_y") - df_result.collect() check_result( session, df_result, From 3005470faa1ffe6422f2cc2774cf2d030cbdf61d Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Fri, 11 Oct 2024 15:48:55 -0700 Subject: [PATCH 7/9] fix error --- .../snowpark/_internal/analyzer/select_statement.py | 6 +++--- tests/unit/compiler/test_replace_child_and_update_node.py | 8 ++++---- tests/unit/test_cte.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index c5c63d7a9a..8d16598152 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -257,7 +257,7 @@ def encoded_id(self) -> str: return encode_id(type(self).__name__, self.sql_query, self.query_params) @cached_property - def encoded_query_id(self) -> str: + def encoded_query_id(self) -> Optional[str]: """Returns the id of the queries for this Selectable logical plan.""" return encoded_query_id(self.sql_query, self.query_params) @@ -509,7 +509,7 @@ def encoded_id(self) -> str: return encode_id(type(self).__name__, self.original_sql, self.query_params) @cached_property - def encoded_query_id(self) -> str: + def encoded_query_id(self) -> Optional[str]: """ Returns the id of this SelectSQL logical plan. The original SQL is used to encode its ID, which might be a non-select SQL. @@ -594,7 +594,7 @@ def placeholder_query(self) -> Optional[str]: return self._snowflake_plan.placeholder_query @cached_property - def encoded_query_id(self) -> str: + def encoded_query_id(self) -> Optional[str]: return self._snowflake_plan.encoded_query_id @property diff --git a/tests/unit/compiler/test_replace_child_and_update_node.py b/tests/unit/compiler/test_replace_child_and_update_node.py index 05098165a1..76c3a0fec4 100644 --- a/tests/unit/compiler/test_replace_child_and_update_node.py +++ b/tests/unit/compiler/test_replace_child_and_update_node.py @@ -441,7 +441,7 @@ def test_select_statement( assert_precondition(plan, new_plan, mock_query_generator, using_deep_copy=True) plan = copy.deepcopy(plan) - replace_child(plan, from_, new_plan, mock_query_generator) + replace_child(plan, plan.children_plan_nodes[0], new_plan, mock_query_generator) assert len(plan.children_plan_nodes) == 1 new_replaced_plan = plan.children_plan_nodes[0] @@ -566,11 +566,11 @@ def test_set_statement( assert_precondition(plan, new_plan, mock_analyzer, using_deep_copy=True) plan = copy.deepcopy(plan) - replace_child(plan, selectable1, new_plan, mock_query_generator) + replace_child(plan, plan.children_plan_nodes[0], new_plan, mock_query_generator) assert len(plan.children_plan_nodes) == 2 new_replaced_plan = plan.children_plan_nodes[0] - assert isinstance(new_replaced_plan, SelectSnowflakePlan) - assert plan.children_plan_nodes[1] == selectable2 + assert isinstance(plan.children_plan_nodes[0], SelectSnowflakePlan) + assert isinstance(plan.children_plan_nodes[1], SelectSQL) mocked_snowflake_plan = mock_snowflake_plan() verify_snowflake_plan(new_replaced_plan.snowflake_plan, mocked_snowflake_plan) diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index bab9d5680b..66f41c8b5d 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -13,7 +13,7 @@ def test_case1(): nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] for i, node in enumerate(nodes): - node._id = i + node.encoded_id = i node.source_plan = None nodes[0].children_plan_nodes = [nodes[1], nodes[3]] nodes[1].children_plan_nodes = [nodes[2], nodes[2]] @@ -30,7 +30,7 @@ def test_case1(): def test_case2(): nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] for i, node in enumerate(nodes): - node._id = i + node.encoded_id = i node.source_plan = None nodes[0].children_plan_nodes = [nodes[1], nodes[3]] nodes[1].children_plan_nodes = [nodes[2], nodes[2]] @@ -47,5 +47,5 @@ def test_case2(): @pytest.mark.parametrize("test_case", [test_case1(), test_case2()]) def test_find_duplicate_subtrees(test_case): plan, expected_duplicate_subtree_ids = test_case - duplicate_subtrees, _ = find_duplicate_subtrees(plan) - assert {node._id for node in duplicate_subtrees} == expected_duplicate_subtree_ids + duplicate_subtrees_ids = find_duplicate_subtrees(plan) + assert duplicate_subtrees_ids == expected_duplicate_subtree_ids From dcf510045f65f6acb18b8746109c9b98f512eb9e Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Mon, 14 Oct 2024 19:18:49 -0700 Subject: [PATCH 8/9] address feedabck --- .../snowpark/_internal/analyzer/cte_utils.py | 89 ++++++++++++------- .../_internal/analyzer/select_statement.py | 37 ++++---- .../_internal/analyzer/snowflake_plan.py | 10 +-- .../compiler/repeated_subquery_elimination.py | 12 ++- tests/integ/test_deepcopy.py | 16 ---- 5 files changed, 82 insertions(+), 82 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py index ca3adfc076..eee666f155 100644 --- a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py @@ -4,9 +4,8 @@ import hashlib import logging -import uuid from collections import defaultdict -from typing import TYPE_CHECKING, Any, Optional, Sequence, Set, Union +from typing import TYPE_CHECKING, Optional, Set, Union from snowflake.snowpark._internal.analyzer.analyzer_utils import ( SPACE, @@ -59,29 +58,34 @@ def traverse(root: "TreeNode") -> None: while len(current_level) > 0: next_level = [] for node in current_level: - id_count_map[node.encoded_id] += 1 + id_count_map[node.encoded_node_id_with_query] += 1 for child in node.children_plan_nodes: - id_parents_map[child.encoded_id].add(node.encoded_id) + id_parents_map[child.encoded_node_id_with_query].add( + node.encoded_node_id_with_query + ) next_level.append(child) current_level = next_level - def is_duplicate_subtree(node_id: str) -> bool: - is_duplicate_node = id_count_map[node_id] > 1 + def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: + is_duplicate_node = id_count_map[encoded_node_id_with_query] > 1 if is_duplicate_node: is_any_parent_unique_node = any( - id_count_map[id] == 1 for id in id_parents_map[node_id] + id_count_map[id] == 1 + for id in id_parents_map[encoded_node_id_with_query] ) if is_any_parent_unique_node: return True else: - has_multi_parents = len(id_parents_map[node_id]) > 1 + has_multi_parents = len(id_parents_map[encoded_node_id_with_query]) > 1 if has_multi_parents: return True return False traverse(root) duplicated_node = { - node_id for node_id in id_count_map if is_duplicate_subtree(node_id) + encoded_node_id_with_query + for encoded_node_id_with_query in id_count_map + if is_duplicate_subtree(encoded_node_id_with_query) } return duplicated_node @@ -111,36 +115,41 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: while stack2: node = stack2.pop() - if node.encoded_id in plan_to_query_map: + if node.encoded_node_id_with_query in plan_to_query_map: continue if not node.children_plan_nodes or not node.placeholder_query: - plan_to_query_map[node.encoded_id] = ( + plan_to_query_map[node.encoded_node_id_with_query] = ( node.sql_query if isinstance(node, Selectable) else node.queries[-1].sql ) else: - plan_to_query_map[node.encoded_id] = node.placeholder_query + plan_to_query_map[ + node.encoded_node_id_with_query + ] = node.placeholder_query for child in node.children_plan_nodes: # replace the placeholder (id) with child query - plan_to_query_map[node.encoded_id] = plan_to_query_map[ - node.encoded_id - ].replace( - child.encoded_query_id, plan_to_query_map[child.encoded_id] + plan_to_query_map[ + node.encoded_node_id_with_query + ] = plan_to_query_map[node.encoded_node_id_with_query].replace( + child.encoded_query_id, + plan_to_query_map[child.encoded_node_id_with_query], ) # duplicate subtrees will be converted CTEs - if node.encoded_id in duplicated_node_ids: + if node.encoded_node_id_with_query in duplicated_node_ids: # when a subquery is converted a CTE to with clause, # it will be replaced by `SELECT * from TEMP_TABLE` in the original query table_name = random_name_for_temp_object(TempObjectType.CTE) select_stmt = project_statement([], table_name) - duplicate_plan_to_table_name_map[node.encoded_id] = table_name - duplicate_plan_to_cte_map[node.encoded_id] = plan_to_query_map[ - node.encoded_id - ] - plan_to_query_map[node.encoded_id] = select_stmt + duplicate_plan_to_table_name_map[ + node.encoded_node_id_with_query + ] = table_name + duplicate_plan_to_cte_map[ + node.encoded_node_id_with_query + ] = plan_to_query_map[node.encoded_node_id_with_query] + plan_to_query_map[node.encoded_node_id_with_query] = select_stmt build_plan_to_query_map_in_post_order(root) @@ -149,13 +158,11 @@ def build_plan_to_query_map_in_post_order(root: "TreeNode") -> None: list(duplicate_plan_to_cte_map.values()), list(duplicate_plan_to_table_name_map.values()), ) - final_query = with_stmt + SPACE + plan_to_query_map[root.encoded_id] + final_query = with_stmt + SPACE + plan_to_query_map[root.encoded_node_id_with_query] return final_query -def encoded_query_id( - query: str, query_params: Optional[Sequence[Any]] = None -) -> Optional[str]: +def encoded_query_id(node) -> Optional[str]: """ Encode the query and its query parameter into an id using sha256. @@ -164,6 +171,21 @@ def encoded_query_id( If encode succeed, return the first 10 encoded value. Otherwise, return None """ + from snowflake.snowpark._internal.analyzer.select_statement import SelectSQL + from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan + + if isinstance(node, SnowflakePlan): + query = node.queries[-1].sql + query_params = node.queries[-1].params + elif isinstance(node, SelectSQL): + # For SelectSql, The original SQL is used to encode its ID, + # which might be a non-select SQL. + query = node.original_sql + query_params = node.query_params + else: + query = node.sql_query + query_params = node.query_params + string = f"{query}#{query_params}" if query_params else query try: return hashlib.sha256(string.encode()).hexdigest()[:10] @@ -172,18 +194,17 @@ def encoded_query_id( return None -def encode_id( - node_type_name: str, query: str, query_params: Optional[Sequence[Any]] = None -) -> str: +def encode_node_id_with_query(node: "TreeNode") -> str: """ - Encode given query, query parameters and the node type into an id. + Encode a for the given TreeNode. If query and query parameters can be encoded successfully using sha256, return the encoded query id + node_type_name. - Otherwise, generate a uuid. + Otherwise, return the original node id. """ - query_id = encoded_query_id(query, query_params) + query_id = encoded_query_id(node) if query_id is not None: - return query_id + node_type_name + node_type_name = type(node).__name__ + return f"{query_id}_{node_type_name}" else: - return str(uuid.uuid4()) + return str(id(node)) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 8d16598152..e421c23218 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -22,7 +22,10 @@ ) import snowflake.snowpark._internal.utils -from snowflake.snowpark._internal.analyzer.cte_utils import encode_id, encoded_query_id +from snowflake.snowpark._internal.analyzer.cte_utils import ( + encode_node_id_with_query, + encoded_query_id, +) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, PlanState, @@ -252,14 +255,20 @@ def placeholder_query(self) -> Optional[str]: pass @cached_property - def encoded_id(self) -> str: - """Returns the id of this Selectable logical plan.""" - return encode_id(type(self).__name__, self.sql_query, self.query_params) + def encoded_node_id_with_query(self) -> str: + """ + Returns an encoded node id of this Selectable logical plan. + + Note that the encoding algorithm uses queries as content, and returns the same id for + two selectable node with same queries. This is currently used by repeated subquery + elimination to detect two nodes with same query, please use it with careful. + """ + return encode_node_id_with_query(self) @cached_property def encoded_query_id(self) -> Optional[str]: - """Returns the id of the queries for this Selectable logical plan.""" - return encoded_query_id(self.sql_query, self.query_params) + """Returns an encoded id of the queries for this Selectable logical plan.""" + return encoded_query_id(self) @property @abstractmethod @@ -500,22 +509,6 @@ def sql_query(self) -> str: def placeholder_query(self) -> Optional[str]: return None - @cached_property - def encoded_id(self) -> str: - """ - Returns the id of this SelectSQL logical plan. The original SQL is used to encode its ID, - which might be a non-select SQL. - """ - return encode_id(type(self).__name__, self.original_sql, self.query_params) - - @cached_property - def encoded_query_id(self) -> Optional[str]: - """ - Returns the id of this SelectSQL logical plan. The original SQL is used to encode its ID, - which might be a non-select SQL. - """ - return encoded_query_id(self.original_sql, self.query_params) - @property def query_params(self) -> Optional[Sequence[Any]]: return self._query_param diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index eb0c1c25d0..652ab71f74 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -85,7 +85,7 @@ ) from snowflake.snowpark._internal.analyzer.cte_utils import ( create_cte_query, - encode_id, + encode_node_id_with_query, encoded_query_id, find_duplicate_subtrees, ) @@ -259,13 +259,11 @@ def __init__( # encode an id for CTE optimization. This is generated based on the main # query, query parameters and the node type. We use this id for equality # comparison to determine if two plans are the same. - self.encoded_id = encode_id( - type(self).__name__, queries[-1].sql, queries[-1].params - ) + self.encoded_node_id_with_query = encode_node_id_with_query(self) # encode id for the main query and query parameters, this is currently only used # by the create_cte_query process. # TODO (SNOW-1541096) remove this filed along removing the old cte implementation - self.encoded_query_id = encoded_query_id(queries[-1].sql, queries[-1].params) + self.encoded_query_id = encoded_query_id(self) self.referenced_ctes: Set[WithQueryBlock] = ( referenced_ctes.copy() if referenced_ctes else set() ) @@ -414,7 +412,7 @@ def output_dict(self) -> Dict[str, Any]: @cached_property def num_duplicate_nodes(self) -> int: - duplicated_nodes, _ = find_duplicate_subtrees(self) + duplicated_nodes = find_duplicate_subtrees(self) return len(duplicated_nodes) @cached_property diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py index 4cdec32a53..afb9626673 100644 --- a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -156,11 +156,13 @@ def _update_parents( # if the node is a duplicated node and deduplication is not done for the node, # start the deduplication transformation use CTE - if node.encoded_id in duplicated_node_ids: - if node.encoded_id in resolved_with_block_map: + if node.encoded_node_id_with_query in duplicated_node_ids: + if node.encoded_node_id_with_query in resolved_with_block_map: # if the corresponding CTE block has been created, use the existing # one. - resolved_with_block = resolved_with_block_map[node.encoded_id] + resolved_with_block = resolved_with_block_map[ + node.encoded_node_id_with_query + ] else: # create a WithQueryBlock node with_block = WithQueryBlock( @@ -169,7 +171,9 @@ def _update_parents( with_block._is_valid_for_replacement = True resolved_with_block = self._query_generator.resolve(with_block) - resolved_with_block_map[node.encoded_id] = resolved_with_block + resolved_with_block_map[ + node.encoded_node_id_with_query + ] = resolved_with_block self._total_number_ctes += 1 _update_parents( node, should_replace_child=True, new_child=resolved_with_block diff --git a/tests/integ/test_deepcopy.py b/tests/integ/test_deepcopy.py index e0fa5fe74e..a07aa49950 100644 --- a/tests/integ/test_deepcopy.py +++ b/tests/integ/test_deepcopy.py @@ -401,19 +401,8 @@ def test_deepcopy_no_duplicate(session, generator): copied_plan = copy.deepcopy(final_df._plan) check_copied_plan(copied_plan, final_df._plan) - # we will traverse the plan to assert that the tuple (plan._id, type(plan)) have a unique id(plan) - # note that two nodes with same plan._id can have different id(plan) since SnowflakePlan inherits plan._id - # from its source selectable. - # plan._id: calculated by hashing the query sql and query params of plan - # type(plan): the type of the plan - # id(plan): the memory address of the plan object - # If the same plan._id has multiple id(plan), it means the deepcopy is duplicating source nodes - def traverse_plan(plan, plan_id_map): - plan_id = plan._id - plan_type = type(plan) plan_memo = id(plan) - identifier_tuple = (plan_id, plan_type) local_deepcopy_memo = {} first_deepcopy = copy.deepcopy(plan, local_deepcopy_memo) @@ -421,11 +410,6 @@ def traverse_plan(plan, plan_id_map): assert plan_memo in local_deepcopy_memo assert first_deepcopy is second_deepcopy - if identifier_tuple not in plan_id_map: - plan_id_map[identifier_tuple] = plan_memo - else: - assert plan_id_map[identifier_tuple] == plan_memo - for child in plan.children_plan_nodes: traverse_plan(child, plan_id_map) From f42030dc06910e3216bee315556d93981cd16d8f Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Tue, 15 Oct 2024 09:30:40 -0700 Subject: [PATCH 9/9] fix test failure --- tests/unit/test_cte.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index 66f41c8b5d..77024d1272 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -13,7 +13,7 @@ def test_case1(): nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] for i, node in enumerate(nodes): - node.encoded_id = i + node.encoded_node_id_with_query = i node.source_plan = None nodes[0].children_plan_nodes = [nodes[1], nodes[3]] nodes[1].children_plan_nodes = [nodes[2], nodes[2]] @@ -30,7 +30,7 @@ def test_case1(): def test_case2(): nodes = [mock.create_autospec(SnowflakePlan) for _ in range(7)] for i, node in enumerate(nodes): - node.encoded_id = i + node.encoded_node_id_with_query = i node.source_plan = None nodes[0].children_plan_nodes = [nodes[1], nodes[3]] nodes[1].children_plan_nodes = [nodes[2], nodes[2]]