From 18e295265f62f41545f56d90ad448a27c80108d5 Mon Sep 17 00:00:00 2001 From: d42me Date: Mon, 9 Feb 2026 21:26:03 -0800 Subject: [PATCH 1/3] Add base64 image support for results. --- tests/conftest.py | 2 + tests/test_environment_extra.py | 26 ++++ tests/test_eval_cli.py | 116 ++++++++++++++++- tests/test_message_utils_audio.py | 37 ++++++ verifiers/envs/env_group.py | 16 ++- verifiers/envs/environment.py | 40 +++++- verifiers/scripts/eval.py | 22 ++++ verifiers/types.py | 3 + verifiers/utils/eval_utils.py | 14 +++ verifiers/utils/message_utils.py | 165 +++++++++++++++++++++---- verifiers/utils/save_utils.py | 47 +++++-- verifiers/workers/client/env_client.py | 8 ++ verifiers/workers/server/env_server.py | 4 + verifiers/workers/types.py | 4 + 14 files changed, 470 insertions(+), 34 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d44e6f164..2b96d4f2e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -559,6 +559,7 @@ def _make_metadata( state_columns: list[str] = ["foo"], path_to_save: Path = Path("test.jsonl"), tools: list[Tool] | None = None, + save_image_mode: str = "base64", ) -> GenerateMetadata: if version_info is None: version_info = { @@ -584,6 +585,7 @@ def _make_metadata( state_columns=state_columns, path_to_save=path_to_save, tools=tools, + save_image_mode=save_image_mode, ) return _make_metadata diff --git a/tests/test_environment_extra.py b/tests/test_environment_extra.py index 789e477df..3de58e4d5 100644 --- a/tests/test_environment_extra.py +++ b/tests/test_environment_extra.py @@ -329,6 +329,8 @@ async def run_group( sampling_args, max_retries, state_columns, + image_mode="placeholder", + max_image_base64_chars=None, ): assert isinstance(client_config, ClientConfig) self.client_urls_per_group.append(str(client_config.api_base_url)) @@ -424,6 +426,8 @@ async def run_group( sampling_args, max_retries, state_columns, + image_mode="placeholder", + max_image_base64_chars=None, ): assert isinstance(client_config, ClientConfig) self.client_url = str(client_config.api_base_url) @@ -483,6 +487,8 @@ async def run_rollout( sampling_args, max_retries, state_columns, + image_mode="placeholder", + max_image_base64_chars=None, ): assert isinstance(client_config, ClientConfig) self.client_url = str(client_config.api_base_url) @@ -635,6 +641,26 @@ def model_dump(self, **kwargs): assert isinstance(sanitized[0]["tool_calls"][0], str) +def test_sanitize_tool_calls_preserves_serialized_strings_and_extra_fields(): + serialized_tool_call = ( + '{"id":"x","type":"function","function":{"name":"echo","arguments":"{}"}}' + ) + msgs = [ + { + "role": "assistant", + "content": "", + "tool_calls": [serialized_tool_call], + "images": [{"media_type": "image/png", "base64": "QUJDRA=="}], + "custom_field": "kept", + } + ] + + sanitized = sanitize_tool_calls(msgs) + assert sanitized[0]["tool_calls"][0] == serialized_tool_call + assert sanitized[0]["images"] == [{"media_type": "image/png", "base64": "QUJDRA=="}] + assert sanitized[0]["custom_field"] == "kept" + + def test_make_dataset_basic_without_tools(make_metadata, make_output): results = GenerateOutputs(outputs=[make_output()], metadata=make_metadata()) ds = build_dataset(results) diff --git a/tests/test_eval_cli.py b/tests/test_eval_cli.py index 42a62e2b7..43a4ac3c8 100644 --- a/tests/test_eval_cli.py +++ b/tests/test_eval_cli.py @@ -1,4 +1,5 @@ import argparse +import json import os import tempfile import time @@ -11,7 +12,7 @@ import verifiers.utils.eval_utils from verifiers.types import GenerateOutputs from verifiers.utils.eval_utils import load_toml_config -from verifiers.utils.save_utils import states_to_outputs +from verifiers.utils.save_utils import save_metadata, save_outputs, states_to_outputs @pytest.fixture @@ -21,6 +22,7 @@ def _run_cli( overrides, capture_all_configs: bool = False, endpoints: dict | None = None, + run_evaluation_impl=None, ): """Run CLI with mocked arguments and capture config(s). @@ -49,6 +51,7 @@ def _run_cli( "no_interleave_scoring": False, "state_columns": [], "save_results": False, + "save_image_mode": "base64", "resume": None, "save_every": -1, "save_to_hf_hub": False, @@ -73,6 +76,12 @@ def _run_cli( monkeypatch.setattr(vf_eval, "load_endpoints", lambda *_: endpoints or {}) async def fake_run_evaluation(config, **kwargs): + if run_evaluation_impl is not None: + result = await run_evaluation_impl(config, **kwargs) + captured["sampling_args"] = dict(config.sampling_args) + captured["configs"].append(config) + return result + captured["sampling_args"] = dict(config.sampling_args) captured["configs"].append(config) _make_metadata = make_metadata @@ -858,3 +867,108 @@ def test_cli_toml_resume_false_disables_global_resume(monkeypatch, run_cli): assert configs[0].resume_path is None assert configs[1].env_id == "env-b" assert configs[1].resume_path is None + + +def test_cli_save_dataset_with_base64_images( + monkeypatch, run_cli, make_metadata, make_state, tmp_path: Path +): + saved_results_path: Path | None = None + + async def fake_run_evaluation(config, **kwargs): + nonlocal saved_results_path + state = make_state( + prompt=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "question"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,QUJDRA=="}, + }, + ], + } + ], + completion=[{"role": "assistant", "content": "ok"}], + reward=1.0, + ) + + outputs = states_to_outputs( + [state], + image_mode=config.save_image_mode, + max_image_base64_chars=config.max_image_base64_chars, + ) + saved_results_path = tmp_path / "results" + metadata = make_metadata( + env_id=config.env_id, + model=config.model, + sampling_args=config.sampling_args, + num_examples=config.num_examples, + rollouts_per_example=config.rollouts_per_example, + path_to_save=saved_results_path, + save_image_mode=config.save_image_mode, + ) + if config.save_results: + save_outputs(outputs, saved_results_path) + save_metadata(metadata, saved_results_path) + return GenerateOutputs(outputs=outputs, metadata=metadata) + + run_cli( + monkeypatch, + { + "save_results": True, + "save_image_mode": "base64", + "debug": True, + }, + run_evaluation_impl=fake_run_evaluation, + ) + + assert saved_results_path is not None + results_file = saved_results_path / "results.jsonl" + assert results_file.exists() + row = json.loads(results_file.read_text(encoding="utf-8").splitlines()[0]) + assert row["prompt"][0]["content"] == "question\n\n[image]" + assert row["prompt"][0]["images"][0]["media_type"] == "image/png" + assert row["prompt"][0]["images"][0]["base64"] == "QUJDRA==" + + +def test_cli_save_dataset_base64_limit_enforced( + monkeypatch, run_cli, make_metadata, make_state +): + monkeypatch.setattr(vf_eval, "MAX_IMAGE_BASE64_CHARS", 4) + + async def fake_run_evaluation(config, **kwargs): + state = make_state( + prompt=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "question"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,QUJDRA=="}, + }, + ], + } + ], + completion=[{"role": "assistant", "content": "ok"}], + reward=1.0, + ) + outputs = states_to_outputs( + [state], + image_mode=config.save_image_mode, + max_image_base64_chars=config.max_image_base64_chars, + ) + metadata = make_metadata(save_image_mode=config.save_image_mode) + return GenerateOutputs(outputs=outputs, metadata=metadata) + + with pytest.raises(ValueError, match="exceeds max_image_base64_chars"): + run_cli( + monkeypatch, + { + "save_results": True, + "save_image_mode": "base64", + "debug": True, + }, + run_evaluation_impl=fake_run_evaluation, + ) diff --git a/tests/test_message_utils_audio.py b/tests/test_message_utils_audio.py index 00282f7c4..7f373eb79 100644 --- a/tests/test_message_utils_audio.py +++ b/tests/test_message_utils_audio.py @@ -1,5 +1,6 @@ # tests/test_message_utils_audio.py from verifiers.utils.message_utils import ( + ImageMode, message_to_printable, messages_to_printable, ) @@ -108,3 +109,39 @@ def format_prompt(example): "type": "image_url", "image_url": {"url": "data:image/png;base64,abc123"}, } + + +def test_message_to_printable_base64_mode_extracts_images(): + msg = { + "role": "user", + "content": [ + {"type": "text", "text": "question"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,QUJDRA=="}, + }, + ], + } + + out = message_to_printable(msg, image_mode=ImageMode.BASE64) + assert out["content"] == "question\n\n[image]" + assert out["images"][0]["media_type"] == "image/png" + assert out["images"][0]["base64"] == "QUJDRA==" + assert out["images"][0]["base64_chars"] == 8 + + +def test_message_to_printable_base64_mode_enforces_limit(): + msg = { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,QUJDRA=="}, + } + ], + } + + import pytest + + with pytest.raises(ValueError, match="exceeds max_image_base64_chars"): + message_to_printable(msg, image_mode="base64", max_image_base64_chars=4) diff --git a/verifiers/envs/env_group.py b/verifiers/envs/env_group.py index bf17df5ba..a80f99a6b 100644 --- a/verifiers/envs/env_group.py +++ b/verifiers/envs/env_group.py @@ -277,12 +277,22 @@ async def run_rollout( # type: ignore[override] sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + image_mode: str = "base64", + max_image_base64_chars: int | None = None, env_client: EnvClient | None = None, ) -> vf.RolloutOutput: env = self.get_env_for_task(input["task"]) env_client = env_client or env.env_client or self.env_client return await env.run_rollout( - input, client, model, sampling_args, max_retries, state_columns, env_client + input, + client, + model, + sampling_args, + max_retries, + state_columns, + image_mode, + max_image_base64_chars, + env_client, ) @final @@ -294,6 +304,8 @@ async def run_group( # type: ignore[override] sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + image_mode: str = "base64", + max_image_base64_chars: int | None = None, env_client: EnvClient | None = None, ) -> list[vf.RolloutOutput]: env = self.get_env_for_task(group_inputs[0]["task"]) @@ -305,6 +317,8 @@ async def run_group( # type: ignore[override] sampling_args, max_retries, state_columns, + image_mode, + max_image_base64_chars, env_client, ) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 32859d29f..6e54376d7 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -70,7 +70,7 @@ with_sem, ) from verifiers.utils.error_utils import ErrorChain -from verifiers.utils.message_utils import normalize_messages +from verifiers.utils.message_utils import ImageMode, normalize_messages from verifiers.utils.save_utils import ( GenerateOutputsBuilder, load_outputs, @@ -702,6 +702,8 @@ async def run_rollout( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + image_mode: str = ImageMode.BASE64.value, + max_image_base64_chars: int | None = None, env_client: EnvClient | None = None, ) -> RolloutOutput: """Generate and, optionally, score a rollout.""" @@ -723,6 +725,8 @@ async def run_rollout( sampling_args, max_retries, state_columns, + image_mode, + max_image_base64_chars, ) resolved_client = resolve_client(client) @@ -743,7 +747,12 @@ async def run_rollout_attempt() -> State: return state state = await maybe_retry(run_rollout_attempt, max_retries=max_retries)() - output = state_to_output(state, state_columns or []) + output = state_to_output( + state, + state_columns or [], + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, + ) return output @final @@ -755,6 +764,8 @@ async def run_group( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + image_mode: str = ImageMode.BASE64.value, + max_image_base64_chars: int | None = None, env_client: EnvClient | None = None, **kwargs, ) -> list[RolloutOutput]: @@ -777,6 +788,8 @@ async def run_group( sampling_args, max_retries, state_columns, + image_mode, + max_image_base64_chars, ) resolved_client = resolve_client(client) @@ -801,7 +814,13 @@ async def run_group_attempt() -> list[State]: group_states = await maybe_retry(run_group_attempt, max_retries=max_retries)() outputs = [ - state_to_output(state, state_columns or []) for state in group_states + state_to_output( + state, + state_columns or [], + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, + ) + for state in group_states ] return outputs @@ -815,6 +834,8 @@ async def generate( results_path: Path | None = None, state_columns: list[str] | None = None, save_results: bool = False, + image_mode: str = ImageMode.BASE64.value, + max_image_base64_chars: int | None = None, push_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, independent_scoring: bool = False, @@ -932,6 +953,7 @@ def default_on_progress(*a, **kw): state_columns=state_columns, sampling_args=sampling_args, results_path=results_path, + save_image_mode=image_mode, ) single_client: Client | None = None @@ -1013,6 +1035,8 @@ def get_client_for_group() -> Client | ClientConfig: sampling_args, max_retries=max_retries, state_columns=state_columns, + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, ), ), ) @@ -1039,6 +1063,8 @@ def get_client_for_group() -> Client | ClientConfig: sampling_args, max_retries=max_retries, state_columns=state_columns, + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, ), ), ) @@ -1148,6 +1174,8 @@ async def evaluate( results_path: Path | None = None, state_columns: list[str] | None = None, save_results: bool = False, + image_mode: str = ImageMode.BASE64.value, + max_image_base64_chars: int | None = None, push_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, independent_scoring: bool = False, @@ -1175,6 +1203,8 @@ async def evaluate( results_path=results_path, state_columns=state_columns, save_results=save_results, + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, push_to_hf_hub=push_to_hf_hub, hf_hub_dataset_name=hf_hub_dataset_name, independent_scoring=independent_scoring, @@ -1196,6 +1226,8 @@ def evaluate_sync( results_path: Path | None = None, state_columns: list[str] | None = None, save_results: bool = False, + image_mode: str = ImageMode.BASE64.value, + max_image_base64_chars: int | None = None, push_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, independent_scoring: bool = False, @@ -1214,6 +1246,8 @@ def evaluate_sync( results_path=results_path, state_columns=state_columns, save_results=save_results, + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, push_to_hf_hub=push_to_hf_hub, hf_hub_dataset_name=hf_hub_dataset_name, independent_scoring=independent_scoring, diff --git a/verifiers/scripts/eval.py b/verifiers/scripts/eval.py index ddfa3b8e4..016fbe196 100644 --- a/verifiers/scripts/eval.py +++ b/verifiers/scripts/eval.py @@ -34,6 +34,7 @@ ) from verifiers.utils.import_utils import load_toml from verifiers.utils.install_utils import check_hub_env_installed +from verifiers.utils.message_utils import ImageMode, coerce_image_mode logger = logging.getLogger(__name__) @@ -46,6 +47,7 @@ DEFAULT_API_KEY_VAR = "PRIME_API_KEY" DEFAULT_API_BASE_URL = "https://api.pinference.ai/api/v1" DEFAULT_CLIENT_TYPE = "openai_chat_completions" +MAX_IMAGE_BASE64_CHARS = 10_000_000 # ~7.5 MB decoded; fail-fast guard for base64 mode def get_env_eval_defaults(env_id: str) -> dict[str, Any]: @@ -232,6 +234,17 @@ def main(): action="store_true", help="Save results to disk", ) + parser.add_argument( + "--save-image-mode", + type=str, + choices=[mode.value for mode in ImageMode], + default=ImageMode.BASE64.value, + help=( + "How to serialize image content into saved results. " + "'placeholder' writes [image], 'base64' stores extracted base64 payloads " + "under each message's images field." + ), + ) parser.add_argument( "--resume", "-R", @@ -320,6 +333,13 @@ def main(): def build_eval_config(raw: dict) -> EvalConfig: """Build EvalConfig from a raw config dict.""" env_id = raw["env_id"] + image_mode = coerce_image_mode( + raw.get("save_image_mode", ImageMode.BASE64.value), + arg_name="save_image_mode", + ) + max_image_base64_chars = ( + MAX_IMAGE_BASE64_CHARS if image_mode == ImageMode.BASE64 else None + ) # Resolve num_examples and rollouts_per_example with env defaults env_defaults = get_env_eval_defaults(env_id) @@ -556,6 +576,8 @@ def build_eval_config(raw: dict) -> EvalConfig: debug=raw.get("debug", False), state_columns=raw.get("state_columns", []), save_results=raw.get("save_results", False), + save_image_mode=image_mode.value, + max_image_base64_chars=max_image_base64_chars, resume_path=resume_path, independent_scoring=raw.get("independent_scoring", False), save_to_hf_hub=raw.get("save_to_hf_hub", False), diff --git a/verifiers/types.py b/verifiers/types.py index 793bae542..aeac1fd86 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -363,6 +363,7 @@ class GenerateMetadata(TypedDict): state_columns: list[str] path_to_save: Path tools: list[Tool] | None + save_image_mode: str class GenerateOutputs(TypedDict): @@ -494,6 +495,8 @@ class EvalConfig(BaseModel): # saving state_columns: list[str] | None = None save_results: bool = False + save_image_mode: str = "base64" + max_image_base64_chars: int | None = None resume_path: Path | None = None save_to_hf_hub: bool = False hf_hub_dataset_name: str | None = None diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 9e84b8e6a..2103052fd 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -34,6 +34,7 @@ from verifiers.utils.async_utils import EventLoopLagMonitor from verifiers.utils.import_utils import load_toml from verifiers.utils.logging_utils import print_prompt_completions_sample, print_time +from verifiers.utils.message_utils import ImageMode from verifiers.utils.path_utils import get_eval_results_path logger = logging.getLogger(__name__) @@ -332,6 +333,7 @@ def load_toml_config(path: Path) -> list[dict]: # saving "state_columns", "save_results", + "save_image_mode", "resume", "resume_path", "save_to_hf_hub", @@ -605,6 +607,16 @@ async def run_evaluation( logger.debug( f"Configuration: num_examples={config.num_examples}, rollouts_per_example={config.rollouts_per_example}, max_concurrent={config.max_concurrent}" ) + save_image_mode = ( + config.save_image_mode + if config.save_results + else ImageMode.PLACEHOLDER.value + ) + max_image_base64_chars = ( + config.max_image_base64_chars + if config.save_results and save_image_mode == ImageMode.BASE64.value + else None + ) effective_group_max_concurrent = config.max_concurrent if ( @@ -632,6 +644,8 @@ async def run_evaluation( results_path=results_path, state_columns=config.state_columns, save_results=config.save_results, + image_mode=save_image_mode, + max_image_base64_chars=max_image_base64_chars, push_to_hf_hub=config.save_to_hf_hub, hf_hub_dataset_name=config.hf_hub_dataset_name, independent_scoring=config.independent_scoring, diff --git a/verifiers/utils/message_utils.py b/verifiers/utils/message_utils.py index da8994f6f..07f796674 100644 --- a/verifiers/utils/message_utils.py +++ b/verifiers/utils/message_utils.py @@ -1,5 +1,8 @@ +import base64 +import binascii import json from collections.abc import Mapping +from enum import Enum from typing import Any, cast from rich.text import Text @@ -19,6 +22,26 @@ ) +class ImageMode(str, Enum): + PLACEHOLDER = "placeholder" + BASE64 = "base64" + + +def coerce_image_mode( + image_mode: str | ImageMode, *, arg_name: str = "image_mode" +) -> ImageMode: + """Convert a string to ImageMode with a helpful error.""" + if isinstance(image_mode, ImageMode): + return image_mode + try: + return ImageMode(image_mode) + except ValueError as exc: + valid_modes = "', '".join(mode.value for mode in ImageMode) + raise ValueError( + f"Invalid {arg_name}: {image_mode}. Expected one of '{valid_modes}'." + ) from exc + + def from_raw_content_part(part: dict[str, Any]) -> ContentPart: """Convert a raw content-part dict to a typed content part when possible.""" part_type = part.get("type") @@ -146,21 +169,80 @@ def concat_messages(messages_list: list[Messages]) -> Messages: return result -def message_to_printable(message: Any) -> Any: - """ - Removes image_url objects from message content. - Replaces audio parts with a short placeholder to keep logs readable. +def _extract_data_uri_base64(url: str) -> tuple[str, str]: + if not url.startswith("data:"): + raise ValueError( + f"Image URLs must be data URIs when image_mode='base64'. Got: {url[:64]}" + ) + if "," not in url: + raise ValueError("Invalid data URI: missing comma separator") + header, payload = url.split(",", 1) + if ";base64" not in header: + raise ValueError("Data URI must include ';base64' when image_mode='base64'") + media_type = header.removeprefix("data:").split(";", 1)[0] + if not media_type.startswith("image/"): + raise ValueError(f"Expected image/* media type in data URI, got: {media_type}") + if payload == "": + raise ValueError("Data URI payload is empty") + try: + base64.b64decode(payload, validate=True) + except binascii.Error as exc: + raise ValueError("Data URI payload is not valid base64") from exc + return media_type, payload + + +def _extract_image_payload( + part: dict[str, Any], max_image_base64_chars: int | None +) -> dict[str, str | int]: + image_url_obj = part.get("image_url") + if isinstance(image_url_obj, dict): + url = image_url_obj.get("url") + else: + url = getattr(image_url_obj, "url", None) + if not isinstance(url, str): + raise ValueError("image_url content block must contain a string URL") + + media_type, payload = _extract_data_uri_base64(url) + payload_size = len(payload) + if max_image_base64_chars is not None and payload_size > max_image_base64_chars: + raise ValueError( + f"Image base64 payload exceeds max_image_base64_chars: {payload_size} > {max_image_base64_chars}" + ) + return { + "media_type": media_type, + "base64": payload, + "base64_chars": payload_size, + } + + +def message_to_printable( + message: Any, + image_mode: str | ImageMode = ImageMode.PLACEHOLDER, + max_image_base64_chars: int | None = None, +) -> Any: + """Convert message content into log/save-friendly text placeholders. + + - text parts are preserved + - input_audio parts are rendered as [audio] + - image_url parts are rendered as [image] + - in base64 mode, extracted image payloads are emitted under `images` """ + image_mode = coerce_image_mode(image_mode) + if isinstance(message, dict): role = message.get("role") content = message.get("content") reasoning_content = message.get("reasoning_content") tool_calls = message.get("tool_calls") + if isinstance(content, list): chunks: list[str] = [] + images: list[dict[str, str | int]] = [] for part in content: if not isinstance(part, dict): + chunks.append(str(part)) continue + part_type = part.get("type") if part_type == "text": text = part.get("text") @@ -170,15 +252,23 @@ def message_to_printable(message: Any) -> Any: chunks.append("[audio]") elif part_type == "image_url": chunks.append("[image]") + if image_mode == ImageMode.BASE64: + images.append( + _extract_image_payload(part, max_image_base64_chars) + ) + printable: dict[str, Any] = { "role": role, - "content": " ".join(chunks).strip(), + "content": "\n\n".join(chunks).strip(), } + if images: + printable["images"] = images if isinstance(reasoning_content, str): printable["reasoning_content"] = reasoning_content if tool_calls is not None: printable["tool_calls"] = tool_calls return printable + return message content = getattr(message, "content", None) @@ -188,20 +278,55 @@ def message_to_printable(message: Any) -> Any: if hasattr(message, "model_dump") else {"content": content} ) - printable = message_to_printable(raw) + printable = message_to_printable( + raw, + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, + ) if hasattr(message, "model_copy"): return message.model_copy(update={"content": printable.get("content", "")}) return printable return message -def messages_to_printable(messages: Any) -> Any: - """ - Removes image_url objects from messages. - """ +def messages_to_printable( + messages: Any, + image_mode: str | ImageMode = ImageMode.PLACEHOLDER, + max_image_base64_chars: int | None = None, +) -> Any: + """Convert messages to printable/saveable form.""" if isinstance(messages, str): return messages - return [message_to_printable(m) for m in messages or []] + return [ + message_to_printable( + m, + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, + ) + for m in messages or [] + ] + + +def strip_nones_from_content(messages: list[Any]) -> list[Any]: + """Strip None-valued keys from content parts (HF map schema normalization helper).""" + result: list[Any] = [] + for msg in messages: + if not isinstance(msg, dict): + result.append(msg) + continue + content = msg.get("content") + if isinstance(content, list): + new_msg = dict(msg) + new_msg["content"] = [ + {k: v for k, v in c.items() if v is not None} + if isinstance(c, dict) + else c + for c in content + ] + result.append(new_msg) + else: + result.append(msg) + return result # --- Legacy utilities (still used by save_utils, trainer, logging) --- @@ -272,7 +397,7 @@ def sanitize_tool_calls(messages: Messages): """Sanitize tool calls from messages for serialization. Used by save_utils and trainer to convert tool call objects to JSON strings. - Works with both Pydantic messages and legacy dicts. + Works with both Pydantic message objects and legacy dicts. """ if not isinstance(messages, list): return messages @@ -289,29 +414,27 @@ def sanitize_tool_calls(messages: Messages): if tool_calls: tool_calls_json = [] for tc in tool_calls: + if isinstance(tc, str): + tool_calls_json.append(tc) + continue if isinstance(tc, dict): tc_dict = tc - elif isinstance(tc, str): - tc_dict = json.loads(tc) else: model_dump = getattr(tc, "model_dump", None) assert model_dump is not None tc_dict = model_dump(exclude_none=True) tool_calls_json.append(json.dumps(tc_dict)) if isinstance(m, dict): - new_m = { - "role": m["role"], - "content": m.get("content", ""), - "tool_calls": tool_calls_json, - } + new_m = dict(m) + new_m["tool_calls"] = tool_calls_json else: new_m = { "role": m.role, "content": m.content or "", "tool_calls": tool_calls_json, } - if isinstance(reasoning_content, str): - new_m["reasoning_content"] = reasoning_content + if isinstance(reasoning_content, str): + new_m["reasoning_content"] = reasoning_content sanitized_messages.append(new_m) else: sanitized_messages.append(m) diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index bc7f8257c..aa68b9163 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -23,7 +23,11 @@ Tool, ) from verifiers.utils.error_utils import ErrorChain -from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls +from verifiers.utils.message_utils import ( + ImageMode, + messages_to_printable, + sanitize_tool_calls, +) from verifiers.utils.path_utils import get_results_path from verifiers.utils.usage_utils import ( StateUsageTracker, @@ -137,7 +141,10 @@ def get_hf_hub_dataset_name(outputs: GenerateOutputs) -> str: def state_to_output( - state: State, state_columns: list[str] | None = None + state: State, + state_columns: list[str] | None = None, + image_mode: str | ImageMode = ImageMode.PLACEHOLDER, + max_image_base64_chars: int | None = None, ) -> RolloutOutput: """Convert a State to a serializable RolloutOutput. @@ -195,12 +202,22 @@ def state_to_output( # sanitize messages (handle None for error cases) prompt = state.get("prompt") if prompt is not None: - output_prompt = sanitize_tool_calls(messages_to_printable(prompt)) - output["prompt"] = output_prompt + output["prompt"] = sanitize_tool_calls( + messages_to_printable( + prompt, + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, + ) + ) completion = state.get("completion") if completion is not None: - output_completion = sanitize_tool_calls(messages_to_printable(completion)) - output["completion"] = output_completion + output["completion"] = sanitize_tool_calls( + messages_to_printable( + completion, + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, + ) + ) # use repr for error if state.get("error") is not None: error_chain = ErrorChain(state.get("error")) @@ -234,10 +251,21 @@ def state_to_output( def states_to_outputs( - states: list[State], state_columns: list[str] | None = None + states: list[State], + state_columns: list[str] | None = None, + image_mode: str | ImageMode = ImageMode.PLACEHOLDER, + max_image_base64_chars: int | None = None, ) -> list[RolloutOutput]: """Convert a list of States to serializable RolloutOutputs.""" - return [state_to_output(state, state_columns) for state in states] + return [ + state_to_output( + state, + state_columns, + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, + ) + for state in states + ] class GenerateOutputsBuilder: @@ -254,6 +282,7 @@ def __init__( state_columns: list[str] | None, sampling_args: SamplingArgs, results_path: Path | None, + save_image_mode: str = ImageMode.BASE64.value, ): self.env_id = env_id self.env_args = env_args @@ -264,6 +293,7 @@ def __init__( self.state_columns = state_columns or [] self.sampling_args = sampling_args self.results_path = results_path or get_results_path(env_id, model) + self.save_image_mode = save_image_mode self.start_time = time.time() self.base_url = self._compute_base_url(self.client) self.version_info = get_version_info(env_id=env_id) @@ -376,6 +406,7 @@ def tools_key(tools: list[Tool] | None) -> str: state_columns=self.state_columns, path_to_save=self.results_path, tools=tools, + save_image_mode=self.save_image_mode, ) def build_outputs(self, sort_by_example_id: bool = False) -> list[RolloutOutput]: diff --git a/verifiers/workers/client/env_client.py b/verifiers/workers/client/env_client.py index 3a0e6c806..8621d9247 100644 --- a/verifiers/workers/client/env_client.py +++ b/verifiers/workers/client/env_client.py @@ -50,6 +50,8 @@ async def run_rollout( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + image_mode: str = "base64", + max_image_base64_chars: int | None = None, ) -> RolloutOutput: resolved_client_config = resolve_client_config(client_config) request = RunRolloutRequest( @@ -59,6 +61,8 @@ async def run_rollout( sampling_args=sampling_args, max_retries=max_retries, state_columns=state_columns, + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, ) response = await self.handle_run_rollout_request(request, timeout=None) assert response.output is not None @@ -72,6 +76,8 @@ async def run_group( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + image_mode: str = "base64", + max_image_base64_chars: int | None = None, ) -> list[RolloutOutput]: resolved_client_config = resolve_client_config(client_config) request = RunGroupRequest( @@ -81,6 +87,8 @@ async def run_group( sampling_args=sampling_args, max_retries=max_retries, state_columns=state_columns, + image_mode=image_mode, + max_image_base64_chars=max_image_base64_chars, ) response = await self.handle_run_group_request(request, timeout=None) assert response.outputs is not None diff --git a/verifiers/workers/server/env_server.py b/verifiers/workers/server/env_server.py index f0af19a45..6f8da2999 100644 --- a/verifiers/workers/server/env_server.py +++ b/verifiers/workers/server/env_server.py @@ -122,6 +122,8 @@ async def handle_run_rollout( sampling_args=request.sampling_args, max_retries=request.max_retries, state_columns=request.state_columns, + image_mode=request.image_mode, + max_image_base64_chars=request.max_image_base64_chars, ) return RunRolloutResponse(output=output) @@ -134,6 +136,8 @@ async def handle_run_group(self, request: RunGroupRequest) -> RunGroupResponse: sampling_args=request.sampling_args, max_retries=request.max_retries, state_columns=request.state_columns, + image_mode=request.image_mode, + max_image_base64_chars=request.max_image_base64_chars, ) return RunGroupResponse(outputs=outputs) diff --git a/verifiers/workers/types.py b/verifiers/workers/types.py index 25ddbbc94..e76a8007a 100644 --- a/verifiers/workers/types.py +++ b/verifiers/workers/types.py @@ -49,6 +49,8 @@ class RunRolloutRequest(BaseRequest): sampling_args: SamplingArgs max_retries: int state_columns: list[str] | None + image_mode: str = "base64" + max_image_base64_chars: int | None = None class RunRolloutResponse(BaseResponse): @@ -66,6 +68,8 @@ class RunGroupRequest(BaseRequest): sampling_args: SamplingArgs max_retries: int state_columns: list[str] | None + image_mode: str = "base64" + max_image_base64_chars: int | None = None class RunGroupResponse(BaseResponse): From 5120c9260e899dab8dc4424404fbe35c6044594f Mon Sep 17 00:00:00 2001 From: d42me Date: Wed, 18 Feb 2026 22:05:27 -0800 Subject: [PATCH 2/3] Update rollout handling to use validate env level default.s --- tests/test_environment_extra.py | 70 +++++++++++++++++++++++++++++++-- verifiers/envs/env_group.py | 22 ++++++++++- verifiers/envs/environment.py | 66 ++++++++++++++++++++++++++++--- verifiers/utils/eval_utils.py | 34 ++++++++-------- 4 files changed, 165 insertions(+), 27 deletions(-) diff --git a/tests/test_environment_extra.py b/tests/test_environment_extra.py index 3de58e4d5..9debc4bba 100644 --- a/tests/test_environment_extra.py +++ b/tests/test_environment_extra.py @@ -313,6 +313,70 @@ async def test_generate_inside_running_loop(mock_client, make_dummy_env, make_in assert states[0].get("completion") is not None +@pytest.mark.asyncio +async def test_generate_uses_env_image_mode_setting_for_https_image_urls( + mock_openai_client, make_dummy_env, make_input +): + env = make_dummy_env(mock_openai_client) + env.set_kwargs(image_mode="placeholder") + image_prompt: vf.Messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe this image"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/sample.png"}, + }, + ], + } + ] + + outputs = await env.generate( + [make_input(example_id=0, prompt=image_prompt)], + client=mock_openai_client, + model="test-model", + ) + + prompt = outputs["outputs"][0]["prompt"] + assert isinstance(prompt, list) + assert prompt[0]["content"] == "describe this image\n\n[image]" + assert "images" not in prompt[0] + + +@pytest.mark.asyncio +async def test_generate_explicit_image_mode_overrides_env_setting( + mock_openai_client, make_dummy_env, make_input +): + env = make_dummy_env(mock_openai_client) + env.set_kwargs(image_mode="placeholder") + image_prompt: vf.Messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe this image"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,QUJDRA=="}, + }, + ], + } + ] + + outputs = await env.generate( + [make_input(example_id=0, prompt=image_prompt)], + client=mock_openai_client, + model="test-model", + image_mode="base64", + ) + + prompt = outputs["outputs"][0]["prompt"] + assert isinstance(prompt, list) + assert prompt[0]["content"] == "describe this image\n\n[image]" + assert prompt[0]["images"][0]["media_type"] == "image/png" + assert prompt[0]["images"][0]["base64"] == "QUJDRA==" + + @pytest.mark.asyncio async def test_generate_grouped_scoring_distributes_per_group( mock_client, make_dummy_env, make_input @@ -329,7 +393,7 @@ async def run_group( sampling_args, max_retries, state_columns, - image_mode="placeholder", + image_mode="base64", max_image_base64_chars=None, ): assert isinstance(client_config, ClientConfig) @@ -426,7 +490,7 @@ async def run_group( sampling_args, max_retries, state_columns, - image_mode="placeholder", + image_mode="base64", max_image_base64_chars=None, ): assert isinstance(client_config, ClientConfig) @@ -487,7 +551,7 @@ async def run_rollout( sampling_args, max_retries, state_columns, - image_mode="placeholder", + image_mode="base64", max_image_base64_chars=None, ): assert isinstance(client_config, ClientConfig) diff --git a/verifiers/envs/env_group.py b/verifiers/envs/env_group.py index a80f99a6b..97c86a8a5 100644 --- a/verifiers/envs/env_group.py +++ b/verifiers/envs/env_group.py @@ -277,7 +277,7 @@ async def run_rollout( # type: ignore[override] sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, - image_mode: str = "base64", + image_mode: str | None = None, max_image_base64_chars: int | None = None, env_client: EnvClient | None = None, ) -> vf.RolloutOutput: @@ -304,7 +304,7 @@ async def run_group( # type: ignore[override] sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, - image_mode: str = "base64", + image_mode: str | None = None, max_image_base64_chars: int | None = None, env_client: EnvClient | None = None, ) -> list[vf.RolloutOutput]: @@ -342,6 +342,24 @@ def set_max_seq_len(self, max_seq_len: int | None) -> None: for env in self.envs: env.set_max_seq_len(max_seq_len) + def set_image_mode(self, image_mode: str) -> None: + """Set image output serialization mode for this group and all sub-environments.""" + super().set_image_mode(image_mode) + for env in self.envs: + env.set_image_mode(image_mode) + + def set_max_image_base64_chars(self, max_image_base64_chars: int | None) -> None: + """Set max image base64 payload chars for this group and all sub-environments.""" + super().set_max_image_base64_chars(max_image_base64_chars) + for env in self.envs: + env.set_max_image_base64_chars(max_image_base64_chars) + + def set_interleaved_rollouts(self, interleaved_rollouts: bool) -> None: + """Set the interleaved_rollouts flag for this environment group and all sub-environments.""" + self.interleaved_rollouts = interleaved_rollouts + for env in self.envs: + env.set_interleaved_rollouts(interleaved_rollouts) + def set_score_rollouts(self, score_rollouts: bool) -> None: """Set the score_rollouts flag for this environment group and all sub-environments.""" self.score_rollouts = score_rollouts diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 6e54376d7..34acca74c 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -70,7 +70,11 @@ with_sem, ) from verifiers.utils.error_utils import ErrorChain -from verifiers.utils.message_utils import ImageMode, normalize_messages +from verifiers.utils.message_utils import ( + ImageMode, + coerce_image_mode, + normalize_messages, +) from verifiers.utils.save_utils import ( GenerateOutputsBuilder, load_outputs, @@ -143,6 +147,8 @@ def __init__( self.env_args = env_args or {} self.max_seq_len = max_seq_len self.map_kwargs = map_kwargs + self.image_mode = ImageMode.BASE64.value + self.max_image_base64_chars: int | None = None self.set_score_rollouts(score_rollouts) @@ -683,6 +689,27 @@ async def _render_timing(self, state: State): state["timing"]["generation_ms"] = (end_time - start_time) * 1000 state["timing"]["total_ms"] = (end_time - start_time) * 1000 + def _resolve_output_image_options( + self, + image_mode: str | None, + max_image_base64_chars: int | None, + ) -> tuple[str, int | None]: + resolved_image_mode = self.image_mode if image_mode is None else image_mode + resolved_image_mode = coerce_image_mode( + resolved_image_mode, arg_name="image_mode" + ).value + + resolved_max_image_base64_chars = max_image_base64_chars + if max_image_base64_chars is None: + resolved_max_image_base64_chars = self.max_image_base64_chars + if ( + resolved_max_image_base64_chars is not None + and resolved_max_image_base64_chars < 0 + ): + raise ValueError("max_image_base64_chars must be >= 0 or None") + + return resolved_image_mode, resolved_max_image_base64_chars + @final async def is_completed(self, state: State, **kwargs) -> bool: """Check all stop conditions. Sets state.is_completed=True if any condition is met.""" @@ -702,11 +729,14 @@ async def run_rollout( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, - image_mode: str = ImageMode.BASE64.value, + image_mode: str | None = None, max_image_base64_chars: int | None = None, env_client: EnvClient | None = None, ) -> RolloutOutput: """Generate and, optionally, score a rollout.""" + image_mode, max_image_base64_chars = self._resolve_output_image_options( + image_mode, max_image_base64_chars + ) resolved_client_config: ClientConfig | None = None if isinstance(client, ClientConfig): @@ -764,12 +794,15 @@ async def run_group( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, - image_mode: str = ImageMode.BASE64.value, + image_mode: str | None = None, max_image_base64_chars: int | None = None, env_client: EnvClient | None = None, **kwargs, ) -> list[RolloutOutput]: """Generate and, optionally, score one group.""" + image_mode, max_image_base64_chars = self._resolve_output_image_options( + image_mode, max_image_base64_chars + ) resolved_client_config: ClientConfig | None = None if isinstance(client, ClientConfig): @@ -834,7 +867,7 @@ async def generate( results_path: Path | None = None, state_columns: list[str] | None = None, save_results: bool = False, - image_mode: str = ImageMode.BASE64.value, + image_mode: str | None = None, max_image_base64_chars: int | None = None, push_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, @@ -938,6 +971,9 @@ def default_on_progress(*a, **kw): if sampling_args is not None: default_sampling_args.update(sampling_args) sampling_args = default_sampling_args + image_mode, max_image_base64_chars = self._resolve_output_image_options( + image_mode, max_image_base64_chars + ) # initialize outputs builder total_rollouts = len(raw_inputs) @@ -1174,7 +1210,7 @@ async def evaluate( results_path: Path | None = None, state_columns: list[str] | None = None, save_results: bool = False, - image_mode: str = ImageMode.BASE64.value, + image_mode: str | None = None, max_image_base64_chars: int | None = None, push_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, @@ -1226,7 +1262,7 @@ def evaluate_sync( results_path: Path | None = None, state_columns: list[str] | None = None, save_results: bool = False, - image_mode: str = ImageMode.BASE64.value, + image_mode: str | None = None, max_image_base64_chars: int | None = None, push_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, @@ -1287,6 +1323,24 @@ def set_score_rollouts(self, score_rollouts: bool) -> None: """Set the score rollouts flag for this environment.""" self.score_rollouts = score_rollouts + def set_image_mode(self, image_mode: str | ImageMode) -> None: + """Set how image content is serialized in rollout outputs.""" + self.image_mode = coerce_image_mode(image_mode, arg_name="image_mode").value + + def set_max_image_base64_chars(self, max_image_base64_chars: int | None) -> None: + """Set the max allowed image base64 payload length in output serialization.""" + if max_image_base64_chars is not None and max_image_base64_chars < 0: + raise ValueError("max_image_base64_chars must be >= 0 or None") + self.max_image_base64_chars = max_image_base64_chars + + def set_interleaved_rollouts(self, interleaved_rollouts: bool) -> None: + """Set the interleaved rollouts flag for this environment.""" + self.interleaved_rollouts = interleaved_rollouts + if self.interleaved_rollouts: + self.logger.warning( + f"{self.__class__.__name__} is configured to use interleaved rollouts. All model responses after the first turn will be pre-tokenized before being sent to the model. Currently, this is a hand-crafted feature for PRIME-RL's vLLM server extension." + ) + async def start_server( self, address: str | None = None, diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 2103052fd..64133e3a9 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -578,24 +578,38 @@ async def run_evaluation( # load environment vf_env = vf.load_environment(env_id=config.env_id, **config.env_args) - # set extra environment kwargs + save_image_mode = ( + config.save_image_mode if config.save_results else ImageMode.PLACEHOLDER.value + ) + max_image_base64_chars = ( + config.max_image_base64_chars + if config.save_results and save_image_mode == ImageMode.BASE64.value + else None + ) + runtime_env_kwargs = { + **config.extra_env_kwargs, + "image_mode": save_image_mode, + "max_image_base64_chars": max_image_base64_chars, + } + + # set runtime environment kwargs once (local + server env) if config.extra_env_kwargs: logger.info(f"Setting extra environment kwargs: {config.extra_env_kwargs}") - vf_env.set_kwargs(**config.extra_env_kwargs) + vf_env.set_kwargs(**runtime_env_kwargs) results_path = config.resume_path or get_eval_results_path(config) try: if config.debug: await vf_env.start_server( - extra_env_kwargs=config.extra_env_kwargs, + extra_env_kwargs=runtime_env_kwargs, log_level=get_log_level(config.verbose), ) else: log_file = results_path / "eval.log" log_file.parent.mkdir(parents=True, exist_ok=True) await vf_env.start_server( - extra_env_kwargs=config.extra_env_kwargs, + extra_env_kwargs=runtime_env_kwargs, log_level="CRITICAL", # disable console logging log_file=str(log_file), log_file_level=get_log_level(config.verbose), @@ -607,16 +621,6 @@ async def run_evaluation( logger.debug( f"Configuration: num_examples={config.num_examples}, rollouts_per_example={config.rollouts_per_example}, max_concurrent={config.max_concurrent}" ) - save_image_mode = ( - config.save_image_mode - if config.save_results - else ImageMode.PLACEHOLDER.value - ) - max_image_base64_chars = ( - config.max_image_base64_chars - if config.save_results and save_image_mode == ImageMode.BASE64.value - else None - ) effective_group_max_concurrent = config.max_concurrent if ( @@ -644,8 +648,6 @@ async def run_evaluation( results_path=results_path, state_columns=config.state_columns, save_results=config.save_results, - image_mode=save_image_mode, - max_image_base64_chars=max_image_base64_chars, push_to_hf_hub=config.save_to_hf_hub, hf_hub_dataset_name=config.hf_hub_dataset_name, independent_scoring=config.independent_scoring, From d6bfbdb3ad3541acf8e688309a34367942ab22b5 Mon Sep 17 00:00:00 2001 From: d42me Date: Wed, 18 Feb 2026 22:18:05 -0800 Subject: [PATCH 3/3] Fix type-checking for asyncio.to_thread rmtree call --- verifiers/envs/experimental/rlm_env.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/verifiers/envs/experimental/rlm_env.py b/verifiers/envs/experimental/rlm_env.py index 1e1022594..3bf246225 100644 --- a/verifiers/envs/experimental/rlm_env.py +++ b/verifiers/envs/experimental/rlm_env.py @@ -1434,7 +1434,10 @@ async def cleanup(self, state: State) -> None: if session.sandbox_id: await self.delete_sandbox(session.sandbox_id) - await asyncio.to_thread(shutil.rmtree, session.local_rollout_dir, True) + await asyncio.to_thread( + lambda path: shutil.rmtree(path, ignore_errors=True), + session.local_rollout_dir, + ) async def teardown(self) -> None: if self._sessions: