diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py index 3c64e12919..f8aa62ba31 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py @@ -59,6 +59,9 @@ RawContentBlockDeltaEvent, RawContentBlockStartEvent, RawContentBlockStopEvent, + RawMessageDeltaEvent, + RawMessageStartEvent, + RawMessageStopEvent, TextBlock, TextDelta, ThinkingBlock, @@ -462,6 +465,9 @@ def gen() -> Generator[AnthropicChatResponse, None, None]: cur_citations: List[Dict[str, Any]] = [] tracked_citations: Set[str] = set() role = MessageRole.ASSISTANT + # Track usage metadata and stop_reason from RawMessage events + usage_metadata: Dict[str, Any] = {} + stop_reason: Optional[str] = None for r in response: if isinstance(r, (ContentBlockDeltaEvent, RawContentBlockDeltaEvent)): if isinstance(r.delta, TextDelta): @@ -550,6 +556,10 @@ def gen() -> Generator[AnthropicChatResponse, None, None]: message=ChatMessage( role=role, blocks=content, + additional_kwargs={ + "usage": usage_metadata if usage_metadata else None, + "stop_reason": stop_reason, + }, ), citations=cur_citations, delta=content_delta, @@ -584,11 +594,49 @@ def gen() -> Generator[AnthropicChatResponse, None, None]: message=ChatMessage( role=role, blocks=content, + additional_kwargs={ + "usage": usage_metadata if usage_metadata else None, + "stop_reason": stop_reason, + }, ), citations=cur_citations, delta="", raw=dict(r), ) + elif isinstance(r, RawMessageStartEvent): + # Capture initial usage metadata from message_start + if hasattr(r.message, "usage") and r.message.usage: + usage_metadata = { + "input_tokens": r.message.usage.input_tokens, + "output_tokens": r.message.usage.output_tokens, + } + elif isinstance(r, RawMessageDeltaEvent): + # Update usage metadata and capture stop_reason from message_delta + if hasattr(r, "usage") and r.usage: + usage_metadata = { + "input_tokens": r.usage.input_tokens, + "output_tokens": r.usage.output_tokens, + } + if hasattr(r, "delta") and hasattr(r.delta, "stop_reason"): + stop_reason = r.delta.stop_reason + + # Yield a final chunk with updated metadata including stop_reason + yield AnthropicChatResponse( + message=ChatMessage( + role=role, + blocks=content, + additional_kwargs={ + "usage": usage_metadata if usage_metadata else None, + "stop_reason": stop_reason, + }, + ), + citations=cur_citations, + delta="", + raw=dict(r), + ) + elif isinstance(r, RawMessageStopEvent): + # Final event - no additional data to capture + pass return gen() @@ -664,6 +712,9 @@ async def gen() -> ChatResponseAsyncGen: cur_citations: List[Dict[str, Any]] = [] tracked_citations: Set[str] = set() role = MessageRole.ASSISTANT + # Track usage metadata and stop_reason from RawMessage events + usage_metadata: Dict[str, Any] = {} + stop_reason: Optional[str] = None async for r in response: if isinstance(r, (ContentBlockDeltaEvent, RawContentBlockDeltaEvent)): if isinstance(r.delta, TextDelta): @@ -752,6 +803,10 @@ async def gen() -> ChatResponseAsyncGen: message=ChatMessage( role=role, blocks=content, + additional_kwargs={ + "usage": usage_metadata if usage_metadata else None, + "stop_reason": stop_reason, + }, ), citations=cur_citations, delta=content_delta, @@ -786,11 +841,49 @@ async def gen() -> ChatResponseAsyncGen: message=ChatMessage( role=role, blocks=content, + additional_kwargs={ + "usage": usage_metadata if usage_metadata else None, + "stop_reason": stop_reason, + }, + ), + citations=cur_citations, + delta="", + raw=dict(r), + ) + elif isinstance(r, RawMessageStartEvent): + # Capture initial usage metadata from message_start + if hasattr(r.message, "usage") and r.message.usage: + usage_metadata = { + "input_tokens": r.message.usage.input_tokens, + "output_tokens": r.message.usage.output_tokens, + } + elif isinstance(r, RawMessageDeltaEvent): + # Update usage metadata and capture stop_reason from message_delta + if hasattr(r, "usage") and r.usage: + usage_metadata = { + "input_tokens": r.usage.input_tokens, + "output_tokens": r.usage.output_tokens, + } + if hasattr(r, "delta") and hasattr(r.delta, "stop_reason"): + stop_reason = r.delta.stop_reason + + # Yield a final chunk with updated metadata including stop_reason + yield AnthropicChatResponse( + message=ChatMessage( + role=role, + blocks=content, + additional_kwargs={ + "usage": usage_metadata if usage_metadata else None, + "stop_reason": stop_reason, + }, ), citations=cur_citations, delta="", raw=dict(r), ) + elif isinstance(r, RawMessageStopEvent): + # Final event - no additional data to capture + pass return gen() diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml index fb5fdd0fed..e25fbd9e93 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml @@ -27,7 +27,7 @@ dev = [ [project] name = "llama-index-llms-anthropic" -version = "0.10.1" +version = "0.10.2" description = "llama-index llms anthropic integration" authors = [{name = "Your Name", email = "you@example.com"}] requires-python = ">=3.9,<4.0" diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/tests/test_llms_anthropic.py b/llama-index-integrations/llms/llama-index-llms-anthropic/tests/test_llms_anthropic.py index 96c7288e3a..60f099cdf7 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/tests/test_llms_anthropic.py +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/tests/test_llms_anthropic.py @@ -646,3 +646,259 @@ def test_prepare_chat_with_tools_caching_unsupported_model(caplog): # Check that warning was logged assert "does not support prompt caching" in caplog.text assert "claude-2.1" in caplog.text + + +def test_stream_chat_usage_and_stop_reason_mock(): + """ + Mock test for streaming usage metadata and stop_reason - no API key required. + + This test verifies that stream_chat properly captures and yields: + - usage metadata (input_tokens, output_tokens) from RawMessageDeltaEvent + - stop_reason from RawMessageDeltaEvent + + Related to issue #20194. + """ + from unittest.mock import MagicMock + from anthropic.types import TextDelta, Usage + + # Create mock events that simulate Anthropic streaming response + mock_text_delta = MagicMock(spec=TextDelta) + mock_text_delta.text = "Hello" + mock_text_delta.type = "text_delta" + + mock_content_delta_event = MagicMock() + mock_content_delta_event.delta = mock_text_delta + mock_content_delta_event.index = 0 + + mock_content_stop_event = MagicMock() + mock_content_stop_event.index = 0 + + # Create mock RawMessageDeltaEvent with usage and stop_reason + mock_usage = MagicMock(spec=Usage) + mock_usage.input_tokens = 15 + mock_usage.output_tokens = 8 + + mock_delta = MagicMock() + mock_delta.stop_reason = "end_turn" + + mock_message_delta_event = MagicMock() + mock_message_delta_event.usage = mock_usage + mock_message_delta_event.delta = mock_delta + + # Create mock streaming response generator + def mock_stream_generator(): + from anthropic.types import ( + RawContentBlockDeltaEvent, + ContentBlockStopEvent, + RawMessageDeltaEvent, + ) + + # Simulate streaming events + yield MagicMock(spec=RawContentBlockDeltaEvent, delta=mock_text_delta, index=0) + yield MagicMock(spec=ContentBlockStopEvent, index=0) + yield MagicMock( + spec=RawMessageDeltaEvent, + usage=mock_usage, + delta=mock_delta, + ) + + # Create Anthropic LLM and mock its client + llm = Anthropic(model="claude-3-5-sonnet-latest") + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_stream_generator() + llm._client = mock_client + + # Test stream_chat + messages = [ChatMessage(role="user", content="Test message")] + stream_resp = llm.stream_chat(messages) + + # Collect all chunks + chunks = list(stream_resp) + + # Verify we got responses + assert len(chunks) > 0, "Should yield at least one chunk" + last_chunk = chunks[-1] + assert isinstance(last_chunk, AnthropicChatResponse) + + # Verify usage metadata was captured + usage = last_chunk.message.additional_kwargs.get("usage") + assert usage is not None, ( + "Usage metadata should be captured from RawMessageDeltaEvent" + ) + assert usage["input_tokens"] == 15 + assert usage["output_tokens"] == 8 + + # Verify stop_reason was captured + stop_reason = last_chunk.message.additional_kwargs.get("stop_reason") + assert stop_reason is not None, ( + "stop_reason should be captured from RawMessageDeltaEvent" + ) + assert stop_reason == "end_turn" + + +@pytest.mark.asyncio +async def test_astream_chat_usage_and_stop_reason_mock(): + """ + Mock test for async streaming usage metadata and stop_reason - no API key required. + + Async version of test_stream_chat_usage_and_stop_reason_mock. + Related to issue #20194. + """ + from unittest.mock import MagicMock, AsyncMock + from anthropic.types import TextDelta, Usage + + # Create mock events + mock_text_delta = MagicMock(spec=TextDelta) + mock_text_delta.text = "Hello async" + mock_text_delta.type = "text_delta" + + mock_usage = MagicMock(spec=Usage) + mock_usage.input_tokens = 20 + mock_usage.output_tokens = 12 + + mock_delta = MagicMock() + mock_delta.stop_reason = "max_tokens" + + # Create async mock streaming response generator + async def mock_async_stream_generator(): + from anthropic.types import ( + RawContentBlockDeltaEvent, + ContentBlockStopEvent, + RawMessageDeltaEvent, + ) + + yield MagicMock(spec=RawContentBlockDeltaEvent, delta=mock_text_delta, index=0) + yield MagicMock(spec=ContentBlockStopEvent, index=0) + yield MagicMock( + spec=RawMessageDeltaEvent, + usage=mock_usage, + delta=mock_delta, + ) + + # Create Anthropic LLM and mock its async client + llm = Anthropic(model="claude-3-5-sonnet-latest") + mock_async_client = AsyncMock() + # For async client, the create method should be an AsyncMock that returns the generator + mock_async_client.messages.create = AsyncMock( + return_value=mock_async_stream_generator() + ) + llm._aclient = mock_async_client + + # Test astream_chat + messages = [ChatMessage(role="user", content="Test async message")] + stream_resp = await llm.astream_chat(messages) + + # Collect all chunks + chunks = [] + async for chunk in stream_resp: + chunks.append(chunk) + + # Verify we got responses + assert len(chunks) > 0, "Should yield at least one chunk" + last_chunk = chunks[-1] + assert isinstance(last_chunk, AnthropicChatResponse) + + # Verify usage metadata was captured + usage = last_chunk.message.additional_kwargs.get("usage") + assert usage is not None, "Usage metadata should be captured in async streaming" + assert usage["input_tokens"] == 20 + assert usage["output_tokens"] == 12 + + # Verify stop_reason was captured + stop_reason = last_chunk.message.additional_kwargs.get("stop_reason") + assert stop_reason is not None, "stop_reason should be captured in async streaming" + assert stop_reason == "max_tokens" + + +@pytest.mark.skipif( + os.getenv("ANTHROPIC_API_KEY") is None, + reason="Anthropic API key not available to test streaming metadata", +) +def test_stream_chat_usage_and_stop_reason(): + """ + Test that streaming captures usage metadata and stop_reason from RawMessageDeltaEvent. + + This addresses issue #20194 - Anthropic RawMessageDeltaEvent support. + The streaming API should capture: + - input_tokens and output_tokens from usage metadata + - stop_reason (e.g., 'end_turn', 'max_tokens') to understand why streaming stopped + """ + llm = Anthropic(model="claude-3-5-sonnet-latest") + messages = [ + ChatMessage(role="user", content="Say hello in 3 words"), + ] + + # Stream the response + stream_resp = llm.stream_chat(messages) + last_chunk = None + for chunk in stream_resp: + last_chunk = chunk + + # Verify we got a response + assert last_chunk is not None + assert isinstance(last_chunk, AnthropicChatResponse) + + # Check that usage metadata was captured + usage = last_chunk.message.additional_kwargs.get("usage") + assert usage is not None, ( + "Usage metadata should be captured from RawMessageDeltaEvent" + ) + assert "input_tokens" in usage, "Usage should include input_tokens" + assert "output_tokens" in usage, "Usage should include output_tokens" + assert isinstance(usage["input_tokens"], int) + assert isinstance(usage["output_tokens"], int) + assert usage["input_tokens"] > 0, "Should have processed input tokens" + assert usage["output_tokens"] > 0, "Should have generated output tokens" + + # Check that stop_reason was captured + stop_reason = last_chunk.message.additional_kwargs.get("stop_reason") + assert stop_reason is not None, ( + "stop_reason should be captured from RawMessageDeltaEvent" + ) + # Typical stop reasons: "end_turn", "max_tokens", "stop_sequence", "tool_use" + assert isinstance(stop_reason, str) + print(f"Stop reason: {stop_reason}") + print(f"Usage: {usage}") + + +@pytest.mark.skipif( + os.getenv("ANTHROPIC_API_KEY") is None, + reason="Anthropic API key not available to test async streaming metadata", +) +@pytest.mark.asyncio +async def test_astream_chat_usage_and_stop_reason(): + """ + Test that async streaming captures usage metadata and stop_reason. + + Async version of the streaming metadata test for issue #20194. + """ + llm = Anthropic(model="claude-3-5-sonnet-latest") + messages = [ + ChatMessage(role="user", content="Count to 5"), + ] + + # Stream the response asynchronously + stream_resp = await llm.astream_chat(messages) + last_chunk = None + async for chunk in stream_resp: + last_chunk = chunk + + # Verify we got a response + assert last_chunk is not None + assert isinstance(last_chunk, AnthropicChatResponse) + + # Check that usage metadata was captured + usage = last_chunk.message.additional_kwargs.get("usage") + assert usage is not None, "Usage metadata should be captured in async streaming" + assert "input_tokens" in usage + assert "output_tokens" in usage + assert isinstance(usage["input_tokens"], int) + assert isinstance(usage["output_tokens"], int) + assert usage["output_tokens"] > 0 + + # Check that stop_reason was captured + stop_reason = last_chunk.message.additional_kwargs.get("stop_reason") + assert stop_reason is not None, "stop_reason should be captured in async streaming" + assert isinstance(stop_reason, str) + print(f"Async - Stop reason: {stop_reason}") + print(f"Async - Usage: {usage}")