Skip to content

Commit 9f21f0f

Browse files
authored
fix(langchain): anthropic cache token count (#2414)
1 parent 7087022 commit 9f21f0f

File tree

5 files changed

+380
-40
lines changed

5 files changed

+380
-40
lines changed

python/instrumentation/openinference-instrumentation-langchain/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ dependencies = [
3636

3737
[project.optional-dependencies]
3838
instruments = [
39-
"langchain_core >= 0.2.43",
39+
"langchain_core >= 0.3.9",
4040
]
4141
test = [
4242
"langchain_core == 0.3.50",
@@ -55,7 +55,7 @@ test = [
5555
"vcrpy>=6.0.1",
5656
]
5757
type-check = [
58-
"langchain_core == 0.2.43",
58+
"langchain_core == 0.3.9",
5959
]
6060

6161
[project.entry-points.opentelemetry_instrumentor]

python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/_tracer.py

Lines changed: 129 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Optional,
2626
Sequence,
2727
Tuple,
28+
TypedDict,
2829
TypeVar,
2930
Union,
3031
cast,
@@ -35,6 +36,7 @@
3536

3637
import wrapt # type: ignore
3738
from langchain_core.messages import BaseMessage
39+
from langchain_core.messages.ai import UsageMetadata
3840
from langchain_core.tracers import BaseTracer, LangChainTracer
3941
from langchain_core.tracers.schemas import Run
4042
from opentelemetry import context as context_api
@@ -43,6 +45,7 @@
4345
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
4446
from opentelemetry.trace import Span
4547
from opentelemetry.util.types import AttributeValue
48+
from typing_extensions import NotRequired, TypeGuard
4649
from wrapt import ObjectProxy
4750

4851
from openinference.instrumentation import get_attributes_from_context, safe_json_dumps
@@ -832,6 +835,126 @@ def _model_name(
832835
return
833836

834837

838+
class _RawAnthropicUsageWithCacheReadOrWrite(TypedDict):
839+
# https://github.com/anthropics/anthropic-sdk-python/blob/2e2f663104c8926434088828c08fbdf202d6d6fd/src/anthropic/types/usage.py#L13
840+
input_tokens: int
841+
output_tokens: int
842+
cache_read_input_tokens: NotRequired[int]
843+
cache_creation_input_tokens: NotRequired[int]
844+
845+
846+
def _is_raw_anthropic_usage_with_cache_read_or_write(
847+
obj: Mapping[str, Any],
848+
) -> TypeGuard[_RawAnthropicUsageWithCacheReadOrWrite]:
849+
return (
850+
"input_tokens" in obj
851+
and "output_tokens" in obj
852+
and isinstance(obj["input_tokens"], int)
853+
and isinstance(obj["output_tokens"], int)
854+
and (
855+
("cache_read_input_tokens" in obj and isinstance(obj["cache_read_input_tokens"], int))
856+
or (
857+
"cache_creation_input_tokens" in obj
858+
and isinstance(obj["cache_creation_input_tokens"], int)
859+
)
860+
)
861+
)
862+
863+
864+
def _token_counts_from_raw_anthropic_usage_with_cache_read_or_write(
865+
obj: _RawAnthropicUsageWithCacheReadOrWrite,
866+
) -> Iterator[Tuple[str, int]]:
867+
input_tokens = obj["input_tokens"]
868+
output_tokens = obj["output_tokens"]
869+
870+
cache_creation_input_tokens = 0
871+
cache_read_input_tokens = 0
872+
873+
if "cache_creation_input_tokens" in obj:
874+
cache_creation_input_tokens = obj["cache_creation_input_tokens"]
875+
if "cache_read_input_tokens" in obj:
876+
cache_read_input_tokens = obj["cache_read_input_tokens"]
877+
878+
prompt_tokens = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
879+
completion_tokens = output_tokens
880+
881+
yield LLM_TOKEN_COUNT_PROMPT, prompt_tokens
882+
yield LLM_TOKEN_COUNT_COMPLETION, completion_tokens
883+
884+
if cache_creation_input_tokens:
885+
yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE, cache_creation_input_tokens
886+
if cache_read_input_tokens:
887+
yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, cache_read_input_tokens
888+
889+
890+
def _is_lc_usage_metadata(obj: Mapping[str, Any]) -> TypeGuard[UsageMetadata]:
891+
return (
892+
"input_tokens" in obj
893+
and "output_tokens" in obj
894+
and "total_tokens" in obj
895+
and isinstance(obj["input_tokens"], int)
896+
and isinstance(obj["output_tokens"], int)
897+
and isinstance(obj["total_tokens"], int)
898+
)
899+
900+
901+
def _token_counts_from_lc_usage_metadata(obj: UsageMetadata) -> Iterator[Tuple[str, int]]:
902+
input_tokens = obj["input_tokens"]
903+
output_tokens = obj["output_tokens"]
904+
total_tokens = obj["total_tokens"]
905+
906+
input_audio = 0
907+
input_cache_creation = 0
908+
input_cache_read = 0
909+
output_audio = 0
910+
output_reasoning = 0
911+
912+
if "input_token_details" in obj:
913+
input_token_details = obj["input_token_details"]
914+
if "audio" in input_token_details:
915+
input_audio = input_token_details["audio"]
916+
if "cache_creation" in input_token_details:
917+
input_cache_creation = input_token_details["cache_creation"]
918+
if "cache_read" in input_token_details:
919+
input_cache_read = input_token_details["cache_read"]
920+
921+
if "output_token_details" in obj:
922+
output_token_details = obj["output_token_details"]
923+
if "audio" in output_token_details:
924+
output_audio = output_token_details["audio"]
925+
if "reasoning" in output_token_details:
926+
output_reasoning = output_token_details["reasoning"]
927+
928+
prompt_tokens = input_tokens
929+
completion_tokens = output_tokens
930+
931+
# heuristic adjustment for Bedrock Anthropic models with cache read or write
932+
# https://github.com/Arize-ai/openinference/issues/2381
933+
if input_cache := input_cache_creation + input_cache_read:
934+
if total_tokens == input_tokens + output_tokens + input_cache:
935+
# for Bedrock Converse
936+
prompt_tokens += input_cache
937+
elif input_tokens < input_cache:
938+
# for Bedrock InvokeModel
939+
prompt_tokens += input_cache
940+
total_tokens += input_cache
941+
942+
yield LLM_TOKEN_COUNT_PROMPT, prompt_tokens
943+
yield LLM_TOKEN_COUNT_COMPLETION, completion_tokens
944+
yield LLM_TOKEN_COUNT_TOTAL, total_tokens
945+
946+
if input_audio:
947+
yield LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO, input_audio
948+
if input_cache_creation:
949+
yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE, input_cache_creation
950+
if input_cache_read:
951+
yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, input_cache_read
952+
if output_audio:
953+
yield LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO, output_audio
954+
if output_reasoning:
955+
yield LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING, output_reasoning
956+
957+
835958
@stop_on_exception
836959
def _token_counts(outputs: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, int]]:
837960
"""Yields token count information if present."""
@@ -843,26 +966,23 @@ def _token_counts(outputs: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, i
843966
)
844967
):
845968
return
969+
keys: Sequence[str]
846970
for attribute_name, keys in [
847971
(
848972
LLM_TOKEN_COUNT_PROMPT,
849973
(
850974
"prompt_tokens",
851-
"input_tokens", # Anthropic-specific key
852975
"prompt_token_count", # Gemini-specific key - https://ai.google.dev/gemini-api/docs/tokens?lang=python
853976
),
854977
),
855978
(
856979
LLM_TOKEN_COUNT_COMPLETION,
857980
(
858981
"completion_tokens",
859-
"output_tokens", # Anthropic-specific key
860982
"candidates_token_count", # Gemini-specific key
861983
),
862984
),
863985
(LLM_TOKEN_COUNT_TOTAL, ("total_tokens", "total_token_count")), # Gemini-specific key
864-
(LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, ("cache_read_input_tokens",)), # Antrhopic
865-
(LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE, ("cache_creation_input_tokens",)), # Antrhopic
866986
]:
867987
if (token_count := _get_first_value(token_usage, keys)) is not None:
868988
yield attribute_name, token_count
@@ -895,39 +1015,11 @@ def _token_counts(outputs: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, i
8951015
yield attribute_name, token_count
8961016

8971017
# maps langchain_core.messages.ai.UsageMetadata object
898-
for attribute_name, details_key_or_none, keys in [
899-
(LLM_TOKEN_COUNT_PROMPT, None, ("input_tokens",)),
900-
(LLM_TOKEN_COUNT_COMPLETION, None, ("output_tokens",)),
901-
(
902-
LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO,
903-
"input_token_details",
904-
("audio",),
905-
),
906-
(
907-
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE,
908-
"input_token_details",
909-
("cache_creation",),
910-
),
911-
(
912-
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ,
913-
"input_token_details",
914-
("cache_read",),
915-
),
916-
(
917-
LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO,
918-
"output_token_details",
919-
("audio",),
920-
),
921-
(
922-
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING,
923-
"output_token_details",
924-
("reasoning",),
925-
),
926-
]:
927-
details = token_usage.get(details_key_or_none) if details_key_or_none else token_usage
928-
if details is not None:
929-
if (token_count := _get_first_value(details, keys)) is not None:
930-
yield attribute_name, token_count
1018+
if _is_lc_usage_metadata(token_usage):
1019+
yield from _token_counts_from_lc_usage_metadata(token_usage)
1020+
1021+
if _is_raw_anthropic_usage_with_cache_read_or_write(token_usage):
1022+
yield from _token_counts_from_raw_anthropic_usage_with_cache_read_or_write(token_usage)
9311023

9321024

9331025
def _parse_token_usage_for_vertexai(

python/instrumentation/openinference-instrumentation-langchain/tests/test_instrumentor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def test_anthropic_token_counts(
590590
span = spans[0]
591591
llm_attributes = dict(span.attributes or {})
592592
assert llm_attributes.pop(OPENINFERENCE_SPAN_KIND, None) == LLM.value
593-
assert llm_attributes.pop(LLM_TOKEN_COUNT_PROMPT, None) == 22
593+
assert llm_attributes.pop(LLM_TOKEN_COUNT_PROMPT, None) == 33
594594
assert llm_attributes.pop(LLM_TOKEN_COUNT_COMPLETION, None) == 5
595595
assert llm_attributes.pop(LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE) == 2
596596
assert llm_attributes.pop(LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ) == 9

0 commit comments

Comments
 (0)