diff --git a/environments/tool_test/tool_test.py b/environments/tool_test/tool_test.py index 5bd38331f..9b9afca1b 100644 --- a/environments/tool_test/tool_test.py +++ b/environments/tool_test/tool_test.py @@ -3,9 +3,11 @@ from datasets import Dataset import verifiers as vf +from verifiers.utils.tool_registry import register_tool # dummy tools for sanity checking parallel tool calls +@register_tool("tool-test", "tool_A") async def tool_A(x: int) -> int: """ Tool for adding 1 to an integer. @@ -19,6 +21,7 @@ async def tool_A(x: int) -> int: return x + 1 +@register_tool("tool-test", "tool_B") async def tool_B(x: str) -> str: """ Tool for concatenating a string with "2". @@ -32,6 +35,7 @@ async def tool_B(x: str) -> str: return x + "2" +@register_tool("tool-test", "tool_C") async def tool_C(x: float) -> float: """ Tool for adding 3.0 to a float. @@ -45,6 +49,7 @@ async def tool_C(x: float) -> float: return x + 3.0 +@register_tool("tool-test", "tool_D") async def tool_D(x: bool) -> bool: """ Tool for negating a boolean. @@ -58,8 +63,7 @@ async def tool_D(x: bool) -> bool: return not x -tool_list = [tool_A, tool_B, tool_C, tool_D] -tool_name_list = [tool.__name__ for tool in tool_list] +DEFAULT_TOOL_LIST = [tool_A, tool_B, tool_C, tool_D] def tool_call_reward_func(completion, info): @@ -76,17 +80,42 @@ def tool_call_reward_func(completion, info): def load_environment( - num_train_examples: int = 1000, num_eval_examples: int = 100 + num_train_examples: int = 1000, + num_eval_examples: int = 100, + tools: list | None = None, ) -> vf.ToolEnv: """ Loads tool-test environment. """ + # Use provided tools or fall back to default + if tools is None: + tools = DEFAULT_TOOL_LIST + + # Extract tool names from ACTUAL tools being used (not hardcoded list) + actual_tool_names = [tool.__name__ for tool in tools] + + # Handle empty tools case + if not actual_tool_names: + # Create empty datasets when no tools available + dataset = Dataset.from_list([]) + eval_dataset = Dataset.from_list([]) + rubric = vf.Rubric(funcs=[tool_call_reward_func]) + vf_env = vf.ToolEnv( + dataset=dataset, + eval_dataset=eval_dataset, + rubric=rubric, + tools=tools, + max_turns=1, + ) + return vf_env + train_rows = [] eval_rows = [] for i in range(num_train_examples + num_eval_examples): + # Sample from actual available tools only tool_names = random.sample( - tool_name_list, random.randint(1, len(tool_name_list)) + actual_tool_names, random.randint(1, len(actual_tool_names)) ) prompt = [ { @@ -107,7 +136,7 @@ def load_environment( dataset=dataset, eval_dataset=eval_dataset, rubric=rubric, - tools=tool_list, + tools=tools, max_turns=1, ) return vf_env diff --git a/packages/verifiers-rl/verifiers_rl/scripts/train.py b/packages/verifiers-rl/verifiers_rl/scripts/train.py index 5c8e0be21..d5f3255b7 100644 --- a/packages/verifiers-rl/verifiers_rl/scripts/train.py +++ b/packages/verifiers-rl/verifiers_rl/scripts/train.py @@ -1,4 +1,5 @@ import argparse +import logging from pathlib import Path try: @@ -9,6 +10,8 @@ import verifiers as vf from verifiers_rl.rl.trainer import RLConfig, RLTrainer +logger = logging.getLogger("verifiers_rl.scripts.train") + def main() -> None: parser = argparse.ArgumentParser() @@ -27,9 +30,33 @@ def main() -> None: config = tomllib.load(f) model = config["model"] - env_id = config["env"]["id"] - env_args = config["env"].get("args", {}) - env = vf.load_environment(env_id=env_id, **env_args) + + # Handle both [[env]] array syntax (configs/rl/*.toml) and [env] dict syntax (configs/local/vf-rl/*.toml) + env_section = config["env"] + if isinstance(env_section, list): + # [[env]] array - use first environment + env_config = env_section[0] + if len(env_section) > 1: + logger.warning(f"Multiple environments in config, using first: {env_config['id']}") + else: + # [env] dict - single environment + env_config = env_section + + env_id = env_config["id"] + env_args = env_config.get("args", {}) + + # Extract tools from config (will be resolved by env_utils.py after environment import) + if "tools" in env_config: + tool_names = env_config["tools"] + if not isinstance(tool_names, list): + raise ValueError( + f"env.tools must be list of tool names, got {type(tool_names).__name__}" + ) + tools = tool_names # Pass as-is, let env_utils.py resolve after import + else: + tools = None + + env = vf.load_environment(env_id=env_id, tools=tools, **env_args) rl_config = RLConfig(**config["trainer"].get("args", {})) trainer = RLTrainer(model=model, env=env, args=rl_config) trainer.train() diff --git a/tests/test_env_utils_tools.py b/tests/test_env_utils_tools.py new file mode 100644 index 000000000..68029b4de --- /dev/null +++ b/tests/test_env_utils_tools.py @@ -0,0 +1,183 @@ +""" +Integration tests for env_utils tool resolution functionality + +Tests cover: +- Loading environment with string tools (resolved via registry) +- Loading environment with callable tools (passed through) +- Loading environment with no tools (backward compatibility) +- Error handling for mixed tool types +""" + +import pytest + +from verifiers.utils.env_utils import load_environment +from verifiers.utils.tool_registry import register_tool + + +@pytest.fixture(autouse=True) +def clear_registry(): + """Clear the registry before and after each test.""" + from verifiers.utils import tool_registry + + # Clear before test + tool_registry._tool_registry.clear() + yield + # Clear after test + tool_registry._tool_registry.clear() + + +def test_load_environment_with_string_tools(clear_registry): + """Test loading environment with string tool names (registry resolution).""" + + # Register test tools + @register_tool("tool-test", "test_tool_a") + async def test_tool_a(x: int) -> int: + return x + 1 + + @register_tool("tool-test", "test_tool_b") + async def test_tool_b(x: str) -> str: + return x + "suffix" + + # Load environment with string tools + env = load_environment("tool-test", tools=["test_tool_a", "test_tool_b"]) + + # Verify tools were resolved and attached + assert hasattr(env, "tools") + assert len(env.tools) == 2 + assert test_tool_a in env.tools + assert test_tool_b in env.tools + + +def test_load_environment_with_callable_tools(clear_registry): + """Test loading environment with callable tools (direct pass-through).""" + + # Define test tools + async def direct_tool_a(x: int) -> int: + return x + 1 + + async def direct_tool_b(x: str) -> str: + return x + "suffix" + + # Load environment with callable tools + env = load_environment( + "tool-test", tools=[direct_tool_a, direct_tool_b] + ) + + # Verify tools were passed through + assert hasattr(env, "tools") + assert len(env.tools) == 2 + assert direct_tool_a in env.tools + assert direct_tool_b in env.tools + + +def test_load_environment_no_tools(clear_registry): + """Test loading environment without tools parameter (backward compatibility).""" + + # Load environment without tools parameter + env = load_environment("tool-test") + + # Verify environment loaded with default tools + assert hasattr(env, "tools") + # tool-test environment has 4 default tools + assert len(env.tools) == 4 + + +def test_load_environment_empty_tool_list(clear_registry): + """Test loading environment with empty tool list.""" + + # Load environment with empty tools list + env = load_environment("tool-test", tools=[]) + + # Verify environment has no tools + assert hasattr(env, "tools") + assert len(env.tools) == 0 + + +def test_mixed_tool_types_error(clear_registry): + """Test that mixing Callable and str tools raises TypeError.""" + + # Define a callable tool + async def my_tool(x: int) -> int: + return x + 1 + + # Register a string tool + @register_tool("tool-test", "registered_tool") + async def registered_tool() -> str: + return "registered" + + # Attempt to load with mixed types - should raise TypeError + with pytest.raises(TypeError, match="tools must be all Callable or all str"): + load_environment("tool-test", tools=[my_tool, "registered_tool"]) + + with pytest.raises(TypeError, match="tools must be all Callable or all str"): + load_environment("tool-test", tools=["registered_tool", my_tool]) + + +def test_invalid_tool_name_in_registry(clear_registry): + """Test that unregistered tool name raises KeyError from registry.""" + + # Register one tool so environment exists in registry + @register_tool("tool-test", "valid_tool") + async def valid_tool(x: int) -> int: + return x + 1 + + # Try to load with a different, unregistered tool name + with pytest.raises(KeyError, match=r"Tools \['nonexistent_tool'\] not found"): + load_environment("tool-test", tools=["valid_tool", "nonexistent_tool"]) + + +def test_invalid_tool_type_error(clear_registry): + """Test that invalid tool type raises TypeError.""" + + # Load with invalid tool type (int, not Callable or str) + with pytest.raises(TypeError, match="tools must be list of Callable or list of str"): + load_environment("tool-test", tools=[123, 456]) + + +def test_environment_with_other_args(clear_registry): + """Test that tools parameter works alongside other environment arguments.""" + + # Register a tool + @register_tool("tool-test", "custom_tool") + async def custom_tool() -> str: + return "custom" + + # Load environment with tools and other args + env = load_environment( + "tool-test", + tools=["custom_tool"], + num_train_examples=50, + num_eval_examples=10, + ) + + # Verify both tools and other args were applied + assert hasattr(env, "tools") + assert len(env.tools) >= 1 + assert custom_tool in env.tools + # num_train_examples should affect dataset size + # (actual value depends on tool-test env implementation) + + +def test_single_string_tool(clear_registry): + """Test loading environment with single string tool.""" + + @register_tool("tool-test", "single_tool") + async def single_tool(x: int) -> int: + return x * 2 + + env = load_environment("tool-test", tools=["single_tool"]) + + assert hasattr(env, "tools") + assert single_tool in env.tools + + +def test_single_callable_tool(clear_registry): + """Test loading environment with single callable tool.""" + + async def my_tool() -> str: + return "result" + + env = load_environment("tool-test", tools=[my_tool]) + + assert hasattr(env, "tools") + assert my_tool in env.tools diff --git a/tests/test_tool_registry.py b/tests/test_tool_registry.py new file mode 100644 index 000000000..eb4c7c497 --- /dev/null +++ b/tests/test_tool_registry.py @@ -0,0 +1,183 @@ +""" +Unit tests for tool_registry module + +Tests cover: +- Tool registration and retrieval +- Batch tool retrieval +- Tool validation +- Listing tools and environments +- Error cases and edge conditions +""" + +import pytest + +from verifiers.utils.tool_registry import ( + clear_registry, + get_tool, + get_tools, + list_tools, + list_environments, + register_tool, + validate_tools, +) + + +@pytest.fixture(autouse=True) +def clear_registry_before_and_after_test(): + """Clear the registry before and after each test.""" + clear_registry() + yield + clear_registry() + + +def test_registration_and_retrieval(): + """Test registering a tool and retrieving it.""" + # Register a test tool + @register_tool("test-env", "test_tool") + async def test_tool(x: int) -> int: + return x + 1 + + # Retrieve the tool + retrieved = get_tool("test-env", "test_tool") + + # Verify it's the same function + assert retrieved == test_tool + assert retrieved.__name__ == "test_tool" + + +def test_batch_retrieval(): + """Test retrieving multiple tools at once.""" + # Register multiple tools + @register_tool("test-env", "tool_a") + async def tool_a(x: int) -> int: + return x + 1 + + @register_tool("test-env", "tool_b") + async def tool_b(x: str) -> str: + return x + "suffix" + + # Retrieve both tools + tools = get_tools("test-env", ["tool_a", "tool_b"]) + + # Verify both were retrieved + assert len(tools) == 2 + assert tool_a in tools + assert tool_b in tools + + +def test_validation(): + """Test tool validation.""" + # Register tools + @register_tool("test-env", "tool_a") + async def tool_a(x: int) -> int: + return x + 1 + + @register_tool("test-env", "tool_b") + async def tool_b(x: str) -> str: + return x + "suffix" + + # Valid tools should pass validation + validate_tools("test-env", ["tool_a", "tool_b"]) # Should not raise + + # Invalid tool should raise ValueError + with pytest.raises(ValueError, match="Unregistered tools"): + validate_tools("test-env", ["tool_a", "nonexistent_tool"]) + + +def test_clear_registry(): + """Test clearing the registry.""" + # Register a tool + @register_tool("test-env", "test_tool") + async def test_tool(x: int) -> int: + return x + 1 + + # Verify it's registered + assert get_tool("test-env", "test_tool") == test_tool + + # Clear registry + clear_registry() + + # Verify it's gone + with pytest.raises(KeyError): + get_tool("test-env", "test_tool") + + +def test_multiple_environments(): + """Test that tools from different environments don't interfere.""" + # Register tools in different environments + @register_tool("env-a", "shared_name") + async def env_a_tool(x: int) -> int: + return x + 1 + + @register_tool("env-b", "shared_name") + async def env_b_tool(x: int) -> int: + return x + 2 + + # Retrieve from each environment + tool_from_a = get_tool("env-a", "shared_name") + tool_from_b = get_tool("env-b", "shared_name") + + # Verify they're different functions + assert tool_from_a == env_a_tool + assert tool_from_b == env_b_tool + assert tool_from_a != tool_from_b + assert tool_from_a.__name__ == "shared_name" + assert tool_from_b.__name__ == "shared_name" + + +def test_list_tools(): + """Test listing all tools in an environment.""" + # Register tools + @register_tool("test-env", "tool_a") + async def tool_a(x: int) -> int: + return x + 1 + + @register_tool("test-env", "tool_b") + async def tool_b(x: str) -> str: + return x + "suffix" + + # List tools + tools = list_tools("test-env") + + # Verify both tools are listed + assert len(tools) == 2 + assert "tool_a" in tools + assert "tool_b" in tools + + +def test_list_environments(): + """Test listing all environments with registered tools.""" + # Register tools in different environments + @register_tool("env-a", "tool_a") + async def tool_a(x: int) -> int: + return x + 1 + + @register_tool("env-b", "tool_b") + async def tool_b(x: str) -> str: + return x + "suffix" + + # List environments + envs = list_environments() + + # Verify both environments are listed + assert len(envs) == 2 + assert "env-a" in envs + assert "env-b" in envs + + +def test_error_get_nonexistent_tool(): + """Test error when retrieving a nonexistent tool.""" + with pytest.raises(KeyError, match="not found"): + get_tool("test-env", "nonexistent_tool") + + +def test_error_get_tools_partial_match(): + """Test error when some tools don't exist.""" + # Register only one tool + @register_tool("test-env", "tool_a") + async def tool_a(x: int) -> int: + return x + 1 + + # Try to retrieve multiple tools where one doesn't exist + with pytest.raises(KeyError, match="not found"): + get_tools("test-env", ["tool_a", "nonexistent_tool"]) diff --git a/verifiers/utils/env_utils.py b/verifiers/utils/env_utils.py index b0c4d284b..a20d30ade 100644 --- a/verifiers/utils/env_utils.py +++ b/verifiers/utils/env_utils.py @@ -1,15 +1,80 @@ import importlib import inspect import logging -from typing import Callable +from typing import Callable, cast from verifiers.envs.environment import Environment -def load_environment(env_id: str, **env_args) -> Environment: +def load_environment( + env_id: str, + tools: list[Callable] | list[str] | None = None, + **env_args, +) -> Environment: logger = logging.getLogger("verifiers.utils.env_utils") logger.info(f"Loading environment: {env_id}") + # Phase 1: Pre-process tools parameter + tool_names_to_resolve = None + + if tools is not None: + # Check if tools is actually a list (not a string or other type) + if not isinstance(tools, list): + raise TypeError( + f"tools must be a list, got {type(tools).__name__}. " + f"If passing tool names, use tools=['tool_name'] not tools='tool_name'" + ) + + if tools: + # Check for mixed types (both Callable and str in same list) + first_is_callable = callable(tools[0]) + first_is_str = isinstance(tools[0], str) + + if not first_is_callable and not first_is_str: + raise TypeError( + f"tools must be list of Callable or list of str, " + f"got list containing {type(tools[0]).__name__}" + ) + + # Verify all tools are same type AND are valid types (callable or str) + for i, tool in enumerate(tools): + is_callable = callable(tool) + is_str = isinstance(tool, str) + + # Check if element is neither callable nor string + if not is_callable and not is_str: + raise TypeError( + f"tools list elements must be Callable or str, " + f"got {type(tool).__name__} at index {i}" + ) + + # Check for mixed types + if is_callable and first_is_str: + raise TypeError( + f"tools must be all Callable or all str, got mixed types " + f"(tool[0] is {type(tools[0]).__name__}, tool[{i}] is {type(tool).__name__})" + ) + if is_str and first_is_callable: + raise TypeError( + f"tools must be all Callable or all str, got mixed types " + f"(tool[0] is {type(tools[0]).__name__}, tool[{i}] is {type(tool).__name__})" + ) + + if first_is_str: + # String list: store for later resolution AFTER module import + tool_names_to_resolve = cast(list[str], tools) + logger.info(f"Will resolve tools after import: {tool_names_to_resolve}") + else: + # Callable list: pass through immediately + env_args["tools"] = tools + logger.info(f"Using callable tools directly: {len(tools)} tools") + else: + # Empty list: explicitly set to no tools + env_args["tools"] = [] + logger.info("Using empty tool list") + + # Phase 2: Import environment module FIRST + # This triggers @register_tool decorators, populating the registry module_name = env_id.replace("-", "_").split("/")[-1] try: module = importlib.import_module(module_name) @@ -24,6 +89,20 @@ def load_environment(env_id: str, **env_args) -> Environment: ) env_load_func: Callable[..., Environment] = getattr(module, "load_environment") + + # Phase 3: NOW resolve string tools (registry is populated!) + if tool_names_to_resolve is not None: + from verifiers.utils.tool_registry import get_tools + + logger.info(f"Resolving tools from registry: {tool_names_to_resolve}") + try: + tools = get_tools(env_id, tool_names_to_resolve) + env_args["tools"] = tools + logger.info(f"Successfully resolved {len(tools)} tools") + except KeyError: + # Re-raise KeyError to preserve original error type + # The error message from get_tools() is already descriptive + raise sig = inspect.signature(env_load_func) defaults_info = [] for param_name, param in sig.parameters.items(): @@ -81,6 +160,10 @@ def load_environment(env_id: str, **env_args) -> Environment: raise ValueError( f"Could not import '{env_id}' environment. Ensure the package for the '{env_id}' environment is installed." ) from e + except KeyError: + # KeyError from tool resolution should propagate as-is + # The error message from get_tools() is already descriptive + raise except Exception as e: logger.error( f"Failed to load environment {env_id} with args {env_args}: {str(e)}" diff --git a/verifiers/utils/tool_registry.py b/verifiers/utils/tool_registry.py new file mode 100644 index 000000000..bfc9e269b --- /dev/null +++ b/verifiers/utils/tool_registry.py @@ -0,0 +1,234 @@ +""" +Tool Registry for Environment-Specific Tool Management + +This module provides a centralized registry system for managing tools across different +environments. Tools are registered by environment ID and can be retrieved by name. + +Core Functions: + register_tool: Decorator to register tools in the registry + get_tool: Retrieve a single tool by environment ID and tool name + get_tools: Retrieve multiple tools by environment ID and tool names + validate_tools: Validate that all specified tools are registered + list_tools: List all available tools for an environment + +Example: + @register_tool("tool-test", "tool_A") + async def tool_A(x: int) -> int: + return x + 1 + + tools = get_tools("tool-test", ["tool_A", "tool_B"]) +""" + +import logging +import threading +from collections import defaultdict +from typing import Callable + +# Global tool registry: {env_id: {tool_name: tool_function}} +_tool_registry: dict[str, dict[str, Callable]] = defaultdict(dict) + +# Thread-safe lock for registry operations +_registry_lock = threading.RLock() + +logger = logging.getLogger("verifiers.utils.tool_registry") + + +def register_tool(env_id: str, tool_name: str): + """ + Decorator to register a tool function in the global registry. + + Args: + env_id: Environment identifier (e.g., "tool-test", "wiki-search") + tool_name: Name to register the tool under (typically function.__name__) + + Returns: + Decorator function that registers the tool and returns it unchanged + + Example: + @register_tool("tool-test", "my_tool") + async def my_tool(x: int) -> int: + return x + 1 + """ + def decorator(tool_func: Callable) -> Callable: + with _registry_lock: + _tool_registry[env_id][tool_name] = tool_func + logger.debug( + f"Registered tool '{tool_name}' for environment '{env_id}': " + f"{tool_func.__module__}.{tool_func.__name__}" + ) + return tool_func + + return decorator + + +def get_tool(env_id: str, tool_name: str) -> Callable: + """ + Retrieve a single tool function from the registry. + + Args: + env_id: Environment identifier + tool_name: Name of the tool to retrieve + + Returns: + The tool function + + Raises: + KeyError: If environment or tool not found + + Example: + tool = get_tool("tool-test", "tool_A") + """ + with _registry_lock: + if env_id not in _tool_registry: + available_envs = sorted(_tool_registry.keys()) + raise KeyError( + f"Environment '{env_id}' not found in tool registry. " + f"Available environments: {available_envs}" + ) + + if tool_name not in _tool_registry[env_id]: + available_tools = sorted(_tool_registry[env_id].keys()) + raise KeyError( + f"Tool '{tool_name}' not found for environment '{env_id}'. " + f"Available tools: {available_tools}" + ) + + return _tool_registry[env_id][tool_name] + + +def get_tools(env_id: str, tool_names: list[str]) -> list[Callable]: + """ + Retrieve multiple tool functions from the registry. + + Args: + env_id: Environment identifier + tool_names: List of tool names to retrieve + + Returns: + List of tool functions in the same order as tool_names + + Raises: + KeyError: If any tool not found (includes list of available tools) + + Example: + tools = get_tools("tool-test", ["tool_A", "tool_B", "tool_C"]) + """ + with _registry_lock: + if env_id not in _tool_registry: + available_envs = sorted(_tool_registry.keys()) + raise KeyError( + f"Environment '{env_id}' not found in tool registry. " + f"Available environments: {available_envs}" + ) + + tools = [] + missing_tools = [] + + for tool_name in tool_names: + if tool_name not in _tool_registry[env_id]: + missing_tools.append(tool_name) + else: + tools.append(_tool_registry[env_id][tool_name]) + + if missing_tools: + available_tools = sorted(_tool_registry[env_id].keys()) + raise KeyError( + f"Tools {missing_tools} not found for environment '{env_id}'. " + f"Available tools: {available_tools}" + ) + + return tools + + +def validate_tools(env_id: str, tool_names: list[str]) -> None: + """ + Validate that all specified tools are registered for the environment. + + Args: + env_id: Environment identifier + tool_names: List of tool names to validate + + Raises: + ValueError: If any tool is not registered (includes helpful context) + + Example: + try: + validate_tools("tool-test", ["tool_A", "tool_B"]) + except ValueError as e: + print(f"Invalid tools: {e}") + """ + with _registry_lock: + if env_id not in _tool_registry: + available_envs = sorted(_tool_registry.keys()) + raise ValueError( + f"Environment '{env_id}' not found in tool registry. " + f"Available environments: {available_envs}" + ) + + missing_tools = [ + tool_name for tool_name in tool_names if tool_name not in _tool_registry[env_id] + ] + + if missing_tools: + available_tools = sorted(_tool_registry[env_id].keys()) + raise ValueError( + f"Unregistered tools found: {missing_tools}. " + f"Available tools for '{env_id}': {available_tools}" + ) + + +def list_tools(env_id: str) -> list[str]: + """ + List all available tool names for a specific environment. + + Args: + env_id: Environment identifier + + Returns: + Sorted list of tool names registered for this environment + + Raises: + KeyError: If environment not found in registry + + Example: + tools = list_tools("tool-test") + # Returns: ["tool_A", "tool_B", "tool_C", "tool_D"] + """ + with _registry_lock: + if env_id not in _tool_registry: + available_envs = sorted(_tool_registry.keys()) + raise KeyError( + f"Environment '{env_id}' not found in tool registry. " + f"Available environments: {available_envs}" + ) + + return sorted(_tool_registry[env_id].keys()) + + +def list_environments() -> list[str]: + """ + List all environment IDs that have registered tools. + + Returns: + Sorted list of environment IDs + + Example: + envs = list_environments() + # Returns: ["tool-test", "wiki-search", ...] + """ + with _registry_lock: + return sorted(_tool_registry.keys()) + + +def clear_registry() -> None: + """ + Clear all tools from the registry. + + This is primarily useful for testing to ensure a clean state between tests. + + Example: + clear_registry() + """ + with _registry_lock: + _tool_registry.clear() + logger.debug("Tool registry cleared")