Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions environments/tool_test/tool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from datasets import Dataset

import verifiers as vf
from verifiers.utils.tool_registry import register_tool


# dummy tools for sanity checking parallel tool calls
@register_tool("tool-test", "tool_A")
async def tool_A(x: int) -> int:
"""
Tool for adding 1 to an integer.
Expand All @@ -19,6 +21,7 @@ async def tool_A(x: int) -> int:
return x + 1


@register_tool("tool-test", "tool_B")
async def tool_B(x: str) -> str:
"""
Tool for concatenating a string with "2".
Expand All @@ -32,6 +35,7 @@ async def tool_B(x: str) -> str:
return x + "2"


@register_tool("tool-test", "tool_C")
async def tool_C(x: float) -> float:
"""
Tool for adding 3.0 to a float.
Expand All @@ -45,6 +49,7 @@ async def tool_C(x: float) -> float:
return x + 3.0


@register_tool("tool-test", "tool_D")
async def tool_D(x: bool) -> bool:
"""
Tool for negating a boolean.
Expand All @@ -58,8 +63,7 @@ async def tool_D(x: bool) -> bool:
return not x


tool_list = [tool_A, tool_B, tool_C, tool_D]
tool_name_list = [tool.__name__ for tool in tool_list]
DEFAULT_TOOL_LIST = [tool_A, tool_B, tool_C, tool_D]


def tool_call_reward_func(completion, info):
Expand All @@ -76,17 +80,42 @@ def tool_call_reward_func(completion, info):


def load_environment(
num_train_examples: int = 1000, num_eval_examples: int = 100
num_train_examples: int = 1000,
num_eval_examples: int = 100,
tools: list | None = None,
) -> vf.ToolEnv:
"""
Loads tool-test environment.
"""

# Use provided tools or fall back to default
if tools is None:
tools = DEFAULT_TOOL_LIST

# Extract tool names from ACTUAL tools being used (not hardcoded list)
actual_tool_names = [tool.__name__ for tool in tools]

# Handle empty tools case
if not actual_tool_names:
# Create empty datasets when no tools available
dataset = Dataset.from_list([])
eval_dataset = Dataset.from_list([])
rubric = vf.Rubric(funcs=[tool_call_reward_func])
vf_env = vf.ToolEnv(
dataset=dataset,
eval_dataset=eval_dataset,
rubric=rubric,
tools=tools,
max_turns=1,
)
return vf_env

train_rows = []
eval_rows = []
for i in range(num_train_examples + num_eval_examples):
# Sample from actual available tools only
tool_names = random.sample(
tool_name_list, random.randint(1, len(tool_name_list))
actual_tool_names, random.randint(1, len(actual_tool_names))
)
prompt = [
{
Expand All @@ -107,7 +136,7 @@ def load_environment(
dataset=dataset,
eval_dataset=eval_dataset,
rubric=rubric,
tools=tool_list,
tools=tools,
max_turns=1,
)
return vf_env
33 changes: 30 additions & 3 deletions packages/verifiers-rl/verifiers_rl/scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import logging
from pathlib import Path

try:
Expand All @@ -9,6 +10,8 @@
import verifiers as vf
from verifiers_rl.rl.trainer import RLConfig, RLTrainer

logger = logging.getLogger("verifiers_rl.scripts.train")


def main() -> None:
parser = argparse.ArgumentParser()
Expand All @@ -27,9 +30,33 @@ def main() -> None:
config = tomllib.load(f)

model = config["model"]
env_id = config["env"]["id"]
env_args = config["env"].get("args", {})
env = vf.load_environment(env_id=env_id, **env_args)

# Handle both [[env]] array syntax (configs/rl/*.toml) and [env] dict syntax (configs/local/vf-rl/*.toml)
env_section = config["env"]
if isinstance(env_section, list):
# [[env]] array - use first environment
env_config = env_section[0]
if len(env_section) > 1:
logger.warning(f"Multiple environments in config, using first: {env_config['id']}")
else:
# [env] dict - single environment
env_config = env_section

env_id = env_config["id"]
env_args = env_config.get("args", {})

# Extract tools from config (will be resolved by env_utils.py after environment import)
if "tools" in env_config:
tool_names = env_config["tools"]
if not isinstance(tool_names, list):
raise ValueError(
f"env.tools must be list of tool names, got {type(tool_names).__name__}"
)
tools = tool_names # Pass as-is, let env_utils.py resolve after import
else:
tools = None

env = vf.load_environment(env_id=env_id, tools=tools, **env_args)
rl_config = RLConfig(**config["trainer"].get("args", {}))
trainer = RLTrainer(model=model, env=env, args=rl_config)
trainer.train()
Expand Down
183 changes: 183 additions & 0 deletions tests/test_env_utils_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""
Integration tests for env_utils tool resolution functionality

Tests cover:
- Loading environment with string tools (resolved via registry)
- Loading environment with callable tools (passed through)
- Loading environment with no tools (backward compatibility)
- Error handling for mixed tool types
"""

import pytest

from verifiers.utils.env_utils import load_environment
from verifiers.utils.tool_registry import register_tool


@pytest.fixture(autouse=True)
def clear_registry():
"""Clear the registry before and after each test."""
from verifiers.utils import tool_registry

# Clear before test
tool_registry._tool_registry.clear()
yield
# Clear after test
tool_registry._tool_registry.clear()


def test_load_environment_with_string_tools(clear_registry):
"""Test loading environment with string tool names (registry resolution)."""

# Register test tools
@register_tool("tool-test", "test_tool_a")
async def test_tool_a(x: int) -> int:
return x + 1

@register_tool("tool-test", "test_tool_b")
async def test_tool_b(x: str) -> str:
return x + "suffix"

# Load environment with string tools
env = load_environment("tool-test", tools=["test_tool_a", "test_tool_b"])

# Verify tools were resolved and attached
assert hasattr(env, "tools")
assert len(env.tools) == 2
assert test_tool_a in env.tools
assert test_tool_b in env.tools


def test_load_environment_with_callable_tools(clear_registry):
"""Test loading environment with callable tools (direct pass-through)."""

# Define test tools
async def direct_tool_a(x: int) -> int:
return x + 1

async def direct_tool_b(x: str) -> str:
return x + "suffix"

# Load environment with callable tools
env = load_environment(
"tool-test", tools=[direct_tool_a, direct_tool_b]
)

# Verify tools were passed through
assert hasattr(env, "tools")
assert len(env.tools) == 2
assert direct_tool_a in env.tools
assert direct_tool_b in env.tools


def test_load_environment_no_tools(clear_registry):
"""Test loading environment without tools parameter (backward compatibility)."""

# Load environment without tools parameter
env = load_environment("tool-test")

# Verify environment loaded with default tools
assert hasattr(env, "tools")
# tool-test environment has 4 default tools
assert len(env.tools) == 4


def test_load_environment_empty_tool_list(clear_registry):
"""Test loading environment with empty tool list."""

# Load environment with empty tools list
env = load_environment("tool-test", tools=[])

# Verify environment has no tools
assert hasattr(env, "tools")
assert len(env.tools) == 0


def test_mixed_tool_types_error(clear_registry):
"""Test that mixing Callable and str tools raises TypeError."""

# Define a callable tool
async def my_tool(x: int) -> int:
return x + 1

# Register a string tool
@register_tool("tool-test", "registered_tool")
async def registered_tool() -> str:
return "registered"

# Attempt to load with mixed types - should raise TypeError
with pytest.raises(TypeError, match="tools must be all Callable or all str"):
load_environment("tool-test", tools=[my_tool, "registered_tool"])

with pytest.raises(TypeError, match="tools must be all Callable or all str"):
load_environment("tool-test", tools=["registered_tool", my_tool])


def test_invalid_tool_name_in_registry(clear_registry):
"""Test that unregistered tool name raises KeyError from registry."""

# Register one tool so environment exists in registry
@register_tool("tool-test", "valid_tool")
async def valid_tool(x: int) -> int:
return x + 1

# Try to load with a different, unregistered tool name
with pytest.raises(KeyError, match=r"Tools \['nonexistent_tool'\] not found"):
load_environment("tool-test", tools=["valid_tool", "nonexistent_tool"])


def test_invalid_tool_type_error(clear_registry):
"""Test that invalid tool type raises TypeError."""

# Load with invalid tool type (int, not Callable or str)
with pytest.raises(TypeError, match="tools must be list of Callable or list of str"):
load_environment("tool-test", tools=[123, 456])


def test_environment_with_other_args(clear_registry):
"""Test that tools parameter works alongside other environment arguments."""

# Register a tool
@register_tool("tool-test", "custom_tool")
async def custom_tool() -> str:
return "custom"

# Load environment with tools and other args
env = load_environment(
"tool-test",
tools=["custom_tool"],
num_train_examples=50,
num_eval_examples=10,
)

# Verify both tools and other args were applied
assert hasattr(env, "tools")
assert len(env.tools) >= 1
assert custom_tool in env.tools
# num_train_examples should affect dataset size
# (actual value depends on tool-test env implementation)


def test_single_string_tool(clear_registry):
"""Test loading environment with single string tool."""

@register_tool("tool-test", "single_tool")
async def single_tool(x: int) -> int:
return x * 2

env = load_environment("tool-test", tools=["single_tool"])

assert hasattr(env, "tools")
assert single_tool in env.tools


def test_single_callable_tool(clear_registry):
"""Test loading environment with single callable tool."""

async def my_tool() -> str:
return "result"

env = load_environment("tool-test", tools=[my_tool])

assert hasattr(env, "tools")
assert my_tool in env.tools
Loading