Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions tests/test_rate_limit_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Tests for rate limit error handling and retry mechanism."""

import httpx
import pytest
from openai import RateLimitError as OpenAIRateLimitError

from verifiers.errors import RateLimitError as VFRateLimitError
from verifiers.types import EvalConfig
from verifiers.utils.async_utils import maybe_retry


def _make_rate_limit_error() -> OpenAIRateLimitError:
response = httpx.Response(
status_code=429,
request=httpx.Request("POST", "https://api.openai.com/v1/chat/completions"),
json={"error": {"message": "Too many requests", "type": "rate_limit_error"}},
)
return OpenAIRateLimitError("Rate limit exceeded", response=response, body=None)


@pytest.mark.asyncio
async def test_rate_limit_error_retries_with_config():
"""Test that RateLimitError triggers retry when configured."""
call_count = 0

async def failing_func():
nonlocal call_count
call_count += 1
if call_count < 3:
return {"error": VFRateLimitError("Rate limited")}
return {"result": "success"}

wrapped = maybe_retry(failing_func, max_retries=3, initial=0.01)
result = await wrapped()

assert call_count == 3
assert result["result"] == "success"


@pytest.mark.asyncio
async def test_rate_limit_error_exhaustion_returns_error():
"""Test that exhausted retries return error in state."""
async def always_failing_func():
return {"error": VFRateLimitError("Always rate limited")}

wrapped = maybe_retry(always_failing_func, max_retries=2, initial=0.01)
result = await wrapped()

assert "error" in result
assert isinstance(result["error"], VFRateLimitError)


@pytest.mark.asyncio
async def test_no_retry_when_max_retries_zero():
"""Test that max_retries=0 disables retry."""
call_count = 0

async def failing_func():
nonlocal call_count
call_count += 1
return {"error": VFRateLimitError("Rate limited")}

wrapped = maybe_retry(failing_func, max_retries=0)
result = await wrapped()

assert call_count == 1 # Only called once, no retry
assert "error" in result


@pytest.mark.asyncio
async def test_jitter_configuration():
"""Test that jitter can be disabled."""
async def failing_func():
return {"error": VFRateLimitError("Rate limited")}

# Should not raise with jitter enabled (default)
wrapped_with_jitter = maybe_retry(failing_func, max_retries=1, initial=0.01, jitter=True)
result = await wrapped_with_jitter()
assert "error" in result

# Should not raise with jitter disabled
wrapped_no_jitter = maybe_retry(failing_func, max_retries=1, initial=0.01, jitter=False)
result = await wrapped_no_jitter()
assert "error" in result


@pytest.mark.asyncio
async def test_multiple_error_types_in_retry():
"""Test that multiple error types can be retried."""
from verifiers.errors import InfraError

call_count = 0

async def multi_error_func():
nonlocal call_count
call_count += 1
if call_count == 1:
return {"error": VFRateLimitError("Rate limited")}
elif call_count == 2:
return {"error": InfraError("Infra error")}
else:
return {"result": "success"}

wrapped = maybe_retry(
multi_error_func,
max_retries=3,
initial=0.01,
error_types=(VFRateLimitError, InfraError)
)
result = await wrapped()

assert result["result"] == "success"
assert call_count == 3


@pytest.mark.asyncio
async def test_retry_configuration_values_are_used():
"""Test that EvalConfig accepts and stores retry timing parameters."""
from verifiers.types import ClientConfig

config = EvalConfig(
env_id="test_env",
env_args={},
env_dir_path="/tmp/test",
model="gpt-4",
client_config=ClientConfig(api_key_var="TEST_KEY"),
sampling_args={},
num_examples=1,
rollouts_per_example=1,
max_concurrent=1,
retry_base_delay=2.0,
retry_max_backoff=30.0,
retry_jitter=False,
)

assert config.retry_base_delay == 2.0
assert config.retry_max_backoff == 30.0
assert config.retry_jitter is False


@pytest.mark.asyncio
async def test_retry_configuration_defaults():
"""Test that EvalConfig has correct default values for retry timing."""
from verifiers.types import ClientConfig

config = EvalConfig(
env_id="test_env",
env_args={},
env_dir_path="/tmp/test",
model="gpt-4",
client_config=ClientConfig(api_key_var="TEST_KEY"),
sampling_args={},
num_examples=1,
rollouts_per_example=1,
max_concurrent=1,
)

# Verify defaults match maybe_retry defaults
assert config.retry_base_delay == 1.0
assert config.retry_max_backoff == 60.0
assert config.retry_jitter is True
5 changes: 4 additions & 1 deletion verifiers/clients/anthropic_messages_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)

from verifiers.clients.client import Client
from verifiers.errors import OverlongPromptError
from verifiers.errors import OverlongPromptError, RateLimitError as VFRateLimitError
from verifiers.types import (
AssistantMessage,
ClientConfig,
Expand Down Expand Up @@ -60,6 +60,9 @@ async def wrapper(*args, **kwargs):
except (AuthenticationError, PermissionDeniedError):
raise
except BadRequestError as e:
# Check for HTTP 429 rate limit
if hasattr(e, "response") and hasattr(e.response, "status_code") and e.response.status_code == 429:
raise VFRateLimitError(f"Anthropic rate limit: {e}") from e
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anthropic rate limit check is unreachable dead code

High Severity

The Anthropic SDK raises anthropic.RateLimitError (not BadRequestError) for HTTP 429 responses. The new check for e.response.status_code == 429 inside except BadRequestError is unreachable — a BadRequestError will never have a 429 status code. Anthropic rate limit errors are not imported or caught, so they fall through to the generic except Exception in client.py and become a ModelError, bypassing the retry mechanism entirely. The OpenAI client correctly imports and catches a separate RateLimitError exception.

Fix in Cursor Fix in Web

error_text = e.message.lower()
context_length_phrases = [
"prompt is too long",
Expand Down
4 changes: 4 additions & 0 deletions verifiers/clients/openai_chat_completions_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AuthenticationError,
BadRequestError,
PermissionDeniedError,
RateLimitError as OpenAIRateLimitError,
)
from openai.types.chat import (
ChatCompletion,
Expand Down Expand Up @@ -40,6 +41,7 @@
EmptyModelResponseError,
InvalidModelResponseError,
OverlongPromptError,
RateLimitError as VFRateLimitError,
)
from verifiers.types import (
AssistantMessage,
Expand Down Expand Up @@ -71,6 +73,8 @@ async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
except (AuthenticationError, PermissionDeniedError):
raise
except OpenAIRateLimitError as e:
raise VFRateLimitError(f"OpenAI rate limit: {e}") from e
except BadRequestError as e:
error_text = e.response.text.lower()
context_length_phrases = [
Expand Down
20 changes: 19 additions & 1 deletion verifiers/envs/env_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,25 @@ async def run_rollout( # type: ignore[override]
model: str,
sampling_args: SamplingArgs,
max_retries: int = 0,
retry_base_delay: float = 1.0,
retry_max_backoff: float = 60.0,
retry_jitter: bool = True,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
) -> vf.RolloutOutput:
env = self.get_env_for_task(input["task"])
env_client = env_client or env.env_client or self.env_client
return await env.run_rollout(
input, client, model, sampling_args, max_retries, state_columns, env_client
input,
client,
model,
sampling_args,
max_retries,
retry_base_delay,
retry_max_backoff,
retry_jitter,
state_columns,
env_client,
)

@final
Expand All @@ -293,6 +305,9 @@ async def run_group( # type: ignore[override]
model: str,
sampling_args: SamplingArgs,
max_retries: int = 0,
retry_base_delay: float = 1.0,
retry_max_backoff: float = 60.0,
retry_jitter: bool = True,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
) -> list[vf.RolloutOutput]:
Expand All @@ -304,6 +319,9 @@ async def run_group( # type: ignore[override]
model,
sampling_args,
max_retries,
retry_base_delay,
retry_max_backoff,
retry_jitter,
state_columns,
env_client,
)
Expand Down
43 changes: 41 additions & 2 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,9 @@ async def run_rollout(
model: str,
sampling_args: SamplingArgs,
max_retries: int = 0,
retry_base_delay: float = 1.0,
retry_max_backoff: float = 60.0,
retry_jitter: bool = True,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
) -> RolloutOutput:
Expand All @@ -722,6 +725,9 @@ async def run_rollout(
model,
sampling_args,
max_retries,
retry_base_delay,
retry_max_backoff,
retry_jitter,
state_columns,
)

Expand All @@ -742,7 +748,13 @@ async def run_rollout_attempt() -> State:

return state

state = await maybe_retry(run_rollout_attempt, max_retries=max_retries)()
state = await maybe_retry(
run_rollout_attempt,
max_retries=max_retries,
initial=retry_base_delay,
max_wait=retry_max_backoff,
jitter=retry_jitter,
)()
output = state_to_output(state, state_columns or [])
return output

Expand All @@ -754,6 +766,9 @@ async def run_group(
model: str,
sampling_args: SamplingArgs,
max_retries: int = 0,
retry_base_delay: float = 1.0,
retry_max_backoff: float = 60.0,
retry_jitter: bool = True,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
**kwargs,
Expand All @@ -776,6 +791,9 @@ async def run_group(
model,
sampling_args,
max_retries,
retry_base_delay,
retry_max_backoff,
retry_jitter,
state_columns,
)

Expand All @@ -799,7 +817,13 @@ async def run_group_attempt() -> list[State]:
await self.rubric.dummy_score_group(group_states)
return group_states

group_states = await maybe_retry(run_group_attempt, max_retries=max_retries)()
group_states = await maybe_retry(
run_group_attempt,
max_retries=max_retries,
initial=retry_base_delay,
max_wait=retry_max_backoff,
jitter=retry_jitter,
)()
outputs = [
state_to_output(state, state_columns or []) for state in group_states
]
Expand All @@ -819,6 +843,9 @@ async def generate(
hf_hub_dataset_name: str | None = None,
independent_scoring: bool = False,
max_retries: int = 0,
retry_base_delay: float = 1.0,
retry_max_backoff: float = 60.0,
retry_jitter: bool = True,
on_start: StartCallback | None = None,
on_progress: ProgressCallback | list[ProgressCallback] | None = None,
on_log: LogCallback | None = None,
Expand Down Expand Up @@ -1012,6 +1039,9 @@ def get_client_for_group() -> Client | ClientConfig:
model,
sampling_args,
max_retries=max_retries,
retry_base_delay=retry_base_delay,
retry_max_backoff=retry_max_backoff,
retry_jitter=retry_jitter,
state_columns=state_columns,
),
),
Expand All @@ -1038,6 +1068,9 @@ def get_client_for_group() -> Client | ClientConfig:
model,
sampling_args,
max_retries=max_retries,
retry_base_delay=retry_base_delay,
retry_max_backoff=retry_max_backoff,
retry_jitter=retry_jitter,
state_columns=state_columns,
),
),
Expand Down Expand Up @@ -1152,6 +1185,9 @@ async def evaluate(
hf_hub_dataset_name: str | None = None,
independent_scoring: bool = False,
max_retries: int = 0,
retry_base_delay: float = 1.0,
retry_max_backoff: float = 60.0,
retry_jitter: bool = True,
on_start: StartCallback | None = None,
on_progress: ProgressCallback | list[ProgressCallback] | None = None,
on_log: LogCallback | None = None,
Expand Down Expand Up @@ -1179,6 +1215,9 @@ async def evaluate(
hf_hub_dataset_name=hf_hub_dataset_name,
independent_scoring=independent_scoring,
max_retries=max_retries,
retry_base_delay=retry_base_delay,
retry_max_backoff=retry_max_backoff,
retry_jitter=retry_jitter,
on_start=on_start,
on_progress=on_progress,
on_log=on_log,
Expand Down
6 changes: 6 additions & 0 deletions verifiers/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class InfraError(Error):
pass


class RateLimitError(Error):
"""Used to catch rate limit errors (HTTP 429, Too Many Requests)."""

pass


class SandboxError(InfraError):
"""Used to catch errors while interacting with sandboxes."""

Expand Down
Loading