-
Notifications
You must be signed in to change notification settings - Fork 504
Add full Anthropic router replay token handling #928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2de6356
03d4de6
752ee4f
016232a
34bd2ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,12 @@ | ||
| import base64 | ||
| import functools | ||
| import json | ||
| import time | ||
| from collections.abc import Mapping | ||
| from collections.abc import Iterable, Mapping | ||
| from typing import Any, cast | ||
|
|
||
| import numpy as np | ||
|
|
||
| from anthropic import ( | ||
| AsyncAnthropic, | ||
| AuthenticationError, | ||
|
|
@@ -38,6 +41,7 @@ | |
| Messages, | ||
| Response, | ||
| ResponseMessage, | ||
| ResponseTokens, | ||
| SamplingArgs, | ||
| SystemMessage, | ||
| TextMessage, | ||
|
|
@@ -49,6 +53,10 @@ | |
| ) | ||
| from verifiers.utils.client_utils import setup_anthropic_client | ||
|
|
||
| # Default output-token limit used when callers omit max_tokens. | ||
| # Anthropic /v1/messages requires max_tokens on every request. | ||
| ANTHROPIC_MAX_TOKENS: int = 32768 | ||
|
|
||
|
|
||
| def _handle_anthropic_overlong_prompt(func): | ||
| """Decorator to handle overlong prompt errors from the Anthropic API.""" | ||
|
|
@@ -87,6 +95,12 @@ class AnthropicMessagesClient( | |
| """Wrapper for Messages API via AsyncAnthropic client.""" | ||
|
|
||
| def setup_client(self, config: ClientConfig) -> AsyncAnthropic: | ||
| # Log the default and remind that max_tokens is required for Anthropic. | ||
| self.logger.info( | ||
| "Anthropic client initialized. max_tokens is required on every request; " | ||
| "defaulting to ANTHROPIC_MAX_TOKENS=%d when not provided.", | ||
| ANTHROPIC_MAX_TOKENS, | ||
| ) | ||
| return setup_anthropic_client(config) | ||
|
|
||
| async def close(self) -> None: | ||
|
|
@@ -345,13 +359,37 @@ def normalize_sampling_args(sampling_args: SamplingArgs) -> dict: | |
| max_tokens = sampling_args.pop("max_tokens", None) | ||
| sampling_args.pop("n", None) | ||
| sampling_args.pop("stop", None) | ||
| if max_tokens is None: | ||
| self.logger.warning( | ||
| "max_tokens is not set but Anthropic /v1/messages endpoint requires it, falling back to max_tokens=4096" | ||
| extra_body = sampling_args.pop("extra_body", {}) | ||
| if not isinstance(extra_body, Mapping): | ||
| raise TypeError( | ||
| "sampling_args['extra_body'] must be a mapping when provided" | ||
| ) | ||
| max_tokens = 4096 | ||
| if max_tokens is None: | ||
| # Anthropic /v1/messages requires max_tokens to be set in every request. | ||
| max_tokens = ANTHROPIC_MAX_TOKENS | ||
| sampling_args["max_tokens"] = max_tokens | ||
|
|
||
| # Anthropic SDK validates top-level request fields. | ||
| # Forward unknown model args through extra_body | ||
| # so backend-specific payloads (e.g. routed_experts) can be passed via | ||
| # sampling_args without custom provider branching. | ||
| known_anthropic_args = { | ||
| "max_tokens", | ||
| "metadata", | ||
| "service_tier", | ||
| "stop_sequences", | ||
| "temperature", | ||
| "thinking", | ||
| "top_k", | ||
| "top_p", | ||
| } | ||
| extra_body_dict: dict[str, Any] = dict(extra_body) | ||
| for key in list(sampling_args.keys()): | ||
| if key not in known_anthropic_args: | ||
| extra_body_dict[key] = sampling_args.pop(key) | ||
| if extra_body_dict: | ||
| sampling_args["extra_body"] = extra_body_dict | ||
|
|
||
| return {k: v for k, v in sampling_args.items() if v is not None} | ||
|
|
||
| # Remove internal framework keys not recognized by the Anthropic SDK | ||
|
|
@@ -440,6 +478,81 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason: | |
| case _: | ||
| return None | ||
|
|
||
| def parse_completion_logprobs(logprobs: Any) -> list[float] | None: | ||
| if isinstance(logprobs, Mapping): | ||
| content = logprobs.get("content") | ||
| else: | ||
| content = getattr(logprobs, "content", None) | ||
| if content is None: | ||
| return None | ||
| if isinstance(content, Mapping): | ||
| content_items: Iterable[Any] = [content] | ||
| elif isinstance(content, list): | ||
| content_items = content | ||
| elif isinstance(content, Iterable) and not isinstance( | ||
| content, (str, bytes) | ||
| ): | ||
| content_items = list(content) | ||
| else: | ||
| return None | ||
| values: list[float] = [] | ||
| for token in content_items: | ||
| if isinstance(token, Mapping): | ||
| value = token.get("logprob") | ||
| else: | ||
| value = getattr(token, "logprob", None) | ||
| if not isinstance(value, (float, int)): | ||
| return None | ||
| values.append(float(value)) | ||
| return values | ||
|
|
||
| def parse_tokens(response: AnthropicMessage) -> ResponseTokens | None: | ||
| prompt_ids = getattr(response, "prompt_token_ids", None) | ||
| completion_ids = getattr(response, "token_ids", None) | ||
| if not isinstance(prompt_ids, list) or not isinstance(completion_ids, list): | ||
| return None | ||
| if not all(isinstance(token_id, int) for token_id in prompt_ids): | ||
| return None | ||
| if not all(isinstance(token_id, int) for token_id in completion_ids): | ||
| return None | ||
|
|
||
| completion_logprobs = parse_completion_logprobs( | ||
| getattr(response, "logprobs", None) | ||
| ) | ||
| if completion_logprobs is None: | ||
| return None | ||
|
|
||
| has_routed_experts = ( | ||
| isinstance( | ||
| routed_experts := getattr(response, "routed_experts", None), dict | ||
| ) | ||
| and "data" in routed_experts | ||
| and "shape" in routed_experts | ||
| ) | ||
| if has_routed_experts: | ||
| routed_experts = cast(dict[str, Any], routed_experts) | ||
| routed_experts = cast( | ||
| list[list[list[int]]], | ||
| ( | ||
| np.frombuffer( | ||
| base64.b85decode(routed_experts["data"]), dtype=np.int32 | ||
| ) | ||
| .reshape(routed_experts["shape"]) | ||
| .tolist() | ||
| ), | ||
| ) | ||
| else: | ||
| routed_experts = None | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Routed-experts decoding duplicated across two client filesLow Severity The |
||
|
|
||
| return ResponseTokens( | ||
| prompt_ids=prompt_ids, | ||
| prompt_mask=[0] * len(prompt_ids), | ||
| completion_ids=completion_ids, | ||
| completion_mask=[1] * len(completion_ids), | ||
| completion_logprobs=completion_logprobs, | ||
| routed_experts=routed_experts, | ||
| ) | ||
|
|
||
| content, reasoning_content, tool_calls, thinking_blocks = parse_content( | ||
| response.content | ||
| ) | ||
|
|
@@ -465,6 +578,6 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason: | |
| tool_calls=tool_calls or None, | ||
| finish_reason=parse_finish_reason(response), | ||
| is_truncated=is_truncated, | ||
| tokens=None, | ||
| tokens=parse_tokens(response), | ||
| ), | ||
| ) | ||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Malformed routed_experts crashes entire response parsing
Medium Severity
The
parse_tokensfunction gracefully returnsNonewhenprompt_token_ids,token_ids, orlogprobsare missing or invalid, but therouted_expertsdecoding block (base64.b85decode+np.frombuffer+.reshape) has no try-except. If the server returns arouted_expertsdict with validdataandshapekeys but malformed content (e.g. corrupt base85 or shape mismatch), an unhandledValueErrorpropagates out offrom_native_response, causing the entire response — including valid text content — to be lost as aModelError. Wrapping the decode in a try-except and falling back torouted_experts = Nonewould be consistent with the rest of the function's defensive design.