Skip to content

Commit

Permalink
fix erro
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-yzou committed Oct 24, 2024
1 parent a8153c7 commit c3156e3
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions src/snowflake/snowpark/_internal/compiler/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

from snowflake.snowpark._internal.analyzer.snowflake_plan_node import WithQueryBlock

from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
get_complexity_score,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import WithQueryBlock
from snowflake.snowpark._internal.utils import is_sql_select_statement

if TYPE_CHECKING:
Expand Down Expand Up @@ -48,7 +47,7 @@ def find_duplicate_subtrees(
"""
id_count_map = defaultdict(int)
id_parents_map = defaultdict(set)
id_complexity_map = defaultdict(list)
id_complexity_map = defaultdict(int)

def traverse(root: "TreeNode") -> None:
"""
Expand All @@ -59,10 +58,14 @@ def traverse(root: "TreeNode") -> None:
next_level = []
for node in current_level:
id_count_map[node.encoded_node_id_with_query] += 1
if propagate_complexity_hist:
id_complexity_map[node.encoded_node_id_with_query].append(
get_complexity_score(node)
)
if propagate_complexity_hist and (
node.encoded_node_id_with_query not in id_complexity_map
):
# if propagate_complexity_hist is true, and the complexity score is not
# recorded for the current node id, record the complexity
id_complexity_map[
node.encoded_node_id_with_query
] = get_complexity_score(node)
for child in node.children_plan_nodes:
id_parents_map[child.encoded_node_id_with_query].add(
node.encoded_node_id_with_query
Expand Down Expand Up @@ -96,15 +99,17 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:
return (
duplicated_node_ids,
get_duplicated_node_complexity_distribution(
duplicated_node_ids, id_complexity_map
duplicated_node_ids, id_complexity_map, id_count_map
),
)
else:
return (duplicated_node_ids, None)


def get_duplicated_node_complexity_distribution(
duplicated_node_id_set: Set[str], id_complexity_map: Dict[str, List[int]]
duplicated_node_id_set: Set[str],
id_complexity_map: Dict[str, int],
id_count_map: Dict[str, int],
) -> List[int]:
"""
Calculate the complexity distribution for the detected repeated node. The complexity are categorized as following:
Expand All @@ -120,21 +125,22 @@ def get_duplicated_node_complexity_distribution(
"""
node_complexity_dist = [0] * 7
for node_id in duplicated_node_id_set:
for complexity_score in id_complexity_map[node_id]:
if complexity_score <= 10000:
node_complexity_dist[0] += 1
elif 10000 < complexity_score <= 100000:
node_complexity_dist[1] += 1
elif 100000 < complexity_score <= 500000:
node_complexity_dist[2] += 1
elif 500000 < complexity_score <= 1000000:
node_complexity_dist[3] += 1
elif 1000000 < complexity_score <= 5000000:
node_complexity_dist[4] += 1
elif 5000000 < complexity_score <= 10000000:
node_complexity_dist[5] += 1
elif complexity_score > 10000000:
node_complexity_dist[6] += 1
complexity_score = id_complexity_map[node_id]
repeated_count = id_count_map[node_id]
if complexity_score <= 10000:
node_complexity_dist[0] += repeated_count
elif 10000 < complexity_score <= 100000:
node_complexity_dist[1] += repeated_count
elif 100000 < complexity_score <= 500000:
node_complexity_dist[2] += repeated_count
elif 500000 < complexity_score <= 1000000:
node_complexity_dist[3] += repeated_count
elif 1000000 < complexity_score <= 5000000:
node_complexity_dist[4] += repeated_count
elif 5000000 < complexity_score <= 10000000:
node_complexity_dist[5] += repeated_count
elif complexity_score > 10000000:
node_complexity_dist[6] += repeated_count

return node_complexity_dist

Expand Down

0 comments on commit c3156e3

Please sign in to comment.