diff --git a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py index fd3721d2..af2b10cb 100644 --- a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py @@ -59,14 +59,39 @@ def _format_image(image_url: str) -> Dict: } -def _format_message_anthropic(message: Union[HumanMessage, AIMessage]): - role = _message_type_lookups[message.type] +def _get_cache_control(message: BaseMessage) -> Optional[Dict[str, Any]]: + """Extract cache control from message's additional_kwargs or content block.""" + return ( + message.additional_kwargs.get("cache_control") + if isinstance(message.additional_kwargs, dict) + else None + ) + + +def _format_text_content(text: str) -> Dict[str, Union[str, Dict[str, Any]]]: + """Format text content.""" + content: Dict[str, Union[str, Dict[str, Any]]] = {"type": "text", "text": text} + return content + + +def _format_message_anthropic(message: Union[HumanMessage, AIMessage, SystemMessage]): + """Format a message for Anthropic API. + + Args: + message: The message to format. Can be HumanMessage, AIMessage, or SystemMessage. + + Returns: + A dictionary with the formatted message, or None if the message is empty. + """ # noqa: E501 content: List[Dict[str, Any]] = [] if isinstance(message.content, str): if not message.content.strip(): return None - content.append({"type": "text", "text": message.content}) + message_dict = _format_text_content(message.content) + if cache_control := _get_cache_control(message): + message_dict["cache_control"] = cache_control + content.append(message_dict) elif isinstance(message.content, list): for block in message.content: if isinstance(block, str): @@ -75,9 +100,8 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage]): # https://github.com/anthropics/anthropic-sdk-python/issues/461 if not block.strip(): continue - content.append({"type": "text", "text": block}) - - if isinstance(block, dict): + content.append(_format_text_content(block)) + elif isinstance(block, dict): if "type" not in block: raise ValueError("Dict content block must have a type key") @@ -113,25 +137,26 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage]): if not is_unique: continue - # all other block types content.append(block) else: raise ValueError("Message should be a str, list of str or list of dicts") - # adding all tool calls if isinstance(message, AIMessage) and message.tool_calls: for tc in message.tool_calls: tu = cast(Dict[str, Any], _lc_tool_call_to_anthropic_tool_use_block(tc)) content.append(tu) - return {"role": role, "content": content} + if message.type == "system": + return content + else: + return {"role": _message_type_lookups[message.type], "content": content} def _format_messages_anthropic( messages: List[BaseMessage], -) -> Tuple[Optional[str], List[Dict]]: +) -> Tuple[Optional[Dict[str, Any]], List[Dict]]: """Formats messages for anthropic.""" - system_message: Optional[str] = None + system_messages: Optional[Dict[str, Any]] = None formatted_messages: List[Dict] = [] merged_messages = _merge_messages(messages) @@ -139,12 +164,9 @@ def _format_messages_anthropic( if message.type == "system": if i != 0: raise ValueError("System message must be at beginning of message list.") - if not isinstance(message.content, str): - raise ValueError( - "System message must be a string, " - f"instead was: {type(message.content)}" - ) - system_message = message.content + fm = _format_message_anthropic(message) + if fm: + system_messages = fm continue fm = _format_message_anthropic(message) @@ -152,7 +174,7 @@ def _format_messages_anthropic( continue formatted_messages.append(fm) - return system_message, formatted_messages + return system_messages, formatted_messages class AnthropicTool(TypedDict): diff --git a/libs/vertexai/langchain_google_vertexai/model_garden.py b/libs/vertexai/langchain_google_vertexai/model_garden.py index 9070ac05..fd7938d6 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden.py @@ -32,6 +32,7 @@ AIMessage, BaseMessage, ) +from langchain_core.messages.ai import UsageMetadata from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, @@ -61,6 +62,13 @@ from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon +class CacheUsageMetadata(UsageMetadata): + cache_creation_input_tokens: Optional[int] + """The number of input tokens used to create the cache entry.""" + cache_read_input_tokens: Optional[int] + """The number of input tokens read from the cache.""" + + class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM): """Large language models served from Vertex AI Model Garden.""" @@ -225,11 +233,13 @@ def _format_output(self, data: Any, **kwargs: Any) -> ChatResult: else: msg = AIMessage(content=content) # Collect token usage - msg.usage_metadata = { - "input_tokens": data.usage.input_tokens, - "output_tokens": data.usage.output_tokens, - "total_tokens": data.usage.input_tokens + data.usage.output_tokens, - } + msg.usage_metadata = CacheUsageMetadata( + input_tokens=data.usage.input_tokens, + output_tokens=data.usage.output_tokens, + total_tokens=data.usage.input_tokens + data.usage.output_tokens, + cache_creation_input_tokens=data.usage.cache_creation_input_tokens, + cache_read_input_tokens=data.usage.cache_read_input_tokens, + ) return ChatResult( generations=[ChatGeneration(message=msg)], llm_output=llm_output, diff --git a/libs/vertexai/tests/integration_tests/test_anthropic_cache.py b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py new file mode 100644 index 00000000..9399331b --- /dev/null +++ b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py @@ -0,0 +1,147 @@ +"""Integration tests for Anthropic cache control functionality.""" +import os +from typing import Dict, List, Union + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.prompts import ChatPromptTemplate + +from langchain_google_vertexai.model_garden import ChatAnthropicVertex + + +@pytest.mark.extended +def test_anthropic_system_cache() -> None: + """Test chat with system message having cache control.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + context = SystemMessage( + content="You are my personal assistant. Be helpful and concise.", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ) + message = HumanMessage(content="Hello! What can you do for me?") + + response = model.invoke( + [context, message], model_name="claude-3-5-sonnet-v2@20241022" + ) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "usage_metadata" in response.additional_kwargs + assert "cache_creation_input_tokens" in response.additional_kwargs["usage_metadata"] + + +@pytest.mark.extended +def test_anthropic_mixed_cache() -> None: + """Test chat with different cache control types.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + context = SystemMessage( + content=[ + { + "type": "text", + "text": "You are my personal assistant.", + "cache_control": {"type": "ephemeral"}, + } + ] + ) + message = HumanMessage( + content=[ + { + "type": "text", + "text": "What's your name and what can you help me with?", + "cache_control": {"type": "ephemeral"}, + } + ] + ) + + response = model.invoke( + [context, message], model_name="claude-3-5-sonnet-v2@20241022" + ) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "usage_metadata" in response.additional_kwargs + + +@pytest.mark.extended +def test_anthropic_conversation_cache() -> None: + """Test chat conversation with cache control.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + context = SystemMessage( + content="You are my personal assistant. My name is Peter.", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ) + messages = [ + context, + HumanMessage( + content=[ + { + "type": "text", + "text": "What's my name?", + "cache_control": {"type": "ephemeral"}, + } + ] + ), + AIMessage(content="Your name is Peter."), + HumanMessage( + content=[ + { + "type": "text", + "text": "Can you repeat my name?", + "cache_control": {"type": "ephemeral"}, + } + ] + ), + ] + + response = model.invoke(messages, model_name="claude-3-5-sonnet-v2@20241022") + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "peter" in response.content.lower() # Should remember the name + + +@pytest.mark.extended +def test_anthropic_chat_template_cache() -> None: + """Test chat template with structured content and cache control.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + content: List[Union[Dict[str, Union[str, Dict[str, str]]], str]] = [ + { + "text": "You are a helpful assistant. Be concise and clear.", + "type": "text", + "cache_control": {"type": "ephemeral"}, + } + ] + + prompt = ChatPromptTemplate.from_messages( + [SystemMessage(content=content), ("human", "{input}")] + ) + + chain = prompt | model + + response = chain.invoke( + {"input": "What's the capital of France?"}, + ) + + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "Paris" in response.content diff --git a/libs/vertexai/tests/unit_tests/test_anthropic_utils.py b/libs/vertexai/tests/unit_tests/test_anthropic_utils.py index 7f83bfb3..5e493823 100644 --- a/libs/vertexai/tests/unit_tests/test_anthropic_utils.py +++ b/libs/vertexai/tests/unit_tests/test_anthropic_utils.py @@ -1,3 +1,5 @@ +"""Unit tests for _anthropic_utils.py.""" + import pytest from langchain_core.messages import ( AIMessage, @@ -7,7 +9,260 @@ ) from langchain_core.messages.tool import tool_call as create_tool_call -from langchain_google_vertexai.model_garden import _format_messages_anthropic +from langchain_google_vertexai._anthropic_utils import ( + _format_message_anthropic, + _format_messages_anthropic, +) + + +def test_format_message_anthropic_with_cache_control_in_kwargs(): + """Test formatting a message with cache control in additional_kwargs.""" + message = HumanMessage( + content="Hello", additional_kwargs={"cache_control": {"type": "semantic"}} + ) + result = _format_message_anthropic(message) + assert result == { + "role": "user", + "content": [ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}} + ], + } + + +def test_format_message_anthropic_with_cache_control_in_block(): + """Test formatting a message with cache control in content block.""" + message = HumanMessage( + content=[ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}} + ] + ) + result = _format_message_anthropic(message) + assert result == { + "role": "user", + "content": [ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}} + ], + } + + +def test_format_message_anthropic_with_mixed_blocks(): + """Test formatting a message with mixed blocks, some with cache control.""" + message = HumanMessage( + content=[ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}}, + {"type": "text", "text": "World"}, + "Plain text", + ] + ) + result = _format_message_anthropic(message) + assert result == { + "role": "user", + "content": [ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}}, + {"type": "text", "text": "World"}, + {"type": "text", "text": "Plain text"}, + ], + } + + +def test_format_messages_anthropic_with_system_cache_control(): + """Test formatting messages with system message having cache control.""" + messages = [ + SystemMessage( + content="System message", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_message_anthropic_system(): + """Test formatting a system message.""" + message = SystemMessage( + content="System message", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ) + result = _format_message_anthropic(message) + assert result == [ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + + +def test_format_message_anthropic_system_list(): + """Test formatting a system message with list content.""" + message = SystemMessage( + content=[ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + ) + result = _format_message_anthropic(message) + assert result == [ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + + +def test_format_messages_anthropic_with_system_string(): + """Test formatting messages with system message as string.""" + messages = [ + SystemMessage(content="System message"), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [{"type": "text", "text": "System message"}] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_messages_anthropic_with_system_list(): + """Test formatting messages with system message as a list.""" + messages = [ + SystemMessage( + content=[ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + ), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_messages_anthropic_with_system_mixed_list(): + """Test formatting messages with system message as a mixed list.""" + messages = [ + SystemMessage( + content=[ + "Plain system rule", + { + "type": "text", + "text": "Formatted system rule", + "cache_control": {"type": "ephemeral"}, + }, + ] + ), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + {"type": "text", "text": "Plain system rule"}, + { + "type": "text", + "text": "Formatted system rule", + "cache_control": {"type": "ephemeral"}, + }, + ] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_messages_anthropic_with_mixed_messages(): + """Test formatting a conversation with various message types and cache controls.""" + messages = [ + SystemMessage( + content=[ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + ), + HumanMessage( + content=[ + { + "type": "text", + "text": "Human message", + "cache_control": {"type": "semantic"}, + } + ] + ), + AIMessage( + content="AI response", + additional_kwargs={"cache_control": {"type": "semantic"}}, + ), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + + assert formatted_messages == [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Human message", + "cache_control": {"type": "semantic"}, + } + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "AI response", + "cache_control": {"type": "semantic"}, + } + ], + }, + ] @pytest.mark.parametrize( @@ -113,7 +368,7 @@ content="Mike age is 30", ), ], - "test1", + [{"type": "text", "text": "test1"}], [ { "role": "assistant", @@ -473,6 +728,7 @@ def test_format_messages_anthropic( source_history, expected_sm, expected_history ) -> None: + """Test the original format_messages_anthropic functionality.""" sm, result_history = _format_messages_anthropic(source_history) for result, expected in zip(result_history, expected_history): diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index 06567e71..81e9eebf 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -1077,6 +1077,8 @@ def test_anthropic_format_output() -> None: class Usage: input_tokens: int output_tokens: int + cache_creation_input_tokens: Optional[int] + cache_read_input_tokens: Optional[int] @dataclass class Message: @@ -1092,13 +1094,25 @@ def model_dump(self): ], "model": "baz", "role": "assistant", - "usage": Usage(input_tokens=2, output_tokens=1), + "usage": Usage( + input_tokens=2, + output_tokens=1, + cache_creation_input_tokens=1, + cache_read_input_tokens=1, + ), "type": "message", } usage: Usage - test_msg = Message(usage=Usage(input_tokens=2, output_tokens=1)) + test_msg = Message( + usage=Usage( + input_tokens=2, + output_tokens=1, + cache_creation_input_tokens=1, + cache_read_input_tokens=1, + ) + ) model = ChatAnthropicVertex(project="test-project", location="test-location") result = model._format_output(test_msg) @@ -1113,4 +1127,6 @@ def model_dump(self): "input_tokens": 2, "output_tokens": 1, "total_tokens": 3, + "cache_creation_input_tokens": 1, + "cache_read_input_tokens": 1, }