Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1906607: Fix telemetry collection for all SnowflakePlan #2967

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class CompilationStageTelemetryField(Enum):
"snowpark_large_query_breakdown_update_complexity_bounds"
)

# categories
CAT_COMPILATION_STAGE_STATS = "query_compilation_stage_statistics"
CAT_COMPILATION_STAGE_ERROR = "query_compilation_stage_error"
CAT_SNOWFLAKE_PLAN_METRICS = "snowflake_plan_metrics"

# keys
KEY_REASON = "reason"
PLAN_UUID = "plan_uuid"
Expand Down
11 changes: 10 additions & 1 deletion src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@
)
from snowflake.snowpark._internal.ast.utils import DATAFRAME_AST_PARAMETER
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import TelemetryClient
from snowflake.snowpark._internal.telemetry import (
TelemetryClient,
get_plan_telemetry_metrics,
)
from snowflake.snowpark._internal.utils import (
create_rlock,
create_thread_local,
Expand Down Expand Up @@ -777,6 +780,12 @@ def get_result_set(
**kwargs,
)

if plan.session._collect_snowflake_plan_telemetry_at_critical_path:
self._telemetry_client.send_plan_metrics_telemetry(
session_id=self.get_session_id(),
data=get_plan_telemetry_metrics(plan),
)

if result is None:
raise SnowparkClientExceptionMessages.SQL_LAST_QUERY_RETURN_RESULTSET()

Expand Down
121 changes: 80 additions & 41 deletions src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
TelemetryField as PCTelemetryField,
)
from snowflake.connector.time_util import get_time_millis
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import PlanState
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanState,
get_complexity_score,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan
from snowflake.snowpark._internal.compiler.telemetry_constants import (
CompilationStageTelemetryField,
)
Expand Down Expand Up @@ -170,48 +174,28 @@ def wrap(*args, **kwargs):
def df_collect_api_telemetry(func):
@functools.wraps(func)
def wrap(*args, **kwargs):
with args[0]._session.query_history() as query_history:
result = func(*args, **kwargs)
session = args[0]._session
with session.query_history() as query_history:
try:
result = func(*args, **kwargs)
finally:
if not session._collect_snowflake_plan_telemetry_at_critical_path:
session._conn._telemetry_client.send_plan_metrics_telemetry(
session_id=session.session_id,
data=get_plan_telemetry_metrics(args[0]._plan),
)
plan = args[0]._select_statement or args[0]._plan
api_calls = [
*plan.api_calls,
{TelemetryField.NAME.value: f"DataFrame.{func.__name__}"},
]
# The first api call will indicate following:
# - sql simplifier is enabled.
# - height of the query plan
# - number of unique duplicate subtrees in the query plan
api_calls[0][TelemetryField.SQL_SIMPLIFIER_ENABLED.value] = args[
0
]._session.sql_simplifier_enabled
try:
plan_state = plan.plan_state
api_calls[0][
CompilationStageTelemetryField.QUERY_PLAN_HEIGHT.value
] = plan_state[PlanState.PLAN_HEIGHT]
api_calls[0][
CompilationStageTelemetryField.QUERY_PLAN_NUM_SELECTS_WITH_COMPLEXITY_MERGED.value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-aalam what is the kind of dicrepancy you are referring to in the slack thread https://snowflake.slack.com/archives/C03MJ5AA8CS/p1738347392255399?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our telemetry collected for job etl logs, we have sql statements submitted by snowpark client which have failed with error codes we are interested in, but the corresponding telemetry for the same plan uuid does not exist in client telemtery being sent from this part of code. This is a decorator -df_collect_api_telemetry is only put on select few apis and is not applied on all functions. As a result, will miss some important cases like dataframe.write.save_as_table.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, but do we think that any of the existing logic in the decorator may also be relevant to not-currently-decorated functions like save_as_table?

] = plan_state[PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED]
api_calls[0][
CompilationStageTelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value
] = plan_state[PlanState.NUM_CTE_NODES]
api_calls[0][
CompilationStageTelemetryField.QUERY_PLAN_DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION.value
] = plan_state[PlanState.DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION]

# The uuid for df._select_statement can be different from df._plan. Since plan
# can take both values, we cannot use plan.uuid. We always use df._plan.uuid
# to track the queries.
uuid = args[0]._plan.uuid
api_calls[0][CompilationStageTelemetryField.PLAN_UUID.value] = uuid
api_calls[0][CompilationStageTelemetryField.QUERY_PLAN_COMPLEXITY.value] = {
key.value: value
for key, value in plan.cumulative_node_complexity.items()
}
api_calls[0][TelemetryField.THREAD_IDENTIFIER.value] = threading.get_ident()
except Exception:
pass
args[0]._session._conn._telemetry_client.send_function_usage_telemetry(
api_calls[0][
TelemetryField.SQL_SIMPLIFIER_ENABLED.value
] = session.sql_simplifier_enabled
api_calls[0][TelemetryField.THREAD_IDENTIFIER.value] = threading.get_ident()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would threading.get_ident() throw an exception under any case? if yes, let's keep the try catch part

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is method of in built python module threading which is unlikely to fail under. Only issue I could find relating to this throwing an exception is this: python/cpython#128189 which looks like it was more of a user error than a library error.

session._conn._telemetry_client.send_function_usage_telemetry(
f"action_{func.__name__}",
TelemetryField.FUNC_CAT_ACTION.value,
api_calls=api_calls,
Expand All @@ -225,16 +209,22 @@ def wrap(*args, **kwargs):
def dfw_collect_api_telemetry(func):
@functools.wraps(func)
def wrap(*args, **kwargs):
with args[0]._dataframe._session.query_history() as query_history:
result = func(*args, **kwargs)
session = args[0]._dataframe._session
with session.query_history() as query_history:
try:
result = func(*args, **kwargs)
finally:
if not session._collect_snowflake_plan_telemetry_at_critical_path:
session._conn._telemetry_client.send_plan_metrics_telemetry(
session_id=session.session_id,
data=get_plan_telemetry_metrics(args[0]._dataframe._plan),
)
plan = args[0]._dataframe._select_statement or args[0]._dataframe._plan
api_calls = [
*plan.api_calls,
{TelemetryField.NAME.value: f"DataFrameWriter.{func.__name__}"},
]
args[
0
]._dataframe._session._conn._telemetry_client.send_function_usage_telemetry(
session._conn._telemetry_client.send_function_usage_telemetry(
f"action_{func.__name__}",
TelemetryField.FUNC_CAT_ACTION.value,
api_calls=api_calls,
Expand Down Expand Up @@ -301,6 +291,38 @@ def wrap(*args, **kwargs):
return wrap


def get_plan_telemetry_metrics(plan: SnowflakePlan) -> Dict[str, Any]:
data = {}
try:
data[CompilationStageTelemetryField.PLAN_UUID.value] = plan.uuid
# plan state
plan_state = plan.plan_state
data[CompilationStageTelemetryField.QUERY_PLAN_HEIGHT.value] = plan_state[
PlanState.PLAN_HEIGHT
]
data[
CompilationStageTelemetryField.QUERY_PLAN_NUM_SELECTS_WITH_COMPLEXITY_MERGED.value
] = plan_state[PlanState.NUM_SELECTS_WITH_COMPLEXITY_MERGED]
data[
CompilationStageTelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value
] = plan_state[PlanState.NUM_CTE_NODES]
data[
CompilationStageTelemetryField.QUERY_PLAN_DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION.value
] = plan_state[PlanState.DUPLICATED_NODE_COMPLEXITY_DISTRIBUTION]

# plan complexity score
data[CompilationStageTelemetryField.QUERY_PLAN_COMPLEXITY.value] = {
key.value: value for key, value in plan.cumulative_node_complexity.items()
}
data[
CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value
] = get_complexity_score(plan)
except Exception as e:
data[CompilationStageTelemetryField.ERROR_MESSAGE.value] = str(e)

return data


class TelemetryClient:
def __init__(self, conn: SnowflakeConnection) -> None:
self.telemetry: PCTelemetryClient = (
Expand Down Expand Up @@ -479,6 +501,7 @@ def send_query_compilation_summary_telemetry(
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.KEY_CATEGORY.value: CompilationStageTelemetryField.CAT_COMPILATION_STAGE_STATS.value,
CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid,
**compilation_stage_summary,
},
Expand All @@ -494,13 +517,29 @@ def send_query_compilation_stage_failed_telemetry(
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.KEY_CATEGORY.value: CompilationStageTelemetryField.CAT_COMPILATION_STAGE_ERROR.value,
CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid,
CompilationStageTelemetryField.ERROR_TYPE.value: error_type,
CompilationStageTelemetryField.ERROR_MESSAGE.value: error_message,
},
}
self.send(message)

def send_plan_metrics_telemetry(
self, session_id: int, data: Dict[str, Any]
) -> None:
message = {
**self._create_basic_telemetry_data(
CompilationStageTelemetryField.TYPE_COMPILATION_STAGE_STATISTICS.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.KEY_CATEGORY.value: CompilationStageTelemetryField.CAT_SNOWFLAKE_PLAN_METRICS.value,
**data,
},
}
self.send(message)

def send_temp_table_cleanup_telemetry(
self,
session_id: str,
Expand Down
9 changes: 9 additions & 0 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@
_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION = (
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION"
)
# Flag to control sending snowflake plan telemetry data from get_result_set
_PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION = (
"PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION"
)
# Flag for controlling the usage of scoped temp read only table.
_PYTHON_SNOWPARK_ENABLE_SCOPED_TEMP_READ_ONLY_TABLE = (
"PYTHON_SNOWPARK_ENABLE_SCOPED_TEMP_READ_ONLY_TABLE"
Expand Down Expand Up @@ -662,6 +666,11 @@ def __init__(
self._plan_lock = create_rlock(self._conn._thread_safe_session_enabled)

self._custom_package_usage_config: Dict = {}
self._collect_snowflake_plan_telemetry_at_critical_path: bool = (
self.is_feature_enabled_for_version(
_PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION
)
)
self._conf = self.RuntimeConfig(self, options or {})
self._runtime_version_from_requirement: str = None
self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self)
Expand Down
Loading
Loading