Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 270 additions & 0 deletions tests/test_openai_chat_completions_token_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
from typing import Any, cast

import pytest

from verifiers.clients.openai_chat_completions_client import OpenAIChatCompletionsClient
from verifiers.clients.openai_chat_completions_token_client import (
OpenAIChatCompletionsTokenClient,
)
from verifiers.types import State


class _NoopClient:
base_url = "http://localhost:8000/v1"

def with_options(self, **kwargs): # noqa: ANN003
return self


class _RecordingClient(_NoopClient):
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []

async def post(self, path: str, body: dict[str, Any], cast_to: type) -> Any:
self.calls.append({"path": path, "body": body, "cast_to": cast_to})
return {"ok": True, "path": path, "body": body}


class _PromptIdTestClient(OpenAIChatCompletionsTokenClient):
def __init__(self, full_prompt_ids: list[int]) -> None:
super().__init__(_NoopClient())
self._full_prompt_ids = full_prompt_ids

async def to_native_prompt(self, messages): # type: ignore[override]
return cast(Any, messages), {}

async def tokenize( # type: ignore[override]
self,
messages,
tools,
model,
extra_kwargs: dict = {},
**kwargs,
) -> list[int]:
if isinstance(messages, str):
assert messages == "World!"
return [777]

if messages == [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "World!"},
]:
assert extra_kwargs == {"add_generation_prompt": False}
return [1, 777, 999]

return self._full_prompt_ids


class _NoTokenizeClient(OpenAIChatCompletionsTokenClient):
def __init__(self) -> None:
super().__init__(_NoopClient())

async def to_native_prompt(self, messages): # type: ignore[override]
return cast(Any, messages), {}

async def tokenize( # type: ignore[override]
self,
messages,
tools,
model,
extra_kwargs: dict = {},
**kwargs,
) -> list[int]:
raise AssertionError("tokenize should not be called without a prefix match")


def _make_step(
prompt: list[dict[str, str]],
completion: list[dict[str, str]],
prompt_ids: list[int],
completion_ids: list[int],
) -> dict[str, Any]:
return {
"prompt": prompt,
"completion": completion,
"tokens": {
"prompt_ids": prompt_ids,
"completion_ids": completion_ids,
},
}


@pytest.mark.asyncio
async def test_get_prompt_ids_uses_largest_message_prefix_match():
client = _PromptIdTestClient(full_prompt_ids=[1, 2, 3, 4, 999, 5])
state = cast(
State,
{
"model": "test-model",
"trajectory": [
_make_step(
prompt=[{"role": "user", "content": "u1"}],
completion=[{"role": "assistant", "content": "a1"}],
prompt_ids=[1],
completion_ids=[2],
),
_make_step(
prompt=[
{"role": "user", "content": "u1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "u2"},
],
completion=[{"role": "assistant", "content": "a2"}],
prompt_ids=[1, 2, 3],
completion_ids=[4],
),
],
},
)
prompt_messages = cast(
Any,
[
{"role": "user", "content": "u1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "u2"},
{"role": "assistant", "content": "a2"},
{"role": "user", "content": "u3"},
],
)

prompt_ids = await client.get_prompt_ids(state, prompt_messages, oai_tools=None)

assert prompt_ids == [1, 2, 3, 4, 999, 5]


@pytest.mark.asyncio
async def test_get_prompt_ids_returns_none_when_no_prefix_match():
client = _NoTokenizeClient()
state = cast(
State,
{
"model": "test-model",
"trajectory": [
_make_step(
prompt=[{"role": "user", "content": "old"}],
completion=[{"role": "assistant", "content": "reply"}],
prompt_ids=[1],
completion_ids=[2],
)
],
},
)

prompt_ids = await client.get_prompt_ids(
state,
cast(Any, [{"role": "user", "content": "new"}]),
oai_tools=None,
)

assert prompt_ids is None


@pytest.mark.asyncio
async def test_get_native_response_falls_back_to_super_when_no_prefix_match(
monkeypatch: pytest.MonkeyPatch,
):
client = OpenAIChatCompletionsTokenClient(_NoopClient())
sentinel = {"source": "super"}
calls: list[dict[str, Any]] = []

async def fake_get_prompt_ids(self, state, prompt_messages, oai_tools): # noqa: ANN001
return None

async def fake_super_get_native_response( # noqa: ANN001
self,
prompt,
model,
sampling_args,
tools=None,
**kwargs,
):
calls.append(
{
"prompt": prompt,
"model": model,
"sampling_args": sampling_args,
"tools": tools,
}
)
return sentinel

monkeypatch.setattr(
OpenAIChatCompletionsTokenClient, "get_prompt_ids", fake_get_prompt_ids
)
monkeypatch.setattr(
OpenAIChatCompletionsClient,
"get_native_response",
fake_super_get_native_response,
)

state = cast(
State,
{
"model": "test-model",
"trajectory": [
_make_step(
prompt=[{"role": "user", "content": "u1"}],
completion=[{"role": "assistant", "content": "a1"}],
prompt_ids=[1],
completion_ids=[2],
)
],
},
)
prompt = cast(Any, [{"role": "user", "content": "u2"}])

response = await client.get_native_response(
prompt=prompt,
model="test-model",
sampling_args={},
tools=None,
state=state,
)

assert response is sentinel
assert len(calls) == 1
assert calls[0]["prompt"] == prompt


@pytest.mark.asyncio
async def test_get_native_response_uses_token_route_when_prompt_ids_available(
monkeypatch: pytest.MonkeyPatch,
):
recording_client = _RecordingClient()
client = OpenAIChatCompletionsTokenClient(recording_client)

async def fake_get_prompt_ids(self, state, prompt_messages, oai_tools): # noqa: ANN001
return [10, 20]

monkeypatch.setattr(
OpenAIChatCompletionsTokenClient, "get_prompt_ids", fake_get_prompt_ids
)

state = cast(
State,
{
"model": "test-model",
"trajectory": [
_make_step(
prompt=[{"role": "user", "content": "u1"}],
completion=[{"role": "assistant", "content": "a1"}],
prompt_ids=[1],
completion_ids=[2],
)
],
},
)
prompt = cast(Any, [{"role": "user", "content": "u2"}])

response = await client.get_native_response(
prompt=prompt,
model="test-model",
sampling_args={},
tools=None,
state=state,
)

assert response["ok"] is True
assert len(recording_client.calls) == 1
assert recording_client.calls[0]["path"] == "/chat/completions/tokens"
assert recording_client.calls[0]["body"]["tokens"] == [10, 20]
65 changes: 58 additions & 7 deletions verifiers/clients/openai_chat_completions_token_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Optional, cast
from typing import Any, Optional, cast

from openai import AsyncOpenAI, BaseModel
from openai.types.chat import ChatCompletion
Expand Down Expand Up @@ -85,6 +85,10 @@ def normalize_sampling_args(sampling_args: SamplingArgs):
prompt, model, sampling_args, tools
)
prompt_ids = await self.get_prompt_ids(state, prompt, tools)
if prompt_ids is None:
return await super().get_native_response(
prompt, model, sampling_args, tools
)
extra_body = sampling_args.pop("extra_body", {})
body = dict(
model=model,
Expand All @@ -106,18 +110,65 @@ async def get_prompt_ids(
state: State,
prompt_messages: OpenAIChatMessages,
oai_tools: list[OpenAITool] | None,
) -> list[int]:
) -> list[int] | None:
"""
Build prompt_ids (token prompt) corresponding to prompt_messages. We assume
that this method is called *before* making the model response from
prompt_messages, i.e. the previous turn's prompt and completion do not yet
include the environment response and next turn's model response.

Returns None when no trajectory step has a message-level prefix match with
prompt_messages.
"""
prev_turn_tokens = state["trajectory"][-1]["tokens"]
assert prev_turn_tokens is not None
prev_turn_prompt_ids = prev_turn_tokens["prompt_ids"]
prev_turn_completion_ids = prev_turn_tokens["completion_ids"]
prev_turn_ids = prev_turn_prompt_ids + prev_turn_completion_ids

def normalize_for_comparison(value: Any) -> Any:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make this a general message_util? seems useful in other places too? also, vaguely remember we have a similar util to this alr but might be wrong

if hasattr(value, "model_dump"):
return normalize_for_comparison(value.model_dump())
if isinstance(value, Mapping):
return {
str(key): normalize_for_comparison(val)
for key, val in value.items()
}
if isinstance(value, list):
return [normalize_for_comparison(item) for item in value]
return value

async def find_largest_prefix_match_tokens() -> list[int] | None:
"""Scan trajectory backwards for the step whose messages form the longest
prefix of prompt_messages. Returns that step's token IDs, or None."""
normalized_prompt_messages = normalize_for_comparison(prompt_messages)
best_prefix_len = -1
best_step_tokens = None
for step in reversed(state["trajectory"]):
step_tokens = step["tokens"]
if step_tokens is None:
continue
step_messages = cast(Any, [*step["prompt"], *step["completion"]])
step_prompt_messages, _ = await self.to_native_prompt(step_messages)
normalized_step_messages = normalize_for_comparison(
step_prompt_messages
)
prefix_len = len(normalized_step_messages)
if prefix_len <= 0:
continue
if prefix_len <= best_prefix_len:
continue
if prefix_len > len(normalized_prompt_messages):
continue
if normalized_prompt_messages[:prefix_len] != normalized_step_messages:
continue
best_prefix_len = prefix_len
best_step_tokens = step_tokens
if best_prefix_len == len(normalized_prompt_messages):
break

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per-turn backward scan may be expensive

Medium Severity

get_prompt_ids() now walks backward over the entire state["trajectory"] and calls to_native_prompt() per step until it finds the best prefix, which can add significant overhead on long trajectories and slow every generation turn.

Fix in Cursor Fix in Web

if best_step_tokens is None:
return None
return best_step_tokens["prompt_ids"] + best_step_tokens["completion_ids"]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefix match can miss equivalent messages

Medium Severity

get_prompt_ids’s new message-level prefix matcher compares normalized message objects for strict equality, which can differ across representations (e.g., to_native_prompt emitting {"content": None} while incoming prompt_messages omits content, or other default/extra fields). This can produce false “no prefix match”, disabling the token route unexpectedly.

Fix in Cursor Fix in Web

prev_turn_ids = await find_largest_prefix_match_tokens()
if prev_turn_ids is None:
return None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefix match ignores tool-dependent tokenization

Medium Severity

find_largest_prefix_match_tokens() selects a trajectory step using only a message-level prefix comparison, but the stitched prev_turn_ids are later combined with full_ids produced by tokenize(..., tools=oai_tools). If the effective tool set differs from when the matched step’s tokens were produced, prev_turn_ids may not align with full_ids, yielding incorrect env_response_ids and an invalid prompt for /chat/completions/tokens.

Fix in Cursor Fix in Web


def compute_suffix_ids(lst: list[int], value: int) -> list[int]:
"""Returns all tokens after the last occurrence of `value` in `lst`, if any."""
Expand Down
Loading