Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Tool.from_component #159

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions haystack_experimental/dataclasses/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,24 @@

import inspect
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, get_args, get_origin

from haystack import logging
from haystack.core.component import Component
from haystack.lazy_imports import LazyImport
from haystack.utils import deserialize_callable, serialize_callable
from pydantic import create_model
from pydantic import TypeAdapter, create_model

from haystack_experimental.tools.component_schema import _create_tool_parameters_schema

with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import:
from jsonschema import Draft202012Validator
from jsonschema.exceptions import SchemaError


logger = logging.getLogger(__name__)


class ToolInvocationError(Exception):
"""
Exception raised when a Tool invocation fails.
Expand Down Expand Up @@ -198,6 +205,59 @@ def get_weather(

return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function)

@classmethod
def from_component(cls, component: Component, name: str, description: str) -> "Tool":
"""
Create a Tool instance from a Haystack component.

:param component: The Haystack component to be converted into a Tool.
:param name: Name for the tool.
:param description: Description of the tool.
:returns: The Tool created from the Component.
:raises ValueError: If the component is invalid or schema generation fails.
"""

if not isinstance(component, Component):
message = (
f"Object {component!r} is not a Haystack component. "
"Use this method to create a Tool only with Haystack component instances."
)
raise ValueError(message)

# Create the tools schema from the component run method parameters
tool_schema = _create_tool_parameters_schema(component)

def component_invoker(**kwargs):
"""
Invokes the component using keyword arguments provided by the LLM function calling/tool generated response.

:param kwargs: The keyword arguments to invoke the component with.
:returns: The result of the component invocation.
"""
converted_kwargs = {}
input_sockets = component.__haystack_input__._sockets_dict
for param_name, param_value in kwargs.items():
param_type = input_sockets[param_name].type

# Check if the type (or list element type) has from_dict
target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type
if hasattr(target_type, "from_dict"):
if isinstance(param_value, list):
param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)]
elif isinstance(param_value, dict):
param_value = target_type.from_dict(param_value)
else:
# Let TypeAdapter handle both single values and lists
type_adapter = TypeAdapter(param_type)
param_value = type_adapter.validate_python(param_value)

converted_kwargs[param_name] = param_value
logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}")
return component.run(**converted_kwargs)

# Return a new Tool instance with the component invoker as the function to be called
return Tool(name=name, description=description, parameters=tool_schema, function=component_invoker)


def _remove_title_from_schema(schema: Dict[str, Any]):
"""
Expand Down
3 changes: 3 additions & 0 deletions haystack_experimental/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
184 changes: 184 additions & 0 deletions haystack_experimental/tools/component_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import fields, is_dataclass
from inspect import getdoc
from typing import Any, Callable, Dict, Union, get_args, get_origin

from docstring_parser import parse
from haystack import logging
from haystack.core.component import Component

from haystack_experimental.util.utils import is_pydantic_v2_model

logger = logging.getLogger(__name__)


def _create_tool_parameters_schema(component: Component) -> Dict[str, Any]:
"""
Creates an OpenAI tools schema from a component's run method parameters.

:param component: The component to create the schema from.
:returns: OpenAI tools schema for the component's run method parameters.
"""
properties = {}
required = []

param_descriptions = _get_param_descriptions(component.run)

for input_name, socket in component.__haystack_input__._sockets_dict.items():
input_type = socket.type
description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")

try:
property_schema = _create_property_schema(input_type, description)
except ValueError as e:
raise ValueError(f"Error processing input '{input_name}': {e}")

properties[input_name] = property_schema

# Use socket.is_mandatory to check if the input is required
if socket.is_mandatory:
required.append(input_name)

parameters_schema = {"type": "object", "properties": properties}

if required:
parameters_schema["required"] = required

return parameters_schema


def _get_param_descriptions(method: Callable) -> Dict[str, str]:
"""
Extracts parameter descriptions from the method's docstring using docstring_parser.

:param method: The method to extract parameter descriptions from.
:returns: A dictionary mapping parameter names to their descriptions.
"""
docstring = getdoc(method)
if not docstring:
return {}

parsed_doc = parse(docstring)
param_descriptions = {}
for param in parsed_doc.params:
if not param.description:
logger.warning(
"Missing description for parameter '%s'. Please add a description in the component's "
"run() method docstring using the format ':param %s: <description>'. "
"This description is used to generate the Tool and helps the LLM understand how to use this parameter.",
param.arg_name,
param.arg_name,
)
param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
return param_descriptions


def _is_nullable_type(python_type: Any) -> bool:
"""
Checks if the type is a Union with NoneType (i.e., Optional).

:param python_type: The Python type to check.
:returns: True if the type is a Union with NoneType, False otherwise.
"""
origin = get_origin(python_type)
if origin is Union:
return type(None) in get_args(python_type)
return False


def _create_list_schema(item_type: Any, description: str) -> Dict[str, Any]:
"""
Creates a schema for a list type.

:param item_type: The type of items in the list.
:param description: The description of the list.
:returns: A dictionary representing the list schema.
"""
items_schema = _create_property_schema(item_type, "")
items_schema.pop("description", None)
return {"type": "array", "description": description, "items": items_schema}


def _create_dataclass_schema(python_type: Any, description: str) -> Dict[str, Any]:
"""
Creates a schema for a dataclass.

:param python_type: The dataclass type.
:param description: The description of the dataclass.
:returns: A dictionary representing the dataclass schema.
"""
schema = {"type": "object", "description": description, "properties": {}}
cls = python_type if isinstance(python_type, type) else python_type.__class__
for field in fields(cls):
field_description = f"Field '{field.name}' of '{cls.__name__}'."
if isinstance(schema["properties"], dict):
schema["properties"][field.name] = _create_property_schema(field.type, field_description)
return schema


def _create_pydantic_schema(python_type: Any, description: str) -> Dict[str, Any]:
"""
Creates a schema for a Pydantic model.

:param python_type: The Pydantic model type.
:param description: The description of the model.
:returns: A dictionary representing the Pydantic model schema.
"""
schema = {"type": "object", "description": description, "properties": {}}
required_fields = []

for m_name, m_field in python_type.model_fields.items():
field_description = f"Field '{m_name}' of '{python_type.__name__}'."
if isinstance(schema["properties"], dict):
schema["properties"][m_name] = _create_property_schema(m_field.annotation, field_description)
if m_field.is_required():
required_fields.append(m_name)

if required_fields:
schema["required"] = required_fields
return schema


def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, Any]:
"""
Creates a schema for a basic Python type.

:param python_type: The Python type.
:param description: The description of the type.
:returns: A dictionary representing the basic type schema.
"""
type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"}
return {"type": type_mapping.get(python_type, "string"), "description": description}


def _create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]:
"""
Creates a property schema for a given Python type, recursively if necessary.

:param python_type: The Python type to create a property schema for.
:param description: The description of the property.
:param default: The default value of the property.
:returns: A dictionary representing the property schema.
"""
nullable = _is_nullable_type(python_type)
if nullable:
non_none_types = [t for t in get_args(python_type) if t is not type(None)]
python_type = non_none_types[0] if non_none_types else str

origin = get_origin(python_type)
if origin is list:
schema = _create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description)
elif is_dataclass(python_type):
schema = _create_dataclass_schema(python_type, description)
elif is_pydantic_v2_model(python_type):
schema = _create_pydantic_schema(python_type, description)
else:
schema = _create_basic_type_schema(python_type, description)

if default is not None:
schema["default"] = default

return schema
12 changes: 11 additions & 1 deletion haystack_experimental/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import List, Union
from typing import Any, List, Union


def expand_page_range(page_range: List[Union[str, int]]) -> List[int]:
Expand Down Expand Up @@ -41,3 +41,13 @@ def expand_page_range(page_range: List[Union[str, int]]) -> List[int]:
raise ValueError("No valid page numbers or ranges found in the input list")

return expanded_page_range


def is_pydantic_v2_model(instance: Any) -> bool:
"""
Checks if the instance is a Pydantic v2 model.

:param instance: The instance to check.
:returns: True if the instance is a Pydantic v2 model, False otherwise.
"""
return hasattr(instance, "model_validate")
Loading
Loading