diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 9d285b75e9f..081c8a60f25 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -247,11 +247,16 @@ def get_result_attributes(self, query: str) -> List[Attribute]: self._run_new_describe(self._cursor, query), self.max_string_size ) - def _run_new_describe(self, cursor: SnowflakeCursor, query: str) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]: + def _run_new_describe( + self, cursor: SnowflakeCursor, query: str + ) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]: result_metadata = run_new_describe(cursor, query) - describe_listeners = [listener for listener in self._query_listener if hasattr(listener, 'describe_listener')] - for listener in describe_listeners: + for listener in filter( + lambda listener: hasattr(listener, "include_describe") + and listener.include_describe, + self._query_listener, + ): listener._add_query(QueryRecord(cursor.sfqid, query, True)) return result_metadata diff --git a/src/snowflake/snowpark/query_history.py b/src/snowflake/snowpark/query_history.py index 73a0f9b09ae..69a3e3aa48d 100644 --- a/src/snowflake/snowpark/query_history.py +++ b/src/snowflake/snowpark/query_history.py @@ -22,9 +22,14 @@ class QueryHistory: :meth:`snowflake.snowpark.Session.query_history`. """ - def __init__(self, session: "snowflake.snowpark.session.Session") -> None: + def __init__( + self, + session: "snowflake.snowpark.session.Session", + include_describe: bool = False, + ) -> None: self.session = session self._queries: List[QueryRecord] = [] + self._include_describe = include_describe def __enter__(self): return self @@ -38,3 +43,7 @@ def _add_query(self, query_record: QueryRecord): @property def queries(self) -> List[QueryRecord]: return self._queries + + @property + def include_describe(self) -> bool: + return self._include_describe diff --git a/tests/integ/modin/sql_counter.py b/tests/integ/modin/sql_counter.py index 3734b5b4ba8..6c254ddc5be 100644 --- a/tests/integ/modin/sql_counter.py +++ b/tests/integ/modin/sql_counter.py @@ -158,6 +158,11 @@ def __init__( # Add SqlCounter as a snowpark query listener. self.session._conn.add_query_listener(self) + # The query history listener will include describe queries if this is true. + @property + def include_describe(self) -> bool: + return True + @staticmethod def set_record_mode(record_mode): """Record mode means the SqlCounter does not assert any results, but rather collects them so they can @@ -187,10 +192,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): def _add_query(self, query_record: QueryRecord): self._queries.append(query_record) - # This attribute signals we also want to collect describe queries. - def describe_listener(self): - pass - def expects(self, **kwargs): """ Validate expectation of sql counts. We avoid using asserts because we do not want to interrupt the @@ -274,8 +275,12 @@ def _get_actual_queries(self): for fw in FILTER_OUT_QUERIES ] ), - list(map(lambda q: q.sql_text, - [q for q in self._queries if not q.is_describe])), + list( + map( + lambda q: q.sql_text, + [q for q in self._queries if not q.is_describe], + ) + ), ) )