Skip to content

Commit

Permalink
vertexai: Add ChatAnthropicVertex Text Caching Support (#672)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgesuellip authored Jan 7, 2025
1 parent 1755705 commit 1f6d2d7
Show file tree
Hide file tree
Showing 5 changed files with 478 additions and 27 deletions.
58 changes: 40 additions & 18 deletions libs/vertexai/langchain_google_vertexai/_anthropic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,39 @@ def _format_image(image_url: str) -> Dict:
}


def _format_message_anthropic(message: Union[HumanMessage, AIMessage]):
role = _message_type_lookups[message.type]
def _get_cache_control(message: BaseMessage) -> Optional[Dict[str, Any]]:
"""Extract cache control from message's additional_kwargs or content block."""
return (
message.additional_kwargs.get("cache_control")
if isinstance(message.additional_kwargs, dict)
else None
)


def _format_text_content(text: str) -> Dict[str, Union[str, Dict[str, Any]]]:
"""Format text content."""
content: Dict[str, Union[str, Dict[str, Any]]] = {"type": "text", "text": text}
return content


def _format_message_anthropic(message: Union[HumanMessage, AIMessage, SystemMessage]):
"""Format a message for Anthropic API.
Args:
message: The message to format. Can be HumanMessage, AIMessage, or SystemMessage.
Returns:
A dictionary with the formatted message, or None if the message is empty.
""" # noqa: E501
content: List[Dict[str, Any]] = []

if isinstance(message.content, str):
if not message.content.strip():
return None
content.append({"type": "text", "text": message.content})
message_dict = _format_text_content(message.content)
if cache_control := _get_cache_control(message):
message_dict["cache_control"] = cache_control
content.append(message_dict)
elif isinstance(message.content, list):
for block in message.content:
if isinstance(block, str):
Expand All @@ -75,9 +100,8 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage]):
# https://github.com/anthropics/anthropic-sdk-python/issues/461
if not block.strip():
continue
content.append({"type": "text", "text": block})

if isinstance(block, dict):
content.append(_format_text_content(block))
elif isinstance(block, dict):
if "type" not in block:
raise ValueError("Dict content block must have a type key")

Expand Down Expand Up @@ -113,46 +137,44 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage]):
if not is_unique:
continue

# all other block types
content.append(block)
else:
raise ValueError("Message should be a str, list of str or list of dicts")

# adding all tool calls
if isinstance(message, AIMessage) and message.tool_calls:
for tc in message.tool_calls:
tu = cast(Dict[str, Any], _lc_tool_call_to_anthropic_tool_use_block(tc))
content.append(tu)

return {"role": role, "content": content}
if message.type == "system":
return content
else:
return {"role": _message_type_lookups[message.type], "content": content}


def _format_messages_anthropic(
messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict]]:
) -> Tuple[Optional[Dict[str, Any]], List[Dict]]:
"""Formats messages for anthropic."""
system_message: Optional[str] = None
system_messages: Optional[Dict[str, Any]] = None
formatted_messages: List[Dict] = []

merged_messages = _merge_messages(messages)
for i, message in enumerate(merged_messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
if not isinstance(message.content, str):
raise ValueError(
"System message must be a string, "
f"instead was: {type(message.content)}"
)
system_message = message.content
fm = _format_message_anthropic(message)
if fm:
system_messages = fm
continue

fm = _format_message_anthropic(message)
if not fm:
continue
formatted_messages.append(fm)

return system_message, formatted_messages
return system_messages, formatted_messages


class AnthropicTool(TypedDict):
Expand Down
20 changes: 15 additions & 5 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
AIMessage,
BaseMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
Expand Down Expand Up @@ -61,6 +62,13 @@
from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon


class CacheUsageMetadata(UsageMetadata):
cache_creation_input_tokens: Optional[int]
"""The number of input tokens used to create the cache entry."""
cache_read_input_tokens: Optional[int]
"""The number of input tokens read from the cache."""


class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""

Expand Down Expand Up @@ -225,11 +233,13 @@ def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
else:
msg = AIMessage(content=content)
# Collect token usage
msg.usage_metadata = {
"input_tokens": data.usage.input_tokens,
"output_tokens": data.usage.output_tokens,
"total_tokens": data.usage.input_tokens + data.usage.output_tokens,
}
msg.usage_metadata = CacheUsageMetadata(
input_tokens=data.usage.input_tokens,
output_tokens=data.usage.output_tokens,
total_tokens=data.usage.input_tokens + data.usage.output_tokens,
cache_creation_input_tokens=data.usage.cache_creation_input_tokens,
cache_read_input_tokens=data.usage.cache_read_input_tokens,
)
return ChatResult(
generations=[ChatGeneration(message=msg)],
llm_output=llm_output,
Expand Down
147 changes: 147 additions & 0 deletions libs/vertexai/tests/integration_tests/test_anthropic_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Integration tests for Anthropic cache control functionality."""
import os
from typing import Dict, List, Union

import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate

from langchain_google_vertexai.model_garden import ChatAnthropicVertex


@pytest.mark.extended
def test_anthropic_system_cache() -> None:
"""Test chat with system message having cache control."""
project = os.environ["PROJECT_ID"]
location = "us-central1"
model = ChatAnthropicVertex(
project=project,
location=location,
)

context = SystemMessage(
content="You are my personal assistant. Be helpful and concise.",
additional_kwargs={"cache_control": {"type": "ephemeral"}},
)
message = HumanMessage(content="Hello! What can you do for me?")

response = model.invoke(
[context, message], model_name="claude-3-5-sonnet-v2@20241022"
)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "usage_metadata" in response.additional_kwargs
assert "cache_creation_input_tokens" in response.additional_kwargs["usage_metadata"]


@pytest.mark.extended
def test_anthropic_mixed_cache() -> None:
"""Test chat with different cache control types."""
project = os.environ["PROJECT_ID"]
location = "us-central1"
model = ChatAnthropicVertex(
project=project,
location=location,
)

context = SystemMessage(
content=[
{
"type": "text",
"text": "You are my personal assistant.",
"cache_control": {"type": "ephemeral"},
}
]
)
message = HumanMessage(
content=[
{
"type": "text",
"text": "What's your name and what can you help me with?",
"cache_control": {"type": "ephemeral"},
}
]
)

response = model.invoke(
[context, message], model_name="claude-3-5-sonnet-v2@20241022"
)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "usage_metadata" in response.additional_kwargs


@pytest.mark.extended
def test_anthropic_conversation_cache() -> None:
"""Test chat conversation with cache control."""
project = os.environ["PROJECT_ID"]
location = "us-central1"
model = ChatAnthropicVertex(
project=project,
location=location,
)

context = SystemMessage(
content="You are my personal assistant. My name is Peter.",
additional_kwargs={"cache_control": {"type": "ephemeral"}},
)
messages = [
context,
HumanMessage(
content=[
{
"type": "text",
"text": "What's my name?",
"cache_control": {"type": "ephemeral"},
}
]
),
AIMessage(content="Your name is Peter."),
HumanMessage(
content=[
{
"type": "text",
"text": "Can you repeat my name?",
"cache_control": {"type": "ephemeral"},
}
]
),
]

response = model.invoke(messages, model_name="claude-3-5-sonnet-v2@20241022")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "peter" in response.content.lower() # Should remember the name


@pytest.mark.extended
def test_anthropic_chat_template_cache() -> None:
"""Test chat template with structured content and cache control."""
project = os.environ["PROJECT_ID"]
location = "us-central1"
model = ChatAnthropicVertex(
project=project,
location=location,
)

content: List[Union[Dict[str, Union[str, Dict[str, str]]], str]] = [
{
"text": "You are a helpful assistant. Be concise and clear.",
"type": "text",
"cache_control": {"type": "ephemeral"},
}
]

prompt = ChatPromptTemplate.from_messages(
[SystemMessage(content=content), ("human", "{input}")]
)

chain = prompt | model

response = chain.invoke(
{"input": "What's the capital of France?"},
)

assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "Paris" in response.content
Loading

0 comments on commit 1f6d2d7

Please sign in to comment.