diff --git a/python/instrumentation/openinference-instrumentation-langchain/pyproject.toml b/python/instrumentation/openinference-instrumentation-langchain/pyproject.toml index ddc9b3786..d069873ee 100644 --- a/python/instrumentation/openinference-instrumentation-langchain/pyproject.toml +++ b/python/instrumentation/openinference-instrumentation-langchain/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ [project.optional-dependencies] instruments = [ - "langchain_core >= 0.2.43", + "langchain_core >= 0.3.9", ] test = [ "langchain_core == 0.3.50", @@ -55,7 +55,7 @@ test = [ "vcrpy>=6.0.1", ] type-check = [ - "langchain_core == 0.2.43", + "langchain_core == 0.3.9", ] [project.entry-points.opentelemetry_instrumentor] diff --git a/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/_tracer.py b/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/_tracer.py index 3b72f3bf0..cdf220b9d 100644 --- a/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/_tracer.py +++ b/python/instrumentation/openinference-instrumentation-langchain/src/openinference/instrumentation/langchain/_tracer.py @@ -25,6 +25,7 @@ Optional, Sequence, Tuple, + TypedDict, TypeVar, Union, cast, @@ -35,6 +36,7 @@ import wrapt # type: ignore from langchain_core.messages import BaseMessage +from langchain_core.messages.ai import UsageMetadata from langchain_core.tracers import BaseTracer, LangChainTracer from langchain_core.tracers.schemas import Run from opentelemetry import context as context_api @@ -43,6 +45,7 @@ from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes from opentelemetry.trace import Span from opentelemetry.util.types import AttributeValue +from typing_extensions import NotRequired, TypeGuard from wrapt import ObjectProxy from openinference.instrumentation import get_attributes_from_context, safe_json_dumps @@ -832,6 +835,126 @@ def _model_name( return +class _RawAnthropicUsageWithCacheReadOrWrite(TypedDict): + # https://github.com/anthropics/anthropic-sdk-python/blob/2e2f663104c8926434088828c08fbdf202d6d6fd/src/anthropic/types/usage.py#L13 + input_tokens: int + output_tokens: int + cache_read_input_tokens: NotRequired[int] + cache_creation_input_tokens: NotRequired[int] + + +def _is_raw_anthropic_usage_with_cache_read_or_write( + obj: Mapping[str, Any], +) -> TypeGuard[_RawAnthropicUsageWithCacheReadOrWrite]: + return ( + "input_tokens" in obj + and "output_tokens" in obj + and isinstance(obj["input_tokens"], int) + and isinstance(obj["output_tokens"], int) + and ( + ("cache_read_input_tokens" in obj and isinstance(obj["cache_read_input_tokens"], int)) + or ( + "cache_creation_input_tokens" in obj + and isinstance(obj["cache_creation_input_tokens"], int) + ) + ) + ) + + +def _token_counts_from_raw_anthropic_usage_with_cache_read_or_write( + obj: _RawAnthropicUsageWithCacheReadOrWrite, +) -> Iterator[Tuple[str, int]]: + input_tokens = obj["input_tokens"] + output_tokens = obj["output_tokens"] + + cache_creation_input_tokens = 0 + cache_read_input_tokens = 0 + + if "cache_creation_input_tokens" in obj: + cache_creation_input_tokens = obj["cache_creation_input_tokens"] + if "cache_read_input_tokens" in obj: + cache_read_input_tokens = obj["cache_read_input_tokens"] + + prompt_tokens = input_tokens + cache_creation_input_tokens + cache_read_input_tokens + completion_tokens = output_tokens + + yield LLM_TOKEN_COUNT_PROMPT, prompt_tokens + yield LLM_TOKEN_COUNT_COMPLETION, completion_tokens + + if cache_creation_input_tokens: + yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE, cache_creation_input_tokens + if cache_read_input_tokens: + yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, cache_read_input_tokens + + +def _is_lc_usage_metadata(obj: Mapping[str, Any]) -> TypeGuard[UsageMetadata]: + return ( + "input_tokens" in obj + and "output_tokens" in obj + and "total_tokens" in obj + and isinstance(obj["input_tokens"], int) + and isinstance(obj["output_tokens"], int) + and isinstance(obj["total_tokens"], int) + ) + + +def _token_counts_from_lc_usage_metadata(obj: UsageMetadata) -> Iterator[Tuple[str, int]]: + input_tokens = obj["input_tokens"] + output_tokens = obj["output_tokens"] + total_tokens = obj["total_tokens"] + + input_audio = 0 + input_cache_creation = 0 + input_cache_read = 0 + output_audio = 0 + output_reasoning = 0 + + if "input_token_details" in obj: + input_token_details = obj["input_token_details"] + if "audio" in input_token_details: + input_audio = input_token_details["audio"] + if "cache_creation" in input_token_details: + input_cache_creation = input_token_details["cache_creation"] + if "cache_read" in input_token_details: + input_cache_read = input_token_details["cache_read"] + + if "output_token_details" in obj: + output_token_details = obj["output_token_details"] + if "audio" in output_token_details: + output_audio = output_token_details["audio"] + if "reasoning" in output_token_details: + output_reasoning = output_token_details["reasoning"] + + prompt_tokens = input_tokens + completion_tokens = output_tokens + + # heuristic adjustment for Bedrock Anthropic models with cache read or write + # https://github.com/Arize-ai/openinference/issues/2381 + if input_cache := input_cache_creation + input_cache_read: + if total_tokens == input_tokens + output_tokens + input_cache: + # for Bedrock Converse + prompt_tokens += input_cache + elif input_tokens < input_cache: + # for Bedrock InvokeModel + prompt_tokens += input_cache + total_tokens += input_cache + + yield LLM_TOKEN_COUNT_PROMPT, prompt_tokens + yield LLM_TOKEN_COUNT_COMPLETION, completion_tokens + yield LLM_TOKEN_COUNT_TOTAL, total_tokens + + if input_audio: + yield LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO, input_audio + if input_cache_creation: + yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE, input_cache_creation + if input_cache_read: + yield LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, input_cache_read + if output_audio: + yield LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO, output_audio + if output_reasoning: + yield LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING, output_reasoning + + @stop_on_exception def _token_counts(outputs: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, int]]: """Yields token count information if present.""" @@ -843,12 +966,12 @@ def _token_counts(outputs: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, i ) ): return + keys: Sequence[str] for attribute_name, keys in [ ( LLM_TOKEN_COUNT_PROMPT, ( "prompt_tokens", - "input_tokens", # Anthropic-specific key "prompt_token_count", # Gemini-specific key - https://ai.google.dev/gemini-api/docs/tokens?lang=python ), ), @@ -856,13 +979,10 @@ def _token_counts(outputs: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, i LLM_TOKEN_COUNT_COMPLETION, ( "completion_tokens", - "output_tokens", # Anthropic-specific key "candidates_token_count", # Gemini-specific key ), ), (LLM_TOKEN_COUNT_TOTAL, ("total_tokens", "total_token_count")), # Gemini-specific key - (LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, ("cache_read_input_tokens",)), # Antrhopic - (LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE, ("cache_creation_input_tokens",)), # Antrhopic ]: if (token_count := _get_first_value(token_usage, keys)) is not None: yield attribute_name, token_count @@ -895,39 +1015,11 @@ def _token_counts(outputs: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, i yield attribute_name, token_count # maps langchain_core.messages.ai.UsageMetadata object - for attribute_name, details_key_or_none, keys in [ - (LLM_TOKEN_COUNT_PROMPT, None, ("input_tokens",)), - (LLM_TOKEN_COUNT_COMPLETION, None, ("output_tokens",)), - ( - LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO, - "input_token_details", - ("audio",), - ), - ( - LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE, - "input_token_details", - ("cache_creation",), - ), - ( - LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, - "input_token_details", - ("cache_read",), - ), - ( - LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO, - "output_token_details", - ("audio",), - ), - ( - LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING, - "output_token_details", - ("reasoning",), - ), - ]: - details = token_usage.get(details_key_or_none) if details_key_or_none else token_usage - if details is not None: - if (token_count := _get_first_value(details, keys)) is not None: - yield attribute_name, token_count + if _is_lc_usage_metadata(token_usage): + yield from _token_counts_from_lc_usage_metadata(token_usage) + + if _is_raw_anthropic_usage_with_cache_read_or_write(token_usage): + yield from _token_counts_from_raw_anthropic_usage_with_cache_read_or_write(token_usage) def _parse_token_usage_for_vertexai( diff --git a/python/instrumentation/openinference-instrumentation-langchain/tests/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-langchain/tests/test_instrumentor.py index d2da135a7..35c86668c 100644 --- a/python/instrumentation/openinference-instrumentation-langchain/tests/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-langchain/tests/test_instrumentor.py @@ -590,7 +590,7 @@ def test_anthropic_token_counts( span = spans[0] llm_attributes = dict(span.attributes or {}) assert llm_attributes.pop(OPENINFERENCE_SPAN_KIND, None) == LLM.value - assert llm_attributes.pop(LLM_TOKEN_COUNT_PROMPT, None) == 22 + assert llm_attributes.pop(LLM_TOKEN_COUNT_PROMPT, None) == 33 assert llm_attributes.pop(LLM_TOKEN_COUNT_COMPLETION, None) == 5 assert llm_attributes.pop(LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE) == 2 assert llm_attributes.pop(LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ) == 9 diff --git a/python/instrumentation/openinference-instrumentation-langchain/tests/test_token_counts.py b/python/instrumentation/openinference-instrumentation-langchain/tests/test_token_counts.py new file mode 100644 index 000000000..c99d3b67e --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-langchain/tests/test_token_counts.py @@ -0,0 +1,247 @@ +from typing import Any + +import pytest + +from openinference.instrumentation.langchain._tracer import ( + _is_lc_usage_metadata, + _is_raw_anthropic_usage_with_cache_read_or_write, + _token_counts_from_lc_usage_metadata, + _token_counts_from_raw_anthropic_usage_with_cache_read_or_write, +) +from openinference.semconv.trace import SpanAttributes + + +@pytest.mark.parametrize( + "usage_metadata,expected,is_valid", + [ + pytest.param( + { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + }, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 10, + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 20, + SpanAttributes.LLM_TOKEN_COUNT_TOTAL: 30, + }, + True, + id="basic", + ), + pytest.param( + { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 35, + "input_token_details": {"cache_creation": 3, "cache_read": 2}, + }, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 15, # 10 + 3 + 2 + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 20, + SpanAttributes.LLM_TOKEN_COUNT_TOTAL: 35, + SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE: 3, + SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ: 2, + }, + True, + id="bedrock_converse", + ), + pytest.param( + { + "input_tokens": 5, + "output_tokens": 10, + "total_tokens": 15, + "input_token_details": {"cache_creation": 20, "cache_read": 10}, + }, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 35, # 5 + 20 + 10 + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 10, + SpanAttributes.LLM_TOKEN_COUNT_TOTAL: 45, # adjusted + SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE: 20, + SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ: 10, + }, + True, + id="bedrock_invokemodel", + ), + pytest.param( + { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + "input_token_details": {"audio": 5}, + "output_token_details": {"reasoning": 3}, + }, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 10, + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 20, + SpanAttributes.LLM_TOKEN_COUNT_TOTAL: 30, + SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO: 5, + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING: 3, + }, + True, + id="non_cache_details", + ), + pytest.param( + { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + }, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 0, + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 0, + SpanAttributes.LLM_TOKEN_COUNT_TOTAL: 0, + }, + True, + id="zeros", + ), + pytest.param( + { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + "input_token_details": {"cache_creation": 0, "cache_read": 0}, + }, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 10, + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 20, + SpanAttributes.LLM_TOKEN_COUNT_TOTAL: 30, + }, + True, + id="zero_cache_no_details", + ), + pytest.param( + { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + "input_token_details": {}, + "output_token_details": {}, + }, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 10, + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 20, + SpanAttributes.LLM_TOKEN_COUNT_TOTAL: 30, + }, + True, + id="empty_details", + ), + pytest.param( + {"input_tokens": 10, "output_tokens": 20}, + {}, + False, + id="missing_total", + ), + pytest.param( + {"input_tokens": "10", "output_tokens": 20, "total_tokens": 30}, + {}, + False, + id="wrong_type", + ), + pytest.param( + {"output_tokens": 20, "total_tokens": 30}, + {}, + False, + id="missing_field", + ), + ], +) +def test_token_counts_from_lc_usage_metadata( + usage_metadata: dict[str, Any], expected: dict[str, int], is_valid: bool +) -> None: + """Test _token_counts_from_lc_usage_metadata with various inputs.""" + assert _is_lc_usage_metadata(usage_metadata) == is_valid + if _is_lc_usage_metadata(usage_metadata): + result = dict(_token_counts_from_lc_usage_metadata(usage_metadata)) + assert result == expected + + +@pytest.mark.parametrize( + "usage,expected,is_valid", + [ + pytest.param( + { + "input_tokens": 10, + "output_tokens": 20, + "cache_creation_input_tokens": 5, + "cache_read_input_tokens": 3, + }, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 18, # 10 + 5 + 3 + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 20, + SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE: 5, + SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ: 3, + }, + True, + id="both", + ), + pytest.param( + {"input_tokens": 15, "output_tokens": 25, "cache_creation_input_tokens": 8}, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 23, # 15 + 8 + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 25, + SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE: 8, + }, + True, + id="write", + ), + pytest.param( + {"input_tokens": 12, "output_tokens": 18, "cache_read_input_tokens": 6}, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 18, # 12 + 6 + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 18, + SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ: 6, + }, + True, + id="read", + ), + pytest.param( + {"input_tokens": 10, "output_tokens": 20, "cache_creation_input_tokens": 0}, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 10, + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 20, + }, + True, + id="zero_cache_write", + ), + pytest.param( + { + "input_tokens": 10, + "output_tokens": 20, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + }, + { + SpanAttributes.LLM_TOKEN_COUNT_PROMPT: 10, + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION: 20, + }, + True, + id="zero_both_cache", + ), + pytest.param( + {"input_tokens": 10, "output_tokens": 20}, + {}, + False, + id="no_cache", + ), + pytest.param( + {"input_tokens": "10", "output_tokens": 20, "cache_read_input_tokens": 5}, + {}, + False, + id="wrong_type", + ), + pytest.param( + {"output_tokens": 20, "cache_read_input_tokens": 5}, + {}, + False, + id="missing_field", + ), + ], +) +def test_token_counts_from_raw_anthropic_usage( + usage: dict[str, Any], expected: dict[str, int], is_valid: bool +) -> None: + """Test Anthropic usage with cache.""" + assert _is_raw_anthropic_usage_with_cache_read_or_write(usage) == is_valid + if _is_raw_anthropic_usage_with_cache_read_or_write(usage): + result = dict(_token_counts_from_raw_anthropic_usage_with_cache_read_or_write(usage)) + assert result == expected diff --git a/python/tox.ini b/python/tox.ini index ab56dd791..a0980b900 100644 --- a/python/tox.ini +++ b/python/tox.ini @@ -162,6 +162,7 @@ commands_pre = openllmetry-latest: uv pip install -U opentelemetry-instrumentation-openai openlit: uv pip install --reinstall {toxinidir}/instrumentation/openinference-instrumentation-openlit[test] openlit-latest: uv pip install -U openlit + uv pip list -v commands = ruff: ruff format .