From 51b05c791fb592d7357eb3b8c8aa41ae1b27c768 Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 11 Feb 2026 03:07:16 +0000 Subject: [PATCH 1/9] add per-variant endpoint concurrency with least-loaded dispatch --- tests/test_endpoint_dispatcher.py | 167 ++++++++++++++++++++++++++++++ tests/test_endpoint_registry.py | 75 ++++++++++++++ verifiers/envs/environment.py | 141 +++++++++++++++++-------- verifiers/scripts/eval.py | 6 +- verifiers/types.py | 12 ++- verifiers/utils/async_utils.py | 104 ++++++++++++++++++- verifiers/utils/eval_utils.py | 128 +++++++++++++++++++---- 7 files changed, 568 insertions(+), 65 deletions(-) create mode 100644 tests/test_endpoint_dispatcher.py diff --git a/tests/test_endpoint_dispatcher.py b/tests/test_endpoint_dispatcher.py new file mode 100644 index 000000000..1cfad8765 --- /dev/null +++ b/tests/test_endpoint_dispatcher.py @@ -0,0 +1,167 @@ +import asyncio +import sys + +import pytest + +from verifiers.types import ClientConfig +from verifiers.utils.async_utils import ( + EndpointDispatcher, + EndpointVariantSlot, + NullEndpointDispatcher, +) + + +def _make_config(url: str = "https://a.example/v1") -> ClientConfig: + return ClientConfig(api_base_url=url) + + +class TestEndpointVariantSlot: + def test_available_reflects_capacity(self): + slot = EndpointVariantSlot(config=_make_config(), max_concurrent=10) + assert slot.available == 10 + slot.active = 3 + assert slot.available == 7 + + +class TestEndpointDispatcher: + @pytest.mark.asyncio + async def test_least_loaded_picks_emptier_variant(self): + slot_a = EndpointVariantSlot( + config=_make_config("https://a.example/v1"), max_concurrent=4 + ) + slot_b = EndpointVariantSlot( + config=_make_config("https://b.example/v1"), max_concurrent=4 + ) + dispatcher = EndpointDispatcher([slot_a, slot_b]) + + # Acquire one on each — first should go to whichever has more available + async with dispatcher.acquire() as got1: + assert got1 in (slot_a, slot_b) + # got1 now has 1 active, the other has 0 + other = slot_b if got1 is slot_a else slot_a + async with dispatcher.acquire() as got2: + # Should pick the one with more available (the other) + assert got2 is other + + @pytest.mark.asyncio + async def test_blocks_when_all_full_then_unblocks(self): + slot = EndpointVariantSlot(config=_make_config(), max_concurrent=1) + dispatcher = EndpointDispatcher([slot]) + + acquired = asyncio.Event() + released = asyncio.Event() + + async def holder(): + async with dispatcher.acquire(): + acquired.set() + await released.wait() + + async def waiter(): + await acquired.wait() + # This should block until holder releases + async with dispatcher.acquire() as got: + assert got is slot + + holder_task = asyncio.create_task(holder()) + waiter_task = asyncio.create_task(waiter()) + + # Give tasks time to start + await asyncio.sleep(0.05) + assert acquired.is_set() + assert not waiter_task.done() + + # Release the holder + released.set() + await asyncio.wait_for(waiter_task, timeout=2.0) + await holder_task + + @pytest.mark.asyncio + async def test_count_parameter_consumes_correct_slots(self): + slot = EndpointVariantSlot(config=_make_config(), max_concurrent=4) + dispatcher = EndpointDispatcher([slot]) + + async with dispatcher.acquire(count=3) as got: + assert got is slot + assert slot.active == 3 + assert slot.available == 1 + + assert slot.active == 0 + assert slot.available == 4 + + @pytest.mark.asyncio + async def test_releases_on_exception(self): + slot = EndpointVariantSlot(config=_make_config(), max_concurrent=2) + dispatcher = EndpointDispatcher([slot]) + + with pytest.raises(RuntimeError, match="boom"): + async with dispatcher.acquire(): + raise RuntimeError("boom") + + assert slot.active == 0 + + @pytest.mark.asyncio + async def test_oversize_count_waits_for_full_idle(self): + """When count exceeds every variant's max_concurrent, wait for idle.""" + slot = EndpointVariantSlot(config=_make_config(), max_concurrent=2) + dispatcher = EndpointDispatcher([slot]) + + released = asyncio.Event() + + async def holder(): + async with dispatcher.acquire(count=1): + await released.wait() + + holder_task = asyncio.create_task(holder()) + await asyncio.sleep(0.05) + + # count=5 exceeds max_concurrent=2 — must wait for full idle + async def oversize(): + async with dispatcher.acquire(count=5) as got: + assert got is slot + + oversize_task = asyncio.create_task(oversize()) + await asyncio.sleep(0.05) + assert not oversize_task.done() + + released.set() + await asyncio.wait_for(oversize_task, timeout=2.0) + await holder_task + + def test_empty_variants_raises(self): + with pytest.raises(ValueError): + EndpointDispatcher([]) + + +class TestNullEndpointDispatcher: + @pytest.mark.asyncio + async def test_round_robins_without_blocking(self): + configs = [ + _make_config("https://a.example/v1"), + _make_config("https://b.example/v1"), + _make_config("https://c.example/v1"), + ] + dispatcher = NullEndpointDispatcher(configs) + + urls = [] + for _ in range(6): + async with dispatcher.acquire() as slot: + urls.append(slot.config.api_base_url) + + assert urls == [ + "https://a.example/v1", + "https://b.example/v1", + "https://c.example/v1", + "https://a.example/v1", + "https://b.example/v1", + "https://c.example/v1", + ] + + @pytest.mark.asyncio + async def test_slot_has_maxsize_capacity(self): + dispatcher = NullEndpointDispatcher([_make_config()]) + async with dispatcher.acquire() as slot: + assert slot.max_concurrent == sys.maxsize + + def test_empty_configs_raises(self): + with pytest.raises(ValueError): + NullEndpointDispatcher([]) diff --git a/tests/test_endpoint_registry.py b/tests/test_endpoint_registry.py index 9c6e44513..70b5b32cd 100644 --- a/tests/test_endpoint_registry.py +++ b/tests/test_endpoint_registry.py @@ -164,6 +164,81 @@ def test_load_endpoints_directory_prefers_toml_then_python(tmp_path: Path): assert set(endpoints.keys()) == {"from-py"} +def test_load_endpoints_toml_parses_max_concurrent(tmp_path: Path): + registry_path = tmp_path / "endpoints.toml" + registry_path.write_text( + "[[endpoint]]\n" + 'endpoint_id = "my-model"\n' + 'model = "my/model"\n' + 'url = "https://a.example/v1"\n' + 'key = "A_KEY"\n' + "max_concurrent = 16\n" + "\n" + "[[endpoint]]\n" + 'endpoint_id = "my-model"\n' + 'model = "my/model"\n' + 'url = "https://b.example/v1"\n' + 'key = "A_KEY"\n' + "max_concurrent = 32\n", + encoding="utf-8", + ) + + endpoints = load_endpoints(str(registry_path)) + + assert endpoints["my-model"][0]["max_concurrent"] == 16 + assert endpoints["my-model"][1]["max_concurrent"] == 32 + + +def test_load_endpoints_toml_max_concurrent_optional(tmp_path: Path): + registry_path = tmp_path / "endpoints.toml" + registry_path.write_text( + "[[endpoint]]\n" + 'endpoint_id = "my-model"\n' + 'model = "my/model"\n' + 'url = "https://a.example/v1"\n' + 'key = "A_KEY"\n', + encoding="utf-8", + ) + + endpoints = load_endpoints(str(registry_path)) + + assert "max_concurrent" not in endpoints["my-model"][0] + + +def test_load_endpoints_python_parses_max_concurrent(tmp_path: Path): + registry_path = tmp_path / "endpoints.py" + registry_path.write_text( + "ENDPOINTS = {\n" + ' "my-model": [\n' + ' {"model": "m", "url": "https://a.example/v1", "key": "K", "max_concurrent": 8},\n' + " ]\n" + "}\n", + encoding="utf-8", + ) + + endpoints = load_endpoints(str(registry_path)) + + assert endpoints["my-model"][0]["max_concurrent"] == 8 + + +def test_load_endpoints_rejects_invalid_max_concurrent(tmp_path: Path): + for bad_value in [0, -1, '"not_an_int"']: + registry_path = tmp_path / "endpoints.toml" + registry_path.write_text( + "[[endpoint]]\n" + 'endpoint_id = "my-model"\n' + 'model = "my/model"\n' + 'url = "https://a.example/v1"\n' + 'key = "A_KEY"\n' + f"max_concurrent = {bad_value}\n", + encoding="utf-8", + ) + + endpoints = load_endpoints(str(registry_path)) + # Invalid values cause load to fail and return empty + assert endpoints == {} + + def test_qwen3_vl_endpoint_ids_map_to_vl_models(): endpoints = load_endpoints("./configs/endpoints.toml") diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index d7658f7ed..fb66a06f8 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -71,6 +71,8 @@ TokenUsage, ) from verifiers.utils.async_utils import ( + EndpointDispatcher, + NullEndpointDispatcher, maybe_retry, maybe_semaphore, with_sem, @@ -977,6 +979,7 @@ async def generate( on_start: StartCallback | None = None, on_progress: ProgressCallback | None = None, on_log: LogCallback | None = None, + dispatcher: EndpointDispatcher | NullEndpointDispatcher | None = None, ) -> GenerateOutputs: """ Generate rollouts for a set of inputs. @@ -1143,52 +1146,100 @@ def get_client_for_group() -> AsyncOpenAI | ClientConfig: tasks: dict[asyncio.Task, int] = {} try: - # create tasks based on mode - if independent_scoring: - on_start(raw_inputs, filtered_inputs) - for i, rollout_input in enumerate(filtered_inputs): - task = asyncio.create_task( - with_sem( - sem, - self.run_rollout( - rollout_input, - get_client_for_group(), - model, - sampling_args, - max_retries=max_retries, - state_columns=state_columns, - ), - ), - ) - tasks[task] = i + if dispatcher is not None: + # ----- dispatcher-based path ----- + async def _dispatched_rollout( + rollout_input: RolloutInput, + ) -> RolloutOutput: + async with dispatcher.acquire(count=1) as slot: + return await self.run_rollout( + rollout_input, + slot.config, + model, + sampling_args, + max_retries=max_retries, + state_columns=state_columns, + ) + + async def _dispatched_group( + group_input: list[RolloutInput], + ) -> list[RolloutOutput]: + async with dispatcher.acquire(count=len(group_input)) as slot: + return await self.run_group( + group_input, + slot.config, + model, + sampling_args, + max_retries=max_retries, + state_columns=state_columns, + ) + + if independent_scoring: + on_start(raw_inputs, filtered_inputs) + for i, rollout_input in enumerate(filtered_inputs): + task = asyncio.create_task( + _dispatched_rollout(rollout_input) + ) + tasks[task] = i + else: + group_inputs: dict[int, list[RolloutInput]] = defaultdict(list) + for rollout_input in filtered_inputs: + example_id = rollout_input["example_id"] + group_inputs[example_id].append(rollout_input) + filtered_group_inputs = list(group_inputs.values()) + on_start(raw_inputs, filtered_group_inputs) + + for i, group_input in enumerate(filtered_group_inputs): + task = asyncio.create_task(_dispatched_group(group_input)) + tasks[task] = i else: - group_inputs: dict[int, list[RolloutInput]] = defaultdict(list) - for rollout_input in filtered_inputs: - example_id = rollout_input["example_id"] - group_inputs[example_id].append(rollout_input) - filtered_group_inputs = list(group_inputs.values()) - on_start(raw_inputs, filtered_group_inputs) - - for i, group_input in enumerate(filtered_group_inputs): - # For grouped scoring, keep each group on one endpoint so - # rollouts in the same group can benefit from shared KV cache. - group_client: AsyncOpenAI | ClientConfig = ( - get_client_for_group() - ) - task = asyncio.create_task( - with_sem( - sem, - self.run_group( - group_input, - group_client, - model, - sampling_args, - max_retries=max_retries, - state_columns=state_columns, + # ----- legacy path (semaphore + round-robin) ----- + # create tasks based on mode + if independent_scoring: + on_start(raw_inputs, filtered_inputs) + for i, rollout_input in enumerate(filtered_inputs): + task = asyncio.create_task( + with_sem( + sem, + self.run_rollout( + rollout_input, + get_client_for_group(), + model, + sampling_args, + max_retries=max_retries, + state_columns=state_columns, + ), + ), + ) + tasks[task] = i + else: + group_inputs: dict[int, list[RolloutInput]] = defaultdict(list) + for rollout_input in filtered_inputs: + example_id = rollout_input["example_id"] + group_inputs[example_id].append(rollout_input) + filtered_group_inputs = list(group_inputs.values()) + on_start(raw_inputs, filtered_group_inputs) + + for i, group_input in enumerate(filtered_group_inputs): + # For grouped scoring, keep each group on one endpoint so + # rollouts in the same group can benefit from shared KV cache. + group_client: AsyncOpenAI | ClientConfig = ( + get_client_for_group() + ) + task = asyncio.create_task( + with_sem( + sem, + self.run_group( + group_input, + group_client, + model, + sampling_args, + max_retries=max_retries, + state_columns=state_columns, + ), ), - ), - ) - tasks[task] = i + ) + tasks[task] = i for coro in asyncio.as_completed(tasks.keys()): result = await coro @@ -1301,6 +1352,7 @@ async def evaluate( on_start: StartCallback | None = None, on_progress: ProgressCallback | None = None, on_log: LogCallback | None = None, + dispatcher: EndpointDispatcher | NullEndpointDispatcher | None = None, **kwargs, ) -> GenerateOutputs: """ @@ -1323,6 +1375,7 @@ async def evaluate( on_start=on_start, on_progress=on_progress, on_log=on_log, + dispatcher=dispatcher, **kwargs, ) diff --git a/verifiers/scripts/eval.py b/verifiers/scripts/eval.py index 3dce167c7..f6219d591 100644 --- a/verifiers/scripts/eval.py +++ b/verifiers/scripts/eval.py @@ -449,10 +449,13 @@ def build_eval_config(raw: dict) -> EvalConfig: resolved_api_key_var = api_key_var endpoint_configs: list[EndpointClientConfig] = [] + has_variant_concurrency = endpoint_group is not None and any( + ep.get("max_concurrent") is not None for ep in endpoint_group + ) if ( endpoint_group is not None and not api_base_url_override - and len(endpoint_group) > 1 + and (len(endpoint_group) > 1 or has_variant_concurrency) ): endpoint_configs = [ EndpointClientConfig( @@ -461,6 +464,7 @@ def build_eval_config(raw: dict) -> EvalConfig: ), api_base_url=endpoint["url"], extra_headers=merged_headers, + max_concurrent=endpoint.get("max_concurrent"), ) for endpoint in endpoint_group ] diff --git a/verifiers/types.py b/verifiers/types.py index 52a905634..4041d270f 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -254,7 +254,16 @@ class RolloutScores(TypedDict): metrics: dict[str, list[float]] -Endpoint = TypedDict("Endpoint", {"key": str, "url": str, "model": str}) +class _EndpointRequired(TypedDict): + key: str + url: str + model: str + + +class Endpoint(_EndpointRequired, total=False): + max_concurrent: int + + Endpoints = dict[str, list[Endpoint]] @@ -324,6 +333,7 @@ class EndpointClientConfig(BaseModel): max_keepalive_connections: int = 28000 max_retries: int = 10 extra_headers: dict[str, str] = Field(default_factory=dict) + max_concurrent: int | None = None ClientConfig.model_rebuild() diff --git a/verifiers/utils/async_utils.py b/verifiers/utils/async_utils.py index 9291298bf..5519489d9 100644 --- a/verifiers/utils/async_utils.py +++ b/verifiers/utils/async_utils.py @@ -1,9 +1,22 @@ +from __future__ import annotations + import asyncio import inspect import logging +import sys from collections.abc import Coroutine +from contextlib import asynccontextmanager +from dataclasses import dataclass, field from time import perf_counter -from typing import Any, AsyncContextManager, Callable, Optional, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + AsyncContextManager, + AsyncIterator, + Callable, + Optional, + TypeVar, +) import tenacity as tc @@ -11,6 +24,9 @@ from verifiers.utils.error_utils import ErrorChain from verifiers.utils.logging_utils import print_time +if TYPE_CHECKING: + from verifiers.types import ClientConfig + logger = logging.getLogger(__name__) T = TypeVar("T") @@ -41,6 +57,92 @@ async def __aexit__(self, exc_type, exc_value, traceback): return False +@dataclass +class EndpointVariantSlot: + """Tracks one variant's client config and concurrency capacity.""" + + config: ClientConfig + max_concurrent: int = sys.maxsize + active: int = field(default=0, init=False) + + @property + def available(self) -> int: + return self.max_concurrent - self.active + + +class EndpointDispatcher: + """Least-loaded dispatch with asyncio.Condition for blocking. + + Shared across all evals hitting the same endpoint_id so that + per-variant concurrency limits are respected globally. + """ + + def __init__(self, variants: list[EndpointVariantSlot]) -> None: + if not variants: + raise ValueError("EndpointDispatcher requires at least one variant") + self._variants = variants + self._condition = asyncio.Condition() + + @asynccontextmanager + async def acquire(self, count: int = 1) -> AsyncIterator[EndpointVariantSlot]: + """Acquire a slot on the least-loaded variant that can fit *count* concurrent items.""" + variant: EndpointVariantSlot | None = None + async with self._condition: + while True: + # Find variant with most available capacity that can fit count + best: EndpointVariantSlot | None = None + for v in self._variants: + if v.available >= count and ( + best is None or v.available > best.available + ): + best = v + if best is not None: + variant = best + variant.active += count + break + + # Edge case: count exceeds every variant's max_concurrent. + # Wait for the largest variant to be fully idle, then allow through. + largest = max(self._variants, key=lambda v: v.max_concurrent) + if count > largest.max_concurrent and largest.active == 0: + variant = largest + variant.active += count + break + + await self._condition.wait() + + try: + yield variant + finally: + async with self._condition: + variant.active -= count + self._condition.notify_all() + + +class NullEndpointDispatcher: + """Backward-compatible round-robin dispatcher (no concurrency gating). + + Same ``acquire(count)`` interface as :class:`EndpointDispatcher` but + cycles through configs without blocking. + """ + + def __init__(self, configs: list[ClientConfig]) -> None: + if not configs: + raise ValueError("NullEndpointDispatcher requires at least one config") + self._configs = configs + self._idx = 0 + self._lock = asyncio.Lock() + + @asynccontextmanager + async def acquire(self, count: int = 1) -> AsyncIterator[EndpointVariantSlot]: + """Round-robin pick without blocking.""" + async with self._lock: + config = self._configs[self._idx % len(self._configs)] + self._idx += 1 + slot = EndpointVariantSlot(config=config, max_concurrent=sys.maxsize) + yield slot + + async def maybe_semaphore( limit: Optional[int] = None, ) -> AsyncContextManager: diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 07e488a77..c875d6207 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -5,6 +5,7 @@ import logging import math import os +import sys import time from collections import Counter, defaultdict from collections.abc import Mapping @@ -34,7 +35,13 @@ RolloutOutput, StartCallback, ) -from verifiers.utils.async_utils import EventLoopLagMonitor +from verifiers.utils.async_utils import ( + EndpointDispatcher, + EndpointVariantSlot, + EventLoopLagMonitor, + NullEndpointDispatcher, +) +from verifiers.utils.client_utils import resolve_client_configs from verifiers.utils.logging_utils import print_prompt_completions_sample, print_time from verifiers.utils.path_utils import get_eval_results_path @@ -67,6 +74,11 @@ def _coerce_endpoint(raw_endpoint: object, source: str) -> Endpoint: f"Fields 'model', 'url', and 'key' must all be strings in {source}" ) + max_concurrent = raw_endpoint_dict.get("max_concurrent") + if max_concurrent is not None: + if not isinstance(max_concurrent, int) or max_concurrent <= 0: + raise ValueError(f"'max_concurrent' must be a positive integer in {source}") + return Endpoint(model=model, url=url, key=key, max_concurrent=max_concurrent) return Endpoint(model=model, url=url, key=key) @@ -533,12 +545,65 @@ def quiet_datasets(): enable_progress_bar() +def _build_dispatchers( + evals: list[EvalConfig], +) -> dict[str | None, EndpointDispatcher | NullEndpointDispatcher]: + """Build per-endpoint dispatchers from eval configs. + + Groups evals by ``endpoint_id`` and, for each unique id that carries + ``endpoint_configs`` on its ``client_config``: + + * If **any** variant has ``max_concurrent`` set, create an + :class:`EndpointDispatcher` with one :class:`EndpointVariantSlot` per + variant (variants without ``max_concurrent`` get ``sys.maxsize`` + capacity). + * Otherwise, create a :class:`NullEndpointDispatcher` for plain + round-robin. + + Returns a mapping from ``endpoint_id`` (or ``None``) to the dispatcher. + """ + dispatchers: dict[str | None, EndpointDispatcher | NullEndpointDispatcher] = {} + + # Collect unique endpoint_ids, take the first config as representative + seen: dict[str | None, EvalConfig] = {} + for ec in evals: + if ec.endpoint_id not in seen: + seen[ec.endpoint_id] = ec + + for endpoint_id, ec in seen.items(): + if not ec.client_config.endpoint_configs: + continue + + resolved = resolve_client_configs(ec.client_config) + endpoint_cfgs = ec.client_config.endpoint_configs + has_any_concurrency = any(ep.max_concurrent is not None for ep in endpoint_cfgs) + + if has_any_concurrency: + slots = [ + EndpointVariantSlot( + config=cfg, + max_concurrent=( + ep.max_concurrent + if ep.max_concurrent is not None + else sys.maxsize + ), + ) + for cfg, ep in zip(resolved, endpoint_cfgs) + ] + dispatchers[endpoint_id] = EndpointDispatcher(slots) + else: + dispatchers[endpoint_id] = NullEndpointDispatcher(resolved) + + return dispatchers + + async def run_evaluation( config: EvalConfig, on_start: StartCallback | None = None, on_log_file: Callable[[Path], None] | None = None, on_progress: ProgressCallback | None = None, on_log: LogCallback | None = None, + dispatcher: EndpointDispatcher | NullEndpointDispatcher | None = None, ) -> GenerateOutputs: # load environment vf_env = vf.load_environment(env_id=config.env_id, **config.env_args) @@ -573,21 +638,30 @@ async def run_evaluation( f"Configuration: num_examples={config.num_examples}, rollouts_per_example={config.rollouts_per_example}, max_concurrent={config.max_concurrent}" ) - effective_group_max_concurrent = config.max_concurrent - if ( - not config.independent_scoring - and config.max_concurrent > 0 - and config.rollouts_per_example > 1 - ): - # Grouped scoring applies the semaphore at group level. Convert - # rollout-level concurrency to group-level slots. - effective_group_max_concurrent = math.ceil( - config.max_concurrent / config.rollouts_per_example - ) - if config.num_examples > 0: - effective_group_max_concurrent = min( - effective_group_max_concurrent, config.num_examples + if dispatcher is not None: + # Dispatcher handles concurrency — skip per-eval semaphore + if config.max_concurrent > 0: + logger.debug( + "Endpoint-level dispatcher active; ignoring eval-level max_concurrent=%d", + config.max_concurrent, + ) + effective_group_max_concurrent = -1 # disable semaphore + else: + effective_group_max_concurrent = config.max_concurrent + if ( + not config.independent_scoring + and config.max_concurrent > 0 + and config.rollouts_per_example > 1 + ): + # Grouped scoring applies the semaphore at group level. Convert + # rollout-level concurrency to group-level slots. + effective_group_max_concurrent = math.ceil( + config.max_concurrent / config.rollouts_per_example ) + if config.num_examples > 0: + effective_group_max_concurrent = min( + effective_group_max_concurrent, config.num_examples + ) outputs = await vf_env.evaluate( client=config.client_config, @@ -606,6 +680,7 @@ async def run_evaluation( on_start=on_start, on_progress=on_progress, on_log=on_log, + dispatcher=dispatcher, ) finally: await vf_env.stop_server() @@ -618,9 +693,17 @@ async def run_evaluations(config: EvalRunConfig) -> None: event_loop_lag_monitor = EventLoopLagMonitor() event_loop_lag_monitor.run_in_background() + dispatchers = _build_dispatchers(config.evals) + start_time = time.time() all_results = await asyncio.gather( - *[run_evaluation(eval_config) for eval_config in config.evals] + *[ + run_evaluation( + eval_config, + dispatcher=dispatchers.get(eval_config.endpoint_id), + ) + for eval_config in config.evals + ] ) end_time = time.time() event_loop_lags = event_loop_lag_monitor.get_lags() @@ -657,10 +740,14 @@ async def run_evaluations_tui(config: EvalRunConfig, tui_mode: bool = True) -> N await run_evaluations(config) return + dispatchers = _build_dispatchers(config.evals) + display = EvalDisplay(config.evals, screen=tui_mode) async def run_with_progress( - env_config: EvalConfig, env_idx: int + env_config: EvalConfig, + env_idx: int, + dispatcher: EndpointDispatcher | NullEndpointDispatcher | None = None, ) -> GenerateOutputs: """Run a single evaluation with display progress updates.""" @@ -708,6 +795,7 @@ def register_log_file(log_file: Path) -> None: on_progress=on_progress, on_log=on_log, on_log_file=register_log_file, + dispatcher=dispatcher, ) # get save path if results were saved @@ -738,7 +826,11 @@ async def refresh_loop() -> None: try: await asyncio.gather( *[ - run_with_progress(env_config, idx) + run_with_progress( + env_config, + idx, + dispatcher=dispatchers.get(env_config.endpoint_id), + ) for idx, env_config in enumerate(config.evals) ], return_exceptions=True, From 2338043f5e35b33d3534e94db75d221e97141e32 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 23 Feb 2026 18:04:15 +0000 Subject: [PATCH 2/9] run ruff --- verifiers/envs/environment.py | 4 +--- verifiers/types.py | 5 ++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 051e5d32f..2ddb70a86 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -1083,9 +1083,7 @@ async def _dispatched_group( for i, group_input in enumerate(filtered_group_inputs): # For grouped scoring, keep each group on one endpoint so # rollouts in the same group can benefit from shared KV cache. - group_client: Client | ClientConfig = ( - get_client_for_group() - ) + group_client: Client | ClientConfig = get_client_for_group() task = asyncio.create_task( with_sem( sem, diff --git a/verifiers/types.py b/verifiers/types.py index e76647882..b81484709 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -23,9 +23,9 @@ from verifiers.errors import Error if sys.version_info < (3, 12): - from typing_extensions import NotRequired, TypedDict + from typing_extensions import TypedDict else: - from typing import NotRequired, TypedDict + from typing import TypedDict # Client / message type literals ClientType = Literal[ @@ -397,7 +397,6 @@ class Endpoint(_EndpointRequired, total=False): max_concurrent: int - Endpoints = dict[str, list[Endpoint]] From 58dcaa5f94bf44a0ed2dd47ceb2e91cc5ef77574 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 23 Feb 2026 18:23:35 +0000 Subject: [PATCH 3/9] update docs --- docs/evaluation.md | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/docs/evaluation.md b/docs/evaluation.md index 7b3a2e69e..63caaae2f 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -95,7 +95,25 @@ api_client_type = "anthropic_messages" Each endpoint entry supports an optional `api_client_type` field to select the client implementation (defaults to `"openai_chat_completions"`). Use `"anthropic_messages"` for Anthropic models when calling the Anthropic API directly. -To define equivalent replicas, add multiple `[[endpoint]]` entries with the same `endpoint_id`. +To define equivalent replicas, add multiple `[[endpoint]]` entries with the same `endpoint_id`. You can optionally set `max_concurrent` on each variant to limit how many requests it handles simultaneously: + +```toml +[[endpoint]] +endpoint_id = "my-model" +model = "my-org/my-model" +url = "https://fast-host.example.com/v1" +key = "API_KEY" +max_concurrent = 64 + +[[endpoint]] +endpoint_id = "my-model" +model = "my-org/my-model" +url = "https://slow-host.example.com/v1" +key = "API_KEY" +max_concurrent = 16 +``` + +When any variant has `max_concurrent` set, the evaluator uses least-loaded dispatch: each request is routed to the variant with the most available capacity. Variants without `max_concurrent` have no limit. When no variant has `max_concurrent`, requests are distributed round-robin. Then use the alias directly: @@ -145,6 +163,8 @@ Multiple rollouts per example enable metrics like pass@k and help measure varian By default, scoring runs interleaved with generation. Use `--no-interleave-scoring` to score all rollouts after generation completes. +When per-variant `max_concurrent` limits are configured in the endpoint registry, the endpoint dispatcher manages concurrency globally across all variants and the `--max-concurrent` flag is ignored. + The `--max-retries` flag enables automatic retry with exponential backoff when rollouts fail due to transient infrastructure errors (e.g., sandbox timeouts, API failures). ### Output and Saving From 05e7f0bb02ee1f54b0a5b08fe6b49db22b393b50 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 23 Feb 2026 22:51:27 +0000 Subject: [PATCH 4/9] clean up dispatch: rename, remove NullEndpointDispatcher, enforce all-or-nothing concurrency --- docs/reference.md | 13 ++++-- tests/test_endpoint_dispatcher.py | 71 +++++++--------------------- verifiers/envs/environment.py | 22 +++++---- verifiers/utils/async_utils.py | 42 ++++------------- verifiers/utils/eval_utils.py | 78 ++++++++++++++++--------------- 5 files changed, 89 insertions(+), 137 deletions(-) diff --git a/docs/reference.md b/docs/reference.md index b0af77fcc..24661fe36 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -701,9 +701,10 @@ class EndpointClientConfig(BaseModel): max_keepalive_connections: int = 28000 max_retries: int = 10 extra_headers: dict[str, str] = {} + max_concurrent: int | None = None ``` -Leaf endpoint configuration used inside `ClientConfig.endpoint_configs`. Has the same fields as `ClientConfig` except `endpoint_configs` itself, preventing recursive nesting. +Leaf endpoint configuration used inside `ClientConfig.endpoint_configs`. Has the same fields as `ClientConfig` except `endpoint_configs` itself, preventing recursive nesting. The optional `max_concurrent` field limits how many concurrent requests this variant handles; see [Per-Variant Concurrency](evaluation.md#concurrency). ### EvalConfig @@ -733,11 +734,17 @@ class EvalConfig(BaseModel): ### Endpoint ```python -Endpoint = TypedDict("Endpoint", {"key": str, "url": str, "model": str}) +class Endpoint(TypedDict, total=False): + key: str # required + url: str # required + model: str # required + api_client_type: ClientType + max_concurrent: int + Endpoints = dict[str, list[Endpoint]] ``` -`Endpoints` maps an endpoint id to one or more endpoint variants. A single variant is represented as a one-item list. +`Endpoints` maps an endpoint id to one or more endpoint variants. A single variant is represented as a one-item list. The optional `max_concurrent` field enables per-variant concurrency limiting with least-loaded dispatch. --- diff --git a/tests/test_endpoint_dispatcher.py b/tests/test_endpoint_dispatcher.py index 1cfad8765..45cae7d72 100644 --- a/tests/test_endpoint_dispatcher.py +++ b/tests/test_endpoint_dispatcher.py @@ -1,13 +1,11 @@ import asyncio -import sys import pytest from verifiers.types import ClientConfig from verifiers.utils.async_utils import ( - EndpointDispatcher, - EndpointVariantSlot, - NullEndpointDispatcher, + EndpointSlot, + LeastLoadedDispatcher, ) @@ -15,24 +13,24 @@ def _make_config(url: str = "https://a.example/v1") -> ClientConfig: return ClientConfig(api_base_url=url) -class TestEndpointVariantSlot: +class TestEndpointSlot: def test_available_reflects_capacity(self): - slot = EndpointVariantSlot(config=_make_config(), max_concurrent=10) + slot = EndpointSlot(config=_make_config(), max_concurrent=10) assert slot.available == 10 slot.active = 3 assert slot.available == 7 -class TestEndpointDispatcher: +class TestLeastLoadedDispatcher: @pytest.mark.asyncio async def test_least_loaded_picks_emptier_variant(self): - slot_a = EndpointVariantSlot( + slot_a = EndpointSlot( config=_make_config("https://a.example/v1"), max_concurrent=4 ) - slot_b = EndpointVariantSlot( + slot_b = EndpointSlot( config=_make_config("https://b.example/v1"), max_concurrent=4 ) - dispatcher = EndpointDispatcher([slot_a, slot_b]) + dispatcher = LeastLoadedDispatcher([slot_a, slot_b]) # Acquire one on each — first should go to whichever has more available async with dispatcher.acquire() as got1: @@ -45,8 +43,8 @@ async def test_least_loaded_picks_emptier_variant(self): @pytest.mark.asyncio async def test_blocks_when_all_full_then_unblocks(self): - slot = EndpointVariantSlot(config=_make_config(), max_concurrent=1) - dispatcher = EndpointDispatcher([slot]) + slot = EndpointSlot(config=_make_config(), max_concurrent=1) + dispatcher = LeastLoadedDispatcher([slot]) acquired = asyncio.Event() released = asyncio.Event() @@ -77,8 +75,8 @@ async def waiter(): @pytest.mark.asyncio async def test_count_parameter_consumes_correct_slots(self): - slot = EndpointVariantSlot(config=_make_config(), max_concurrent=4) - dispatcher = EndpointDispatcher([slot]) + slot = EndpointSlot(config=_make_config(), max_concurrent=4) + dispatcher = LeastLoadedDispatcher([slot]) async with dispatcher.acquire(count=3) as got: assert got is slot @@ -90,8 +88,8 @@ async def test_count_parameter_consumes_correct_slots(self): @pytest.mark.asyncio async def test_releases_on_exception(self): - slot = EndpointVariantSlot(config=_make_config(), max_concurrent=2) - dispatcher = EndpointDispatcher([slot]) + slot = EndpointSlot(config=_make_config(), max_concurrent=2) + dispatcher = LeastLoadedDispatcher([slot]) with pytest.raises(RuntimeError, match="boom"): async with dispatcher.acquire(): @@ -102,8 +100,8 @@ async def test_releases_on_exception(self): @pytest.mark.asyncio async def test_oversize_count_waits_for_full_idle(self): """When count exceeds every variant's max_concurrent, wait for idle.""" - slot = EndpointVariantSlot(config=_make_config(), max_concurrent=2) - dispatcher = EndpointDispatcher([slot]) + slot = EndpointSlot(config=_make_config(), max_concurrent=2) + dispatcher = LeastLoadedDispatcher([slot]) released = asyncio.Event() @@ -129,39 +127,4 @@ async def oversize(): def test_empty_variants_raises(self): with pytest.raises(ValueError): - EndpointDispatcher([]) - - -class TestNullEndpointDispatcher: - @pytest.mark.asyncio - async def test_round_robins_without_blocking(self): - configs = [ - _make_config("https://a.example/v1"), - _make_config("https://b.example/v1"), - _make_config("https://c.example/v1"), - ] - dispatcher = NullEndpointDispatcher(configs) - - urls = [] - for _ in range(6): - async with dispatcher.acquire() as slot: - urls.append(slot.config.api_base_url) - - assert urls == [ - "https://a.example/v1", - "https://b.example/v1", - "https://c.example/v1", - "https://a.example/v1", - "https://b.example/v1", - "https://c.example/v1", - ] - - @pytest.mark.asyncio - async def test_slot_has_maxsize_capacity(self): - dispatcher = NullEndpointDispatcher([_make_config()]) - async with dispatcher.acquire() as slot: - assert slot.max_concurrent == sys.maxsize - - def test_empty_configs_raises(self): - with pytest.raises(ValueError): - NullEndpointDispatcher([]) + LeastLoadedDispatcher([]) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 2ddb70a86..2fafa02a4 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -65,8 +65,7 @@ Tool, ) from verifiers.utils.async_utils import ( - EndpointDispatcher, - NullEndpointDispatcher, + LeastLoadedDispatcher, maybe_retry, maybe_semaphore, with_sem, @@ -828,7 +827,7 @@ async def generate( on_start: StartCallback | None = None, on_progress: ProgressCallback | list[ProgressCallback] | None = None, on_log: LogCallback | None = None, - dispatcher: EndpointDispatcher | NullEndpointDispatcher | None = None, + dispatcher: LeastLoadedDispatcher | None = None, ) -> GenerateOutputs: """ Generate rollouts for a set of inputs. @@ -838,6 +837,9 @@ async def generate( on_progress: Progress callback(s). None uses the default tqdm progress bar. A single callback replaces the default. A list of callbacks runs alongside the default. + dispatcher: Optional LeastLoadedDispatcher for per-variant concurrency. + When provided, the dispatcher handles client selection and concurrency. + When None, falls back to semaphore + round-robin dispatch. """ from datasets import Dataset from tqdm import tqdm @@ -916,7 +918,6 @@ def default_on_progress(*a, **kw): elif isinstance(inputs, list): raw_inputs = inputs - # set up semaphores sem = await maybe_semaphore(max_concurrent) # set up sampling args @@ -1007,7 +1008,7 @@ def get_client_for_group() -> Client | ClientConfig: tasks: dict[asyncio.Task, int] = {} try: if dispatcher is not None: - # ----- dispatcher-based path ----- + # Per-variant concurrency via LeastLoadedDispatcher async def _dispatched_rollout( rollout_input: RolloutInput, ) -> RolloutOutput: @@ -1050,11 +1051,12 @@ async def _dispatched_group( on_start(raw_inputs, filtered_group_inputs) for i, group_input in enumerate(filtered_group_inputs): - task = asyncio.create_task(_dispatched_group(group_input)) + task = asyncio.create_task( + _dispatched_group(group_input) + ) tasks[task] = i else: - # ----- legacy path (semaphore + round-robin) ----- - # create tasks based on mode + # Legacy path: semaphore + round-robin if independent_scoring: on_start(raw_inputs, filtered_inputs) for i, rollout_input in enumerate(filtered_inputs): @@ -1083,7 +1085,7 @@ async def _dispatched_group( for i, group_input in enumerate(filtered_group_inputs): # For grouped scoring, keep each group on one endpoint so # rollouts in the same group can benefit from shared KV cache. - group_client: Client | ClientConfig = get_client_for_group() + group_client = get_client_for_group() task = asyncio.create_task( with_sem( sem, @@ -1210,7 +1212,7 @@ async def evaluate( on_start: StartCallback | None = None, on_progress: ProgressCallback | list[ProgressCallback] | None = None, on_log: LogCallback | None = None, - dispatcher: EndpointDispatcher | NullEndpointDispatcher | None = None, + dispatcher: LeastLoadedDispatcher | None = None, **kwargs, ) -> GenerateOutputs: """ diff --git a/verifiers/utils/async_utils.py b/verifiers/utils/async_utils.py index 87ce393ff..414f1f300 100644 --- a/verifiers/utils/async_utils.py +++ b/verifiers/utils/async_utils.py @@ -3,7 +3,6 @@ import asyncio import inspect import logging -import sys from collections.abc import Coroutine from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -27,6 +26,7 @@ if TYPE_CHECKING: from verifiers.types import ClientConfig + logger = logging.getLogger(__name__) T = TypeVar("T") @@ -58,11 +58,11 @@ async def __aexit__(self, exc_type, exc_value, traceback): @dataclass -class EndpointVariantSlot: +class EndpointSlot: """Tracks one variant's client config and concurrency capacity.""" config: ClientConfig - max_concurrent: int = sys.maxsize + max_concurrent: int active: int = field(default=0, init=False) @property @@ -70,27 +70,27 @@ def available(self) -> int: return self.max_concurrent - self.active -class EndpointDispatcher: +class LeastLoadedDispatcher: """Least-loaded dispatch with asyncio.Condition for blocking. Shared across all evals hitting the same endpoint_id so that per-variant concurrency limits are respected globally. """ - def __init__(self, variants: list[EndpointVariantSlot]) -> None: + def __init__(self, variants: list[EndpointSlot]) -> None: if not variants: - raise ValueError("EndpointDispatcher requires at least one variant") + raise ValueError("LeastLoadedDispatcher requires at least one variant") self._variants = variants self._condition = asyncio.Condition() @asynccontextmanager - async def acquire(self, count: int = 1) -> AsyncIterator[EndpointVariantSlot]: + async def acquire(self, count: int = 1) -> AsyncIterator[EndpointSlot]: """Acquire a slot on the least-loaded variant that can fit *count* concurrent items.""" - variant: EndpointVariantSlot | None = None + variant: EndpointSlot | None = None async with self._condition: while True: # Find variant with most available capacity that can fit count - best: EndpointVariantSlot | None = None + best: EndpointSlot | None = None for v in self._variants: if v.available >= count and ( best is None or v.available > best.available @@ -119,30 +119,6 @@ async def acquire(self, count: int = 1) -> AsyncIterator[EndpointVariantSlot]: self._condition.notify_all() -class NullEndpointDispatcher: - """Backward-compatible round-robin dispatcher (no concurrency gating). - - Same ``acquire(count)`` interface as :class:`EndpointDispatcher` but - cycles through configs without blocking. - """ - - def __init__(self, configs: list[ClientConfig]) -> None: - if not configs: - raise ValueError("NullEndpointDispatcher requires at least one config") - self._configs = configs - self._idx = 0 - self._lock = asyncio.Lock() - - @asynccontextmanager - async def acquire(self, count: int = 1) -> AsyncIterator[EndpointVariantSlot]: - """Round-robin pick without blocking.""" - async with self._lock: - config = self._configs[self._idx % len(self._configs)] - self._idx += 1 - slot = EndpointVariantSlot(config=config, max_concurrent=sys.maxsize) - yield slot - - async def maybe_semaphore( limit: Optional[int] = None, ) -> AsyncContextManager: diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index b52adbb99..d32bbd627 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -5,7 +5,6 @@ import logging import math import os -import sys import time from collections import Counter, defaultdict from collections.abc import Mapping @@ -33,10 +32,9 @@ StartCallback, ) from verifiers.utils.async_utils import ( - EndpointDispatcher, - EndpointVariantSlot, + EndpointSlot, EventLoopLagMonitor, - NullEndpointDispatcher, + LeastLoadedDispatcher, ) from verifiers.utils.client_utils import resolve_client_configs from verifiers.utils.import_utils import load_toml @@ -584,22 +582,22 @@ def quiet_datasets(): def _build_dispatchers( evals: list[EvalConfig], -) -> dict[str | None, EndpointDispatcher | NullEndpointDispatcher]: +) -> dict[str | None, LeastLoadedDispatcher]: """Build per-endpoint dispatchers from eval configs. - Groups evals by ``endpoint_id`` and, for each unique id that carries - ``endpoint_configs`` on its ``client_config``: + Groups evals by ``endpoint_id`` and, for each unique id where + variants have ``max_concurrent`` set, creates a + :class:`LeastLoadedDispatcher` with one :class:`EndpointSlot` per + variant. - * If **any** variant has ``max_concurrent`` set, create an - :class:`EndpointDispatcher` with one :class:`EndpointVariantSlot` per - variant (variants without ``max_concurrent`` get ``sys.maxsize`` - capacity). - * Otherwise, create a :class:`NullEndpointDispatcher` for plain - round-robin. + Per-variant concurrency is all-or-nothing: if any variant in an + endpoint group sets ``max_concurrent``, every variant must. + Endpoint groups without ``max_concurrent`` use the default + semaphore + round-robin path in ``environment.generate()``. Returns a mapping from ``endpoint_id`` (or ``None``) to the dispatcher. """ - dispatchers: dict[str | None, EndpointDispatcher | NullEndpointDispatcher] = {} + dispatchers: dict[str | None, LeastLoadedDispatcher] = {} # Collect unique endpoint_ids, take the first config as representative seen: dict[str | None, EvalConfig] = {} @@ -611,25 +609,28 @@ def _build_dispatchers( if not ec.client_config.endpoint_configs: continue - resolved = resolve_client_configs(ec.client_config) endpoint_cfgs = ec.client_config.endpoint_configs - has_any_concurrency = any(ep.max_concurrent is not None for ep in endpoint_cfgs) + has_concurrency = [ep.max_concurrent is not None for ep in endpoint_cfgs] - if has_any_concurrency: + if any(has_concurrency): + if not all(has_concurrency): + missing = [ + i for i, has in enumerate(has_concurrency) if not has + ] + raise ValueError( + f"Endpoint '{endpoint_id}': max_concurrent is set on some variants " + f"but missing on variant(s) {missing}. Either set max_concurrent on " + f"all variants or remove it entirely to use the global concurrency limit." + ) + resolved = resolve_client_configs(ec.client_config) slots = [ - EndpointVariantSlot( + EndpointSlot( config=cfg, - max_concurrent=( - ep.max_concurrent - if ep.max_concurrent is not None - else sys.maxsize - ), + max_concurrent=ep.max_concurrent, ) for cfg, ep in zip(resolved, endpoint_cfgs) ] - dispatchers[endpoint_id] = EndpointDispatcher(slots) - else: - dispatchers[endpoint_id] = NullEndpointDispatcher(resolved) + dispatchers[endpoint_id] = LeastLoadedDispatcher(slots) return dispatchers @@ -640,7 +641,7 @@ async def run_evaluation( on_log_file: Callable[[Path], None] | None = None, on_progress: ProgressCallback | list[ProgressCallback] | None = None, on_log: LogCallback | None = None, - dispatcher: EndpointDispatcher | NullEndpointDispatcher | None = None, + dispatcher: LeastLoadedDispatcher | None = None, ) -> GenerateOutputs: # load environment vf_env = vf.load_environment(env_id=config.env_id, **config.env_args) @@ -675,29 +676,32 @@ async def run_evaluation( f"Configuration: num_examples={config.num_examples}, rollouts_per_example={config.rollouts_per_example}, max_concurrent={config.max_concurrent}" ) + # Compute effective concurrency for generate(). if dispatcher is not None: - # Dispatcher handles concurrency — skip per-eval semaphore + # LeastLoadedDispatcher handles concurrency per-variant; + # disable the generate-level semaphore. if config.max_concurrent > 0: logger.debug( "Endpoint-level dispatcher active; ignoring eval-level max_concurrent=%d", config.max_concurrent, ) - effective_group_max_concurrent = -1 # disable semaphore + effective_max_concurrent = -1 else: - effective_group_max_concurrent = config.max_concurrent + # Legacy semaphore + round-robin path. The semaphore limits + # concurrent groups, so convert rollout-level concurrency to + # group-level slots. + effective_max_concurrent = config.max_concurrent if ( not config.independent_scoring and config.max_concurrent > 0 and config.rollouts_per_example > 1 ): - # Grouped scoring applies the semaphore at group level. Convert - # rollout-level concurrency to group-level slots. - effective_group_max_concurrent = math.ceil( + effective_max_concurrent = math.ceil( config.max_concurrent / config.rollouts_per_example ) if config.num_examples > 0: - effective_group_max_concurrent = min( - effective_group_max_concurrent, config.num_examples + effective_max_concurrent = min( + effective_max_concurrent, config.num_examples ) outputs = await vf_env.evaluate( @@ -706,7 +710,7 @@ async def run_evaluation( sampling_args=config.sampling_args, num_examples=config.num_examples, rollouts_per_example=config.rollouts_per_example, - max_concurrent=effective_group_max_concurrent, + max_concurrent=effective_max_concurrent, results_path=results_path, state_columns=config.state_columns, save_results=config.save_results, @@ -802,7 +806,7 @@ async def run_evaluations_tui(config: EvalRunConfig, tui_mode: bool = True) -> N async def run_with_progress( env_config: EvalConfig, env_idx: int, - dispatcher: EndpointDispatcher | NullEndpointDispatcher | None = None, + dispatcher: LeastLoadedDispatcher | None = None, ) -> GenerateOutputs: """Run a single evaluation with display progress updates.""" From 06c9ec00efbfb7e7f3b4a9b530ccd80d71dd094d Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 23 Feb 2026 22:53:13 +0000 Subject: [PATCH 5/9] fix ruff --- verifiers/envs/environment.py | 4 +--- verifiers/utils/eval_utils.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 2fafa02a4..aada2c61e 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -1051,9 +1051,7 @@ async def _dispatched_group( on_start(raw_inputs, filtered_group_inputs) for i, group_input in enumerate(filtered_group_inputs): - task = asyncio.create_task( - _dispatched_group(group_input) - ) + task = asyncio.create_task(_dispatched_group(group_input)) tasks[task] = i else: # Legacy path: semaphore + round-robin diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index d32bbd627..40a2ecdbf 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -614,9 +614,7 @@ def _build_dispatchers( if any(has_concurrency): if not all(has_concurrency): - missing = [ - i for i, has in enumerate(has_concurrency) if not has - ] + missing = [i for i, has in enumerate(has_concurrency) if not has] raise ValueError( f"Endpoint '{endpoint_id}': max_concurrent is set on some variants " f"but missing on variant(s) {missing}. Either set max_concurrent on " From 042aea7d5da78c0e1e081c30c994087eda1273af Mon Sep 17 00:00:00 2001 From: hallerite Date: Tue, 24 Feb 2026 13:04:38 +0000 Subject: [PATCH 6/9] fix leak & update docs --- docs/evaluation.md | 2 +- docs/reference.md | 2 +- tests/test_endpoint_dispatcher.py | 30 ++++++++++++++++++++++++++++++ verifiers/utils/async_utils.py | 14 +++++++++++--- 4 files changed, 43 insertions(+), 5 deletions(-) diff --git a/docs/evaluation.md b/docs/evaluation.md index 63caaae2f..970ba0866 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -113,7 +113,7 @@ key = "API_KEY" max_concurrent = 16 ``` -When any variant has `max_concurrent` set, the evaluator uses least-loaded dispatch: each request is routed to the variant with the most available capacity. Variants without `max_concurrent` have no limit. When no variant has `max_concurrent`, requests are distributed round-robin. +When variants have `max_concurrent` set, the evaluator uses least-loaded dispatch: each request is routed to the variant with the most available capacity. Per-variant concurrency is all-or-nothing — either every variant in an endpoint group sets `max_concurrent`, or none does. When no variant has `max_concurrent`, requests are distributed round-robin and the global `--max-concurrent` flag controls concurrency. Then use the alias directly: diff --git a/docs/reference.md b/docs/reference.md index 24661fe36..bb5197bb6 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -744,7 +744,7 @@ class Endpoint(TypedDict, total=False): Endpoints = dict[str, list[Endpoint]] ``` -`Endpoints` maps an endpoint id to one or more endpoint variants. A single variant is represented as a one-item list. The optional `max_concurrent` field enables per-variant concurrency limiting with least-loaded dispatch. +`Endpoints` maps an endpoint id to one or more endpoint variants. A single variant is represented as a one-item list. The `max_concurrent` field enables per-variant concurrency limiting with least-loaded dispatch; if set on any variant in a group, it must be set on all. --- diff --git a/tests/test_endpoint_dispatcher.py b/tests/test_endpoint_dispatcher.py index 45cae7d72..6770caf11 100644 --- a/tests/test_endpoint_dispatcher.py +++ b/tests/test_endpoint_dispatcher.py @@ -125,6 +125,36 @@ async def oversize(): await asyncio.wait_for(oversize_task, timeout=2.0) await holder_task + @pytest.mark.asyncio + async def test_cancellation_does_not_leak_capacity(self): + """Cancelled tasks must release their slots.""" + slot = EndpointSlot(config=_make_config(), max_concurrent=1) + dispatcher = LeastLoadedDispatcher([slot]) + + entered = asyncio.Event() + + async def hold_then_wait(): + async with dispatcher.acquire(): + entered.set() + await asyncio.sleep(999) # will be cancelled + + task = asyncio.create_task(hold_then_wait()) + await entered.wait() + assert slot.active == 1 + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + # Give the shielded notification task a chance to run + await asyncio.sleep(0) + + assert slot.active == 0, "capacity leaked after cancellation" + + # Verify the slot is actually usable again + async with dispatcher.acquire() as got: + assert got is slot + def test_empty_variants_raises(self): with pytest.raises(ValueError): LeastLoadedDispatcher([]) diff --git a/verifiers/utils/async_utils.py b/verifiers/utils/async_utils.py index 414f1f300..93c4edbad 100644 --- a/verifiers/utils/async_utils.py +++ b/verifiers/utils/async_utils.py @@ -83,6 +83,11 @@ def __init__(self, variants: list[EndpointSlot]) -> None: self._variants = variants self._condition = asyncio.Condition() + async def _notify(self) -> None: + """Wake all waiters under the condition lock.""" + async with self._condition: + self._condition.notify_all() + @asynccontextmanager async def acquire(self, count: int = 1) -> AsyncIterator[EndpointSlot]: """Acquire a slot on the least-loaded variant that can fit *count* concurrent items.""" @@ -114,9 +119,12 @@ async def acquire(self, count: int = 1) -> AsyncIterator[EndpointSlot]: try: yield variant finally: - async with self._condition: - variant.active -= count - self._condition.notify_all() + # Decrement synchronously — safe in asyncio's cooperative model + # since no other task can interleave between await points. + variant.active -= count + # Shield notification so waiters are woken even if our task + # is cancelled (the shielded inner task keeps running). + await asyncio.shield(self._notify()) async def maybe_semaphore( From 47e8607aba37be4a044d98d5cb7edd9454b1015a Mon Sep 17 00:00:00 2001 From: hallerite Date: Tue, 24 Feb 2026 13:10:44 +0000 Subject: [PATCH 7/9] fix group too big for endpoint --- tests/test_endpoint_dispatcher.py | 28 +++++----------------------- verifiers/utils/async_utils.py | 22 +++++++++++++--------- 2 files changed, 18 insertions(+), 32 deletions(-) diff --git a/tests/test_endpoint_dispatcher.py b/tests/test_endpoint_dispatcher.py index 6770caf11..d8baf949c 100644 --- a/tests/test_endpoint_dispatcher.py +++ b/tests/test_endpoint_dispatcher.py @@ -98,32 +98,14 @@ async def test_releases_on_exception(self): assert slot.active == 0 @pytest.mark.asyncio - async def test_oversize_count_waits_for_full_idle(self): - """When count exceeds every variant's max_concurrent, wait for idle.""" + async def test_oversize_count_raises(self): + """count > every variant's max_concurrent is a config error.""" slot = EndpointSlot(config=_make_config(), max_concurrent=2) dispatcher = LeastLoadedDispatcher([slot]) - released = asyncio.Event() - - async def holder(): - async with dispatcher.acquire(count=1): - await released.wait() - - holder_task = asyncio.create_task(holder()) - await asyncio.sleep(0.05) - - # count=5 exceeds max_concurrent=2 — must wait for full idle - async def oversize(): - async with dispatcher.acquire(count=5) as got: - assert got is slot - - oversize_task = asyncio.create_task(oversize()) - await asyncio.sleep(0.05) - assert not oversize_task.done() - - released.set() - await asyncio.wait_for(oversize_task, timeout=2.0) - await holder_task + with pytest.raises(ValueError, match="exceeds the largest variant"): + async with dispatcher.acquire(count=5): + pass @pytest.mark.asyncio async def test_cancellation_does_not_leak_capacity(self): diff --git a/verifiers/utils/async_utils.py b/verifiers/utils/async_utils.py index 93c4edbad..e1ce92527 100644 --- a/verifiers/utils/async_utils.py +++ b/verifiers/utils/async_utils.py @@ -90,7 +90,19 @@ async def _notify(self) -> None: @asynccontextmanager async def acquire(self, count: int = 1) -> AsyncIterator[EndpointSlot]: - """Acquire a slot on the least-loaded variant that can fit *count* concurrent items.""" + """Acquire a slot on the least-loaded variant that can fit *count* concurrent items. + + Raises ValueError if count exceeds every variant's max_concurrent, + since allowing it would defeat the configured concurrency limit. + """ + largest_cap = max(v.max_concurrent for v in self._variants) + if count > largest_cap: + raise ValueError( + f"Group size {count} exceeds the largest variant's " + f"max_concurrent ({largest_cap}). Each group must fit on a " + f"single variant. Increase max_concurrent or reduce " + f"rollouts_per_example." + ) variant: EndpointSlot | None = None async with self._condition: while True: @@ -106,14 +118,6 @@ async def acquire(self, count: int = 1) -> AsyncIterator[EndpointSlot]: variant.active += count break - # Edge case: count exceeds every variant's max_concurrent. - # Wait for the largest variant to be fully idle, then allow through. - largest = max(self._variants, key=lambda v: v.max_concurrent) - if count > largest.max_concurrent and largest.active == 0: - variant = largest - variant.active += count - break - await self._condition.wait() try: From a95ddd522bd6ff083b7c69011479df523e93dfc5 Mon Sep 17 00:00:00 2001 From: hallerite Date: Tue, 24 Feb 2026 13:21:56 +0000 Subject: [PATCH 8/9] update skill & fix small issue --- skills/evaluate-environments/SKILL.md | 16 ++++++++++++++++ verifiers/utils/eval_utils.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/skills/evaluate-environments/SKILL.md b/skills/evaluate-environments/SKILL.md index a31455281..913f11e15 100644 --- a/skills/evaluate-environments/SKILL.md +++ b/skills/evaluate-environments/SKILL.md @@ -43,6 +43,22 @@ model = "qwen/qwen3-32b-instruct" url = "https://api.pinference.ai/api/v1" key = "PRIME_API_KEY" ``` +7. For multi-host setups, set `max_concurrent` on every variant to enable least-loaded dispatch: +```toml +[[endpoint]] +endpoint_id = "qwen3-235b" +model = "qwen/qwen3-235b" +url = "https://fast-host.example/v1" +key = "API_KEY" +max_concurrent = 64 + +[[endpoint]] +endpoint_id = "qwen3-235b" +model = "qwen/qwen3-235b" +url = "https://slow-host.example/v1" +key = "API_KEY" +max_concurrent = 16 +``` ## Publish Gate Before Large Runs 1. After smoke tests pass and results look stable, proactively suggest pushing the environment to Hub before large eval sweeps or RL work. diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 40a2ecdbf..e60532641 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -75,7 +75,7 @@ def _coerce_endpoint(raw_endpoint: object, source: str) -> Endpoint: # Parse optional max_concurrent max_concurrent = raw_endpoint_dict.get("max_concurrent") if max_concurrent is not None: - if not isinstance(max_concurrent, int) or max_concurrent <= 0: + if isinstance(max_concurrent, bool) or not isinstance(max_concurrent, int) or max_concurrent <= 0: raise ValueError(f"'max_concurrent' must be a positive integer in {source}") endpoint["max_concurrent"] = max_concurrent From fbf26b249251d7b1877b5f2ff9796c2bfc83549c Mon Sep 17 00:00:00 2001 From: hallerite Date: Tue, 24 Feb 2026 18:55:58 +0000 Subject: [PATCH 9/9] elastic endpoint pool: re-read endpoints.toml mid-run --- tests/test_elastic_pool.py | 287 +++++++++++++++++++++++++++++++++ verifiers/envs/environment.py | 20 ++- verifiers/scripts/eval.py | 16 ++ verifiers/types.py | 4 + verifiers/utils/async_utils.py | 40 +++++ verifiers/utils/elastic.py | 134 +++++++++++++++ verifiers/utils/eval_utils.py | 68 ++++++-- 7 files changed, 547 insertions(+), 22 deletions(-) create mode 100644 tests/test_elastic_pool.py create mode 100644 verifiers/utils/elastic.py diff --git a/tests/test_elastic_pool.py b/tests/test_elastic_pool.py new file mode 100644 index 000000000..4a6680ce1 --- /dev/null +++ b/tests/test_elastic_pool.py @@ -0,0 +1,287 @@ +"""Tests for update_variants() and ElasticEndpointPool.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from verifiers.types import ClientConfig, EndpointClientConfig +from verifiers.utils.async_utils import EndpointSlot, LeastLoadedDispatcher +from verifiers.utils.elastic import ElasticEndpointPool + + +def _make_config(url: str = "https://a.example/v1") -> ClientConfig: + return ClientConfig(api_base_url=url) + + +def _make_slot(url: str, max_concurrent: int = 4) -> EndpointSlot: + return EndpointSlot(config=_make_config(url), max_concurrent=max_concurrent) + + +# --------------------------------------------------------------------------- +# update_variants tests +# --------------------------------------------------------------------------- + + +class TestUpdateVariants: + @pytest.mark.asyncio + async def test_adds_new(self): + slot_a = _make_slot("https://a.example/v1") + dispatcher = LeastLoadedDispatcher([slot_a]) + + slot_b = _make_slot("https://b.example/v1") + added, removed = await dispatcher.update_variants( + [_make_slot("https://a.example/v1"), slot_b] + ) + + assert added == 1 + assert removed == 0 + + @pytest.mark.asyncio + async def test_removes(self): + slot_a = _make_slot("https://a.example/v1") + slot_b = _make_slot("https://b.example/v1") + dispatcher = LeastLoadedDispatcher([slot_a, slot_b]) + + added, removed = await dispatcher.update_variants( + [_make_slot("https://a.example/v1")] + ) + + assert added == 0 + assert removed == 1 + + @pytest.mark.asyncio + async def test_preserves_active(self): + slot_a = _make_slot("https://a.example/v1", max_concurrent=4) + dispatcher = LeastLoadedDispatcher([slot_a]) + + # Simulate in-flight request + async with dispatcher.acquire() as got: + assert got.active == 1 + + # Update with same URL — should preserve the slot object and active count + added, removed = await dispatcher.update_variants( + [_make_slot("https://a.example/v1", max_concurrent=8)] + ) + assert added == 0 + assert removed == 0 + # Active count preserved on the original object + assert got.active == 1 + # max_concurrent updated + assert got.max_concurrent == 8 + + @pytest.mark.asyncio + async def test_wakes_waiters(self): + slot_a = _make_slot("https://a.example/v1", max_concurrent=1) + dispatcher = LeastLoadedDispatcher([slot_a]) + + acquired = asyncio.Event() + unblocked = asyncio.Event() + + async def holder(): + async with dispatcher.acquire(): + acquired.set() + # Hold slot while waiter tries to acquire + await unblocked.wait() + + async def waiter(): + await acquired.wait() + # This blocks because slot_a is full, but adding slot_b should unblock it + async with dispatcher.acquire() as got: + assert got.config.api_base_url == "https://b.example/v1" + unblocked.set() + + holder_task = asyncio.create_task(holder()) + waiter_task = asyncio.create_task(waiter()) + + await acquired.wait() + await asyncio.sleep(0.05) # let waiter block + + # Add a new endpoint — should wake the waiter + await dispatcher.update_variants( + [ + _make_slot("https://a.example/v1", max_concurrent=1), + _make_slot("https://b.example/v1", max_concurrent=1), + ] + ) + + await asyncio.wait_for(waiter_task, timeout=2.0) + await holder_task + + @pytest.mark.asyncio + async def test_rejects_empty(self): + dispatcher = LeastLoadedDispatcher([_make_slot("https://a.example/v1")]) + with pytest.raises(ValueError, match="at least one variant"): + await dispatcher.update_variants([]) + + +# --------------------------------------------------------------------------- +# ElasticEndpointPool tests +# --------------------------------------------------------------------------- + + +def _write_endpoints_toml(path: Path, entries: list[dict]) -> None: + """Write a minimal endpoints.toml file.""" + lines: list[str] = [] + for entry in entries: + lines.append("[[endpoint]]") + for k, v in entry.items(): + if isinstance(v, int): + lines.append(f'{k} = {v}') + else: + lines.append(f'{k} = "{v}"') + lines.append("") + path.write_text("\n".join(lines)) + + +class TestElasticEndpointPool: + @pytest.mark.asyncio + async def test_reload_updates_dispatcher(self, tmp_path: Path): + toml_file = tmp_path / "endpoints.toml" + _write_endpoints_toml( + toml_file, + [ + { + "endpoint_id": "my-ep", + "url": "https://a.example/v1", + "key": "KEY", + "model": "m1", + "max_concurrent": 4, + }, + ], + ) + + # Build initial dispatcher + slot = _make_slot("https://a.example/v1", max_concurrent=4) + dispatcher = LeastLoadedDispatcher([slot]) + + base_config = ClientConfig( + api_key_var="KEY", + api_base_url="https://a.example/v1", + endpoint_configs=[ + EndpointClientConfig( + api_key_var="KEY", + api_base_url="https://a.example/v1", + max_concurrent=4, + ) + ], + ) + + pool = ElasticEndpointPool( + dispatcher=dispatcher, + endpoints_path=str(toml_file), + endpoint_id="my-ep", + poll_interval=1.0, + base_client_config=base_config, + ) + + # Now update the file with a second endpoint + _write_endpoints_toml( + toml_file, + [ + { + "endpoint_id": "my-ep", + "url": "https://a.example/v1", + "key": "KEY", + "model": "m1", + "max_concurrent": 4, + }, + { + "endpoint_id": "my-ep", + "url": "https://b.example/v1", + "key": "KEY", + "model": "m1", + "max_concurrent": 8, + }, + ], + ) + + await pool._reload() + + # Dispatcher should now have 2 variants + assert len(dispatcher._variants) == 2 + + @pytest.mark.asyncio + async def test_reload_failure_keeps_previous(self, tmp_path: Path): + toml_file = tmp_path / "endpoints.toml" + _write_endpoints_toml( + toml_file, + [ + { + "endpoint_id": "my-ep", + "url": "https://a.example/v1", + "key": "KEY", + "model": "m1", + "max_concurrent": 4, + }, + ], + ) + + slot = _make_slot("https://a.example/v1", max_concurrent=4) + dispatcher = LeastLoadedDispatcher([slot]) + + base_config = ClientConfig( + api_key_var="KEY", + api_base_url="https://a.example/v1", + endpoint_configs=[ + EndpointClientConfig( + api_key_var="KEY", + api_base_url="https://a.example/v1", + max_concurrent=4, + ) + ], + ) + + pool = ElasticEndpointPool( + dispatcher=dispatcher, + endpoints_path=str(tmp_path / "nonexistent.toml"), + endpoint_id="my-ep", + poll_interval=1.0, + base_client_config=base_config, + ) + + # _reload should not raise — it logs a warning and keeps previous + await pool._reload() + assert len(dispatcher._variants) == 1 + + @pytest.mark.asyncio + async def test_start_stop(self, tmp_path: Path): + toml_file = tmp_path / "endpoints.toml" + _write_endpoints_toml( + toml_file, + [ + { + "endpoint_id": "my-ep", + "url": "https://a.example/v1", + "key": "KEY", + "model": "m1", + "max_concurrent": 4, + }, + ], + ) + + slot = _make_slot("https://a.example/v1", max_concurrent=4) + dispatcher = LeastLoadedDispatcher([slot]) + + base_config = ClientConfig( + api_key_var="KEY", + api_base_url="https://a.example/v1", + ) + + pool = ElasticEndpointPool( + dispatcher=dispatcher, + endpoints_path=str(toml_file), + endpoint_id="my-ep", + poll_interval=0.05, + base_client_config=base_config, + ) + + assert pool._task is None + pool.start() + assert pool._task is not None + assert not pool._task.done() + + await pool.stop() + assert pool._task is None diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index aada2c61e..e331664e7 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -1008,7 +1008,10 @@ def get_client_for_group() -> Client | ClientConfig: tasks: dict[asyncio.Task, int] = {} try: if dispatcher is not None: - # Per-variant concurrency via LeastLoadedDispatcher + # Per-variant concurrency via LeastLoadedDispatcher. + # Retries are applied *outside* acquire() so each attempt + # re-acquires a slot — critical when a server is preempted + # and removed by the elastic pool. async def _dispatched_rollout( rollout_input: RolloutInput, ) -> RolloutOutput: @@ -1018,7 +1021,7 @@ async def _dispatched_rollout( slot.config, model, sampling_args, - max_retries=max_retries, + max_retries=0, state_columns=state_columns, ) @@ -1031,15 +1034,22 @@ async def _dispatched_group( slot.config, model, sampling_args, - max_retries=max_retries, + max_retries=0, state_columns=state_columns, ) + _retried_rollout = maybe_retry( + _dispatched_rollout, max_retries=max_retries + ) + _retried_group = maybe_retry( + _dispatched_group, max_retries=max_retries + ) + if independent_scoring: on_start(raw_inputs, filtered_inputs) for i, rollout_input in enumerate(filtered_inputs): task = asyncio.create_task( - _dispatched_rollout(rollout_input) + _retried_rollout(rollout_input) ) tasks[task] = i else: @@ -1051,7 +1061,7 @@ async def _dispatched_group( on_start(raw_inputs, filtered_group_inputs) for i, group_input in enumerate(filtered_group_inputs): - task = asyncio.create_task(_dispatched_group(group_input)) + task = asyncio.create_task(_retried_group(group_input)) tasks[task] = i else: # Legacy path: semaphore + round-robin diff --git a/verifiers/scripts/eval.py b/verifiers/scripts/eval.py index 6ea3455a8..6b6a972ee 100644 --- a/verifiers/scripts/eval.py +++ b/verifiers/scripts/eval.py @@ -421,6 +421,19 @@ def build_eval_config(raw: dict) -> EvalConfig: "Set endpoints_path to an endpoints.toml file." ) + # Elastic mode validation + raw_elastic = raw.get("elastic", False) + if raw_elastic: + if raw_endpoint_id is None: + raise ValueError( + "'elastic=true' requires 'endpoint_id' to be set." + ) + if resolved_endpoints_file is None or resolved_endpoints_file.suffix != ".toml": + raise ValueError( + "'elastic=true' requires a TOML endpoints file. " + "Set endpoints_path to an endpoints.toml file." + ) + raw_model = raw_model_field if raw_model_field is not None else DEFAULT_MODEL endpoint_lookup_id = ( raw_endpoint_id if raw_endpoint_id is not None else raw_model @@ -624,6 +637,9 @@ def build_eval_config(raw: dict) -> EvalConfig: rollouts_per_example=rollouts_per_example, max_concurrent=raw.get("max_concurrent", DEFAULT_MAX_CONCURRENT), max_retries=raw.get("max_retries", 0), + elastic=raw.get("elastic", False), + elastic_poll_interval=raw.get("elastic_poll_interval", 10.0), + endpoints_path=str(endpoints_path), verbose=raw.get("verbose", False), debug=raw.get("debug", False), state_columns=raw.get("state_columns", []), diff --git a/verifiers/types.py b/verifiers/types.py index b81484709..c8dd71f95 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -491,6 +491,10 @@ class EvalConfig(BaseModel): independent_scoring: bool = False extra_env_kwargs: dict = {} max_retries: int = 0 + # elastic endpoint pool + elastic: bool = False + elastic_poll_interval: float = 10.0 + endpoints_path: str = "" # logging verbose: bool = False debug: bool = False diff --git a/verifiers/utils/async_utils.py b/verifiers/utils/async_utils.py index e1ce92527..180b21ad0 100644 --- a/verifiers/utils/async_utils.py +++ b/verifiers/utils/async_utils.py @@ -130,6 +130,46 @@ async def acquire(self, count: int = 1) -> AsyncIterator[EndpointSlot]: # is cancelled (the shielded inner task keeps running). await asyncio.shield(self._notify()) + async def update_variants( + self, new_variants: list[EndpointSlot] + ) -> tuple[int, int]: + """Replace the variant list, preserving in-flight slots. + + Endpoints are keyed by ``api_base_url``. Kept endpoints preserve + their existing :class:`EndpointSlot` (retaining the ``active`` + count); ``max_concurrent`` is updated from the new slot. New + endpoints get fresh slots and removed endpoints are dropped. + + In-flight requests on removed endpoints continue normally — they + hold their own reference to the old slot object. + + Returns ``(added_count, removed_count)``. + """ + if not new_variants: + raise ValueError("update_variants requires at least one variant") + + async with self._condition: + old_by_url = {v.config.api_base_url: v for v in self._variants} + new_by_url = {v.config.api_base_url: v for v in new_variants} + + merged: list[EndpointSlot] = [] + added = 0 + for url, new_slot in new_by_url.items(): + old_slot = old_by_url.get(url) + if old_slot is not None: + # Preserve in-flight count, update capacity + old_slot.max_concurrent = new_slot.max_concurrent + merged.append(old_slot) + else: + merged.append(new_slot) + added += 1 + + removed = len(old_by_url) - (len(new_by_url) - added) + self._variants = merged + self._condition.notify_all() + + return added, removed + async def maybe_semaphore( limit: Optional[int] = None, diff --git a/verifiers/utils/elastic.py b/verifiers/utils/elastic.py new file mode 100644 index 000000000..bd651e5c4 --- /dev/null +++ b/verifiers/utils/elastic.py @@ -0,0 +1,134 @@ +"""Elastic endpoint pool — background polling loop that reloads endpoints.toml.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING + +from verifiers.types import EndpointClientConfig +from verifiers.utils.async_utils import EndpointSlot, LeastLoadedDispatcher +from verifiers.utils.client_utils import resolve_client_configs +from verifiers.utils.eval_utils import load_endpoints + +if TYPE_CHECKING: + from verifiers.types import ClientConfig + +logger = logging.getLogger(__name__) + + +class ElasticEndpointPool: + """Periodically re-reads an endpoints file and updates a dispatcher. + + The pool runs a background ``asyncio.Task`` that polls the endpoints + file at a fixed interval. On each tick it rebuilds the + :class:`EndpointSlot` list and calls + :meth:`LeastLoadedDispatcher.update_variants` so that new endpoints + are picked up and removed endpoints are drained. + """ + + def __init__( + self, + dispatcher: LeastLoadedDispatcher, + endpoints_path: str, + endpoint_id: str, + poll_interval: float, + base_client_config: ClientConfig, + ) -> None: + self._dispatcher = dispatcher + self._endpoints_path = endpoints_path + self._endpoint_id = endpoint_id + self._poll_interval = poll_interval + self._base_client_config = base_client_config + self._task: asyncio.Task | None = None + + def start(self) -> None: + """Start the background polling loop.""" + if self._task is not None: + return + self._task = asyncio.create_task(self._poll_loop()) + + async def stop(self) -> None: + """Cancel the polling task and wait for it to finish.""" + if self._task is None: + return + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + self._task = None + + async def _poll_loop(self) -> None: + """Sleep → reload, repeat until cancelled.""" + while True: + await asyncio.sleep(self._poll_interval) + try: + await self._reload() + except Exception: + logger.warning( + "Elastic pool reload failed; keeping previous endpoints", + exc_info=True, + ) + + async def _reload(self) -> None: + """Load endpoints file and push updated variants to the dispatcher.""" + endpoints = load_endpoints(self._endpoints_path) + + endpoint_group = endpoints.get(self._endpoint_id) + if endpoint_group is None: + logger.warning( + "Elastic pool: endpoint_id %r not found in %s; skipping update", + self._endpoint_id, + self._endpoints_path, + ) + return + + # Check that all variants have max_concurrent set + missing = [ + i + for i, ep in enumerate(endpoint_group) + if ep.get("max_concurrent") is None + ] + if missing: + logger.warning( + "Elastic pool: endpoint_id %r has variants without max_concurrent " + "(indices %s); skipping update", + self._endpoint_id, + missing, + ) + return + + # Build EndpointClientConfig list (same pattern as eval.py) + endpoint_configs = [ + EndpointClientConfig( + api_key_var=ep["key"], + api_base_url=ep["url"], + max_concurrent=ep.get("max_concurrent"), + ) + for ep in endpoint_group + ] + + # Create a temporary ClientConfig with the new endpoint_configs + # so resolve_client_configs can merge parent fields. + temp_config = self._base_client_config.model_copy( + update={"endpoint_configs": endpoint_configs} + ) + resolved = resolve_client_configs(temp_config) + + slots = [ + EndpointSlot( + config=cfg, + max_concurrent=ep.max_concurrent, + ) + for cfg, ep in zip(resolved, endpoint_configs) + ] + + added, removed = await self._dispatcher.update_variants(slots) + if added or removed: + logger.info( + "Elastic pool updated endpoint_id %r: +%d -%d endpoints", + self._endpoint_id, + added, + removed, + ) diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index e60532641..d46425fd8 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -340,6 +340,9 @@ def load_toml_config(path: Path) -> list[dict]: "max_concurrent", "independent_scoring", "max_retries", + # elastic endpoint pool + "elastic", + "elastic_poll_interval", # logging "verbose", "debug", @@ -582,8 +585,8 @@ def quiet_datasets(): def _build_dispatchers( evals: list[EvalConfig], -) -> dict[str | None, LeastLoadedDispatcher]: - """Build per-endpoint dispatchers from eval configs. +) -> tuple[dict[str | None, LeastLoadedDispatcher], list]: + """Build per-endpoint dispatchers (and optional elastic pools) from eval configs. Groups evals by ``endpoint_id`` and, for each unique id where variants have ``max_concurrent`` set, creates a @@ -595,9 +598,16 @@ def _build_dispatchers( Endpoint groups without ``max_concurrent`` use the default semaphore + round-robin path in ``environment.generate()``. - Returns a mapping from ``endpoint_id`` (or ``None``) to the dispatcher. + When an eval has ``elastic=True``, an :class:`ElasticEndpointPool` + is created to poll the endpoints file and update the dispatcher. + + Returns ``(dispatchers, pools)`` — a mapping from ``endpoint_id`` + (or ``None``) to the dispatcher and a list of elastic pools. """ + from verifiers.utils.elastic import ElasticEndpointPool + dispatchers: dict[str | None, LeastLoadedDispatcher] = {} + pools: list[ElasticEndpointPool] = [] # Collect unique endpoint_ids, take the first config as representative seen: dict[str | None, EvalConfig] = {} @@ -628,9 +638,20 @@ def _build_dispatchers( ) for cfg, ep in zip(resolved, endpoint_cfgs) ] - dispatchers[endpoint_id] = LeastLoadedDispatcher(slots) + dispatcher = LeastLoadedDispatcher(slots) + dispatchers[endpoint_id] = dispatcher + + if ec.elastic and endpoint_id is not None: + pool = ElasticEndpointPool( + dispatcher=dispatcher, + endpoints_path=ec.endpoints_path, + endpoint_id=endpoint_id, + poll_interval=ec.elastic_poll_interval, + base_client_config=ec.client_config, + ) + pools.append(pool) - return dispatchers + return dispatchers, pools async def run_evaluation( @@ -732,7 +753,7 @@ async def run_evaluations(config: EvalRunConfig) -> None: event_loop_lag_monitor = EventLoopLagMonitor() event_loop_lag_monitor.run_in_background() - dispatchers = _build_dispatchers(config.evals) + dispatchers, pools = _build_dispatchers(config.evals) on_progress: list[ProgressCallback] | None = None if config.heartbeat_url is not None: @@ -741,17 +762,25 @@ async def run_evaluations(config: EvalRunConfig) -> None: heart = Heartbeat(config.heartbeat_url) on_progress = [lambda *_a, **_kw: asyncio.create_task(heart.beat())] + for pool in pools: + pool.start() + start_time = time.time() - all_results = await asyncio.gather( - *[ - run_evaluation( - eval_config, - on_progress=on_progress, - dispatcher=dispatchers.get(eval_config.endpoint_id), - ) - for eval_config in config.evals - ] - ) + try: + all_results = await asyncio.gather( + *[ + run_evaluation( + eval_config, + on_progress=on_progress, + dispatcher=dispatchers.get(eval_config.endpoint_id), + ) + for eval_config in config.evals + ] + ) + finally: + for pool in pools: + await pool.stop() + end_time = time.time() if config.heartbeat_url is not None: @@ -791,7 +820,7 @@ async def run_evaluations_tui(config: EvalRunConfig, tui_mode: bool = True) -> N await run_evaluations(config) return - dispatchers = _build_dispatchers(config.evals) + dispatchers, pools = _build_dispatchers(config.evals) heart = None if config.heartbeat_url is not None: @@ -881,6 +910,9 @@ async def refresh_loop() -> None: display.refresh() await asyncio.sleep(1) + for pool in pools: + pool.start() + try: async with display: refresh_task = asyncio.create_task(refresh_loop()) @@ -908,6 +940,8 @@ async def refresh_loop() -> None: except KeyboardInterrupt: pass # exit on interrupt finally: + for pool in pools: + await pool.stop() if heart is not None: await heart.close()