From 72e7155d61cf4291535b1f499b9e3211c786082b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Clgesuellip=E2=80=9D?= <“lgesuellipinto@uade.edu.ar”> Date: Fri, 3 Jan 2025 14:05:38 -0300 Subject: [PATCH 1/6] Add anthropic cache functionality --- .../_anthropic_utils.py | 61 ++-- .../langchain_google_vertexai/model_garden.py | 20 +- .../integration_tests/test_anthropic_cache.py | 138 ++++++++++ .../tests/unit_tests/test_anthropic_utils.py | 260 +++++++++++++++++- .../tests/unit_tests/test_chat_models.py | 20 +- 5 files changed, 472 insertions(+), 27 deletions(-) create mode 100644 libs/vertexai/tests/integration_tests/test_anthropic_cache.py diff --git a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py index fd3721d2..d32a7005 100644 --- a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py @@ -59,14 +59,42 @@ def _format_image(image_url: str) -> Dict: } -def _format_message_anthropic(message: Union[HumanMessage, AIMessage]): - role = _message_type_lookups[message.type] +def _get_cache_control(message: BaseMessage) -> Optional[Dict[str, Any]]: + """Extract cache control from message's additional_kwargs or content block.""" + return ( + message.additional_kwargs.get("cache_control") + if isinstance(message.additional_kwargs, dict) + else None + ) + + +def _format_text_content( + text: str, cache_control: Optional[Dict[str, Any]] = None +) -> Dict[str, Union[str, Dict[str, Any]]]: + """Format text content with optional cache control.""" + content: Dict[str, Union[str, Dict[str, Any]]] = {"type": "text", "text": text} + if cache_control: + content["cache_control"] = cache_control + return content + + +def _format_message_anthropic(message: Union[HumanMessage, AIMessage, SystemMessage]): + """Format a message for Anthropic API. + + Args: + message: The message to format. Can be HumanMessage, AIMessage, or SystemMessage. + + Returns: + A dictionary with the formatted message, or None if the message is empty. + """ # noqa: E501 content: List[Dict[str, Any]] = [] if isinstance(message.content, str): if not message.content.strip(): return None - content.append({"type": "text", "text": message.content}) + content.append( + _format_text_content(message.content, _get_cache_control(message)) + ) elif isinstance(message.content, list): for block in message.content: if isinstance(block, str): @@ -75,9 +103,8 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage]): # https://github.com/anthropics/anthropic-sdk-python/issues/461 if not block.strip(): continue - content.append({"type": "text", "text": block}) - - if isinstance(block, dict): + content.append(_format_text_content(block)) + elif isinstance(block, dict): if "type" not in block: raise ValueError("Dict content block must have a type key") @@ -113,25 +140,26 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage]): if not is_unique: continue - # all other block types content.append(block) else: raise ValueError("Message should be a str, list of str or list of dicts") - # adding all tool calls if isinstance(message, AIMessage) and message.tool_calls: for tc in message.tool_calls: tu = cast(Dict[str, Any], _lc_tool_call_to_anthropic_tool_use_block(tc)) content.append(tu) - return {"role": role, "content": content} + if message.type == "system": + return content + else: + return {"role": _message_type_lookups[message.type], "content": content} def _format_messages_anthropic( messages: List[BaseMessage], -) -> Tuple[Optional[str], List[Dict]]: +) -> Tuple[Optional[Dict[str, Any]], List[Dict]]: """Formats messages for anthropic.""" - system_message: Optional[str] = None + system_messages: Optional[Dict[str, Any]] = None formatted_messages: List[Dict] = [] merged_messages = _merge_messages(messages) @@ -139,12 +167,9 @@ def _format_messages_anthropic( if message.type == "system": if i != 0: raise ValueError("System message must be at beginning of message list.") - if not isinstance(message.content, str): - raise ValueError( - "System message must be a string, " - f"instead was: {type(message.content)}" - ) - system_message = message.content + fm = _format_message_anthropic(message) + if fm: + system_messages = fm continue fm = _format_message_anthropic(message) @@ -152,7 +177,7 @@ def _format_messages_anthropic( continue formatted_messages.append(fm) - return system_message, formatted_messages + return system_messages, formatted_messages class AnthropicTool(TypedDict): diff --git a/libs/vertexai/langchain_google_vertexai/model_garden.py b/libs/vertexai/langchain_google_vertexai/model_garden.py index 9070ac05..fd7938d6 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden.py @@ -32,6 +32,7 @@ AIMessage, BaseMessage, ) +from langchain_core.messages.ai import UsageMetadata from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, @@ -61,6 +62,13 @@ from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon +class CacheUsageMetadata(UsageMetadata): + cache_creation_input_tokens: Optional[int] + """The number of input tokens used to create the cache entry.""" + cache_read_input_tokens: Optional[int] + """The number of input tokens read from the cache.""" + + class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM): """Large language models served from Vertex AI Model Garden.""" @@ -225,11 +233,13 @@ def _format_output(self, data: Any, **kwargs: Any) -> ChatResult: else: msg = AIMessage(content=content) # Collect token usage - msg.usage_metadata = { - "input_tokens": data.usage.input_tokens, - "output_tokens": data.usage.output_tokens, - "total_tokens": data.usage.input_tokens + data.usage.output_tokens, - } + msg.usage_metadata = CacheUsageMetadata( + input_tokens=data.usage.input_tokens, + output_tokens=data.usage.output_tokens, + total_tokens=data.usage.input_tokens + data.usage.output_tokens, + cache_creation_input_tokens=data.usage.cache_creation_input_tokens, + cache_read_input_tokens=data.usage.cache_read_input_tokens, + ) return ChatResult( generations=[ChatGeneration(message=msg)], llm_output=llm_output, diff --git a/libs/vertexai/tests/integration_tests/test_anthropic_cache.py b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py new file mode 100644 index 00000000..c423f645 --- /dev/null +++ b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py @@ -0,0 +1,138 @@ +"""Integration tests for Anthropic cache control functionality.""" + +import os + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.prompts import ChatPromptTemplate + +from langchain_google_vertexai.model_garden import ChatAnthropicVertex + + +def test_anthropic_system_cache() -> None: + """Test chat with system message having cache control.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + context = SystemMessage( + content="You are my personal assistant. Be helpful and concise.", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ) + message = HumanMessage(content="Hello! What can you do for me?") + + response = model.invoke([context, message], model_name="claude-3-sonnet@latest") + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "usage_metadata" in response.additional_kwargs + assert "cache_creation_input_tokens" in response.additional_kwargs["usage_metadata"] + + +def test_anthropic_mixed_cache() -> None: + """Test chat with different cache control types.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + context = SystemMessage( + content=[ + { + "type": "text", + "text": "You are my personal assistant.", + "cache_control": {"type": "ephemeral"}, + } + ] + ) + message = HumanMessage( + content=[ + { + "type": "text", + "text": "What's your name and what can you help me with?", + "cache_control": {"type": "semantic"}, + } + ] + ) + + response = model.invoke([context, message], model_name="claude-3-sonnet@latest") + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "usage_metadata" in response.additional_kwargs + + +def test_anthropic_conversation_cache() -> None: + """Test chat conversation with cache control.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + context = SystemMessage( + content="You are my personal assistant. My name is Peter.", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ) + messages = [ + context, + HumanMessage( + content=[ + { + "type": "text", + "text": "What's my name?", + "cache_control": {"type": "semantic"}, + } + ] + ), + AIMessage(content="Your name is Peter."), + HumanMessage( + content=[ + { + "type": "text", + "text": "Can you repeat my name?", + "cache_control": {"type": "semantic"}, + } + ] + ), + ] + + response = model.invoke(messages, model_name="claude-3-sonnet@latest") + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "peter" in response.content.lower() # Should remember the name + + +def test_anthropic_chat_template_cache() -> None: + """Test chat template with structured content and cache control.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + content: list[dict[str, str | dict[str, str]] | str] = [ + { + "text": "You are a helpful assistant. Be concise and clear.", + "type": "text", + "cache_control": {"type": "ephemeral"}, + } + ] + + prompt = ChatPromptTemplate.from_messages( + [SystemMessage(content=content), ("human", "{input}")] + ) + + chain = prompt | model + + response = chain.invoke( + {"input": "What's the capital of France?"}, + ) + + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "Paris" in response.content diff --git a/libs/vertexai/tests/unit_tests/test_anthropic_utils.py b/libs/vertexai/tests/unit_tests/test_anthropic_utils.py index 7f83bfb3..5e493823 100644 --- a/libs/vertexai/tests/unit_tests/test_anthropic_utils.py +++ b/libs/vertexai/tests/unit_tests/test_anthropic_utils.py @@ -1,3 +1,5 @@ +"""Unit tests for _anthropic_utils.py.""" + import pytest from langchain_core.messages import ( AIMessage, @@ -7,7 +9,260 @@ ) from langchain_core.messages.tool import tool_call as create_tool_call -from langchain_google_vertexai.model_garden import _format_messages_anthropic +from langchain_google_vertexai._anthropic_utils import ( + _format_message_anthropic, + _format_messages_anthropic, +) + + +def test_format_message_anthropic_with_cache_control_in_kwargs(): + """Test formatting a message with cache control in additional_kwargs.""" + message = HumanMessage( + content="Hello", additional_kwargs={"cache_control": {"type": "semantic"}} + ) + result = _format_message_anthropic(message) + assert result == { + "role": "user", + "content": [ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}} + ], + } + + +def test_format_message_anthropic_with_cache_control_in_block(): + """Test formatting a message with cache control in content block.""" + message = HumanMessage( + content=[ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}} + ] + ) + result = _format_message_anthropic(message) + assert result == { + "role": "user", + "content": [ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}} + ], + } + + +def test_format_message_anthropic_with_mixed_blocks(): + """Test formatting a message with mixed blocks, some with cache control.""" + message = HumanMessage( + content=[ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}}, + {"type": "text", "text": "World"}, + "Plain text", + ] + ) + result = _format_message_anthropic(message) + assert result == { + "role": "user", + "content": [ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}}, + {"type": "text", "text": "World"}, + {"type": "text", "text": "Plain text"}, + ], + } + + +def test_format_messages_anthropic_with_system_cache_control(): + """Test formatting messages with system message having cache control.""" + messages = [ + SystemMessage( + content="System message", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_message_anthropic_system(): + """Test formatting a system message.""" + message = SystemMessage( + content="System message", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ) + result = _format_message_anthropic(message) + assert result == [ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + + +def test_format_message_anthropic_system_list(): + """Test formatting a system message with list content.""" + message = SystemMessage( + content=[ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + ) + result = _format_message_anthropic(message) + assert result == [ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + + +def test_format_messages_anthropic_with_system_string(): + """Test formatting messages with system message as string.""" + messages = [ + SystemMessage(content="System message"), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [{"type": "text", "text": "System message"}] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_messages_anthropic_with_system_list(): + """Test formatting messages with system message as a list.""" + messages = [ + SystemMessage( + content=[ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + ), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_messages_anthropic_with_system_mixed_list(): + """Test formatting messages with system message as a mixed list.""" + messages = [ + SystemMessage( + content=[ + "Plain system rule", + { + "type": "text", + "text": "Formatted system rule", + "cache_control": {"type": "ephemeral"}, + }, + ] + ), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + {"type": "text", "text": "Plain system rule"}, + { + "type": "text", + "text": "Formatted system rule", + "cache_control": {"type": "ephemeral"}, + }, + ] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_messages_anthropic_with_mixed_messages(): + """Test formatting a conversation with various message types and cache controls.""" + messages = [ + SystemMessage( + content=[ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + ), + HumanMessage( + content=[ + { + "type": "text", + "text": "Human message", + "cache_control": {"type": "semantic"}, + } + ] + ), + AIMessage( + content="AI response", + additional_kwargs={"cache_control": {"type": "semantic"}}, + ), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + + assert formatted_messages == [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Human message", + "cache_control": {"type": "semantic"}, + } + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "AI response", + "cache_control": {"type": "semantic"}, + } + ], + }, + ] @pytest.mark.parametrize( @@ -113,7 +368,7 @@ content="Mike age is 30", ), ], - "test1", + [{"type": "text", "text": "test1"}], [ { "role": "assistant", @@ -473,6 +728,7 @@ def test_format_messages_anthropic( source_history, expected_sm, expected_history ) -> None: + """Test the original format_messages_anthropic functionality.""" sm, result_history = _format_messages_anthropic(source_history) for result, expected in zip(result_history, expected_history): diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index 06567e71..81e9eebf 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -1077,6 +1077,8 @@ def test_anthropic_format_output() -> None: class Usage: input_tokens: int output_tokens: int + cache_creation_input_tokens: Optional[int] + cache_read_input_tokens: Optional[int] @dataclass class Message: @@ -1092,13 +1094,25 @@ def model_dump(self): ], "model": "baz", "role": "assistant", - "usage": Usage(input_tokens=2, output_tokens=1), + "usage": Usage( + input_tokens=2, + output_tokens=1, + cache_creation_input_tokens=1, + cache_read_input_tokens=1, + ), "type": "message", } usage: Usage - test_msg = Message(usage=Usage(input_tokens=2, output_tokens=1)) + test_msg = Message( + usage=Usage( + input_tokens=2, + output_tokens=1, + cache_creation_input_tokens=1, + cache_read_input_tokens=1, + ) + ) model = ChatAnthropicVertex(project="test-project", location="test-location") result = model._format_output(test_msg) @@ -1113,4 +1127,6 @@ def model_dump(self): "input_tokens": 2, "output_tokens": 1, "total_tokens": 3, + "cache_creation_input_tokens": 1, + "cache_read_input_tokens": 1, } From 8cab93d03bf0338d1fe40d2ec2c708843552b4d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Clgesuellip=E2=80=9D?= <“lgesuellipinto@uade.edu.ar”> Date: Fri, 3 Jan 2025 15:12:21 -0300 Subject: [PATCH 2/6] Fix model name --- .../tests/integration_tests/test_anthropic_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/vertexai/tests/integration_tests/test_anthropic_cache.py b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py index c423f645..c334a9e8 100644 --- a/libs/vertexai/tests/integration_tests/test_anthropic_cache.py +++ b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py @@ -23,7 +23,7 @@ def test_anthropic_system_cache() -> None: ) message = HumanMessage(content="Hello! What can you do for me?") - response = model.invoke([context, message], model_name="claude-3-sonnet@latest") + response = model.invoke([context, message], model_name="claude-3-sonnet@20240229") assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert "usage_metadata" in response.additional_kwargs @@ -58,7 +58,7 @@ def test_anthropic_mixed_cache() -> None: ] ) - response = model.invoke([context, message], model_name="claude-3-sonnet@latest") + response = model.invoke([context, message], model_name="claude-3-sonnet@20240229") assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert "usage_metadata" in response.additional_kwargs @@ -100,7 +100,7 @@ def test_anthropic_conversation_cache() -> None: ), ] - response = model.invoke(messages, model_name="claude-3-sonnet@latest") + response = model.invoke(messages, model_name="claude-3-sonnet@20240229") assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert "peter" in response.content.lower() # Should remember the name From 43d9a7f2c140c158809aeaa9b1a0aa1e3c08da7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Clgesuellip=E2=80=9D?= <“lgesuellipinto@uade.edu.ar”> Date: Fri, 3 Jan 2025 16:58:32 -0300 Subject: [PATCH 3/6] Fix model name --- .../tests/integration_tests/test_anthropic_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/vertexai/tests/integration_tests/test_anthropic_cache.py b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py index c334a9e8..0c8906d9 100644 --- a/libs/vertexai/tests/integration_tests/test_anthropic_cache.py +++ b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py @@ -23,7 +23,7 @@ def test_anthropic_system_cache() -> None: ) message = HumanMessage(content="Hello! What can you do for me?") - response = model.invoke([context, message], model_name="claude-3-sonnet@20240229") + response = model.invoke([context, message], model_name="claude-3-5-sonnet-latest") assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert "usage_metadata" in response.additional_kwargs @@ -58,7 +58,7 @@ def test_anthropic_mixed_cache() -> None: ] ) - response = model.invoke([context, message], model_name="claude-3-sonnet@20240229") + response = model.invoke([context, message], model_name="claude-3-5-sonnet-latest") assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert "usage_metadata" in response.additional_kwargs @@ -100,7 +100,7 @@ def test_anthropic_conversation_cache() -> None: ), ] - response = model.invoke(messages, model_name="claude-3-sonnet@20240229") + response = model.invoke(messages, model_name="claude-3-5-sonnet-latest") assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert "peter" in response.content.lower() # Should remember the name From 1e921ae0c5d5f052fc9c25fb32bb4ae998f89e5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Clgesuellip=E2=80=9D?= <“lgesuellipinto@uade.edu.ar”> Date: Fri, 3 Jan 2025 17:50:31 -0300 Subject: [PATCH 4/6] Update model names in integration tests for Anthropic cache functionality --- .../integration_tests/test_anthropic_cache.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/libs/vertexai/tests/integration_tests/test_anthropic_cache.py b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py index 0c8906d9..860b2dee 100644 --- a/libs/vertexai/tests/integration_tests/test_anthropic_cache.py +++ b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py @@ -1,13 +1,14 @@ """Integration tests for Anthropic cache control functionality.""" - import os +import pytest from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate from langchain_google_vertexai.model_garden import ChatAnthropicVertex +@pytest.mark.extended def test_anthropic_system_cache() -> None: """Test chat with system message having cache control.""" project = os.environ["PROJECT_ID"] @@ -23,13 +24,16 @@ def test_anthropic_system_cache() -> None: ) message = HumanMessage(content="Hello! What can you do for me?") - response = model.invoke([context, message], model_name="claude-3-5-sonnet-latest") + response = model.invoke( + [context, message], model_name="claude-3-5-sonnet-v2@20241022" + ) assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert "usage_metadata" in response.additional_kwargs assert "cache_creation_input_tokens" in response.additional_kwargs["usage_metadata"] +@pytest.mark.extended def test_anthropic_mixed_cache() -> None: """Test chat with different cache control types.""" project = os.environ["PROJECT_ID"] @@ -53,17 +57,20 @@ def test_anthropic_mixed_cache() -> None: { "type": "text", "text": "What's your name and what can you help me with?", - "cache_control": {"type": "semantic"}, + "cache_control": {"type": "ephemeral"}, } ] ) - response = model.invoke([context, message], model_name="claude-3-5-sonnet-latest") + response = model.invoke( + [context, message], model_name="claude-3-5-sonnet-v2@20241022" + ) assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert "usage_metadata" in response.additional_kwargs +@pytest.mark.extended def test_anthropic_conversation_cache() -> None: """Test chat conversation with cache control.""" project = os.environ["PROJECT_ID"] @@ -84,7 +91,7 @@ def test_anthropic_conversation_cache() -> None: { "type": "text", "text": "What's my name?", - "cache_control": {"type": "semantic"}, + "cache_control": {"type": "ephemeral"}, } ] ), @@ -94,18 +101,19 @@ def test_anthropic_conversation_cache() -> None: { "type": "text", "text": "Can you repeat my name?", - "cache_control": {"type": "semantic"}, + "cache_control": {"type": "ephemeral"}, } ] ), ] - response = model.invoke(messages, model_name="claude-3-5-sonnet-latest") + response = model.invoke(messages, model_name="claude-3-5-sonnet-v2@20241022") assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert "peter" in response.content.lower() # Should remember the name +@pytest.mark.extended def test_anthropic_chat_template_cache() -> None: """Test chat template with structured content and cache control.""" project = os.environ["PROJECT_ID"] From 4e5521362177d1fbf82ca1d47f8ac4be32d48ed3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Clgesuellip=E2=80=9D?= <“lgesuellipinto@uade.edu.ar”> Date: Sun, 5 Jan 2025 23:35:03 -0300 Subject: [PATCH 5/6] Fix typing --- libs/vertexai/tests/integration_tests/test_anthropic_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/vertexai/tests/integration_tests/test_anthropic_cache.py b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py index 860b2dee..9399331b 100644 --- a/libs/vertexai/tests/integration_tests/test_anthropic_cache.py +++ b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py @@ -1,5 +1,6 @@ """Integration tests for Anthropic cache control functionality.""" import os +from typing import Dict, List, Union import pytest from langchain_core.messages import AIMessage, HumanMessage, SystemMessage @@ -123,7 +124,7 @@ def test_anthropic_chat_template_cache() -> None: location=location, ) - content: list[dict[str, str | dict[str, str]] | str] = [ + content: List[Union[Dict[str, Union[str, Dict[str, str]]], str]] = [ { "text": "You are a helpful assistant. Be concise and clear.", "type": "text", From 7f809764f689cfa02e23a66d1279232de57097cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Clgesuellip=E2=80=9D?= <“lgesuellipinto@uade.edu.ar”> Date: Mon, 6 Jan 2025 17:31:55 -0300 Subject: [PATCH 6/6] Code more readable --- .../langchain_google_vertexai/_anthropic_utils.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py index d32a7005..af2b10cb 100644 --- a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py @@ -68,13 +68,9 @@ def _get_cache_control(message: BaseMessage) -> Optional[Dict[str, Any]]: ) -def _format_text_content( - text: str, cache_control: Optional[Dict[str, Any]] = None -) -> Dict[str, Union[str, Dict[str, Any]]]: - """Format text content with optional cache control.""" +def _format_text_content(text: str) -> Dict[str, Union[str, Dict[str, Any]]]: + """Format text content.""" content: Dict[str, Union[str, Dict[str, Any]]] = {"type": "text", "text": text} - if cache_control: - content["cache_control"] = cache_control return content @@ -92,9 +88,10 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage, SystemMess if isinstance(message.content, str): if not message.content.strip(): return None - content.append( - _format_text_content(message.content, _get_cache_control(message)) - ) + message_dict = _format_text_content(message.content) + if cache_control := _get_cache_control(message): + message_dict["cache_control"] = cache_control + content.append(message_dict) elif isinstance(message.content, list): for block in message.content: if isinstance(block, str):