From 31fe259741c458027c72a4e13250e0b6213ec021 Mon Sep 17 00:00:00 2001 From: wingding12 Date: Tue, 27 Jan 2026 11:39:19 -0500 Subject: [PATCH] feat: add LiteLLM callback handler for all providers Fixes #1028, #1079 Implements a LiteLLM callback handler that properly tracks LLM calls across all LiteLLM-supported providers including OpenAI and Anthropic. Root Cause Analysis: - #1079: Anthropic models weren't being tracked because there was no LiteLLM callback handler - only direct provider instrumentation - #1028: The Responses API wasn't supported because there was no callback handler to intercept litellm.responses() calls Solution: - Created LiteLLMCallbackHandler extending litellm's CustomLogger - Properly extracts provider from model strings (anthropic/, openai/, etc) - Handles both completion API (choices) and responses API (output) formats - Records token usage with input_tokens/output_tokens mapping - Creates proper spans with gen_ai.* semantic conventions Features: - Works with litellm.callbacks = [handler] - Supports all LiteLLM providers (OpenAI, Anthropic, Google, Mistral, etc) - Handles both sync and async API calls - Records prompt/completion tokens and cost - Properly nests under session spans Files Changed: - agentops/integration/callbacks/litellm/__init__.py (new) - agentops/integration/callbacks/litellm/callback.py (new) - agentops/__init__.py (export LiteLLMCallbackHandler) - tests/unit/integration/callbacks/litellm/test_litellm_callback.py (new) --- agentops/__init__.py | 5 + .../integration/callbacks/litellm/__init__.py | 9 + .../integration/callbacks/litellm/callback.py | 409 ++++++++++++++++++ tests/unit/integration/__init__.py | 0 tests/unit/integration/callbacks/__init__.py | 0 .../integration/callbacks/litellm/__init__.py | 1 + .../litellm/test_litellm_callback.py | 264 +++++++++++ 7 files changed, 688 insertions(+) create mode 100644 agentops/integration/callbacks/litellm/__init__.py create mode 100644 agentops/integration/callbacks/litellm/callback.py create mode 100644 tests/unit/integration/__init__.py create mode 100644 tests/unit/integration/callbacks/__init__.py create mode 100644 tests/unit/integration/callbacks/litellm/__init__.py create mode 100644 tests/unit/integration/callbacks/litellm/test_litellm_callback.py diff --git a/agentops/__init__.py b/agentops/__init__.py index 816e77443..449f9a004 100755 --- a/agentops/__init__.py +++ b/agentops/__init__.py @@ -37,6 +37,9 @@ # Import validation functions from agentops.validation import validate_trace_spans, print_validation_summary, ValidationError +# Import callback handlers for external integrations +from agentops.integration.callbacks.litellm import LiteLLMCallbackHandler + # Thread-safe client management _client_lock = threading.Lock() _client = None @@ -485,4 +488,6 @@ def extract_key_from_attr(attr_value: str) -> str: "validate_trace_spans", "print_validation_summary", "ValidationError", + # Callback handlers for external integrations + "LiteLLMCallbackHandler", ] diff --git a/agentops/integration/callbacks/litellm/__init__.py b/agentops/integration/callbacks/litellm/__init__.py new file mode 100644 index 000000000..32b239198 --- /dev/null +++ b/agentops/integration/callbacks/litellm/__init__.py @@ -0,0 +1,9 @@ +""" +LiteLLM callback handler for AgentOps. + +This module provides the LiteLLM callback handler for AgentOps tracing and monitoring. +""" + +from agentops.integration.callbacks.litellm.callback import LiteLLMCallbackHandler + +__all__ = ["LiteLLMCallbackHandler"] diff --git a/agentops/integration/callbacks/litellm/callback.py b/agentops/integration/callbacks/litellm/callback.py new file mode 100644 index 000000000..4164f048a --- /dev/null +++ b/agentops/integration/callbacks/litellm/callback.py @@ -0,0 +1,409 @@ +""" +LiteLLM callback handler for AgentOps. + +This module provides the LiteLLM callback handler for AgentOps tracing and monitoring. +It handles both completion and responses API calls across all LiteLLM-supported providers +including OpenAI, Anthropic, and others. + +Usage: + from agentops.integration.callbacks.litellm import LiteLLMCallbackHandler + import litellm + + handler = LiteLLMCallbackHandler(api_key="your-api-key") + litellm.callbacks = [handler] + + # Or use the string-based callback registration (after importing agentops) + import agentops + agentops.init() + litellm.success_callback = ["agentops"] +""" + +from typing import Any, Dict, List, Optional, Union +from datetime import datetime + +from opentelemetry import trace +from opentelemetry.context import attach, detach +from opentelemetry.trace import set_span_in_context +from opentelemetry.sdk.trace import Span as SDKSpan + +from agentops.helpers.serialization import safe_serialize +from agentops.logging import logger +from agentops.sdk.core import tracer +from agentops.semconv import SpanKind, SpanAttributes, AgentOpsSpanKindValues + +try: + from litellm.integrations.custom_logger import CustomLogger +except ImportError: + # Create a stub class if litellm is not installed + class CustomLogger: + pass + + +class LiteLLMCallbackHandler(CustomLogger): + """ + AgentOps callback handler for LiteLLM. + + This handler creates spans for LLM calls made through LiteLLM, supporting + all providers (OpenAI, Anthropic, etc.) including both completion and + responses API calls. + + Args: + api_key (str, optional): AgentOps API key + tags (List[str], optional): Tags to add to the session + auto_session (bool, optional): Whether to automatically create a session span + """ + + def __init__( + self, + api_key: Optional[str] = None, + tags: Optional[List[str]] = None, + auto_session: bool = True, + ): + """Initialize the callback handler.""" + super().__init__() + self.active_spans: Dict[str, SDKSpan] = {} + self.api_key = api_key + self.tags = tags or [] + self.session_span = None + self.session_token = None + self.context_tokens: Dict[str, Any] = {} + + # Initialize AgentOps + if auto_session: + self._initialize_agentops() + + def _initialize_agentops(self): + """Initialize AgentOps if not already initialized.""" + import agentops + + if not tracer.initialized: + init_kwargs = { + "auto_start_session": False, + "instrument_llm_calls": False, # We handle LLM calls via callback + } + + if self.api_key: + init_kwargs["api_key"] = self.api_key + + agentops.init(**init_kwargs) + logger.debug("AgentOps initialized from LiteLLM callback handler") + + if not tracer.initialized: + logger.warning("AgentOps not initialized, session span will not be created") + return + + otel_tracer = tracer.get_tracer() + + span_name = f"session.{SpanKind.SESSION}" + + attributes = { + SpanAttributes.AGENTOPS_SPAN_KIND: SpanKind.SESSION, + "session.tags": self.tags, + "agentops.operation.name": "session", + "span.kind": SpanKind.SESSION, + } + + # Create a root session span + self.session_span = otel_tracer.start_span(span_name, attributes=attributes) + + # Attach session span to the current context + self.session_token = attach(set_span_in_context(self.session_span)) + + logger.debug("Created session span as root span for LiteLLM") + + def _get_call_id(self, kwargs: Dict[str, Any]) -> str: + """Generate a unique call ID for tracking spans.""" + litellm_call_id = kwargs.get("litellm_call_id", "") + if litellm_call_id: + return str(litellm_call_id) + # Fallback to a combination of model and timestamp + model = kwargs.get("model", "unknown") + return f"{model}_{id(kwargs)}" + + def _extract_provider(self, model: str) -> str: + """Extract the provider from the model string.""" + if "/" in model: + return model.split("/")[0] + # Default provider mapping based on common model names + model_lower = model.lower() + if "claude" in model_lower: + return "anthropic" + elif "gpt" in model_lower or "o1" in model_lower: + return "openai" + elif "gemini" in model_lower: + return "google" + elif "mistral" in model_lower or "mixtral" in model_lower: + return "mistral" + elif "command" in model_lower: + return "cohere" + return "unknown" + + def _create_span( + self, + call_id: str, + model: str, + messages: Optional[List[Dict[str, Any]]] = None, + kwargs: Optional[Dict[str, Any]] = None, + ) -> Optional[SDKSpan]: + """Create a span for an LLM call.""" + if not tracer.initialized: + logger.debug("Tracer not initialized, skipping span creation") + return None + + otel_tracer = tracer.get_tracer() + provider = self._extract_provider(model) + + span_name = f"litellm.{provider}.{model.replace('/', '_')}" + + attributes: Dict[str, Any] = { + SpanAttributes.AGENTOPS_SPAN_KIND: AgentOpsSpanKindValues.LLM.value, + "agentops.operation.name": "llm_call", + "gen_ai.system": provider, + "gen_ai.request.model": model, + "litellm.provider": provider, + } + + # Add input messages if available + if messages: + try: + for i, msg in enumerate(messages): + if isinstance(msg, dict): + role = msg.get("role", "unknown") + content = msg.get("content", "") + attributes[f"gen_ai.prompt.{i}.role"] = str(role) + if content: + attributes[f"gen_ai.prompt.{i}.content"] = safe_serialize(content)[:1000] + except Exception as e: + logger.debug(f"Failed to extract messages: {e}") + + # Add additional kwargs + if kwargs: + if "temperature" in kwargs: + attributes["gen_ai.request.temperature"] = kwargs["temperature"] + if "max_tokens" in kwargs: + attributes["gen_ai.request.max_tokens"] = kwargs["max_tokens"] + if "top_p" in kwargs: + attributes["gen_ai.request.top_p"] = kwargs["top_p"] + + # Create span with parent context + if self.session_span: + parent_ctx = set_span_in_context(self.session_span) + span = otel_tracer.start_span(span_name, context=parent_ctx, attributes=attributes) + else: + span = otel_tracer.start_span(span_name, attributes=attributes) + + if isinstance(span, SDKSpan): + self.active_spans[call_id] = span + token = attach(set_span_in_context(span)) + self.context_tokens[call_id] = token + + return span + + def _end_span( + self, + call_id: str, + response_obj: Any = None, + kwargs: Optional[Dict[str, Any]] = None, + exception: Optional[Exception] = None, + ): + """End a span for an LLM call.""" + if call_id not in self.active_spans: + return + + span = self.active_spans.pop(call_id) + token = self.context_tokens.pop(call_id, None) + + try: + # Add response attributes + if response_obj: + # Handle different response types + if hasattr(response_obj, "model"): + span.set_attribute("gen_ai.response.model", str(response_obj.model)) + + # Extract usage information + usage = None + if hasattr(response_obj, "usage") and response_obj.usage: + usage = response_obj.usage + elif isinstance(response_obj, dict) and "usage" in response_obj: + usage = response_obj["usage"] + + if usage: + if hasattr(usage, "prompt_tokens"): + span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, usage.prompt_tokens) + elif isinstance(usage, dict) and "prompt_tokens" in usage: + span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, usage["prompt_tokens"]) + elif isinstance(usage, dict) and "input_tokens" in usage: + span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, usage["input_tokens"]) + + if hasattr(usage, "completion_tokens"): + span.set_attribute(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, usage.completion_tokens) + elif isinstance(usage, dict) and "completion_tokens" in usage: + span.set_attribute(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, usage["completion_tokens"]) + elif isinstance(usage, dict) and "output_tokens" in usage: + span.set_attribute(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, usage["output_tokens"]) + + if hasattr(usage, "total_tokens"): + span.set_attribute(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, usage.total_tokens) + elif isinstance(usage, dict) and "total_tokens" in usage: + span.set_attribute(SpanAttributes.LLM_USAGE_TOTAL_TOKENS, usage["total_tokens"]) + + # Extract completion content + choices = None + if hasattr(response_obj, "choices"): + choices = response_obj.choices + elif isinstance(response_obj, dict) and "choices" in response_obj: + choices = response_obj["choices"] + + # Handle responses API format (output instead of choices) + output = None + if hasattr(response_obj, "output"): + output = response_obj.output + elif isinstance(response_obj, dict) and "output" in response_obj: + output = response_obj["output"] + + if choices: + for i, choice in enumerate(choices): + try: + message = choice.message if hasattr(choice, "message") else choice.get("message", {}) + if message: + role = message.role if hasattr(message, "role") else message.get("role", "assistant") + content = message.content if hasattr(message, "content") else message.get("content", "") + span.set_attribute(f"gen_ai.completion.{i}.role", str(role)) + if content: + span.set_attribute(f"gen_ai.completion.{i}.content", str(content)[:1000]) + + # Handle finish reason + finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else choice.get("finish_reason") + if finish_reason: + span.set_attribute(f"gen_ai.completion.{i}.finish_reason", str(finish_reason)) + except Exception as e: + logger.debug(f"Failed to extract choice {i}: {e}") + elif output: + # Handle responses API format + for i, item in enumerate(output if isinstance(output, list) else [output]): + try: + if hasattr(item, "content"): + content = item.content + elif isinstance(item, dict): + content = item.get("content", item.get("text", "")) + else: + content = str(item) + + if content: + # Handle content that's a list of content blocks + if isinstance(content, list): + text_parts = [] + for block in content: + if isinstance(block, dict) and "text" in block: + text_parts.append(block["text"]) + elif hasattr(block, "text"): + text_parts.append(block.text) + content = " ".join(text_parts) + + span.set_attribute(f"gen_ai.completion.{i}.content", str(content)[:1000]) + except Exception as e: + logger.debug(f"Failed to extract output {i}: {e}") + + # Add cost information if available + if kwargs and "response_cost" in kwargs: + span.set_attribute("gen_ai.usage.cost", kwargs["response_cost"]) + + # Handle exception + if exception: + span.record_exception(exception) + span.set_attribute("error.type", type(exception).__name__) + span.set_attribute("error.message", str(exception)) + + except Exception as e: + logger.warning(f"Error setting span attributes: {e}") + + # Detach context and end span + if token: + detach(token) + + try: + span.end() + except Exception as e: + logger.warning(f"Error ending span: {e}") + + # LiteLLM CustomLogger interface methods + + def log_pre_api_call(self, model: str, messages: List[Dict[str, Any]], kwargs: Dict[str, Any]): + """Called before an API call is made.""" + call_id = self._get_call_id(kwargs) + self._create_span(call_id, model, messages, kwargs) + logger.debug(f"LiteLLM pre-API call: model={model}, call_id={call_id}") + + def log_post_api_call(self, kwargs: Dict[str, Any], response_obj: Any, start_time: datetime, end_time: datetime): + """Called after an API call completes (success or failure).""" + # We don't need to do anything here as we handle success/failure separately + pass + + def log_success_event(self, kwargs: Dict[str, Any], response_obj: Any, start_time: datetime, end_time: datetime): + """Called when an API call succeeds.""" + call_id = self._get_call_id(kwargs) + self._end_span(call_id, response_obj, kwargs) + logger.debug(f"LiteLLM success: call_id={call_id}") + + def log_failure_event(self, kwargs: Dict[str, Any], response_obj: Any, start_time: datetime, end_time: datetime): + """Called when an API call fails.""" + call_id = self._get_call_id(kwargs) + exception = kwargs.get("exception", None) + self._end_span(call_id, response_obj, kwargs, exception) + logger.debug(f"LiteLLM failure: call_id={call_id}") + + async def async_log_pre_api_call(self, model: str, messages: List[Dict[str, Any]], kwargs: Dict[str, Any]): + """Async version of log_pre_api_call.""" + self.log_pre_api_call(model, messages, kwargs) + + async def async_log_success_event(self, kwargs: Dict[str, Any], response_obj: Any, start_time: datetime, end_time: datetime): + """Async version of log_success_event.""" + self.log_success_event(kwargs, response_obj, start_time, end_time) + + async def async_log_failure_event(self, kwargs: Dict[str, Any], response_obj: Any, start_time: datetime, end_time: datetime): + """Async version of log_failure_event.""" + self.log_failure_event(kwargs, response_obj, start_time, end_time) + + def end_session(self): + """End the session span and clean up resources.""" + # End any remaining active spans + for call_id in list(self.active_spans.keys()): + self._end_span(call_id) + + # End session span + if self.session_span: + try: + self.session_span.end() + except Exception as e: + logger.warning(f"Error ending session span: {e}") + self.session_span = None + + # Detach session token + if self.session_token: + try: + detach(self.session_token) + except Exception as e: + logger.warning(f"Error detaching session token: {e}") + self.session_token = None + + +# Register as a string callback for litellm.success_callback = ["agentops"] +def _register_litellm_callback(): + """Register the AgentOps callback with LiteLLM's callback registry.""" + try: + import litellm + + # Check if agentops is already registered + if hasattr(litellm, "_known_custom_logger_compatible_callbacks"): + if "agentops" not in litellm._known_custom_logger_compatible_callbacks: + litellm._known_custom_logger_compatible_callbacks["agentops"] = LiteLLMCallbackHandler + + logger.debug("Registered AgentOps callback with LiteLLM") + except ImportError: + pass + except Exception as e: + logger.debug(f"Could not register LiteLLM callback: {e}") + + +# Auto-register when module is imported +_register_litellm_callback() diff --git a/tests/unit/integration/__init__.py b/tests/unit/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/integration/callbacks/__init__.py b/tests/unit/integration/callbacks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/integration/callbacks/litellm/__init__.py b/tests/unit/integration/callbacks/litellm/__init__.py new file mode 100644 index 000000000..6872d9d70 --- /dev/null +++ b/tests/unit/integration/callbacks/litellm/__init__.py @@ -0,0 +1 @@ +"""Tests for LiteLLM callback handler.""" diff --git a/tests/unit/integration/callbacks/litellm/test_litellm_callback.py b/tests/unit/integration/callbacks/litellm/test_litellm_callback.py new file mode 100644 index 000000000..106f9e8c1 --- /dev/null +++ b/tests/unit/integration/callbacks/litellm/test_litellm_callback.py @@ -0,0 +1,264 @@ +""" +Tests for LiteLLM callback handler. + +Tests the LiteLLMCallbackHandler for proper span creation, attribute extraction, +and handling of both completion and responses API calls. +""" + +import pytest +from unittest.mock import MagicMock, patch +from datetime import datetime + +from agentops.integration.callbacks.litellm.callback import LiteLLMCallbackHandler +from agentops.semconv import SpanAttributes, AgentOpsSpanKindValues + + +class TestLiteLLMCallbackHandler: + """Tests for the LiteLLM callback handler.""" + + @pytest.fixture + def mock_tracer(self): + """Mock the tracer for testing.""" + with patch("agentops.integration.callbacks.litellm.callback.tracer") as mock: + mock.initialized = True + mock_otel_tracer = MagicMock() + mock.get_tracer.return_value = mock_otel_tracer + yield mock + + @pytest.fixture + def handler(self, mock_tracer): + """Create a handler with mocked dependencies.""" + with patch("agentops.integration.callbacks.litellm.callback.tracer", mock_tracer): + handler = LiteLLMCallbackHandler(auto_session=False) + yield handler + + def test_handler_initialization(self): + """Test that the handler initializes correctly.""" + with patch("agentops.integration.callbacks.litellm.callback.tracer") as mock_tracer: + mock_tracer.initialized = False + + handler = LiteLLMCallbackHandler( + api_key="test-key", + tags=["test-tag"], + auto_session=False, + ) + + assert handler.api_key == "test-key" + assert handler.tags == ["test-tag"] + assert handler.active_spans == {} + assert handler.context_tokens == {} + + def test_extract_provider_from_model_string(self, handler): + """Test provider extraction from model strings.""" + assert handler._extract_provider("anthropic/claude-3-5-sonnet-20240620") == "anthropic" + assert handler._extract_provider("openai/gpt-4o") == "openai" + assert handler._extract_provider("google/gemini-pro") == "google" + assert handler._extract_provider("gpt-4o") == "openai" + assert handler._extract_provider("claude-3-sonnet") == "anthropic" + assert handler._extract_provider("gemini-pro") == "google" + assert handler._extract_provider("mixtral-8x7b") == "mistral" + assert handler._extract_provider("command-r") == "cohere" + assert handler._extract_provider("unknown-model") == "unknown" + + def test_get_call_id(self, handler): + """Test call ID generation.""" + # With litellm_call_id + kwargs_with_id = {"litellm_call_id": "test-call-123"} + assert handler._get_call_id(kwargs_with_id) == "test-call-123" + + # Without litellm_call_id (fallback) + kwargs_without_id = {"model": "gpt-4"} + call_id = handler._get_call_id(kwargs_without_id) + assert "gpt-4" in call_id + + def test_log_pre_api_call_creates_span(self, handler, mock_tracer): + """Test that log_pre_api_call creates a span.""" + mock_span = MagicMock() + mock_tracer.get_tracer().start_span.return_value = mock_span + + model = "anthropic/claude-3-5-sonnet-20240620" + messages = [{"role": "user", "content": "Hello"}] + kwargs = {"litellm_call_id": "test-123", "temperature": 0.7} + + handler.log_pre_api_call(model, messages, kwargs) + + # Verify span was created and stored + assert "test-123" in handler.active_spans + mock_tracer.get_tracer().start_span.assert_called_once() + + def test_log_success_event_ends_span(self, handler, mock_tracer): + """Test that log_success_event ends the span correctly.""" + mock_span = MagicMock() + mock_tracer.get_tracer().start_span.return_value = mock_span + + # First create a span + model = "gpt-4o" + messages = [{"role": "user", "content": "Hello"}] + kwargs = {"litellm_call_id": "test-123"} + + handler.log_pre_api_call(model, messages, kwargs) + + # Create mock response + mock_response = MagicMock() + mock_response.model = "gpt-4o" + mock_response.usage = MagicMock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 20 + mock_response.usage.total_tokens = 30 + mock_response.choices = [ + MagicMock( + message=MagicMock(role="assistant", content="Hi there!"), + finish_reason="stop" + ) + ] + + # Now end the span + handler.log_success_event( + kwargs, + mock_response, + datetime.now(), + datetime.now(), + ) + + # Verify span was ended + assert "test-123" not in handler.active_spans + mock_span.end.assert_called_once() + + def test_log_failure_event_records_exception(self, handler, mock_tracer): + """Test that log_failure_event records the exception.""" + mock_span = MagicMock() + mock_tracer.get_tracer().start_span.return_value = mock_span + + # First create a span + model = "gpt-4o" + messages = [{"role": "user", "content": "Hello"}] + kwargs = {"litellm_call_id": "test-123"} + + handler.log_pre_api_call(model, messages, kwargs) + + # Create exception + test_exception = ValueError("API Error") + kwargs["exception"] = test_exception + + # Now handle failure + handler.log_failure_event( + kwargs, + None, + datetime.now(), + datetime.now(), + ) + + # Verify exception was recorded + mock_span.record_exception.assert_called_once_with(test_exception) + + def test_handles_responses_api_format(self, handler, mock_tracer): + """Test handling of responses API format (output instead of choices).""" + mock_span = MagicMock() + mock_tracer.get_tracer().start_span.return_value = mock_span + + # Create a span + model = "gpt-4o" + messages = [{"role": "user", "content": "Hello"}] + kwargs = {"litellm_call_id": "test-123"} + + handler.log_pre_api_call(model, messages, kwargs) + + # Create responses API format response + mock_response = MagicMock() + mock_response.model = "gpt-4o" + mock_response.usage = {"input_tokens": 10, "output_tokens": 20} + mock_response.choices = None + mock_response.output = [ + {"content": [{"type": "text", "text": "Hello from responses API!"}]} + ] + + handler.log_success_event( + kwargs, + mock_response, + datetime.now(), + datetime.now(), + ) + + # Verify span attributes were set + mock_span.set_attribute.assert_any_call( + SpanAttributes.LLM_USAGE_PROMPT_TOKENS, 10 + ) + mock_span.set_attribute.assert_any_call( + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, 20 + ) + + def test_handles_anthropic_model(self, handler, mock_tracer): + """Test that Anthropic models are properly handled.""" + mock_span = MagicMock() + mock_tracer.get_tracer().start_span.return_value = mock_span + + model = "anthropic/claude-3-5-sonnet-20240620" + messages = [{"role": "user", "content": "Hello"}] + kwargs = {"litellm_call_id": "test-123"} + + handler.log_pre_api_call(model, messages, kwargs) + + # Verify provider was extracted correctly + call_args = mock_tracer.get_tracer().start_span.call_args + attributes = call_args[1].get("attributes", {}) + + assert attributes.get("litellm.provider") == "anthropic" + assert attributes.get("gen_ai.system") == "anthropic" + + def test_handles_openai_model(self, handler, mock_tracer): + """Test that OpenAI models are properly handled.""" + mock_span = MagicMock() + mock_tracer.get_tracer().start_span.return_value = mock_span + + model = "gpt-4o" + messages = [{"role": "user", "content": "Hello"}] + kwargs = {"litellm_call_id": "test-123"} + + handler.log_pre_api_call(model, messages, kwargs) + + # Verify provider was extracted correctly + call_args = mock_tracer.get_tracer().start_span.call_args + attributes = call_args[1].get("attributes", {}) + + assert attributes.get("litellm.provider") == "openai" + assert attributes.get("gen_ai.system") == "openai" + + def test_end_session_cleans_up(self, handler, mock_tracer): + """Test that end_session properly cleans up resources.""" + mock_span = MagicMock() + mock_tracer.get_tracer().start_span.return_value = mock_span + + # Create some spans + handler.log_pre_api_call("gpt-4", [], {"litellm_call_id": "test-1"}) + handler.log_pre_api_call("gpt-4", [], {"litellm_call_id": "test-2"}) + + assert len(handler.active_spans) == 2 + + # End session + handler.end_session() + + # Verify cleanup + assert len(handler.active_spans) == 0 + assert len(handler.context_tokens) == 0 + + +class TestLiteLLMCallbackIntegration: + """Integration-style tests for LiteLLM callback handler.""" + + def test_callback_handler_can_be_imported(self): + """Test that the callback handler can be imported from agentops.""" + from agentops import LiteLLMCallbackHandler + + assert LiteLLMCallbackHandler is not None + + def test_callback_inherits_from_custom_logger(self): + """Test that the callback handler inherits from CustomLogger if available.""" + try: + from litellm.integrations.custom_logger import CustomLogger + from agentops.integration.callbacks.litellm.callback import LiteLLMCallbackHandler + + # Should inherit from CustomLogger + assert issubclass(LiteLLMCallbackHandler, CustomLogger) + except ImportError: + # LiteLLM not installed, skip this test + pytest.skip("LiteLLM not installed")