diff --git a/tests/test_client_multimodal_types.py b/tests/test_client_multimodal_types.py index d51c38262..c2b715a9a 100644 --- a/tests/test_client_multimodal_types.py +++ b/tests/test_client_multimodal_types.py @@ -1,4 +1,7 @@ +import base64 + import pytest +import numpy as np from types import SimpleNamespace from verifiers.clients.openai_chat_completions_client import OpenAIChatCompletionsClient @@ -237,3 +240,115 @@ async def test_anthropic_tool_call_round_trips_thinking_blocks(): {"type": "thinking", "thinking": "hidden chain", "signature": "sig_1"}, {"type": "tool_use", "id": "call_1", "name": "lookup", "input": {"q": "x"}}, ] + + +class _CaptureAnthropicMessages: + def __init__(self) -> None: + self.last_kwargs: dict | None = None + + async def create(self, **kwargs): + self.last_kwargs = kwargs + return SimpleNamespace() + + +class _CaptureAnthropicClient: + def __init__(self) -> None: + self.messages = _CaptureAnthropicMessages() + + +@pytest.mark.asyncio +async def test_anthropic_get_native_response_forwards_router_replay_with_extra_body(): + pytest.importorskip("anthropic") + from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient + + native_client = _CaptureAnthropicClient() + client = AnthropicMessagesClient(native_client) + + await client.get_native_response( + prompt=[{"role": "user", "content": "hello"}], + model="claude-test", + sampling_args={ + "max_tokens": 32, + "temperature": 0.2, + "extra_body": {"seed": 7}, + "routed_experts": [[[1]]], + }, + ) + + sent = native_client.messages.last_kwargs + assert sent is not None + assert sent["temperature"] == 0.2 + assert sent["extra_body"] == {"seed": 7, "routed_experts": [[[1]]]} + assert "routed_experts" not in sent + + +@pytest.mark.asyncio +async def test_anthropic_get_native_response_defaults_max_tokens_when_missing(): + pytest.importorskip("anthropic") + from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient + + native_client = _CaptureAnthropicClient() + client = AnthropicMessagesClient(native_client) + + await client.get_native_response( + prompt=[{"role": "user", "content": "hello"}], + model="claude-test", + sampling_args={"temperature": 0.2}, + ) + + sent = native_client.messages.last_kwargs + assert sent is not None + assert sent["max_tokens"] == 32768 + assert sent["temperature"] == 0.2 + + +@pytest.mark.asyncio +async def test_anthropic_from_native_response_extracts_tokens_and_router_replay(): + pytest.importorskip("anthropic") + from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient + + client = AnthropicMessagesClient(object()) + routed = np.array([[[11, 12]], [[21, 22]]], dtype=np.int32) + native_response = SimpleNamespace( + id="msg_tokens", + model="claude-haiku-4-5", + stop_reason="end_turn", + usage=SimpleNamespace(input_tokens=3, output_tokens=2), + content=[SimpleNamespace(type="text", text="ok")], + prompt_token_ids=[1, 2, 3], + token_ids=[4, 5], + logprobs={"content": [{"logprob": -0.1}, {"logprob": -0.2}]}, + routed_experts={ + "data": base64.b85encode(routed.tobytes()).decode("utf-8"), + "shape": list(routed.shape), + }, + ) + + response = await client.from_native_response(native_response) + + assert response.message.tokens is not None + assert response.message.tokens.prompt_ids == [1, 2, 3] + assert response.message.tokens.completion_ids == [4, 5] + assert response.message.tokens.completion_logprobs == [-0.1, -0.2] + assert response.message.tokens.routed_experts == routed.tolist() + + +@pytest.mark.asyncio +async def test_anthropic_from_native_response_requires_logprobs_for_tokens(): + pytest.importorskip("anthropic") + from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient + + client = AnthropicMessagesClient(object()) + native_response = SimpleNamespace( + id="msg_tokens_missing", + model="claude-haiku-4-5", + stop_reason="end_turn", + usage=SimpleNamespace(input_tokens=2, output_tokens=1), + content=[SimpleNamespace(type="text", text="ok")], + prompt_token_ids=[1, 2], + token_ids=[3], + logprobs=None, + ) + + response = await client.from_native_response(native_response) + assert response.message.tokens is None diff --git a/verifiers/clients/anthropic_messages_client.py b/verifiers/clients/anthropic_messages_client.py index 9e80b63b7..392901318 100644 --- a/verifiers/clients/anthropic_messages_client.py +++ b/verifiers/clients/anthropic_messages_client.py @@ -1,9 +1,12 @@ +import base64 import functools import json import time -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from typing import Any, cast +import numpy as np + from anthropic import ( AsyncAnthropic, AuthenticationError, @@ -38,6 +41,7 @@ Messages, Response, ResponseMessage, + ResponseTokens, SamplingArgs, SystemMessage, TextMessage, @@ -49,6 +53,10 @@ ) from verifiers.utils.client_utils import setup_anthropic_client +# Default output-token limit used when callers omit max_tokens. +# Anthropic /v1/messages requires max_tokens on every request. +ANTHROPIC_MAX_TOKENS: int = 32768 + def _handle_anthropic_overlong_prompt(func): """Decorator to handle overlong prompt errors from the Anthropic API.""" @@ -87,6 +95,12 @@ class AnthropicMessagesClient( """Wrapper for Messages API via AsyncAnthropic client.""" def setup_client(self, config: ClientConfig) -> AsyncAnthropic: + # Log the default and remind that max_tokens is required for Anthropic. + self.logger.info( + "Anthropic client initialized. max_tokens is required on every request; " + "defaulting to ANTHROPIC_MAX_TOKENS=%d when not provided.", + ANTHROPIC_MAX_TOKENS, + ) return setup_anthropic_client(config) async def close(self) -> None: @@ -345,13 +359,37 @@ def normalize_sampling_args(sampling_args: SamplingArgs) -> dict: max_tokens = sampling_args.pop("max_tokens", None) sampling_args.pop("n", None) sampling_args.pop("stop", None) - if max_tokens is None: - self.logger.warning( - "max_tokens is not set but Anthropic /v1/messages endpoint requires it, falling back to max_tokens=4096" + extra_body = sampling_args.pop("extra_body", {}) + if not isinstance(extra_body, Mapping): + raise TypeError( + "sampling_args['extra_body'] must be a mapping when provided" ) - max_tokens = 4096 + if max_tokens is None: + # Anthropic /v1/messages requires max_tokens to be set in every request. + max_tokens = ANTHROPIC_MAX_TOKENS sampling_args["max_tokens"] = max_tokens + # Anthropic SDK validates top-level request fields. + # Forward unknown model args through extra_body + # so backend-specific payloads (e.g. routed_experts) can be passed via + # sampling_args without custom provider branching. + known_anthropic_args = { + "max_tokens", + "metadata", + "service_tier", + "stop_sequences", + "temperature", + "thinking", + "top_k", + "top_p", + } + extra_body_dict: dict[str, Any] = dict(extra_body) + for key in list(sampling_args.keys()): + if key not in known_anthropic_args: + extra_body_dict[key] = sampling_args.pop(key) + if extra_body_dict: + sampling_args["extra_body"] = extra_body_dict + return {k: v for k, v in sampling_args.items() if v is not None} # Remove internal framework keys not recognized by the Anthropic SDK @@ -440,6 +478,81 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason: case _: return None + def parse_completion_logprobs(logprobs: Any) -> list[float] | None: + if isinstance(logprobs, Mapping): + content = logprobs.get("content") + else: + content = getattr(logprobs, "content", None) + if content is None: + return None + if isinstance(content, Mapping): + content_items: Iterable[Any] = [content] + elif isinstance(content, list): + content_items = content + elif isinstance(content, Iterable) and not isinstance( + content, (str, bytes) + ): + content_items = list(content) + else: + return None + values: list[float] = [] + for token in content_items: + if isinstance(token, Mapping): + value = token.get("logprob") + else: + value = getattr(token, "logprob", None) + if not isinstance(value, (float, int)): + return None + values.append(float(value)) + return values + + def parse_tokens(response: AnthropicMessage) -> ResponseTokens | None: + prompt_ids = getattr(response, "prompt_token_ids", None) + completion_ids = getattr(response, "token_ids", None) + if not isinstance(prompt_ids, list) or not isinstance(completion_ids, list): + return None + if not all(isinstance(token_id, int) for token_id in prompt_ids): + return None + if not all(isinstance(token_id, int) for token_id in completion_ids): + return None + + completion_logprobs = parse_completion_logprobs( + getattr(response, "logprobs", None) + ) + if completion_logprobs is None: + return None + + has_routed_experts = ( + isinstance( + routed_experts := getattr(response, "routed_experts", None), dict + ) + and "data" in routed_experts + and "shape" in routed_experts + ) + if has_routed_experts: + routed_experts = cast(dict[str, Any], routed_experts) + routed_experts = cast( + list[list[list[int]]], + ( + np.frombuffer( + base64.b85decode(routed_experts["data"]), dtype=np.int32 + ) + .reshape(routed_experts["shape"]) + .tolist() + ), + ) + else: + routed_experts = None + + return ResponseTokens( + prompt_ids=prompt_ids, + prompt_mask=[0] * len(prompt_ids), + completion_ids=completion_ids, + completion_mask=[1] * len(completion_ids), + completion_logprobs=completion_logprobs, + routed_experts=routed_experts, + ) + content, reasoning_content, tool_calls, thinking_blocks = parse_content( response.content ) @@ -465,6 +578,6 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason: tool_calls=tool_calls or None, finish_reason=parse_finish_reason(response), is_truncated=is_truncated, - tokens=None, + tokens=parse_tokens(response), ), )