Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vertexai: Add ChatAnthropicVertex Text Caching Support #672

Merged
merged 6 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions libs/vertexai/langchain_google_vertexai/_anthropic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,39 @@ def _format_image(image_url: str) -> Dict:
}


def _format_message_anthropic(message: Union[HumanMessage, AIMessage]):
role = _message_type_lookups[message.type]
def _get_cache_control(message: BaseMessage) -> Optional[Dict[str, Any]]:
"""Extract cache control from message's additional_kwargs or content block."""
return (
message.additional_kwargs.get("cache_control")
if isinstance(message.additional_kwargs, dict)
else None
)


def _format_text_content(text: str) -> Dict[str, Union[str, Dict[str, Any]]]:
"""Format text content."""
content: Dict[str, Union[str, Dict[str, Any]]] = {"type": "text", "text": text}
return content


def _format_message_anthropic(message: Union[HumanMessage, AIMessage, SystemMessage]):
"""Format a message for Anthropic API.

Args:
message: The message to format. Can be HumanMessage, AIMessage, or SystemMessage.

Returns:
A dictionary with the formatted message, or None if the message is empty.
""" # noqa: E501
content: List[Dict[str, Any]] = []

if isinstance(message.content, str):
if not message.content.strip():
return None
content.append({"type": "text", "text": message.content})
message_dict = _format_text_content(message.content)
if cache_control := _get_cache_control(message):
message_dict["cache_control"] = cache_control
content.append(message_dict)
elif isinstance(message.content, list):
for block in message.content:
if isinstance(block, str):
Expand All @@ -75,9 +100,8 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage]):
# https://github.com/anthropics/anthropic-sdk-python/issues/461
if not block.strip():
continue
content.append({"type": "text", "text": block})

if isinstance(block, dict):
content.append(_format_text_content(block))
elif isinstance(block, dict):
if "type" not in block:
raise ValueError("Dict content block must have a type key")

Expand Down Expand Up @@ -113,46 +137,44 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage]):
if not is_unique:
continue

# all other block types
content.append(block)
else:
raise ValueError("Message should be a str, list of str or list of dicts")

# adding all tool calls
if isinstance(message, AIMessage) and message.tool_calls:
for tc in message.tool_calls:
tu = cast(Dict[str, Any], _lc_tool_call_to_anthropic_tool_use_block(tc))
content.append(tu)

return {"role": role, "content": content}
if message.type == "system":
return content
else:
return {"role": _message_type_lookups[message.type], "content": content}


def _format_messages_anthropic(
messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict]]:
) -> Tuple[Optional[Dict[str, Any]], List[Dict]]:
"""Formats messages for anthropic."""
system_message: Optional[str] = None
system_messages: Optional[Dict[str, Any]] = None
formatted_messages: List[Dict] = []

merged_messages = _merge_messages(messages)
for i, message in enumerate(merged_messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
if not isinstance(message.content, str):
raise ValueError(
"System message must be a string, "
f"instead was: {type(message.content)}"
)
system_message = message.content
fm = _format_message_anthropic(message)
if fm:
system_messages = fm
continue

fm = _format_message_anthropic(message)
if not fm:
continue
formatted_messages.append(fm)

return system_message, formatted_messages
return system_messages, formatted_messages


class AnthropicTool(TypedDict):
Expand Down
20 changes: 15 additions & 5 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
AIMessage,
BaseMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
Expand Down Expand Up @@ -61,6 +62,13 @@
from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon


class CacheUsageMetadata(UsageMetadata):
cache_creation_input_tokens: Optional[int]
"""The number of input tokens used to create the cache entry."""
cache_read_input_tokens: Optional[int]
"""The number of input tokens read from the cache."""


class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""

Expand Down Expand Up @@ -225,11 +233,13 @@ def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
else:
msg = AIMessage(content=content)
# Collect token usage
msg.usage_metadata = {
"input_tokens": data.usage.input_tokens,
"output_tokens": data.usage.output_tokens,
"total_tokens": data.usage.input_tokens + data.usage.output_tokens,
}
msg.usage_metadata = CacheUsageMetadata(
input_tokens=data.usage.input_tokens,
output_tokens=data.usage.output_tokens,
total_tokens=data.usage.input_tokens + data.usage.output_tokens,
cache_creation_input_tokens=data.usage.cache_creation_input_tokens,
cache_read_input_tokens=data.usage.cache_read_input_tokens,
)
return ChatResult(
generations=[ChatGeneration(message=msg)],
llm_output=llm_output,
Expand Down
147 changes: 147 additions & 0 deletions libs/vertexai/tests/integration_tests/test_anthropic_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Integration tests for Anthropic cache control functionality."""
import os
from typing import Dict, List, Union

import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate

from langchain_google_vertexai.model_garden import ChatAnthropicVertex


@pytest.mark.extended
def test_anthropic_system_cache() -> None:
"""Test chat with system message having cache control."""
project = os.environ["PROJECT_ID"]
location = "us-central1"
model = ChatAnthropicVertex(
project=project,
location=location,
)

context = SystemMessage(
content="You are my personal assistant. Be helpful and concise.",
additional_kwargs={"cache_control": {"type": "ephemeral"}},
)
message = HumanMessage(content="Hello! What can you do for me?")

response = model.invoke(
[context, message], model_name="claude-3-5-sonnet-v2@20241022"
)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "usage_metadata" in response.additional_kwargs
assert "cache_creation_input_tokens" in response.additional_kwargs["usage_metadata"]


@pytest.mark.extended
def test_anthropic_mixed_cache() -> None:
"""Test chat with different cache control types."""
project = os.environ["PROJECT_ID"]
location = "us-central1"
model = ChatAnthropicVertex(
project=project,
location=location,
)

context = SystemMessage(
content=[
{
"type": "text",
"text": "You are my personal assistant.",
"cache_control": {"type": "ephemeral"},
}
]
)
message = HumanMessage(
content=[
{
"type": "text",
"text": "What's your name and what can you help me with?",
"cache_control": {"type": "ephemeral"},
}
]
)

response = model.invoke(
[context, message], model_name="claude-3-5-sonnet-v2@20241022"
)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "usage_metadata" in response.additional_kwargs


@pytest.mark.extended
def test_anthropic_conversation_cache() -> None:
"""Test chat conversation with cache control."""
project = os.environ["PROJECT_ID"]
location = "us-central1"
model = ChatAnthropicVertex(
project=project,
location=location,
)

context = SystemMessage(
content="You are my personal assistant. My name is Peter.",
additional_kwargs={"cache_control": {"type": "ephemeral"}},
)
messages = [
context,
HumanMessage(
content=[
{
"type": "text",
"text": "What's my name?",
"cache_control": {"type": "ephemeral"},
}
]
),
AIMessage(content="Your name is Peter."),
HumanMessage(
content=[
{
"type": "text",
"text": "Can you repeat my name?",
"cache_control": {"type": "ephemeral"},
}
]
),
]

response = model.invoke(messages, model_name="claude-3-5-sonnet-v2@20241022")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "peter" in response.content.lower() # Should remember the name


@pytest.mark.extended
def test_anthropic_chat_template_cache() -> None:
"""Test chat template with structured content and cache control."""
project = os.environ["PROJECT_ID"]
location = "us-central1"
model = ChatAnthropicVertex(
project=project,
location=location,
)

content: List[Union[Dict[str, Union[str, Dict[str, str]]], str]] = [
{
"text": "You are a helpful assistant. Be concise and clear.",
"type": "text",
"cache_control": {"type": "ephemeral"},
}
]

prompt = ChatPromptTemplate.from_messages(
[SystemMessage(content=content), ("human", "{input}")]
)

chain = prompt | model

response = chain.invoke(
{"input": "What's the capital of France?"},
)

assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "Paris" in response.content
Loading
Loading