Skip to content

Commit

Permalink
tools: prioritizes explicitly passed parameters over automatically ge…
Browse files Browse the repository at this point in the history
…nerated ones

Key changes:
1. We now handle custom parameters first, before any description
processing
2. Added a safety check `if param.name not in
parameters.get("properties", {})` to ensure we only try to add
descriptions for parameters that exist in the schema3. The rest of the
function (return value handling, description length check, etc.) remains
unchanged. This way:
- Custom parameters are respected
- Parameter descriptions from annotations are still added when available
- The function remains backward compatible with all existing tests
- We avoid trying to add descriptions to parameters that don't exist in
the schema

Signed-off-by: Teo <teocns@gmail.com>
  • Loading branch information
teocns committed Dec 7, 2024
1 parent f259fa8 commit 73f1e18
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 13 deletions.
24 changes: 16 additions & 8 deletions src/controlflow/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,27 @@ def from_function(
):
name = name or fn.__name__
description = description or fn.__doc__ or ""

signature = inspect.signature(fn)
try:
parameters = TypeAdapter(fn).json_schema()
except PydanticSchemaGenerationError:
raise ValueError(
f'Could not generate a schema for tool "{name}". '
"Tool functions must have type hints that are compatible with Pydantic."
)

# If parameters are provided in kwargs, use those instead of generating them
if "parameters" in kwargs:
parameters = kwargs.pop("parameters") # Custom parameters are respected
else:
try:
parameters = TypeAdapter(fn).json_schema()
except PydanticSchemaGenerationError:
raise ValueError(
f'Could not generate a schema for tool "{name}". '
"Tool functions must have type hints that are compatible with Pydantic."
)

# load parameter descriptions
if include_param_descriptions:
for param in signature.parameters.values():
# ensure we only try to add descriptions for parameters that exist in the schema
if param.name not in parameters.get("properties", {}):
continue

# handle Annotated type hints
if typing.get_origin(param.annotation) is Annotated:
param_description = " ".join(
Expand Down
77 changes: 72 additions & 5 deletions tests/tools/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
import controlflow
from controlflow.agents.agent import Agent
from controlflow.llm.messages import ToolMessage
from controlflow.tools.tools import (
Tool,
handle_tool_call,
tool,
)
from controlflow.tools.tools import Tool, handle_tool_call, tool


@pytest.mark.parametrize("style", ["decorator", "class"])
Expand Down Expand Up @@ -170,6 +166,77 @@ def add(a: int, b: float) -> float:
elif style == "decorator":
tool(add)

def test_custom_parameters(self, style):
"""Test that custom parameters override generated ones."""

def add(a: int, b: float):
return a + b

custom_params = {
"type": "object",
"properties": {
"x": {"type": "number", "description": "Custom parameter"},
"y": {"type": "string"},
},
"required": ["x"],
}

if style == "class":
tool_obj = Tool.from_function(add, parameters=custom_params)
elif style == "decorator":
tool_obj = tool(add, parameters=custom_params)

assert tool_obj.parameters == custom_params
assert "a" not in tool_obj.parameters["properties"]
assert "b" not in tool_obj.parameters["properties"]
assert (
tool_obj.parameters["properties"]["x"]["description"] == "Custom parameter"
)

def test_custom_parameters_with_annotations(self, style):
"""Test that annotations still work with custom parameters if param names match."""

def process(x: Annotated[float, "The x value"], y: str):
return x

custom_params = {
"type": "object",
"properties": {"x": {"type": "number"}, "y": {"type": "string"}},
"required": ["x"],
}

if style == "class":
tool_obj = Tool.from_function(process, parameters=custom_params)
elif style == "decorator":
tool_obj = tool(process, parameters=custom_params)

assert tool_obj.parameters["properties"]["x"]["description"] == "The x value"
assert "description" not in tool_obj.parameters["properties"]["y"]

def test_custom_parameters_ignore_descriptions(self, style):
"""Test that include_param_descriptions=False works with custom parameters."""

def process(x: Annotated[float, "The x value"], y: str):
return x

custom_params = {
"type": "object",
"properties": {"x": {"type": "number"}, "y": {"type": "string"}},
"required": ["x"],
}

if style == "class":
tool_obj = Tool.from_function(
process, parameters=custom_params, include_param_descriptions=False
)
elif style == "decorator":
tool_obj = tool(
process, parameters=custom_params, include_param_descriptions=False
)

assert "description" not in tool_obj.parameters["properties"]["x"]
assert "description" not in tool_obj.parameters["properties"]["y"]


class TestToolFunctions:
def test_non_serializable_return_value(self):
Expand Down

0 comments on commit 73f1e18

Please sign in to comment.