Skip to content

Commit

Permalink
[SNOW-1566363] Add telemetry for new compilation stage (#2394)
Browse files Browse the repository at this point in the history
<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

SNOW-1566363

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.

  1. add status for number of selectStatement with complexity merged
2.add status for number of cte created during repeated subquery
elimination
  • Loading branch information
sfc-gh-yzou authored Oct 7, 2024
1 parent 56a8ae2 commit bd1c498
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ def __repr__(self) -> str:
return self.name


class PlanState(Enum):
"""
This is an enum class for the state that are extracted for a given SnowflakePlan
or SelectStatement.
"""

# the height of the given plan
PLAN_HEIGHT = "plan_height"
# the number of SelectStatement nodes in the plan that have
# _merge_projection_complexity_with_subquery set to True
NUM_SELECTS_WITH_COMPLEXITY_MERGED = "num_selects_with_complexity_merged"


def sum_node_complexities(
*node_complexities: Dict[PlanNodeCategory, int]
) -> Dict[PlanNodeCategory, int]:
Expand Down
5 changes: 3 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from snowflake.snowpark._internal.analyzer.cte_utils import encode_id
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
PlanState,
subtract_complexities,
sum_node_complexities,
)
Expand Down Expand Up @@ -327,8 +328,8 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan:
return self._snowflake_plan

@property
def plan_height(self) -> int:
return self.snowflake_plan.plan_height
def plan_state(self) -> Dict[PlanState, Any]:
return self.snowflake_plan.plan_state

@property
def num_duplicate_nodes(self) -> int:
Expand Down
28 changes: 21 additions & 7 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
PlanState,
)
from snowflake.snowpark._internal.analyzer.table_function import (
GeneratorTableFunction,
Expand Down Expand Up @@ -416,21 +417,34 @@ def output_dict(self) -> Dict[str, Any]:
return self._output_dict

@cached_property
def plan_height(self) -> int:
def num_duplicate_nodes(self) -> int:
duplicated_nodes, _ = find_duplicate_subtrees(self)
return len(duplicated_nodes)

@cached_property
def plan_state(self) -> Dict[PlanState, Any]:
from snowflake.snowpark._internal.analyzer.select_statement import (
SelectStatement,
)

height = 0
num_selects_with_complexity_merged = 0
current_level = [self]
while len(current_level) > 0:
next_level = []
for node in current_level:
next_level.extend(node.children_plan_nodes)
if (
isinstance(node, SelectStatement)
and node._merge_projection_complexity_with_subquery
):
num_selects_with_complexity_merged += 1
height += 1
current_level = next_level
return height

@cached_property
def num_duplicate_nodes(self) -> int:
duplicated_nodes, _ = find_duplicate_subtrees(self)
return len(duplicated_nodes)
return {
PlanState.PLAN_HEIGHT: height,
PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED: num_selects_with_complexity_merged,
}

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
Expand Down
12 changes: 10 additions & 2 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import copy
import time
from typing import Dict, List
from typing import Any, Dict, List

from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
get_complexity_score,
Expand Down Expand Up @@ -88,14 +88,20 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
# 2. create a code generator with the original plan
query_generator = create_query_generator(self._plan)

extra_optimization_status: Dict[str, Any] = {}
# 3. apply each optimizations if needed
# CTE optimization
cte_start_time = time.time()
if self._plan.session.cte_optimization_enabled:
repeated_subquery_eliminator = RepeatedSubqueryElimination(
logical_plans, query_generator
)
logical_plans = repeated_subquery_eliminator.apply()
elimination_result = repeated_subquery_eliminator.apply()
logical_plans = elimination_result.logical_plans
# add the extra repeated subquery elimination status
extra_optimization_status[
CompilationStageTelemetryField.CTE_NODE_CREATED.value
] = elimination_result.total_num_of_ctes

cte_end_time = time.time()
complexity_scores_after_cte = [
Expand Down Expand Up @@ -139,6 +145,8 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION.value: complexity_scores_after_cte,
CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN.value: complexity_scores_after_large_query_breakdown,
}
# add the extra optimization status
summary_value.update(extra_optimization_status)
session._conn._telemetry_client.send_query_compilation_summary_telemetry(
session_id=session.session_id,
plan_uuid=self._plan.uuid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@
)


class RepeatedSubqueryEliminationResult:
# the result logical plans after repeated subquery elimination
logical_plans: List[LogicalPlan]
# total number of cte nodes created the transformation
total_num_of_ctes: int

def __init__(
self,
logical_plans: List[LogicalPlan],
total_num_ctes: int,
) -> None:
self.logical_plans = logical_plans
self.total_num_of_ctes = total_num_ctes


class RepeatedSubqueryElimination:
"""
Optimization that used eliminate duplicated queries in the plan.
Expand All @@ -44,6 +59,7 @@ class RepeatedSubqueryElimination:
# original logical plans to apply the optimization on
_logical_plans: List[LogicalPlan]
_query_generator: QueryGenerator
_total_number_ctes: int

def __init__(
self,
Expand All @@ -52,8 +68,9 @@ def __init__(
) -> None:
self._logical_plans = logical_plans
self._query_generator = query_generator
self._total_number_ctes = 0

def apply(self) -> List[LogicalPlan]:
def apply(self) -> RepeatedSubqueryEliminationResult:
"""
Applies Common SubDataframe elimination on the set of logical plans one after another.
Expand All @@ -79,8 +96,10 @@ def apply(self) -> List[LogicalPlan]:
else:
final_logical_plans.append(logical_plan)

# TODO (SNOW-1566363): Add telemetry for CTE
return final_logical_plans
return RepeatedSubqueryEliminationResult(
logical_plans=final_logical_plans,
total_num_ctes=self._total_number_ctes,
)

def _replace_duplicate_node_with_cte(
self,
Expand Down Expand Up @@ -143,6 +162,7 @@ def _update_parents(
_update_parents(
node, should_replace_child=True, new_child=resolved_with_block
)
self._total_number_ctes += 1
elif node in updated_nodes:
# if the node is updated, make sure all nodes up to parent is updated
_update_parents(node, should_replace_child=False)
Expand Down
12 changes: 12 additions & 0 deletions src/snowflake/snowpark/_internal/compiler/telemetry_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@


class CompilationStageTelemetryField(Enum):
# dataframe query stats that are used for the
# new compilation stage optimizations
QUERY_PLAN_HEIGHT = "query_plan_height"
QUERY_PLAN_NUM_SELECTS_WITH_COMPLEXITY_MERGED = (
"query_plan_num_selects_with_complexity_merged"
)
QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes"
QUERY_PLAN_COMPLEXITY = "query_plan_complexity"

# types
TYPE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED = (
"snowpark_large_query_breakdown_optimization_skipped"
Expand All @@ -29,6 +38,9 @@ class CompilationStageTelemetryField(Enum):
"complexity_score_after_large_query_breakdown"
)

# keys for repeated subquery elimination
CTE_NODE_CREATED = "cte_node_created"


class SkipLargeQueryBreakdownCategory(Enum):
ACTIVE_TRANSACTION = "active transaction"
Expand Down
16 changes: 9 additions & 7 deletions src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TelemetryField as PCTelemetryField,
)
from snowflake.connector.time_util import get_time_millis
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import PlanState
from snowflake.snowpark._internal.compiler.telemetry_constants import (
CompilationStageTelemetryField,
)
Expand Down Expand Up @@ -75,10 +76,6 @@ class TelemetryField(Enum):
SQL_SIMPLIFIER_ENABLED = "sql_simplifier_enabled"
CTE_OPTIMIZATION_ENABLED = "cte_optimization_enabled"
LARGE_QUERY_BREAKDOWN_ENABLED = "large_query_breakdown_enabled"
# dataframe query stats
QUERY_PLAN_HEIGHT = "query_plan_height"
QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes"
QUERY_PLAN_COMPLEXITY = "query_plan_complexity"
# temp table cleanup
TYPE_TEMP_TABLE_CLEANUP = "snowpark_temp_table_cleanup"
NUM_TEMP_TABLES_CLEANED = "num_temp_tables_cleaned"
Expand Down Expand Up @@ -181,16 +178,21 @@ def wrap(*args, **kwargs):
0
]._session.sql_simplifier_enabled
try:
api_calls[0][TelemetryField.QUERY_PLAN_HEIGHT.value] = plan.plan_height
api_calls[0][
CompilationStageTelemetryField.QUERY_PLAN_HEIGHT.value
] = plan.plan_state[PlanState.PLAN_HEIGHT]
api_calls[0][
CompilationStageTelemetryField.QUERY_PLAN_NUM_SELECTS_WITH_COMPLEXITY_MERGED.value
] = plan.plan_state[PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED]
# The uuid for df._select_statement can be different from df._plan. Since plan
# can take both values, we cannot use plan.uuid. We always use df._plan.uuid
# to track the queries.
uuid = args[0]._plan.uuid
api_calls[0][CompilationStageTelemetryField.PLAN_UUID.value] = uuid
api_calls[0][
TelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value
CompilationStageTelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value
] = plan.num_duplicate_nodes
api_calls[0][TelemetryField.QUERY_PLAN_COMPLEXITY.value] = {
api_calls[0][CompilationStageTelemetryField.QUERY_PLAN_COMPLEXITY.value] = {
key.value: value
for key, value in plan.cumulative_node_complexity.items()
}
Expand Down
10 changes: 7 additions & 3 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections.abc import Iterable
from enum import Enum
from functools import cached_property, partial, reduce
from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union
from unittest.mock import MagicMock

from snowflake.snowpark._internal.analyzer.table_merge_expression import (
Expand Down Expand Up @@ -96,6 +96,7 @@
UnresolvedAttribute,
WithinGroup,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import PlanState
from snowflake.snowpark._internal.analyzer.snowflake_plan import (
PlanQueryType,
Query,
Expand Down Expand Up @@ -206,9 +207,12 @@ def output(self) -> List[Attribute]:
return [Attribute(a.name, a.datatype, a.nullable) for a in self.attributes]

@cached_property
def plan_height(self) -> int:
def plan_state(self) -> Dict[PlanState, Any]:
# dummy return
return -1
return {
PlanState.PLAN_HEIGHT: -1,
PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED: -1,
}

@cached_property
def num_duplicate_nodes(self) -> int:
Expand Down
35 changes: 22 additions & 13 deletions tests/integ/scala/test_snowflake_plan_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from snowflake.snowpark import Row
from snowflake.snowpark._internal.analyzer.analyzer_utils import schema_value_statement
from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import PlanState
from snowflake.snowpark._internal.analyzer.snowflake_plan import (
PlanQueryType,
Query,
Expand Down Expand Up @@ -189,46 +190,54 @@ def check_plan_queries(
)
def test_plan_height(session, temp_table, sql_simplifier_enabled):
df1 = session.table(temp_table)
assert df1._plan.plan_height == 1
assert df1._plan.plan_state[PlanState.PLAN_HEIGHT] == 1

df2 = session.create_dataframe([(1, 20), (3, 40)], schema=["a", "c"])
df3 = session.create_dataframe(
[(2, "twenty two"), (4, "forty four"), (4, "forty four")], schema=["b", "d"]
)
assert df2._plan.plan_height == 2
assert df2._plan.plan_height == 2
assert df2._plan.plan_state[PlanState.PLAN_HEIGHT] == 2
assert df2._plan.plan_state[PlanState.PLAN_HEIGHT] == 2

filter1 = df1.where(col("a") > 1)
assert filter1._plan.plan_height == 2
assert filter1._plan.plan_state[PlanState.PLAN_HEIGHT] == 2

join1 = filter1.join(df2, on=["a"])
assert join1._plan.plan_height == 4
assert join1._plan.plan_state[PlanState.PLAN_HEIGHT] == 4

aggregate1 = df3.distinct()
if sql_simplifier_enabled:
assert aggregate1._plan.plan_height == 4
assert aggregate1._plan.plan_state[PlanState.PLAN_HEIGHT] == 4
else:
assert aggregate1._plan.plan_height == 3
assert aggregate1._plan.plan_state[PlanState.PLAN_HEIGHT] == 3

join2 = join1.join(aggregate1, on=["b"])
assert join2._plan.plan_height == 6
assert join2._plan.plan_state[PlanState.PLAN_HEIGHT] == 6

split_to_table = table_function("split_to_table")
table_function1 = join2.select("a", "b", split_to_table("d", lit(" ")))
assert table_function1._plan.plan_height == 8
assert table_function1._plan.plan_state[PlanState.PLAN_HEIGHT] == 8

filter3 = join2.where(col("a") > 1)
filter4 = join2.where(col("a") < 1)
if sql_simplifier_enabled:
assert filter3._plan.plan_height == filter4._plan.plan_height == 6
assert (
filter3._plan.plan_state[PlanState.PLAN_HEIGHT]
== filter4._plan.plan_state[PlanState.PLAN_HEIGHT]
== 6
)
else:
assert filter3._plan.plan_height == filter4._plan.plan_height == 7
assert (
filter3._plan.plan_state[PlanState.PLAN_HEIGHT]
== filter4._plan.plan_state[PlanState.PLAN_HEIGHT]
== 7
)

union1 = filter3.union_all_by_name(filter4)
if sql_simplifier_enabled:
assert union1._plan.plan_height == 8
assert union1._plan.plan_state[PlanState.PLAN_HEIGHT] == 8
else:
assert union1._plan.plan_height == 9
assert union1._plan.plan_state[PlanState.PLAN_HEIGHT] == 9


def test_plan_num_duplicate_nodes_describe_query(session, temp_table):
Expand Down
Loading

0 comments on commit bd1c498

Please sign in to comment.