diff --git a/docs/evaluation.md b/docs/evaluation.md index 7b3a2e69e..970ba0866 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -95,7 +95,25 @@ api_client_type = "anthropic_messages" Each endpoint entry supports an optional `api_client_type` field to select the client implementation (defaults to `"openai_chat_completions"`). Use `"anthropic_messages"` for Anthropic models when calling the Anthropic API directly. -To define equivalent replicas, add multiple `[[endpoint]]` entries with the same `endpoint_id`. +To define equivalent replicas, add multiple `[[endpoint]]` entries with the same `endpoint_id`. You can optionally set `max_concurrent` on each variant to limit how many requests it handles simultaneously: + +```toml +[[endpoint]] +endpoint_id = "my-model" +model = "my-org/my-model" +url = "https://fast-host.example.com/v1" +key = "API_KEY" +max_concurrent = 64 + +[[endpoint]] +endpoint_id = "my-model" +model = "my-org/my-model" +url = "https://slow-host.example.com/v1" +key = "API_KEY" +max_concurrent = 16 +``` + +When variants have `max_concurrent` set, the evaluator uses least-loaded dispatch: each request is routed to the variant with the most available capacity. Per-variant concurrency is all-or-nothing — either every variant in an endpoint group sets `max_concurrent`, or none does. When no variant has `max_concurrent`, requests are distributed round-robin and the global `--max-concurrent` flag controls concurrency. Then use the alias directly: @@ -145,6 +163,8 @@ Multiple rollouts per example enable metrics like pass@k and help measure varian By default, scoring runs interleaved with generation. Use `--no-interleave-scoring` to score all rollouts after generation completes. +When per-variant `max_concurrent` limits are configured in the endpoint registry, the endpoint dispatcher manages concurrency globally across all variants and the `--max-concurrent` flag is ignored. + The `--max-retries` flag enables automatic retry with exponential backoff when rollouts fail due to transient infrastructure errors (e.g., sandbox timeouts, API failures). ### Output and Saving diff --git a/docs/reference.md b/docs/reference.md index b0af77fcc..bb5197bb6 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -701,9 +701,10 @@ class EndpointClientConfig(BaseModel): max_keepalive_connections: int = 28000 max_retries: int = 10 extra_headers: dict[str, str] = {} + max_concurrent: int | None = None ``` -Leaf endpoint configuration used inside `ClientConfig.endpoint_configs`. Has the same fields as `ClientConfig` except `endpoint_configs` itself, preventing recursive nesting. +Leaf endpoint configuration used inside `ClientConfig.endpoint_configs`. Has the same fields as `ClientConfig` except `endpoint_configs` itself, preventing recursive nesting. The optional `max_concurrent` field limits how many concurrent requests this variant handles; see [Per-Variant Concurrency](evaluation.md#concurrency). ### EvalConfig @@ -733,11 +734,17 @@ class EvalConfig(BaseModel): ### Endpoint ```python -Endpoint = TypedDict("Endpoint", {"key": str, "url": str, "model": str}) +class Endpoint(TypedDict, total=False): + key: str # required + url: str # required + model: str # required + api_client_type: ClientType + max_concurrent: int + Endpoints = dict[str, list[Endpoint]] ``` -`Endpoints` maps an endpoint id to one or more endpoint variants. A single variant is represented as a one-item list. +`Endpoints` maps an endpoint id to one or more endpoint variants. A single variant is represented as a one-item list. The `max_concurrent` field enables per-variant concurrency limiting with least-loaded dispatch; if set on any variant in a group, it must be set on all. --- diff --git a/skills/evaluate-environments/SKILL.md b/skills/evaluate-environments/SKILL.md index a31455281..913f11e15 100644 --- a/skills/evaluate-environments/SKILL.md +++ b/skills/evaluate-environments/SKILL.md @@ -43,6 +43,22 @@ model = "qwen/qwen3-32b-instruct" url = "https://api.pinference.ai/api/v1" key = "PRIME_API_KEY" ``` +7. For multi-host setups, set `max_concurrent` on every variant to enable least-loaded dispatch: +```toml +[[endpoint]] +endpoint_id = "qwen3-235b" +model = "qwen/qwen3-235b" +url = "https://fast-host.example/v1" +key = "API_KEY" +max_concurrent = 64 + +[[endpoint]] +endpoint_id = "qwen3-235b" +model = "qwen/qwen3-235b" +url = "https://slow-host.example/v1" +key = "API_KEY" +max_concurrent = 16 +``` ## Publish Gate Before Large Runs 1. After smoke tests pass and results look stable, proactively suggest pushing the environment to Hub before large eval sweeps or RL work. diff --git a/tests/test_elastic_pool.py b/tests/test_elastic_pool.py new file mode 100644 index 000000000..4a6680ce1 --- /dev/null +++ b/tests/test_elastic_pool.py @@ -0,0 +1,287 @@ +"""Tests for update_variants() and ElasticEndpointPool.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from verifiers.types import ClientConfig, EndpointClientConfig +from verifiers.utils.async_utils import EndpointSlot, LeastLoadedDispatcher +from verifiers.utils.elastic import ElasticEndpointPool + + +def _make_config(url: str = "https://a.example/v1") -> ClientConfig: + return ClientConfig(api_base_url=url) + + +def _make_slot(url: str, max_concurrent: int = 4) -> EndpointSlot: + return EndpointSlot(config=_make_config(url), max_concurrent=max_concurrent) + + +# --------------------------------------------------------------------------- +# update_variants tests +# --------------------------------------------------------------------------- + + +class TestUpdateVariants: + @pytest.mark.asyncio + async def test_adds_new(self): + slot_a = _make_slot("https://a.example/v1") + dispatcher = LeastLoadedDispatcher([slot_a]) + + slot_b = _make_slot("https://b.example/v1") + added, removed = await dispatcher.update_variants( + [_make_slot("https://a.example/v1"), slot_b] + ) + + assert added == 1 + assert removed == 0 + + @pytest.mark.asyncio + async def test_removes(self): + slot_a = _make_slot("https://a.example/v1") + slot_b = _make_slot("https://b.example/v1") + dispatcher = LeastLoadedDispatcher([slot_a, slot_b]) + + added, removed = await dispatcher.update_variants( + [_make_slot("https://a.example/v1")] + ) + + assert added == 0 + assert removed == 1 + + @pytest.mark.asyncio + async def test_preserves_active(self): + slot_a = _make_slot("https://a.example/v1", max_concurrent=4) + dispatcher = LeastLoadedDispatcher([slot_a]) + + # Simulate in-flight request + async with dispatcher.acquire() as got: + assert got.active == 1 + + # Update with same URL — should preserve the slot object and active count + added, removed = await dispatcher.update_variants( + [_make_slot("https://a.example/v1", max_concurrent=8)] + ) + assert added == 0 + assert removed == 0 + # Active count preserved on the original object + assert got.active == 1 + # max_concurrent updated + assert got.max_concurrent == 8 + + @pytest.mark.asyncio + async def test_wakes_waiters(self): + slot_a = _make_slot("https://a.example/v1", max_concurrent=1) + dispatcher = LeastLoadedDispatcher([slot_a]) + + acquired = asyncio.Event() + unblocked = asyncio.Event() + + async def holder(): + async with dispatcher.acquire(): + acquired.set() + # Hold slot while waiter tries to acquire + await unblocked.wait() + + async def waiter(): + await acquired.wait() + # This blocks because slot_a is full, but adding slot_b should unblock it + async with dispatcher.acquire() as got: + assert got.config.api_base_url == "https://b.example/v1" + unblocked.set() + + holder_task = asyncio.create_task(holder()) + waiter_task = asyncio.create_task(waiter()) + + await acquired.wait() + await asyncio.sleep(0.05) # let waiter block + + # Add a new endpoint — should wake the waiter + await dispatcher.update_variants( + [ + _make_slot("https://a.example/v1", max_concurrent=1), + _make_slot("https://b.example/v1", max_concurrent=1), + ] + ) + + await asyncio.wait_for(waiter_task, timeout=2.0) + await holder_task + + @pytest.mark.asyncio + async def test_rejects_empty(self): + dispatcher = LeastLoadedDispatcher([_make_slot("https://a.example/v1")]) + with pytest.raises(ValueError, match="at least one variant"): + await dispatcher.update_variants([]) + + +# --------------------------------------------------------------------------- +# ElasticEndpointPool tests +# --------------------------------------------------------------------------- + + +def _write_endpoints_toml(path: Path, entries: list[dict]) -> None: + """Write a minimal endpoints.toml file.""" + lines: list[str] = [] + for entry in entries: + lines.append("[[endpoint]]") + for k, v in entry.items(): + if isinstance(v, int): + lines.append(f'{k} = {v}') + else: + lines.append(f'{k} = "{v}"') + lines.append("") + path.write_text("\n".join(lines)) + + +class TestElasticEndpointPool: + @pytest.mark.asyncio + async def test_reload_updates_dispatcher(self, tmp_path: Path): + toml_file = tmp_path / "endpoints.toml" + _write_endpoints_toml( + toml_file, + [ + { + "endpoint_id": "my-ep", + "url": "https://a.example/v1", + "key": "KEY", + "model": "m1", + "max_concurrent": 4, + }, + ], + ) + + # Build initial dispatcher + slot = _make_slot("https://a.example/v1", max_concurrent=4) + dispatcher = LeastLoadedDispatcher([slot]) + + base_config = ClientConfig( + api_key_var="KEY", + api_base_url="https://a.example/v1", + endpoint_configs=[ + EndpointClientConfig( + api_key_var="KEY", + api_base_url="https://a.example/v1", + max_concurrent=4, + ) + ], + ) + + pool = ElasticEndpointPool( + dispatcher=dispatcher, + endpoints_path=str(toml_file), + endpoint_id="my-ep", + poll_interval=1.0, + base_client_config=base_config, + ) + + # Now update the file with a second endpoint + _write_endpoints_toml( + toml_file, + [ + { + "endpoint_id": "my-ep", + "url": "https://a.example/v1", + "key": "KEY", + "model": "m1", + "max_concurrent": 4, + }, + { + "endpoint_id": "my-ep", + "url": "https://b.example/v1", + "key": "KEY", + "model": "m1", + "max_concurrent": 8, + }, + ], + ) + + await pool._reload() + + # Dispatcher should now have 2 variants + assert len(dispatcher._variants) == 2 + + @pytest.mark.asyncio + async def test_reload_failure_keeps_previous(self, tmp_path: Path): + toml_file = tmp_path / "endpoints.toml" + _write_endpoints_toml( + toml_file, + [ + { + "endpoint_id": "my-ep", + "url": "https://a.example/v1", + "key": "KEY", + "model": "m1", + "max_concurrent": 4, + }, + ], + ) + + slot = _make_slot("https://a.example/v1", max_concurrent=4) + dispatcher = LeastLoadedDispatcher([slot]) + + base_config = ClientConfig( + api_key_var="KEY", + api_base_url="https://a.example/v1", + endpoint_configs=[ + EndpointClientConfig( + api_key_var="KEY", + api_base_url="https://a.example/v1", + max_concurrent=4, + ) + ], + ) + + pool = ElasticEndpointPool( + dispatcher=dispatcher, + endpoints_path=str(tmp_path / "nonexistent.toml"), + endpoint_id="my-ep", + poll_interval=1.0, + base_client_config=base_config, + ) + + # _reload should not raise — it logs a warning and keeps previous + await pool._reload() + assert len(dispatcher._variants) == 1 + + @pytest.mark.asyncio + async def test_start_stop(self, tmp_path: Path): + toml_file = tmp_path / "endpoints.toml" + _write_endpoints_toml( + toml_file, + [ + { + "endpoint_id": "my-ep", + "url": "https://a.example/v1", + "key": "KEY", + "model": "m1", + "max_concurrent": 4, + }, + ], + ) + + slot = _make_slot("https://a.example/v1", max_concurrent=4) + dispatcher = LeastLoadedDispatcher([slot]) + + base_config = ClientConfig( + api_key_var="KEY", + api_base_url="https://a.example/v1", + ) + + pool = ElasticEndpointPool( + dispatcher=dispatcher, + endpoints_path=str(toml_file), + endpoint_id="my-ep", + poll_interval=0.05, + base_client_config=base_config, + ) + + assert pool._task is None + pool.start() + assert pool._task is not None + assert not pool._task.done() + + await pool.stop() + assert pool._task is None diff --git a/tests/test_endpoint_dispatcher.py b/tests/test_endpoint_dispatcher.py new file mode 100644 index 000000000..d8baf949c --- /dev/null +++ b/tests/test_endpoint_dispatcher.py @@ -0,0 +1,142 @@ +import asyncio + +import pytest + +from verifiers.types import ClientConfig +from verifiers.utils.async_utils import ( + EndpointSlot, + LeastLoadedDispatcher, +) + + +def _make_config(url: str = "https://a.example/v1") -> ClientConfig: + return ClientConfig(api_base_url=url) + + +class TestEndpointSlot: + def test_available_reflects_capacity(self): + slot = EndpointSlot(config=_make_config(), max_concurrent=10) + assert slot.available == 10 + slot.active = 3 + assert slot.available == 7 + + +class TestLeastLoadedDispatcher: + @pytest.mark.asyncio + async def test_least_loaded_picks_emptier_variant(self): + slot_a = EndpointSlot( + config=_make_config("https://a.example/v1"), max_concurrent=4 + ) + slot_b = EndpointSlot( + config=_make_config("https://b.example/v1"), max_concurrent=4 + ) + dispatcher = LeastLoadedDispatcher([slot_a, slot_b]) + + # Acquire one on each — first should go to whichever has more available + async with dispatcher.acquire() as got1: + assert got1 in (slot_a, slot_b) + # got1 now has 1 active, the other has 0 + other = slot_b if got1 is slot_a else slot_a + async with dispatcher.acquire() as got2: + # Should pick the one with more available (the other) + assert got2 is other + + @pytest.mark.asyncio + async def test_blocks_when_all_full_then_unblocks(self): + slot = EndpointSlot(config=_make_config(), max_concurrent=1) + dispatcher = LeastLoadedDispatcher([slot]) + + acquired = asyncio.Event() + released = asyncio.Event() + + async def holder(): + async with dispatcher.acquire(): + acquired.set() + await released.wait() + + async def waiter(): + await acquired.wait() + # This should block until holder releases + async with dispatcher.acquire() as got: + assert got is slot + + holder_task = asyncio.create_task(holder()) + waiter_task = asyncio.create_task(waiter()) + + # Give tasks time to start + await asyncio.sleep(0.05) + assert acquired.is_set() + assert not waiter_task.done() + + # Release the holder + released.set() + await asyncio.wait_for(waiter_task, timeout=2.0) + await holder_task + + @pytest.mark.asyncio + async def test_count_parameter_consumes_correct_slots(self): + slot = EndpointSlot(config=_make_config(), max_concurrent=4) + dispatcher = LeastLoadedDispatcher([slot]) + + async with dispatcher.acquire(count=3) as got: + assert got is slot + assert slot.active == 3 + assert slot.available == 1 + + assert slot.active == 0 + assert slot.available == 4 + + @pytest.mark.asyncio + async def test_releases_on_exception(self): + slot = EndpointSlot(config=_make_config(), max_concurrent=2) + dispatcher = LeastLoadedDispatcher([slot]) + + with pytest.raises(RuntimeError, match="boom"): + async with dispatcher.acquire(): + raise RuntimeError("boom") + + assert slot.active == 0 + + @pytest.mark.asyncio + async def test_oversize_count_raises(self): + """count > every variant's max_concurrent is a config error.""" + slot = EndpointSlot(config=_make_config(), max_concurrent=2) + dispatcher = LeastLoadedDispatcher([slot]) + + with pytest.raises(ValueError, match="exceeds the largest variant"): + async with dispatcher.acquire(count=5): + pass + + @pytest.mark.asyncio + async def test_cancellation_does_not_leak_capacity(self): + """Cancelled tasks must release their slots.""" + slot = EndpointSlot(config=_make_config(), max_concurrent=1) + dispatcher = LeastLoadedDispatcher([slot]) + + entered = asyncio.Event() + + async def hold_then_wait(): + async with dispatcher.acquire(): + entered.set() + await asyncio.sleep(999) # will be cancelled + + task = asyncio.create_task(hold_then_wait()) + await entered.wait() + assert slot.active == 1 + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + # Give the shielded notification task a chance to run + await asyncio.sleep(0) + + assert slot.active == 0, "capacity leaked after cancellation" + + # Verify the slot is actually usable again + async with dispatcher.acquire() as got: + assert got is slot + + def test_empty_variants_raises(self): + with pytest.raises(ValueError): + LeastLoadedDispatcher([]) diff --git a/tests/test_endpoint_registry.py b/tests/test_endpoint_registry.py index 01a0b2482..7e9e8449f 100644 --- a/tests/test_endpoint_registry.py +++ b/tests/test_endpoint_registry.py @@ -192,6 +192,81 @@ def test_load_endpoints_directory_prefers_toml_then_python(tmp_path: Path): assert set(endpoints.keys()) == {"from-py"} +def test_load_endpoints_toml_parses_max_concurrent(tmp_path: Path): + registry_path = tmp_path / "endpoints.toml" + registry_path.write_text( + "[[endpoint]]\n" + 'endpoint_id = "my-model"\n' + 'model = "my/model"\n' + 'url = "https://a.example/v1"\n' + 'key = "A_KEY"\n' + "max_concurrent = 16\n" + "\n" + "[[endpoint]]\n" + 'endpoint_id = "my-model"\n' + 'model = "my/model"\n' + 'url = "https://b.example/v1"\n' + 'key = "A_KEY"\n' + "max_concurrent = 32\n", + encoding="utf-8", + ) + + endpoints = load_endpoints(str(registry_path)) + + assert endpoints["my-model"][0]["max_concurrent"] == 16 + assert endpoints["my-model"][1]["max_concurrent"] == 32 + + +def test_load_endpoints_toml_max_concurrent_optional(tmp_path: Path): + registry_path = tmp_path / "endpoints.toml" + registry_path.write_text( + "[[endpoint]]\n" + 'endpoint_id = "my-model"\n' + 'model = "my/model"\n' + 'url = "https://a.example/v1"\n' + 'key = "A_KEY"\n', + encoding="utf-8", + ) + + endpoints = load_endpoints(str(registry_path)) + + assert "max_concurrent" not in endpoints["my-model"][0] + + +def test_load_endpoints_python_parses_max_concurrent(tmp_path: Path): + registry_path = tmp_path / "endpoints.py" + registry_path.write_text( + "ENDPOINTS = {\n" + ' "my-model": [\n' + ' {"model": "m", "url": "https://a.example/v1", "key": "K", "max_concurrent": 8},\n' + " ]\n" + "}\n", + encoding="utf-8", + ) + + endpoints = load_endpoints(str(registry_path)) + + assert endpoints["my-model"][0]["max_concurrent"] == 8 + + +def test_load_endpoints_rejects_invalid_max_concurrent(tmp_path: Path): + for bad_value in [0, -1, '"not_an_int"']: + registry_path = tmp_path / "endpoints.toml" + registry_path.write_text( + "[[endpoint]]\n" + 'endpoint_id = "my-model"\n' + 'model = "my/model"\n' + 'url = "https://a.example/v1"\n' + 'key = "A_KEY"\n' + f"max_concurrent = {bad_value}\n", + encoding="utf-8", + ) + + endpoints = load_endpoints(str(registry_path)) + # Invalid values cause load to fail and return empty + assert endpoints == {} + + def test_qwen3_vl_endpoint_ids_map_to_vl_models(): endpoints = load_endpoints("./configs/endpoints.toml") diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index a1ccd3daf..e331664e7 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -65,6 +65,7 @@ Tool, ) from verifiers.utils.async_utils import ( + LeastLoadedDispatcher, maybe_retry, maybe_semaphore, with_sem, @@ -826,6 +827,7 @@ async def generate( on_start: StartCallback | None = None, on_progress: ProgressCallback | list[ProgressCallback] | None = None, on_log: LogCallback | None = None, + dispatcher: LeastLoadedDispatcher | None = None, ) -> GenerateOutputs: """ Generate rollouts for a set of inputs. @@ -835,6 +837,9 @@ async def generate( on_progress: Progress callback(s). None uses the default tqdm progress bar. A single callback replaces the default. A list of callbacks runs alongside the default. + dispatcher: Optional LeastLoadedDispatcher for per-variant concurrency. + When provided, the dispatcher handles client selection and concurrency. + When None, falls back to semaphore + round-robin dispatch. """ from datasets import Dataset from tqdm import tqdm @@ -913,7 +918,6 @@ def default_on_progress(*a, **kw): elif isinstance(inputs, list): raw_inputs = inputs - # set up semaphores sem = await maybe_semaphore(max_concurrent) # set up sampling args @@ -1003,50 +1007,107 @@ def get_client_for_group() -> Client | ClientConfig: tasks: dict[asyncio.Task, int] = {} try: - # create tasks based on mode - if independent_scoring: - on_start(raw_inputs, filtered_inputs) - for i, rollout_input in enumerate(filtered_inputs): - task = asyncio.create_task( - with_sem( - sem, - self.run_rollout( - rollout_input, - get_client_for_group(), - model, - sampling_args, - max_retries=max_retries, - state_columns=state_columns, - ), - ), - ) - tasks[task] = i + if dispatcher is not None: + # Per-variant concurrency via LeastLoadedDispatcher. + # Retries are applied *outside* acquire() so each attempt + # re-acquires a slot — critical when a server is preempted + # and removed by the elastic pool. + async def _dispatched_rollout( + rollout_input: RolloutInput, + ) -> RolloutOutput: + async with dispatcher.acquire(count=1) as slot: + return await self.run_rollout( + rollout_input, + slot.config, + model, + sampling_args, + max_retries=0, + state_columns=state_columns, + ) + + async def _dispatched_group( + group_input: list[RolloutInput], + ) -> list[RolloutOutput]: + async with dispatcher.acquire(count=len(group_input)) as slot: + return await self.run_group( + group_input, + slot.config, + model, + sampling_args, + max_retries=0, + state_columns=state_columns, + ) + + _retried_rollout = maybe_retry( + _dispatched_rollout, max_retries=max_retries + ) + _retried_group = maybe_retry( + _dispatched_group, max_retries=max_retries + ) + + if independent_scoring: + on_start(raw_inputs, filtered_inputs) + for i, rollout_input in enumerate(filtered_inputs): + task = asyncio.create_task( + _retried_rollout(rollout_input) + ) + tasks[task] = i + else: + group_inputs: dict[int, list[RolloutInput]] = defaultdict(list) + for rollout_input in filtered_inputs: + example_id = rollout_input["example_id"] + group_inputs[example_id].append(rollout_input) + filtered_group_inputs = list(group_inputs.values()) + on_start(raw_inputs, filtered_group_inputs) + + for i, group_input in enumerate(filtered_group_inputs): + task = asyncio.create_task(_retried_group(group_input)) + tasks[task] = i else: - group_inputs: dict[int, list[RolloutInput]] = defaultdict(list) - for rollout_input in filtered_inputs: - example_id = rollout_input["example_id"] - group_inputs[example_id].append(rollout_input) - filtered_group_inputs = list(group_inputs.values()) - on_start(raw_inputs, filtered_group_inputs) - - for i, group_input in enumerate(filtered_group_inputs): - # For grouped scoring, keep each group on one endpoint so - # rollouts in the same group can benefit from shared KV cache. - group_client = get_client_for_group() - task = asyncio.create_task( - with_sem( - sem, - self.run_group( - group_input, - group_client, - model, - sampling_args, - max_retries=max_retries, - state_columns=state_columns, + # Legacy path: semaphore + round-robin + if independent_scoring: + on_start(raw_inputs, filtered_inputs) + for i, rollout_input in enumerate(filtered_inputs): + task = asyncio.create_task( + with_sem( + sem, + self.run_rollout( + rollout_input, + get_client_for_group(), + model, + sampling_args, + max_retries=max_retries, + state_columns=state_columns, + ), + ), + ) + tasks[task] = i + else: + group_inputs: dict[int, list[RolloutInput]] = defaultdict(list) + for rollout_input in filtered_inputs: + example_id = rollout_input["example_id"] + group_inputs[example_id].append(rollout_input) + filtered_group_inputs = list(group_inputs.values()) + on_start(raw_inputs, filtered_group_inputs) + + for i, group_input in enumerate(filtered_group_inputs): + # For grouped scoring, keep each group on one endpoint so + # rollouts in the same group can benefit from shared KV cache. + group_client = get_client_for_group() + task = asyncio.create_task( + with_sem( + sem, + self.run_group( + group_input, + group_client, + model, + sampling_args, + max_retries=max_retries, + state_columns=state_columns, + ), ), - ), - ) - tasks[task] = i + ) + tasks[task] = i for coro in asyncio.as_completed(tasks.keys()): result = await coro @@ -1159,6 +1220,7 @@ async def evaluate( on_start: StartCallback | None = None, on_progress: ProgressCallback | list[ProgressCallback] | None = None, on_log: LogCallback | None = None, + dispatcher: LeastLoadedDispatcher | None = None, **kwargs, ) -> GenerateOutputs: """ @@ -1186,6 +1248,7 @@ async def evaluate( on_start=on_start, on_progress=on_progress, on_log=on_log, + dispatcher=dispatcher, **kwargs, ) diff --git a/verifiers/scripts/eval.py b/verifiers/scripts/eval.py index f00580f24..6b6a972ee 100644 --- a/verifiers/scripts/eval.py +++ b/verifiers/scripts/eval.py @@ -421,6 +421,19 @@ def build_eval_config(raw: dict) -> EvalConfig: "Set endpoints_path to an endpoints.toml file." ) + # Elastic mode validation + raw_elastic = raw.get("elastic", False) + if raw_elastic: + if raw_endpoint_id is None: + raise ValueError( + "'elastic=true' requires 'endpoint_id' to be set." + ) + if resolved_endpoints_file is None or resolved_endpoints_file.suffix != ".toml": + raise ValueError( + "'elastic=true' requires a TOML endpoints file. " + "Set endpoints_path to an endpoints.toml file." + ) + raw_model = raw_model_field if raw_model_field is not None else DEFAULT_MODEL endpoint_lookup_id = ( raw_endpoint_id if raw_endpoint_id is not None else raw_model @@ -547,11 +560,14 @@ def build_eval_config(raw: dict) -> EvalConfig: resolved_api_key_var = api_key_var endpoint_configs: list[EndpointClientConfig] = [] + has_variant_concurrency = endpoint_group is not None and any( + ep.get("max_concurrent") is not None for ep in endpoint_group + ) if ( endpoint_group is not None and not api_base_url_override and raw_provider is None - and len(endpoint_group) > 1 + and (len(endpoint_group) > 1 or has_variant_concurrency) ): endpoint_configs = [ EndpointClientConfig( @@ -560,6 +576,7 @@ def build_eval_config(raw: dict) -> EvalConfig: ), api_base_url=endpoint["url"], extra_headers=merged_headers, + max_concurrent=endpoint.get("max_concurrent"), ) for endpoint in endpoint_group ] @@ -620,6 +637,9 @@ def build_eval_config(raw: dict) -> EvalConfig: rollouts_per_example=rollouts_per_example, max_concurrent=raw.get("max_concurrent", DEFAULT_MAX_CONCURRENT), max_retries=raw.get("max_retries", 0), + elastic=raw.get("elastic", False), + elastic_poll_interval=raw.get("elastic_poll_interval", 10.0), + endpoints_path=str(endpoints_path), verbose=raw.get("verbose", False), debug=raw.get("debug", False), state_columns=raw.get("state_columns", []), diff --git a/verifiers/types.py b/verifiers/types.py index 793bae542..c8dd71f95 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -23,9 +23,9 @@ from verifiers.errors import Error if sys.version_info < (3, 12): - from typing_extensions import NotRequired, TypedDict + from typing_extensions import TypedDict else: - from typing import NotRequired, TypedDict + from typing import TypedDict # Client / message type literals ClientType = Literal[ @@ -386,15 +386,17 @@ class RolloutScores(TypedDict): metrics: dict[str, list[float]] -Endpoint = TypedDict( - "Endpoint", - { - "key": str, - "url": str, - "model": str, - "api_client_type": NotRequired[ClientType], - }, -) +class _EndpointRequired(TypedDict): + key: str + url: str + model: str + + +class Endpoint(_EndpointRequired, total=False): + api_client_type: ClientType + max_concurrent: int + + Endpoints = dict[str, list[Endpoint]] @@ -465,6 +467,7 @@ class EndpointClientConfig(BaseModel): max_keepalive_connections: int = 28000 max_retries: int = 10 extra_headers: dict[str, str] = Field(default_factory=dict) + max_concurrent: int | None = None ClientConfig.model_rebuild() @@ -488,6 +491,10 @@ class EvalConfig(BaseModel): independent_scoring: bool = False extra_env_kwargs: dict = {} max_retries: int = 0 + # elastic endpoint pool + elastic: bool = False + elastic_poll_interval: float = 10.0 + endpoints_path: str = "" # logging verbose: bool = False debug: bool = False diff --git a/verifiers/utils/async_utils.py b/verifiers/utils/async_utils.py index e8bf95c22..180b21ad0 100644 --- a/verifiers/utils/async_utils.py +++ b/verifiers/utils/async_utils.py @@ -1,9 +1,21 @@ +from __future__ import annotations + import asyncio import inspect import logging from collections.abc import Coroutine +from contextlib import asynccontextmanager +from dataclasses import dataclass, field from time import perf_counter -from typing import Any, AsyncContextManager, Callable, Optional, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + AsyncContextManager, + AsyncIterator, + Callable, + Optional, + TypeVar, +) import tenacity as tc @@ -11,6 +23,10 @@ from verifiers.utils.error_utils import ErrorChain from verifiers.utils.logging_utils import print_time +if TYPE_CHECKING: + from verifiers.types import ClientConfig + + logger = logging.getLogger(__name__) T = TypeVar("T") @@ -41,6 +57,120 @@ async def __aexit__(self, exc_type, exc_value, traceback): return False +@dataclass +class EndpointSlot: + """Tracks one variant's client config and concurrency capacity.""" + + config: ClientConfig + max_concurrent: int + active: int = field(default=0, init=False) + + @property + def available(self) -> int: + return self.max_concurrent - self.active + + +class LeastLoadedDispatcher: + """Least-loaded dispatch with asyncio.Condition for blocking. + + Shared across all evals hitting the same endpoint_id so that + per-variant concurrency limits are respected globally. + """ + + def __init__(self, variants: list[EndpointSlot]) -> None: + if not variants: + raise ValueError("LeastLoadedDispatcher requires at least one variant") + self._variants = variants + self._condition = asyncio.Condition() + + async def _notify(self) -> None: + """Wake all waiters under the condition lock.""" + async with self._condition: + self._condition.notify_all() + + @asynccontextmanager + async def acquire(self, count: int = 1) -> AsyncIterator[EndpointSlot]: + """Acquire a slot on the least-loaded variant that can fit *count* concurrent items. + + Raises ValueError if count exceeds every variant's max_concurrent, + since allowing it would defeat the configured concurrency limit. + """ + largest_cap = max(v.max_concurrent for v in self._variants) + if count > largest_cap: + raise ValueError( + f"Group size {count} exceeds the largest variant's " + f"max_concurrent ({largest_cap}). Each group must fit on a " + f"single variant. Increase max_concurrent or reduce " + f"rollouts_per_example." + ) + variant: EndpointSlot | None = None + async with self._condition: + while True: + # Find variant with most available capacity that can fit count + best: EndpointSlot | None = None + for v in self._variants: + if v.available >= count and ( + best is None or v.available > best.available + ): + best = v + if best is not None: + variant = best + variant.active += count + break + + await self._condition.wait() + + try: + yield variant + finally: + # Decrement synchronously — safe in asyncio's cooperative model + # since no other task can interleave between await points. + variant.active -= count + # Shield notification so waiters are woken even if our task + # is cancelled (the shielded inner task keeps running). + await asyncio.shield(self._notify()) + + async def update_variants( + self, new_variants: list[EndpointSlot] + ) -> tuple[int, int]: + """Replace the variant list, preserving in-flight slots. + + Endpoints are keyed by ``api_base_url``. Kept endpoints preserve + their existing :class:`EndpointSlot` (retaining the ``active`` + count); ``max_concurrent`` is updated from the new slot. New + endpoints get fresh slots and removed endpoints are dropped. + + In-flight requests on removed endpoints continue normally — they + hold their own reference to the old slot object. + + Returns ``(added_count, removed_count)``. + """ + if not new_variants: + raise ValueError("update_variants requires at least one variant") + + async with self._condition: + old_by_url = {v.config.api_base_url: v for v in self._variants} + new_by_url = {v.config.api_base_url: v for v in new_variants} + + merged: list[EndpointSlot] = [] + added = 0 + for url, new_slot in new_by_url.items(): + old_slot = old_by_url.get(url) + if old_slot is not None: + # Preserve in-flight count, update capacity + old_slot.max_concurrent = new_slot.max_concurrent + merged.append(old_slot) + else: + merged.append(new_slot) + added += 1 + + removed = len(old_by_url) - (len(new_by_url) - added) + self._variants = merged + self._condition.notify_all() + + return added, removed + + async def maybe_semaphore( limit: Optional[int] = None, ) -> AsyncContextManager: diff --git a/verifiers/utils/elastic.py b/verifiers/utils/elastic.py new file mode 100644 index 000000000..bd651e5c4 --- /dev/null +++ b/verifiers/utils/elastic.py @@ -0,0 +1,134 @@ +"""Elastic endpoint pool — background polling loop that reloads endpoints.toml.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING + +from verifiers.types import EndpointClientConfig +from verifiers.utils.async_utils import EndpointSlot, LeastLoadedDispatcher +from verifiers.utils.client_utils import resolve_client_configs +from verifiers.utils.eval_utils import load_endpoints + +if TYPE_CHECKING: + from verifiers.types import ClientConfig + +logger = logging.getLogger(__name__) + + +class ElasticEndpointPool: + """Periodically re-reads an endpoints file and updates a dispatcher. + + The pool runs a background ``asyncio.Task`` that polls the endpoints + file at a fixed interval. On each tick it rebuilds the + :class:`EndpointSlot` list and calls + :meth:`LeastLoadedDispatcher.update_variants` so that new endpoints + are picked up and removed endpoints are drained. + """ + + def __init__( + self, + dispatcher: LeastLoadedDispatcher, + endpoints_path: str, + endpoint_id: str, + poll_interval: float, + base_client_config: ClientConfig, + ) -> None: + self._dispatcher = dispatcher + self._endpoints_path = endpoints_path + self._endpoint_id = endpoint_id + self._poll_interval = poll_interval + self._base_client_config = base_client_config + self._task: asyncio.Task | None = None + + def start(self) -> None: + """Start the background polling loop.""" + if self._task is not None: + return + self._task = asyncio.create_task(self._poll_loop()) + + async def stop(self) -> None: + """Cancel the polling task and wait for it to finish.""" + if self._task is None: + return + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + self._task = None + + async def _poll_loop(self) -> None: + """Sleep → reload, repeat until cancelled.""" + while True: + await asyncio.sleep(self._poll_interval) + try: + await self._reload() + except Exception: + logger.warning( + "Elastic pool reload failed; keeping previous endpoints", + exc_info=True, + ) + + async def _reload(self) -> None: + """Load endpoints file and push updated variants to the dispatcher.""" + endpoints = load_endpoints(self._endpoints_path) + + endpoint_group = endpoints.get(self._endpoint_id) + if endpoint_group is None: + logger.warning( + "Elastic pool: endpoint_id %r not found in %s; skipping update", + self._endpoint_id, + self._endpoints_path, + ) + return + + # Check that all variants have max_concurrent set + missing = [ + i + for i, ep in enumerate(endpoint_group) + if ep.get("max_concurrent") is None + ] + if missing: + logger.warning( + "Elastic pool: endpoint_id %r has variants without max_concurrent " + "(indices %s); skipping update", + self._endpoint_id, + missing, + ) + return + + # Build EndpointClientConfig list (same pattern as eval.py) + endpoint_configs = [ + EndpointClientConfig( + api_key_var=ep["key"], + api_base_url=ep["url"], + max_concurrent=ep.get("max_concurrent"), + ) + for ep in endpoint_group + ] + + # Create a temporary ClientConfig with the new endpoint_configs + # so resolve_client_configs can merge parent fields. + temp_config = self._base_client_config.model_copy( + update={"endpoint_configs": endpoint_configs} + ) + resolved = resolve_client_configs(temp_config) + + slots = [ + EndpointSlot( + config=cfg, + max_concurrent=ep.max_concurrent, + ) + for cfg, ep in zip(resolved, endpoint_configs) + ] + + added, removed = await self._dispatcher.update_variants(slots) + if added or removed: + logger.info( + "Elastic pool updated endpoint_id %r: +%d -%d endpoints", + self._endpoint_id, + added, + removed, + ) diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index e30f60ff6..d46425fd8 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -31,7 +31,12 @@ RolloutOutput, StartCallback, ) -from verifiers.utils.async_utils import EventLoopLagMonitor +from verifiers.utils.async_utils import ( + EndpointSlot, + EventLoopLagMonitor, + LeastLoadedDispatcher, +) +from verifiers.utils.client_utils import resolve_client_configs from verifiers.utils.import_utils import load_toml from verifiers.utils.logging_utils import print_prompt_completions_sample, print_time from verifiers.utils.path_utils import get_eval_results_path @@ -67,6 +72,14 @@ def _coerce_endpoint(raw_endpoint: object, source: str) -> Endpoint: endpoint = Endpoint(model=model, url=url, key=key) + # Parse optional max_concurrent + max_concurrent = raw_endpoint_dict.get("max_concurrent") + if max_concurrent is not None: + if isinstance(max_concurrent, bool) or not isinstance(max_concurrent, int) or max_concurrent <= 0: + raise ValueError(f"'max_concurrent' must be a positive integer in {source}") + endpoint["max_concurrent"] = max_concurrent + + # Parse optional api_client_type if "client_type" in raw_endpoint_dict: raise ValueError( f"Field 'client_type' is no longer supported in {source}. " @@ -327,6 +340,9 @@ def load_toml_config(path: Path) -> list[dict]: "max_concurrent", "independent_scoring", "max_retries", + # elastic endpoint pool + "elastic", + "elastic_poll_interval", # logging "verbose", "debug", @@ -567,12 +583,84 @@ def quiet_datasets(): enable_progress_bar() +def _build_dispatchers( + evals: list[EvalConfig], +) -> tuple[dict[str | None, LeastLoadedDispatcher], list]: + """Build per-endpoint dispatchers (and optional elastic pools) from eval configs. + + Groups evals by ``endpoint_id`` and, for each unique id where + variants have ``max_concurrent`` set, creates a + :class:`LeastLoadedDispatcher` with one :class:`EndpointSlot` per + variant. + + Per-variant concurrency is all-or-nothing: if any variant in an + endpoint group sets ``max_concurrent``, every variant must. + Endpoint groups without ``max_concurrent`` use the default + semaphore + round-robin path in ``environment.generate()``. + + When an eval has ``elastic=True``, an :class:`ElasticEndpointPool` + is created to poll the endpoints file and update the dispatcher. + + Returns ``(dispatchers, pools)`` — a mapping from ``endpoint_id`` + (or ``None``) to the dispatcher and a list of elastic pools. + """ + from verifiers.utils.elastic import ElasticEndpointPool + + dispatchers: dict[str | None, LeastLoadedDispatcher] = {} + pools: list[ElasticEndpointPool] = [] + + # Collect unique endpoint_ids, take the first config as representative + seen: dict[str | None, EvalConfig] = {} + for ec in evals: + if ec.endpoint_id not in seen: + seen[ec.endpoint_id] = ec + + for endpoint_id, ec in seen.items(): + if not ec.client_config.endpoint_configs: + continue + + endpoint_cfgs = ec.client_config.endpoint_configs + has_concurrency = [ep.max_concurrent is not None for ep in endpoint_cfgs] + + if any(has_concurrency): + if not all(has_concurrency): + missing = [i for i, has in enumerate(has_concurrency) if not has] + raise ValueError( + f"Endpoint '{endpoint_id}': max_concurrent is set on some variants " + f"but missing on variant(s) {missing}. Either set max_concurrent on " + f"all variants or remove it entirely to use the global concurrency limit." + ) + resolved = resolve_client_configs(ec.client_config) + slots = [ + EndpointSlot( + config=cfg, + max_concurrent=ep.max_concurrent, + ) + for cfg, ep in zip(resolved, endpoint_cfgs) + ] + dispatcher = LeastLoadedDispatcher(slots) + dispatchers[endpoint_id] = dispatcher + + if ec.elastic and endpoint_id is not None: + pool = ElasticEndpointPool( + dispatcher=dispatcher, + endpoints_path=ec.endpoints_path, + endpoint_id=endpoint_id, + poll_interval=ec.elastic_poll_interval, + base_client_config=ec.client_config, + ) + pools.append(pool) + + return dispatchers, pools + + async def run_evaluation( config: EvalConfig, on_start: StartCallback | None = None, on_log_file: Callable[[Path], None] | None = None, on_progress: ProgressCallback | list[ProgressCallback] | None = None, on_log: LogCallback | None = None, + dispatcher: LeastLoadedDispatcher | None = None, ) -> GenerateOutputs: # load environment vf_env = vf.load_environment(env_id=config.env_id, **config.env_args) @@ -607,21 +695,33 @@ async def run_evaluation( f"Configuration: num_examples={config.num_examples}, rollouts_per_example={config.rollouts_per_example}, max_concurrent={config.max_concurrent}" ) - effective_group_max_concurrent = config.max_concurrent - if ( - not config.independent_scoring - and config.max_concurrent > 0 - and config.rollouts_per_example > 1 - ): - # Grouped scoring applies the semaphore at group level. Convert - # rollout-level concurrency to group-level slots. - effective_group_max_concurrent = math.ceil( - config.max_concurrent / config.rollouts_per_example - ) - if config.num_examples > 0: - effective_group_max_concurrent = min( - effective_group_max_concurrent, config.num_examples + # Compute effective concurrency for generate(). + if dispatcher is not None: + # LeastLoadedDispatcher handles concurrency per-variant; + # disable the generate-level semaphore. + if config.max_concurrent > 0: + logger.debug( + "Endpoint-level dispatcher active; ignoring eval-level max_concurrent=%d", + config.max_concurrent, + ) + effective_max_concurrent = -1 + else: + # Legacy semaphore + round-robin path. The semaphore limits + # concurrent groups, so convert rollout-level concurrency to + # group-level slots. + effective_max_concurrent = config.max_concurrent + if ( + not config.independent_scoring + and config.max_concurrent > 0 + and config.rollouts_per_example > 1 + ): + effective_max_concurrent = math.ceil( + config.max_concurrent / config.rollouts_per_example ) + if config.num_examples > 0: + effective_max_concurrent = min( + effective_max_concurrent, config.num_examples + ) outputs = await vf_env.evaluate( client=config.client_config, @@ -629,7 +729,7 @@ async def run_evaluation( sampling_args=config.sampling_args, num_examples=config.num_examples, rollouts_per_example=config.rollouts_per_example, - max_concurrent=effective_group_max_concurrent, + max_concurrent=effective_max_concurrent, results_path=results_path, state_columns=config.state_columns, save_results=config.save_results, @@ -640,6 +740,7 @@ async def run_evaluation( on_start=on_start, on_progress=on_progress, on_log=on_log, + dispatcher=dispatcher, ) finally: await vf_env.stop_server() @@ -652,6 +753,8 @@ async def run_evaluations(config: EvalRunConfig) -> None: event_loop_lag_monitor = EventLoopLagMonitor() event_loop_lag_monitor.run_in_background() + dispatchers, pools = _build_dispatchers(config.evals) + on_progress: list[ProgressCallback] | None = None if config.heartbeat_url is not None: from verifiers.utils.heartbeat import Heartbeat @@ -659,13 +762,25 @@ async def run_evaluations(config: EvalRunConfig) -> None: heart = Heartbeat(config.heartbeat_url) on_progress = [lambda *_a, **_kw: asyncio.create_task(heart.beat())] + for pool in pools: + pool.start() + start_time = time.time() - all_results = await asyncio.gather( - *[ - run_evaluation(eval_config, on_progress=on_progress) - for eval_config in config.evals - ] - ) + try: + all_results = await asyncio.gather( + *[ + run_evaluation( + eval_config, + on_progress=on_progress, + dispatcher=dispatchers.get(eval_config.endpoint_id), + ) + for eval_config in config.evals + ] + ) + finally: + for pool in pools: + await pool.stop() + end_time = time.time() if config.heartbeat_url is not None: @@ -705,6 +820,8 @@ async def run_evaluations_tui(config: EvalRunConfig, tui_mode: bool = True) -> N await run_evaluations(config) return + dispatchers, pools = _build_dispatchers(config.evals) + heart = None if config.heartbeat_url is not None: from verifiers.utils.heartbeat import Heartbeat @@ -714,7 +831,9 @@ async def run_evaluations_tui(config: EvalRunConfig, tui_mode: bool = True) -> N display = EvalDisplay(config.evals, screen=tui_mode) async def run_with_progress( - env_config: EvalConfig, env_idx: int + env_config: EvalConfig, + env_idx: int, + dispatcher: LeastLoadedDispatcher | None = None, ) -> GenerateOutputs: """Run a single evaluation with display progress updates.""" @@ -766,6 +885,7 @@ def register_log_file(log_file: Path) -> None: on_progress=on_progress, on_log=on_log, on_log_file=register_log_file, + dispatcher=dispatcher, ) # get save path if results were saved @@ -790,13 +910,20 @@ async def refresh_loop() -> None: display.refresh() await asyncio.sleep(1) + for pool in pools: + pool.start() + try: async with display: refresh_task = asyncio.create_task(refresh_loop()) try: await asyncio.gather( *[ - run_with_progress(env_config, idx) + run_with_progress( + env_config, + idx, + dispatcher=dispatchers.get(env_config.endpoint_id), + ) for idx, env_config in enumerate(config.evals) ], return_exceptions=True, @@ -813,6 +940,8 @@ async def refresh_loop() -> None: except KeyboardInterrupt: pass # exit on interrupt finally: + for pool in pools: + await pool.stop() if heart is not None: await heart.close()