diff --git a/src/controlflow/tools/tools.py b/src/controlflow/tools/tools.py index cce08fc4..ade8069d 100644 --- a/src/controlflow/tools/tools.py +++ b/src/controlflow/tools/tools.py @@ -118,19 +118,27 @@ def from_function( ): name = name or fn.__name__ description = description or fn.__doc__ or "" - signature = inspect.signature(fn) - try: - parameters = TypeAdapter(fn).json_schema() - except PydanticSchemaGenerationError: - raise ValueError( - f'Could not generate a schema for tool "{name}". ' - "Tool functions must have type hints that are compatible with Pydantic." - ) + + # If parameters are provided in kwargs, use those instead of generating them + if "parameters" in kwargs: + parameters = kwargs.pop("parameters") # Custom parameters are respected + else: + try: + parameters = TypeAdapter(fn).json_schema() + except PydanticSchemaGenerationError: + raise ValueError( + f'Could not generate a schema for tool "{name}". ' + "Tool functions must have type hints that are compatible with Pydantic." + ) # load parameter descriptions if include_param_descriptions: for param in signature.parameters.values(): + # ensure we only try to add descriptions for parameters that exist in the schema + if param.name not in parameters.get("properties", {}): + continue + # handle Annotated type hints if typing.get_origin(param.annotation) is Annotated: param_description = " ".join( diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index fc5efb38..31ca236c 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -7,11 +7,7 @@ import controlflow from controlflow.agents.agent import Agent from controlflow.llm.messages import ToolMessage -from controlflow.tools.tools import ( - Tool, - handle_tool_call, - tool, -) +from controlflow.tools.tools import Tool, handle_tool_call, tool @pytest.mark.parametrize("style", ["decorator", "class"]) @@ -170,6 +166,77 @@ def add(a: int, b: float) -> float: elif style == "decorator": tool(add) + def test_custom_parameters(self, style): + """Test that custom parameters override generated ones.""" + + def add(a: int, b: float): + return a + b + + custom_params = { + "type": "object", + "properties": { + "x": {"type": "number", "description": "Custom parameter"}, + "y": {"type": "string"}, + }, + "required": ["x"], + } + + if style == "class": + tool_obj = Tool.from_function(add, parameters=custom_params) + elif style == "decorator": + tool_obj = tool(add, parameters=custom_params) + + assert tool_obj.parameters == custom_params + assert "a" not in tool_obj.parameters["properties"] + assert "b" not in tool_obj.parameters["properties"] + assert ( + tool_obj.parameters["properties"]["x"]["description"] == "Custom parameter" + ) + + def test_custom_parameters_with_annotations(self, style): + """Test that annotations still work with custom parameters if param names match.""" + + def process(x: Annotated[float, "The x value"], y: str): + return x + + custom_params = { + "type": "object", + "properties": {"x": {"type": "number"}, "y": {"type": "string"}}, + "required": ["x"], + } + + if style == "class": + tool_obj = Tool.from_function(process, parameters=custom_params) + elif style == "decorator": + tool_obj = tool(process, parameters=custom_params) + + assert tool_obj.parameters["properties"]["x"]["description"] == "The x value" + assert "description" not in tool_obj.parameters["properties"]["y"] + + def test_custom_parameters_ignore_descriptions(self, style): + """Test that include_param_descriptions=False works with custom parameters.""" + + def process(x: Annotated[float, "The x value"], y: str): + return x + + custom_params = { + "type": "object", + "properties": {"x": {"type": "number"}, "y": {"type": "string"}}, + "required": ["x"], + } + + if style == "class": + tool_obj = Tool.from_function( + process, parameters=custom_params, include_param_descriptions=False + ) + elif style == "decorator": + tool_obj = tool( + process, parameters=custom_params, include_param_descriptions=False + ) + + assert "description" not in tool_obj.parameters["properties"]["x"] + assert "description" not in tool_obj.parameters["properties"]["y"] + class TestToolFunctions: def test_non_serializable_return_value(self):