diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 8a1cf5bd5..032701f63 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -165,8 +165,9 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: cls._initialize() assert cls._executor is not None - # Use the unique session ID as the key - key = connection.get_session_id_hex() + # Cache at HOST level - share feature flags across connections to same host + # Feature flags are per-host, not per-session + key = connection.session.host if key not in cls._context_map: cls._context_map[key] = FeatureFlagsContext( connection, cls._executor, connection.session.http_client @@ -177,7 +178,8 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: def remove_instance(cls, connection: "Connection"): """Removes the context for a given connection and shuts down the executor if no clients remain.""" with cls._lock: - key = connection.get_session_id_hex() + # Use host as key to match get_instance + key = connection.session.host if key in cls._context_map: cls._context_map.pop(key, None) diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 12cacd851..36ebee2b8 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -1,6 +1,6 @@ import time import functools -from typing import Optional +from typing import Optional, Dict, Any import logging from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.telemetry.models.event import ( @@ -11,127 +11,141 @@ logger = logging.getLogger(__name__) -class TelemetryExtractor: +def _extract_cursor_data(cursor) -> Dict[str, Any]: """ - Base class for extracting telemetry information from various object types. + Extract telemetry data directly from a Cursor object. - This class serves as a proxy that delegates attribute access to the wrapped object - while providing a common interface for extracting telemetry-related data. - """ - - def __init__(self, obj): - self._obj = obj - - def __getattr__(self, name): - return getattr(self._obj, name) - - def get_session_id_hex(self): - pass - - def get_statement_id(self): - pass - - def get_is_compressed(self): - pass - - def get_execution_result_format(self): - pass - - def get_retry_count(self): - pass - - def get_chunk_id(self): - pass + OPTIMIZATION: Uses direct attribute access instead of wrapper objects. + This eliminates object creation overhead and method call indirection. + Args: + cursor: The Cursor object to extract data from -class CursorExtractor(TelemetryExtractor): + Returns: + Dict with telemetry data (values may be None if extraction fails) """ - Telemetry extractor specialized for Cursor objects. - - Extracts telemetry information from database cursor objects, including - statement IDs, session information, compression settings, and result formats. + data = {} + + # Extract statement_id (query_id) - direct attribute access + try: + data["statement_id"] = cursor.query_id + except (AttributeError, Exception): + data["statement_id"] = None + + # Extract session_id_hex - direct method call + try: + data["session_id_hex"] = cursor.connection.get_session_id_hex() + except (AttributeError, Exception): + data["session_id_hex"] = None + + # Extract is_compressed - direct attribute access + try: + data["is_compressed"] = cursor.connection.lz4_compression + except (AttributeError, Exception): + data["is_compressed"] = False + + # Extract execution_result_format - inline logic + try: + if cursor.active_result_set is None: + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + else: + from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue + + results = cursor.active_result_set.results + if isinstance(results, ColumnQueue): + data["execution_result"] = ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(results, CloudFetchQueue): + data["execution_result"] = ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(results, ArrowQueue): + data["execution_result"] = ExecutionResultFormat.INLINE_ARROW + else: + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + except (AttributeError, Exception): + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + + # Extract retry_count - direct attribute access + try: + if hasattr(cursor.backend, "retry_policy") and cursor.backend.retry_policy: + data["retry_count"] = len(cursor.backend.retry_policy.history) + else: + data["retry_count"] = 0 + except (AttributeError, Exception): + data["retry_count"] = 0 + + # chunk_id is always None for Cursor + data["chunk_id"] = None + + return data + + +def _extract_result_set_handler_data(handler) -> Dict[str, Any]: """ + Extract telemetry data directly from a ResultSetDownloadHandler object. - def get_statement_id(self) -> Optional[str]: - return self.query_id - - def get_session_id_hex(self) -> Optional[str]: - return self.connection.get_session_id_hex() - - def get_is_compressed(self) -> bool: - return self.connection.lz4_compression - - def get_execution_result_format(self) -> ExecutionResultFormat: - if self.active_result_set is None: - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue - - if isinstance(self.active_result_set.results, ColumnQueue): - return ExecutionResultFormat.COLUMNAR_INLINE - elif isinstance(self.active_result_set.results, CloudFetchQueue): - return ExecutionResultFormat.EXTERNAL_LINKS - elif isinstance(self.active_result_set.results, ArrowQueue): - return ExecutionResultFormat.INLINE_ARROW - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - def get_retry_count(self) -> int: - if hasattr(self.backend, "retry_policy") and self.backend.retry_policy: - return len(self.backend.retry_policy.history) - return 0 - - def get_chunk_id(self): - return None + OPTIMIZATION: Uses direct attribute access instead of wrapper objects. + Args: + handler: The ResultSetDownloadHandler object to extract data from -class ResultSetDownloadHandlerExtractor(TelemetryExtractor): - """ - Telemetry extractor specialized for ResultSetDownloadHandler objects. + Returns: + Dict with telemetry data (values may be None if extraction fails) """ + data = {} - def get_session_id_hex(self) -> Optional[str]: - return self._obj.session_id_hex + # Extract session_id_hex - direct attribute access + try: + data["session_id_hex"] = handler.session_id_hex + except (AttributeError, Exception): + data["session_id_hex"] = None - def get_statement_id(self) -> Optional[str]: - return self._obj.statement_id + # Extract statement_id - direct attribute access + try: + data["statement_id"] = handler.statement_id + except (AttributeError, Exception): + data["statement_id"] = None - def get_is_compressed(self) -> bool: - return self._obj.settings.is_lz4_compressed + # Extract is_compressed - direct attribute access + try: + data["is_compressed"] = handler.settings.is_lz4_compressed + except (AttributeError, Exception): + data["is_compressed"] = False - def get_execution_result_format(self) -> ExecutionResultFormat: - return ExecutionResultFormat.EXTERNAL_LINKS + # execution_result is always EXTERNAL_LINKS for result set handlers + data["execution_result"] = ExecutionResultFormat.EXTERNAL_LINKS - def get_retry_count(self) -> Optional[int]: - # standard requests and urllib3 libraries don't expose retry count - return None + # retry_count is not available for result set handlers + data["retry_count"] = None + + # Extract chunk_id - direct attribute access + try: + data["chunk_id"] = handler.chunk_id + except (AttributeError, Exception): + data["chunk_id"] = None - def get_chunk_id(self) -> Optional[int]: - return self._obj.chunk_id + return data -def get_extractor(obj): +def _extract_telemetry_data(obj) -> Optional[Dict[str, Any]]: """ - Factory function to create the appropriate telemetry extractor for an object. + Extract telemetry data from an object based on its type. - Determines the object type and returns the corresponding specialized extractor - that can extract telemetry information from that object type. + OPTIMIZATION: Returns a simple dict instead of creating wrapper objects. + This dict will be used to create the SqlExecutionEvent in the background thread. Args: - obj: The object to create an extractor for. Can be a Cursor, - ResultSetDownloadHandler, or any other object. + obj: The object to extract data from (Cursor, ResultSetDownloadHandler, etc.) Returns: - TelemetryExtractor: A specialized extractor instance: - - CursorExtractor for Cursor objects - - ResultSetDownloadHandlerExtractor for ResultSetDownloadHandler objects - - None for all other objects + Dict with telemetry data, or None if object type is not supported """ - if obj.__class__.__name__ == "Cursor": - return CursorExtractor(obj) - elif obj.__class__.__name__ == "ResultSetDownloadHandler": - return ResultSetDownloadHandlerExtractor(obj) + obj_type = obj.__class__.__name__ + + if obj_type == "Cursor": + return _extract_cursor_data(obj) + elif obj_type == "ResultSetDownloadHandler": + return _extract_result_set_handler_data(obj) else: - logger.debug("No extractor found for %s", obj.__class__.__name__) + logger.debug("No telemetry extraction available for %s", obj_type) return None @@ -143,12 +157,6 @@ def log_latency(statement_type: StatementType = StatementType.NONE): data about the operation, including latency, statement information, and execution context. - The decorator automatically: - - Measures execution time using high-precision performance counters - - Extracts telemetry information from the method's object (self) - - Creates a SqlExecutionEvent with execution details - - Sends the telemetry data asynchronously via TelemetryClient - Args: statement_type (StatementType): The type of SQL statement being executed. @@ -162,54 +170,49 @@ def execute(self, query): function: A decorator that wraps methods to add latency logging. Note: - The wrapped method's object (self) must be compatible with the - telemetry extractor system (e.g., Cursor or ResultSet objects). + The wrapped method's object (self) must be a Cursor or + ResultSetDownloadHandler for telemetry data extraction. """ def decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): - start_time = time.perf_counter() - result = None + start_time = time.monotonic() try: - result = func(self, *args, **kwargs) - return result + return func(self, *args, **kwargs) finally: - - def _safe_call(func_to_call): - """Calls a function and returns a default value on any exception.""" - try: - return func_to_call() - except Exception: - return None - - end_time = time.perf_counter() - duration_ms = int((end_time - start_time) * 1000) - - extractor = get_extractor(self) - - if extractor is not None: - session_id_hex = _safe_call(extractor.get_session_id_hex) - statement_id = _safe_call(extractor.get_statement_id) - - sql_exec_event = SqlExecutionEvent( - statement_type=statement_type, - is_compressed=_safe_call(extractor.get_is_compressed), - execution_result=_safe_call( - extractor.get_execution_result_format - ), - retry_count=_safe_call(extractor.get_retry_count), - chunk_id=_safe_call(extractor.get_chunk_id), - ) - - telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex - ) - telemetry_client.export_latency_log( - latency_ms=duration_ms, - sql_execution_event=sql_exec_event, - sql_statement_id=statement_id, - ) + duration_ms = int((time.monotonic() - start_time) * 1000) + + # Always log for debugging + logger.debug("%s completed in %dms", func.__name__, duration_ms) + + # Fast check: use cached telemetry_enabled flag from connection + # Avoids dictionary lookup + instance check on every operation + connection = getattr(self, "connection", None) + if connection and getattr(connection, "telemetry_enabled", False): + session_id_hex = connection.get_session_id_hex() + if session_id_hex: + # Telemetry enabled - extract and send + telemetry_data = _extract_telemetry_data(self) + if telemetry_data: + sql_exec_event = SqlExecutionEvent( + statement_type=statement_type, + is_compressed=telemetry_data.get("is_compressed"), + execution_result=telemetry_data.get("execution_result"), + retry_count=telemetry_data.get("retry_count"), + chunk_id=telemetry_data.get("chunk_id"), + ) + + telemetry_client = ( + TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + ) + telemetry_client.export_latency_log( + latency_ms=duration_ms, + sql_execution_event=sql_exec_event, + sql_statement_id=telemetry_data.get("statement_id"), + ) return wrapper diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 177d5445c..d5f5b575c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -2,6 +2,7 @@ import time import logging import json +from queue import Queue, Full from concurrent.futures import ThreadPoolExecutor from concurrent.futures import Future from datetime import datetime, timezone @@ -114,18 +115,21 @@ def get_auth_flow(auth_provider): @staticmethod def is_telemetry_enabled(connection: "Connection") -> bool: + # Fast path: force enabled - skip feature flag fetch entirely if connection.force_enable_telemetry: return True - if connection.enable_telemetry: - context = FeatureFlagsContextFactory.get_instance(connection) - flag_value = context.get_flag_value( - TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False - ) - return str(flag_value).lower() == "true" - else: + # Fast path: disabled - no need to check feature flag + if not connection.enable_telemetry: return False + # Only fetch feature flags when enable_telemetry=True and not forced + context = FeatureFlagsContextFactory.get_instance(connection) + flag_value = context.get_flag_value( + TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False + ) + return str(flag_value).lower() == "true" + class NoopTelemetryClient(BaseTelemetryClient): """ @@ -185,8 +189,11 @@ def __init__( self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None - self._events_batch: list = [] - self._lock = threading.RLock() + + # OPTIMIZATION: Use lock-free Queue instead of list + lock + # Queue is thread-safe internally and has better performance under concurrency + self._events_queue: Queue[TelemetryFrontendLog] = Queue(maxsize=batch_size * 2) + self._driver_connection_params = None self._host_url = host_url self._executor = executor @@ -196,7 +203,8 @@ def __init__( # Create telemetry push client based on circuit breaker enabled flag if client_context.telemetry_circuit_breaker_enabled: - # Create circuit breaker telemetry push client (circuit breakers created on-demand) + # Create circuit breaker telemetry push client + # (circuit breakers created on-demand) self._telemetry_push_client: ITelemetryPushClient = ( CircuitBreakerTelemetryPushClient( TelemetryPushClient(self._http_client), @@ -210,9 +218,24 @@ def __init__( def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) - with self._lock: - self._events_batch.append(event) - if len(self._events_batch) >= self._batch_size: + + # OPTIMIZATION: Use non-blocking put with queue + # No explicit lock needed - Queue is thread-safe internally + try: + self._events_queue.put_nowait(event) + except Full: + # Queue is full, trigger immediate flush + logger.debug("Event queue full, triggering flush") + self._flush() + # Try again after flush + try: + self._events_queue.put_nowait(event) + except Full: + # Still full, drop event (acceptable for telemetry) + logger.debug("Dropped telemetry event - queue still full") + + # Check if we should flush based on queue size + if self._events_queue.qsize() >= self._batch_size: logger.debug( "Batch size limit reached (%s), flushing events", self._batch_size ) @@ -220,9 +243,16 @@ def _export_event(self, event): def _flush(self): """Flush the current batch of events to the server""" - with self._lock: - events_to_flush = self._events_batch.copy() - self._events_batch = [] + # OPTIMIZATION: Drain queue without locks + # Collect all events currently in the queue + events_to_flush = [] + while not self._events_queue.empty(): + try: + event = self._events_queue.get_nowait() + events_to_flush.append(event) + except: + # Queue is empty + break if events_to_flush: logger.debug("Flushing %s telemetry events to server", len(events_to_flush)) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 6f5a01c7b..96a2f87d8 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -10,6 +10,10 @@ TelemetryClientFactory, TelemetryHelper, ) +from databricks.sql.common.feature_flag import ( + FeatureFlagsContextFactory, + FeatureFlagsContext, +) from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow, DatabricksClientType from databricks.sql.telemetry.models.event import ( TelemetryEvent, @@ -82,12 +86,12 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): client._export_event("event1") client._export_event("event2") mock_send.assert_not_called() - assert len(client._events_batch) == 2 + assert client._events_queue.qsize() == 2 # Third event should trigger flush client._export_event("event3") mock_send.assert_called_once() - assert len(client._events_batch) == 0 # Batch cleared after flush + assert client._events_queue.qsize() == 0 # Queue cleared after flush @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_network_request_flow(self, mock_http_request, mock_telemetry_client): @@ -817,7 +821,67 @@ def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_sess mock_export.assert_called_once() driver_params = mock_export.call_args.kwargs.get("driver_connection_params") - + # CF proxy not yet supported - should be False/None assert driver_params.use_cf_proxy is False assert driver_params.cf_proxy_host_info is None + + +class TestFeatureFlagsContextFactory: + """Tests for FeatureFlagsContextFactory host-level caching.""" + + @pytest.fixture(autouse=True) + def reset_factory(self): + """Reset factory state before/after each test.""" + FeatureFlagsContextFactory._context_map.clear() + if FeatureFlagsContextFactory._executor: + FeatureFlagsContextFactory._executor.shutdown(wait=False) + FeatureFlagsContextFactory._executor = None + yield + FeatureFlagsContextFactory._context_map.clear() + if FeatureFlagsContextFactory._executor: + FeatureFlagsContextFactory._executor.shutdown(wait=False) + FeatureFlagsContextFactory._executor = None + + @pytest.mark.parametrize( + "hosts,expected_contexts", + [ + (["host1.com", "host1.com"], 1), # Same host shares context + (["host1.com", "host2.com"], 2), # Different hosts get separate contexts + (["host1.com", "host1.com", "host2.com"], 2), # Mixed scenario + ], + ) + def test_host_level_caching(self, hosts, expected_contexts): + """Test that contexts are cached by host correctly.""" + contexts = [] + for host in hosts: + conn = MagicMock() + conn.session.host = host + conn.session.http_client = MagicMock() + contexts.append(FeatureFlagsContextFactory.get_instance(conn)) + + assert len(FeatureFlagsContextFactory._context_map) == expected_contexts + if expected_contexts == 1: + assert all(ctx is contexts[0] for ctx in contexts) + + def test_remove_instance_and_executor_cleanup(self): + """Test removal uses host key and cleans up executor when empty.""" + conn1 = MagicMock() + conn1.session.host = "host1.com" + conn1.session.http_client = MagicMock() + + conn2 = MagicMock() + conn2.session.host = "host2.com" + conn2.session.http_client = MagicMock() + + FeatureFlagsContextFactory.get_instance(conn1) + FeatureFlagsContextFactory.get_instance(conn2) + assert FeatureFlagsContextFactory._executor is not None + + FeatureFlagsContextFactory.remove_instance(conn1) + assert len(FeatureFlagsContextFactory._context_map) == 1 + assert FeatureFlagsContextFactory._executor is not None + + FeatureFlagsContextFactory.remove_instance(conn2) + assert len(FeatureFlagsContextFactory._context_map) == 0 + assert FeatureFlagsContextFactory._executor is None