From 3abb52b0f33158a9f20c04fef7c19fec0bd751c6 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 31 Jan 2025 14:58:05 -0800 Subject: [PATCH 1/6] fix telemetry collection for all SnowflakePlans --- .../snowpark/_internal/server_connection.py | 38 ++++++ src/snowflake/snowpark/_internal/telemetry.py | 41 ++---- tests/integ/test_telemetry.py | 118 +++++------------- 3 files changed, 82 insertions(+), 115 deletions(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 0323b46979f..dcbebf5a575 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -36,6 +36,7 @@ ) from snowflake.snowpark._internal.analyzer.datatype_mapper import str_to_sql from snowflake.snowpark._internal.analyzer.expression import Attribute +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import PlanState from snowflake.snowpark._internal.analyzer.schema_utils import ( convert_result_meta_to_attribute, get_new_description, @@ -48,6 +49,9 @@ SnowflakePlan, ) from snowflake.snowpark._internal.ast.utils import DATAFRAME_AST_PARAMETER +from snowflake.snowpark._internal.compiler.telemetry_constants import ( + CompilationStageTelemetryField, +) from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.telemetry import TelemetryClient from snowflake.snowpark._internal.utils import ( @@ -661,6 +665,7 @@ def get_result_set( statement_params = kwargs.get("_statement_params", None) or {} statement_params["_PLAN_UUID"] = plan.uuid kwargs["_statement_params"] = statement_params + self.send_plan_metrics_telemetry(plan) try: main_queries = plan_queries[PlanQueryType.QUERIES] post_actions = plan_queries[PlanQueryType.POST_ACTIONS] @@ -765,6 +770,39 @@ def get_result_set( return result, result_meta + def send_plan_metrics_telemetry(self, plan: SnowflakePlan) -> None: + try: + data = {} + plan_state = plan.plan_state + data[CompilationStageTelemetryField.QUERY_PLAN_HEIGHT.value] = plan_state[ + PlanState.PLAN_HEIGHT + ] + data[ + CompilationStageTelemetryField.QUERY_PLAN_NUM_SELECTS_WITH_COMPLEXITY_MERGED.value + ] = plan_state[PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED] + data[ + CompilationStageTelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value + ] = plan_state[PlanState.NUM_CTE_NODES] + data[ + CompilationStageTelemetryField.QUERY_PLAN_DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION.value + ] = plan_state[PlanState.DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION] + + # The uuid for df._select_statement can be different from df._plan. Since plan + # can take both values, we cannot use plan.uuid. We always use df._plan.uuid + # to track the queries. + uuid = plan.uuid + data[CompilationStageTelemetryField.PLAN_UUID.value] = uuid + data[CompilationStageTelemetryField.QUERY_PLAN_COMPLEXITY.value] = { + key.value: value + for key, value in plan.cumulative_node_complexity.items() + } + + self._telemetry_client.send_plan_metrics_telemetry( + session_id=self.get_session_id(), data=data + ) + except Exception: + pass + def get_result_and_metadata( self, plan: SnowflakePlan, **kwargs ) -> Tuple[List[Row], List[Attribute]]: diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index f449ffc5cc6..d8bfcfbf579 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -15,7 +15,6 @@ TelemetryField as PCTelemetryField, ) from snowflake.connector.time_util import get_time_millis -from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import PlanState from snowflake.snowpark._internal.compiler.telemetry_constants import ( CompilationStageTelemetryField, ) @@ -184,33 +183,7 @@ def wrap(*args, **kwargs): api_calls[0][TelemetryField.SQL_SIMPLIFIER_ENABLED.value] = args[ 0 ]._session.sql_simplifier_enabled - try: - plan_state = plan.plan_state - api_calls[0][ - CompilationStageTelemetryField.QUERY_PLAN_HEIGHT.value - ] = plan_state[PlanState.PLAN_HEIGHT] - api_calls[0][ - CompilationStageTelemetryField.QUERY_PLAN_NUM_SELECTS_WITH_COMPLEXITY_MERGED.value - ] = plan_state[PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED] - api_calls[0][ - CompilationStageTelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value - ] = plan_state[PlanState.NUM_CTE_NODES] - api_calls[0][ - CompilationStageTelemetryField.QUERY_PLAN_DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION.value - ] = plan_state[PlanState.DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION] - - # The uuid for df._select_statement can be different from df._plan. Since plan - # can take both values, we cannot use plan.uuid. We always use df._plan.uuid - # to track the queries. - uuid = args[0]._plan.uuid - api_calls[0][CompilationStageTelemetryField.PLAN_UUID.value] = uuid - api_calls[0][CompilationStageTelemetryField.QUERY_PLAN_COMPLEXITY.value] = { - key.value: value - for key, value in plan.cumulative_node_complexity.items() - } - api_calls[0][TelemetryField.THREAD_IDENTIFIER.value] = threading.get_ident() - except Exception: - pass + api_calls[0][TelemetryField.THREAD_IDENTIFIER.value] = threading.get_ident() args[0]._session._conn._telemetry_client.send_function_usage_telemetry( f"action_{func.__name__}", TelemetryField.FUNC_CAT_ACTION.value, @@ -501,6 +474,18 @@ def send_query_compilation_stage_failed_telemetry( } self.send(message) + def send_plan_metrics_telemetry(self, session_id: int, data) -> None: + message = { + **self._create_basic_telemetry_data( + CompilationStageTelemetryField.TYPE_COMPILATION_STAGE_STATISTICS.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + **data, + }, + } + self.send(message) + def send_temp_table_cleanup_telemetry( self, session_id: str, diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index 027bbf88778..df8dd2b3c80 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -589,30 +589,13 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): df.collect() # API calls don't change after query is executed - query_plan_height = 2 if sql_simplifier_enabled else 3 - filter = 1 if sql_simplifier_enabled else 2 - low_impact = 3 if sql_simplifier_enabled else 2 thread_ident = threading.get_ident() assert df._plan.api_calls == [ { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, - "plan_uuid": df._plan.uuid, "thread_ident": thread_ident, - "query_plan_height": query_plan_height, - "query_plan_num_duplicate_nodes": 0, - "query_plan_num_selects_with_complexity_merged": 0, - "query_plan_duplicated_node_complexity_distribution": [0, 0, 0, 0, 0, 0, 0], - "query_plan_complexity": { - "filter": filter, - "low_impact": low_impact, - "function": 3, - "column": 3, - "literal": 5, - "window": 1, - "order_by": 1, - }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -624,21 +607,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, - "plan_uuid": df._plan.uuid, "thread_ident": thread_ident, - "query_plan_height": query_plan_height, - "query_plan_num_duplicate_nodes": 0, - "query_plan_num_selects_with_complexity_merged": 0, - "query_plan_duplicated_node_complexity_distribution": [0, 0, 0, 0, 0, 0, 0], - "query_plan_complexity": { - "filter": filter, - "low_impact": low_impact, - "function": 3, - "column": 3, - "literal": 5, - "window": 1, - "order_by": 1, - }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -650,21 +619,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, - "plan_uuid": df._plan.uuid, "thread_ident": thread_ident, - "query_plan_height": query_plan_height, - "query_plan_num_duplicate_nodes": 0, - "query_plan_num_selects_with_complexity_merged": 0, - "query_plan_duplicated_node_complexity_distribution": [0, 0, 0, 0, 0, 0, 0], - "query_plan_complexity": { - "filter": filter, - "low_impact": low_impact, - "function": 3, - "column": 3, - "literal": 5, - "window": 1, - "order_by": 1, - }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -676,21 +631,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, - "plan_uuid": df._plan.uuid, "thread_ident": thread_ident, - "query_plan_height": query_plan_height, - "query_plan_num_duplicate_nodes": 0, - "query_plan_num_selects_with_complexity_merged": 0, - "query_plan_duplicated_node_complexity_distribution": [0, 0, 0, 0, 0, 0, 0], - "query_plan_complexity": { - "filter": filter, - "low_impact": low_impact, - "function": 3, - "column": 3, - "literal": 5, - "window": 1, - "order_by": 1, - }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -702,21 +643,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, - "plan_uuid": df._plan.uuid, "thread_ident": thread_ident, - "query_plan_height": query_plan_height, - "query_plan_num_duplicate_nodes": 0, - "query_plan_num_selects_with_complexity_merged": 0, - "query_plan_duplicated_node_complexity_distribution": [0, 0, 0, 0, 0, 0, 0], - "query_plan_complexity": { - "filter": filter, - "low_impact": low_impact, - "function": 3, - "column": 3, - "literal": 5, - "window": 1, - "order_by": 1, - }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -846,24 +773,16 @@ def test_dataframe_stat_functions_api_calls(session): # check to make sure that the original DF is unchanged assert df._plan.api_calls == [{"name": "Session.create_dataframe[values]"}] - column = 6 if session.sql_simplifier_enabled else 9 crosstab = df.stat.crosstab("empid", "month") # uuid here is generated by an intermediate dataframe in crosstab implementation # therefore we can't predict it. We check that the uuid for crosstab is same as # that for df. - uuid = df._plan.api_calls[0]["plan_uuid"] thread_ident = threading.get_ident() assert crosstab._plan.api_calls == [ { "name": "Session.create_dataframe[values]", "sql_simplifier_enabled": session.sql_simplifier_enabled, - "plan_uuid": uuid, "thread_ident": thread_ident, - "query_plan_height": 4, - "query_plan_num_duplicate_nodes": 0, - "query_plan_num_selects_with_complexity_merged": 0, - "query_plan_duplicated_node_complexity_distribution": [0, 0, 0, 0, 0, 0, 0], - "query_plan_complexity": {"group_by": 1, "column": column, "literal": 48}, }, { "name": "DataFrameStatFunctions.crosstab", @@ -880,13 +799,7 @@ def test_dataframe_stat_functions_api_calls(session): { "name": "Session.create_dataframe[values]", "sql_simplifier_enabled": session.sql_simplifier_enabled, - "plan_uuid": uuid, "thread_ident": thread_ident, - "query_plan_height": 4, - "query_plan_num_duplicate_nodes": 0, - "query_plan_num_selects_with_complexity_merged": 0, - "query_plan_duplicated_node_complexity_distribution": [0, 0, 0, 0, 0, 0, 0], - "query_plan_complexity": {"group_by": 1, "column": column, "literal": 48}, } ] @@ -1309,3 +1222,34 @@ def send_telemetry(): data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) assert data == expected_data assert type_ == "snowpark_describe_query_details" + + +def test_plan_metrics_telemetry(session): + client = session._conn._telemetry_client + telemetry_data = { + "plan_uuid": "plan_uuid_placeholder", + "query_plan_height": 10, + "query_plan_num_duplicate_nodes": 5, + "query_plan_num_selects_with_complexity_merged": 3, + "query_plan_duplicated_node_complexity_distribution": [1, 2, 3], + "query_plan_complexity": { + "filter": 1, + "low_impact": 2, + "function": 3, + "column": 4, + "literal": 5, + "window": 6, + "order_by": 7, + }, + } + + def send_telemetry(): + client.send_plan_metrics_telemetry(session.session_id, data=telemetry_data) + + telemetry_tracker = TelemetryDataTracker(session) + + expected_data = {"session_id": session.session_id, **telemetry_data} + + data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) + assert data == expected_data + assert type_ == "snowpark_compilation_stage_statistics" From e19517850fa7cec0c7746c12a698aea5bbe839fc Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 31 Jan 2025 15:06:16 -0800 Subject: [PATCH 2/6] fix documentation --- src/snowflake/snowpark/_internal/server_connection.py | 6 +++++- src/snowflake/snowpark/_internal/telemetry.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index dcbebf5a575..6edc3450697 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -36,7 +36,7 @@ ) from snowflake.snowpark._internal.analyzer.datatype_mapper import str_to_sql from snowflake.snowpark._internal.analyzer.expression import Attribute -from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import PlanState +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import PlanState, get_complexity_score from snowflake.snowpark._internal.analyzer.schema_utils import ( convert_result_meta_to_attribute, get_new_description, @@ -771,6 +771,9 @@ def get_result_set( return result, result_meta def send_plan_metrics_telemetry(self, plan: SnowflakePlan) -> None: + """Extract the SnowflakePlan's metrics and including plan_state, uuid identifiers, complexity + classification breakdown, and complexity score. + """ try: data = {} plan_state = plan.plan_state @@ -796,6 +799,7 @@ def send_plan_metrics_telemetry(self, plan: SnowflakePlan) -> None: key.value: value for key, value in plan.cumulative_node_complexity.items() } + data[CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value] = get_complexity_score(plan) self._telemetry_client.send_plan_metrics_telemetry( session_id=self.get_session_id(), data=data diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index d8bfcfbf579..731948341fc 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -474,7 +474,7 @@ def send_query_compilation_stage_failed_telemetry( } self.send(message) - def send_plan_metrics_telemetry(self, session_id: int, data) -> None: + def send_plan_metrics_telemetry(self, session_id: int, data: Dict[str, Any]) -> None: message = { **self._create_basic_telemetry_data( CompilationStageTelemetryField.TYPE_COMPILATION_STAGE_STATISTICS.value From 2a533cac036386ec23a9253c967a8e5820011337 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 31 Jan 2025 15:07:18 -0800 Subject: [PATCH 3/6] also add complexity score to data --- src/snowflake/snowpark/_internal/server_connection.py | 9 +++++++-- src/snowflake/snowpark/_internal/telemetry.py | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 6edc3450697..f834dc8a8ae 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -36,7 +36,10 @@ ) from snowflake.snowpark._internal.analyzer.datatype_mapper import str_to_sql from snowflake.snowpark._internal.analyzer.expression import Attribute -from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import PlanState, get_complexity_score +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanState, + get_complexity_score, +) from snowflake.snowpark._internal.analyzer.schema_utils import ( convert_result_meta_to_attribute, get_new_description, @@ -799,7 +802,9 @@ def send_plan_metrics_telemetry(self, plan: SnowflakePlan) -> None: key.value: value for key, value in plan.cumulative_node_complexity.items() } - data[CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value] = get_complexity_score(plan) + data[ + CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value + ] = get_complexity_score(plan) self._telemetry_client.send_plan_metrics_telemetry( session_id=self.get_session_id(), data=data diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 731948341fc..7cf44ca721d 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -474,7 +474,9 @@ def send_query_compilation_stage_failed_telemetry( } self.send(message) - def send_plan_metrics_telemetry(self, session_id: int, data: Dict[str, Any]) -> None: + def send_plan_metrics_telemetry( + self, session_id: int, data: Dict[str, Any] + ) -> None: message = { **self._create_basic_telemetry_data( CompilationStageTelemetryField.TYPE_COMPILATION_STAGE_STATISTICS.value From 479e7d61bfd7780d6eb528797e825b346554d71d Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 13 Feb 2025 15:01:18 -0800 Subject: [PATCH 4/6] add param protection --- .../_internal/analyzer/select_statement.py | 3 ++ .../_internal/analyzer/snowflake_plan.py | 36 +++++++++++++ .../_internal/compiler/telemetry_constants.py | 5 ++ .../snowpark/_internal/server_connection.py | 53 +++---------------- src/snowflake/snowpark/_internal/telemetry.py | 33 ++++++++++-- src/snowflake/snowpark/session.py | 9 ++++ 6 files changed, 88 insertions(+), 51 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index f8256365ce7..3bfc2e97f2f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -350,6 +350,9 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan: def plan_state(self) -> Dict[PlanState, Any]: return self.snowflake_plan.plan_state + def get_plan_telemetry_metrics(self) -> Dict[str, Any]: + return self.snowflake_plan.get_plan_telemetry_metrics() + @property def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: with self._session._plan_lock: diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 31982f04124..a0ce9ef4876 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -22,9 +22,14 @@ Union, ) +from snowflake.snowpark._internal.compiler.telemetry_constants import ( + CompilationStageTelemetryField, +) + from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, PlanState, + get_complexity_score, ) from snowflake.snowpark._internal.analyzer.table_function import ( GeneratorTableFunction, @@ -383,6 +388,37 @@ def output_dict(self) -> Dict[str, Any]: } return self._output_dict + def get_plan_telemetry_metrics(self) -> Dict[str, Any]: + data = {CompilationStageTelemetryField.PLAN_UUID.value: self.uuid} + try: + # plan state + plan_state = self.plan_state + data[CompilationStageTelemetryField.QUERY_PLAN_HEIGHT.value] = plan_state[ + PlanState.PLAN_HEIGHT + ] + data[ + CompilationStageTelemetryField.QUERY_PLAN_NUM_SELECTS_WITH_COMPLEXITY_MERGED.value + ] = plan_state[PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED] + data[ + CompilationStageTelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value + ] = plan_state[PlanState.NUM_CTE_NODES] + data[ + CompilationStageTelemetryField.QUERY_PLAN_DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION.value + ] = plan_state[PlanState.DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION] + + # plan complexity score + data[CompilationStageTelemetryField.QUERY_PLAN_COMPLEXITY.value] = { + key.value: value + for key, value in self.cumulative_node_complexity.items() + } + data[ + CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value + ] = get_complexity_score(self) + except Exception as e: + data[CompilationStageTelemetryField.ERROR_MESSAGE.value] = str(e) + + return data + @property def plan_state(self) -> Dict[PlanState, Any]: with self.session._plan_lock: diff --git a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py index bfabcc3115d..326828f29a3 100644 --- a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py +++ b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py @@ -24,6 +24,11 @@ class CompilationStageTelemetryField(Enum): "snowpark_large_query_breakdown_update_complexity_bounds" ) + # categories + CAT_COMPILATION_STAGE_STATS = "query_compilation_stage_statistics" + CAT_COMPILATION_STAGE_ERROR = "query_compilation_stage_error" + CAT_SNOWFLAKE_PLAN_METRICS = "snowflake_plan_metrics" + # keys KEY_REASON = "reason" PLAN_UUID = "plan_uuid" diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index e0574d42059..23adbcb4c6e 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -36,10 +36,6 @@ ) from snowflake.snowpark._internal.analyzer.datatype_mapper import str_to_sql from snowflake.snowpark._internal.analyzer.expression import Attribute -from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanState, - get_complexity_score, -) from snowflake.snowpark._internal.analyzer.schema_utils import ( convert_result_meta_to_attribute, get_new_description, @@ -52,9 +48,6 @@ SnowflakePlan, ) from snowflake.snowpark._internal.ast.utils import DATAFRAME_AST_PARAMETER -from snowflake.snowpark._internal.compiler.telemetry_constants import ( - CompilationStageTelemetryField, -) from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.telemetry import TelemetryClient from snowflake.snowpark._internal.utils import ( @@ -682,7 +675,6 @@ def get_result_set( statement_params = kwargs.get("_statement_params", None) or {} statement_params["_PLAN_UUID"] = plan.uuid kwargs["_statement_params"] = statement_params - self.send_plan_metrics_telemetry(plan) try: main_queries = plan_queries[PlanQueryType.QUERIES] post_actions = plan_queries[PlanQueryType.POST_ACTIONS] @@ -785,50 +777,17 @@ def get_result_set( **kwargs, ) + if plan.session._collect_snowflake_plan_telemetry_at_critical_path: + data = plan.get_plan_telemetry_metrics() + self._telemetry_client.send_plan_metrics_telemetry( + session_id=self.get_session_id(), data=data + ) + if result is None: raise SnowparkClientExceptionMessages.SQL_LAST_QUERY_RETURN_RESULTSET() return result, result_meta - def send_plan_metrics_telemetry(self, plan: SnowflakePlan) -> None: - """Extract the SnowflakePlan's metrics and including plan_state, uuid identifiers, complexity - classification breakdown, and complexity score. - """ - try: - data = {} - plan_state = plan.plan_state - data[CompilationStageTelemetryField.QUERY_PLAN_HEIGHT.value] = plan_state[ - PlanState.PLAN_HEIGHT - ] - data[ - CompilationStageTelemetryField.QUERY_PLAN_NUM_SELECTS_WITH_COMPLEXITY_MERGED.value - ] = plan_state[PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED] - data[ - CompilationStageTelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value - ] = plan_state[PlanState.NUM_CTE_NODES] - data[ - CompilationStageTelemetryField.QUERY_PLAN_DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION.value - ] = plan_state[PlanState.DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION] - - # The uuid for df._select_statement can be different from df._plan. Since plan - # can take both values, we cannot use plan.uuid. We always use df._plan.uuid - # to track the queries. - uuid = plan.uuid - data[CompilationStageTelemetryField.PLAN_UUID.value] = uuid - data[CompilationStageTelemetryField.QUERY_PLAN_COMPLEXITY.value] = { - key.value: value - for key, value in plan.cumulative_node_complexity.items() - } - data[ - CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value - ] = get_complexity_score(plan) - - self._telemetry_client.send_plan_metrics_telemetry( - session_id=self.get_session_id(), data=data - ) - except Exception: - pass - def get_result_and_metadata( self, plan: SnowflakePlan, **kwargs ) -> Tuple[List[Row], List[Attribute]]: diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 7cf44ca721d..5a5f0152e02 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -169,9 +169,20 @@ def wrap(*args, **kwargs): def df_collect_api_telemetry(func): @functools.wraps(func) def wrap(*args, **kwargs): - with args[0]._session.query_history() as query_history: - result = func(*args, **kwargs) plan = args[0]._select_statement or args[0]._plan + with args[0]._session.query_history() as query_history: + try: + result = func(*args, **kwargs) + finally: + if not args[ + 0 + ]._session._collect_snowflake_plan_telemetry_at_critical_path: + args[ + 0 + ]._session._conn._telemetry_client.send_plan_metrics_telemetry( + session_id=args[0]._session.session_id, + data=plan.get_plan_metrics(), + ) api_calls = [ *plan.api_calls, {TelemetryField.NAME.value: f"DataFrame.{func.__name__}"}, @@ -198,9 +209,20 @@ def wrap(*args, **kwargs): def dfw_collect_api_telemetry(func): @functools.wraps(func) def wrap(*args, **kwargs): - with args[0]._dataframe._session.query_history() as query_history: - result = func(*args, **kwargs) plan = args[0]._dataframe._select_statement or args[0]._dataframe._plan + with args[0]._dataframe._session.query_history() as query_history: + try: + result = func(*args, **kwargs) + finally: + if not args[ + 0 + ]._session._collect_snowflake_plan_telemetry_at_critical_path: + args[ + 0 + ]._session._conn._telemetry_client.send_plan_metrics_telemetry( + session_id=args[0]._session.session_id, + data=plan.get_plan_metrics(), + ) api_calls = [ *plan.api_calls, {TelemetryField.NAME.value: f"DataFrameWriter.{func.__name__}"}, @@ -452,6 +474,7 @@ def send_query_compilation_summary_telemetry( ), TelemetryField.KEY_DATA.value: { TelemetryField.SESSION_ID.value: session_id, + TelemetryField.KEY_CATEGORY.value: CompilationStageTelemetryField.CAT_COMPILATION_STAGE_STATS.value, CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid, **compilation_stage_summary, }, @@ -467,6 +490,7 @@ def send_query_compilation_stage_failed_telemetry( ), TelemetryField.KEY_DATA.value: { TelemetryField.SESSION_ID.value: session_id, + TelemetryField.KEY_CATEGORY.value: CompilationStageTelemetryField.CAT_COMPILATION_STAGE_ERROR.value, CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid, CompilationStageTelemetryField.ERROR_TYPE.value: error_type, CompilationStageTelemetryField.ERROR_MESSAGE.value: error_message, @@ -483,6 +507,7 @@ def send_plan_metrics_telemetry( ), TelemetryField.KEY_DATA.value: { TelemetryField.SESSION_ID.value: session_id, + TelemetryField.KEY_CATEGORY.value: CompilationStageTelemetryField.CAT_SNOWFLAKE_PLAN_METRICS.value, **data, }, } diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 1a814a42b60..19df71a47b6 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -268,6 +268,10 @@ _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION = ( "PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION" ) +# Flag to control sending snowflake plan telemetry data from get_result_set +_PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION = ( + "PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION" +) # Flag for controlling the usage of scoped temp read only table. _PYTHON_SNOWPARK_ENABLE_SCOPED_TEMP_READ_ONLY_TABLE = ( "PYTHON_SNOWPARK_ENABLE_SCOPED_TEMP_READ_ONLY_TABLE" @@ -662,6 +666,11 @@ def __init__( self._plan_lock = create_rlock(self._conn._thread_safe_session_enabled) self._custom_package_usage_config: Dict = {} + self._collect_snowflake_plan_telemetry_at_critical_path: bool = ( + self.is_feature_enabled_for_version( + _PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION + ) + ) self._conf = self.RuntimeConfig(self, options or {}) self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) From 457df020c9cd45fbf01b2d2370dfa6a75191a41f Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 13 Feb 2025 15:24:07 -0800 Subject: [PATCH 5/6] minor refactor --- .../_internal/analyzer/select_statement.py | 3 - .../_internal/analyzer/snowflake_plan.py | 36 -------- .../snowpark/_internal/server_connection.py | 7 +- src/snowflake/snowpark/_internal/telemetry.py | 83 +++++++++++++------ 4 files changed, 62 insertions(+), 67 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 3bfc2e97f2f..f8256365ce7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -350,9 +350,6 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan: def plan_state(self) -> Dict[PlanState, Any]: return self.snowflake_plan.plan_state - def get_plan_telemetry_metrics(self) -> Dict[str, Any]: - return self.snowflake_plan.get_plan_telemetry_metrics() - @property def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: with self._session._plan_lock: diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index a0ce9ef4876..31982f04124 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -22,14 +22,9 @@ Union, ) -from snowflake.snowpark._internal.compiler.telemetry_constants import ( - CompilationStageTelemetryField, -) - from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, PlanState, - get_complexity_score, ) from snowflake.snowpark._internal.analyzer.table_function import ( GeneratorTableFunction, @@ -388,37 +383,6 @@ def output_dict(self) -> Dict[str, Any]: } return self._output_dict - def get_plan_telemetry_metrics(self) -> Dict[str, Any]: - data = {CompilationStageTelemetryField.PLAN_UUID.value: self.uuid} - try: - # plan state - plan_state = self.plan_state - data[CompilationStageTelemetryField.QUERY_PLAN_HEIGHT.value] = plan_state[ - PlanState.PLAN_HEIGHT - ] - data[ - CompilationStageTelemetryField.QUERY_PLAN_NUM_SELECTS_WITH_COMPLEXITY_MERGED.value - ] = plan_state[PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED] - data[ - CompilationStageTelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value - ] = plan_state[PlanState.NUM_CTE_NODES] - data[ - CompilationStageTelemetryField.QUERY_PLAN_DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION.value - ] = plan_state[PlanState.DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION] - - # plan complexity score - data[CompilationStageTelemetryField.QUERY_PLAN_COMPLEXITY.value] = { - key.value: value - for key, value in self.cumulative_node_complexity.items() - } - data[ - CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value - ] = get_complexity_score(self) - except Exception as e: - data[CompilationStageTelemetryField.ERROR_MESSAGE.value] = str(e) - - return data - @property def plan_state(self) -> Dict[PlanState, Any]: with self.session._plan_lock: diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 23adbcb4c6e..9f579977bbc 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -49,7 +49,10 @@ ) from snowflake.snowpark._internal.ast.utils import DATAFRAME_AST_PARAMETER from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages -from snowflake.snowpark._internal.telemetry import TelemetryClient +from snowflake.snowpark._internal.telemetry import ( + TelemetryClient, + get_plan_telemetry_metrics, +) from snowflake.snowpark._internal.utils import ( create_rlock, create_thread_local, @@ -778,7 +781,7 @@ def get_result_set( ) if plan.session._collect_snowflake_plan_telemetry_at_critical_path: - data = plan.get_plan_telemetry_metrics() + data = get_plan_telemetry_metrics(plan) self._telemetry_client.send_plan_metrics_telemetry( session_id=self.get_session_id(), data=data ) diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 5a5f0152e02..89b80eb1d25 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -6,7 +6,7 @@ import functools import threading from enum import Enum, unique -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from snowflake.connector import SnowflakeConnection from snowflake.connector.telemetry import ( @@ -15,6 +15,12 @@ TelemetryField as PCTelemetryField, ) from snowflake.connector.time_util import get_time_millis +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanState, + get_complexity_score, +) +from snowflake.snowpark._internal.analyzer.select_statement import Selectable +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan from snowflake.snowpark._internal.compiler.telemetry_constants import ( CompilationStageTelemetryField, ) @@ -170,18 +176,15 @@ def df_collect_api_telemetry(func): @functools.wraps(func) def wrap(*args, **kwargs): plan = args[0]._select_statement or args[0]._plan - with args[0]._session.query_history() as query_history: + session = args[0]._session + with session.query_history() as query_history: try: result = func(*args, **kwargs) finally: - if not args[ - 0 - ]._session._collect_snowflake_plan_telemetry_at_critical_path: - args[ - 0 - ]._session._conn._telemetry_client.send_plan_metrics_telemetry( - session_id=args[0]._session.session_id, - data=plan.get_plan_metrics(), + if not session._collect_snowflake_plan_telemetry_at_critical_path: + session._conn._telemetry_client.send_plan_metrics_telemetry( + session_id=session.session_id, + data=get_plan_telemetry_metrics(plan), ) api_calls = [ *plan.api_calls, @@ -191,11 +194,11 @@ def wrap(*args, **kwargs): # - sql simplifier is enabled. # - height of the query plan # - number of unique duplicate subtrees in the query plan - api_calls[0][TelemetryField.SQL_SIMPLIFIER_ENABLED.value] = args[ - 0 - ]._session.sql_simplifier_enabled + api_calls[0][ + TelemetryField.SQL_SIMPLIFIER_ENABLED.value + ] = session.sql_simplifier_enabled api_calls[0][TelemetryField.THREAD_IDENTIFIER.value] = threading.get_ident() - args[0]._session._conn._telemetry_client.send_function_usage_telemetry( + session._conn._telemetry_client.send_function_usage_telemetry( f"action_{func.__name__}", TelemetryField.FUNC_CAT_ACTION.value, api_calls=api_calls, @@ -210,26 +213,21 @@ def dfw_collect_api_telemetry(func): @functools.wraps(func) def wrap(*args, **kwargs): plan = args[0]._dataframe._select_statement or args[0]._dataframe._plan - with args[0]._dataframe._session.query_history() as query_history: + session = args[0]._dataframe._session + with session.query_history() as query_history: try: result = func(*args, **kwargs) finally: - if not args[ - 0 - ]._session._collect_snowflake_plan_telemetry_at_critical_path: - args[ - 0 - ]._session._conn._telemetry_client.send_plan_metrics_telemetry( - session_id=args[0]._session.session_id, - data=plan.get_plan_metrics(), + if not session._collect_snowflake_plan_telemetry_at_critical_path: + args[0].session._conn._telemetry_client.send_plan_metrics_telemetry( + session_id=session.session_id, + data=get_plan_telemetry_metrics(plan), ) api_calls = [ *plan.api_calls, {TelemetryField.NAME.value: f"DataFrameWriter.{func.__name__}"}, ] - args[ - 0 - ]._dataframe._session._conn._telemetry_client.send_function_usage_telemetry( + session._conn._telemetry_client.send_function_usage_telemetry( f"action_{func.__name__}", TelemetryField.FUNC_CAT_ACTION.value, api_calls=api_calls, @@ -296,6 +294,39 @@ def wrap(*args, **kwargs): return wrap +def get_plan_telemetry_metrics( + plan: Union[SnowflakePlan, Selectable] +) -> Dict[str, Any]: + data = {CompilationStageTelemetryField.PLAN_UUID.value: plan.uuid} + try: + # plan state + plan_state = plan.plan_state + data[CompilationStageTelemetryField.QUERY_PLAN_HEIGHT.value] = plan_state[ + PlanState.PLAN_HEIGHT + ] + data[ + CompilationStageTelemetryField.QUERY_PLAN_NUM_SELECTS_WITH_COMPLEXITY_MERGED.value + ] = plan_state[PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED] + data[ + CompilationStageTelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value + ] = plan_state[PlanState.NUM_CTE_NODES] + data[ + CompilationStageTelemetryField.QUERY_PLAN_DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION.value + ] = plan_state[PlanState.DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION] + + # plan complexity score + data[CompilationStageTelemetryField.QUERY_PLAN_COMPLEXITY.value] = { + key.value: value for key, value in plan.cumulative_node_complexity.items() + } + data[ + CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value + ] = get_complexity_score(plan) + except Exception as e: + data[CompilationStageTelemetryField.ERROR_MESSAGE.value] = str(e) + + return data + + class TelemetryClient: def __init__(self, conn: SnowflakeConnection) -> None: self.telemetry: PCTelemetryClient = ( From d45cc057bd51b4c9e4d197806b55c3a14bd056fa Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 13 Feb 2025 16:57:46 -0800 Subject: [PATCH 6/6] minimalize refactor --- .../snowpark/_internal/server_connection.py | 4 ++-- src/snowflake/snowpark/_internal/telemetry.py | 22 ++++++++----------- tests/integ/test_telemetry.py | 2 ++ 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 9f579977bbc..d650cda89a2 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -781,9 +781,9 @@ def get_result_set( ) if plan.session._collect_snowflake_plan_telemetry_at_critical_path: - data = get_plan_telemetry_metrics(plan) self._telemetry_client.send_plan_metrics_telemetry( - session_id=self.get_session_id(), data=data + session_id=self.get_session_id(), + data=get_plan_telemetry_metrics(plan), ) if result is None: diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 89b80eb1d25..ea4c507e04e 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -6,7 +6,7 @@ import functools import threading from enum import Enum, unique -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from snowflake.connector import SnowflakeConnection from snowflake.connector.telemetry import ( @@ -19,7 +19,6 @@ PlanState, get_complexity_score, ) -from snowflake.snowpark._internal.analyzer.select_statement import Selectable from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan from snowflake.snowpark._internal.compiler.telemetry_constants import ( CompilationStageTelemetryField, @@ -175,7 +174,6 @@ def wrap(*args, **kwargs): def df_collect_api_telemetry(func): @functools.wraps(func) def wrap(*args, **kwargs): - plan = args[0]._select_statement or args[0]._plan session = args[0]._session with session.query_history() as query_history: try: @@ -184,16 +182,15 @@ def wrap(*args, **kwargs): if not session._collect_snowflake_plan_telemetry_at_critical_path: session._conn._telemetry_client.send_plan_metrics_telemetry( session_id=session.session_id, - data=get_plan_telemetry_metrics(plan), + data=get_plan_telemetry_metrics(args[0]._plan), ) + plan = args[0]._select_statement or args[0]._plan api_calls = [ *plan.api_calls, {TelemetryField.NAME.value: f"DataFrame.{func.__name__}"}, ] # The first api call will indicate following: # - sql simplifier is enabled. - # - height of the query plan - # - number of unique duplicate subtrees in the query plan api_calls[0][ TelemetryField.SQL_SIMPLIFIER_ENABLED.value ] = session.sql_simplifier_enabled @@ -212,17 +209,17 @@ def wrap(*args, **kwargs): def dfw_collect_api_telemetry(func): @functools.wraps(func) def wrap(*args, **kwargs): - plan = args[0]._dataframe._select_statement or args[0]._dataframe._plan session = args[0]._dataframe._session with session.query_history() as query_history: try: result = func(*args, **kwargs) finally: if not session._collect_snowflake_plan_telemetry_at_critical_path: - args[0].session._conn._telemetry_client.send_plan_metrics_telemetry( + session._conn._telemetry_client.send_plan_metrics_telemetry( session_id=session.session_id, - data=get_plan_telemetry_metrics(plan), + data=get_plan_telemetry_metrics(args[0]._dataframe._plan), ) + plan = args[0]._dataframe._select_statement or args[0]._dataframe._plan api_calls = [ *plan.api_calls, {TelemetryField.NAME.value: f"DataFrameWriter.{func.__name__}"}, @@ -294,11 +291,10 @@ def wrap(*args, **kwargs): return wrap -def get_plan_telemetry_metrics( - plan: Union[SnowflakePlan, Selectable] -) -> Dict[str, Any]: - data = {CompilationStageTelemetryField.PLAN_UUID.value: plan.uuid} +def get_plan_telemetry_metrics(plan: SnowflakePlan) -> Dict[str, Any]: + data = {} try: + data[CompilationStageTelemetryField.PLAN_UUID.value] = plan.uuid # plan state plan_state = plan.plan_state data[CompilationStageTelemetryField.QUERY_PLAN_HEIGHT.value] = plan_state[ diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index df8dd2b3c80..27fef120cec 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -1117,6 +1117,7 @@ def send_telemetry(): expected_data = { "session_id": session.session_id, "plan_uuid": uuid_str, + "category": "query_compilation_stage_statistics", "cte_optimization_enabled": True, "large_query_breakdown_enabled": True, "complexity_score_bounds": (300, 600), @@ -1228,6 +1229,7 @@ def test_plan_metrics_telemetry(session): client = session._conn._telemetry_client telemetry_data = { "plan_uuid": "plan_uuid_placeholder", + "category": "snowflake_plan_metrics", "query_plan_height": 10, "query_plan_num_duplicate_nodes": 5, "query_plan_num_selects_with_complexity_merged": 3,