From aa47b2d04772d0d452ec257abcc8341387234894 Mon Sep 17 00:00:00 2001 From: Alexandra Eliseeva Date: Mon, 8 Apr 2024 22:45:47 +0200 Subject: [PATCH] Update function calling parsing Make binging configurable --- planning_library/components/agent_component.py | 3 ++- .../function_calling_parsers/base_parser.py | 10 +++++++--- .../openai_functions_parser.py | 8 +++++--- .../function_calling_parsers/openai_tools_parser.py | 9 +++++---- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/planning_library/components/agent_component.py b/planning_library/components/agent_component.py index 8881155..71c727a 100644 --- a/planning_library/components/agent_component.py +++ b/planning_library/components/agent_component.py @@ -40,7 +40,8 @@ def create_agent( ) parser = ParserRegistry.get_parser(parser_name) - llm_with_tools = llm.bind(tools=[parser.prepare_tool(tool) for tool in tools]) + llm_with_tools = parser.prepare_llm(llm=llm, tools=tools) + runnable: Runnable = ( RunnableLambda(parser.format_inputs) | prompt diff --git a/planning_library/function_calling_parsers/base_parser.py b/planning_library/function_calling_parsers/base_parser.py index 52fcf38..94451ab 100644 --- a/planning_library/function_calling_parsers/base_parser.py +++ b/planning_library/function_calling_parsers/base_parser.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod -from typing import Any, List, Tuple +from typing import List, Tuple, Sequence +from langchain_core.language_models import BaseChatModel from langchain.agents.agent import AgentOutputParser, MultiActionAgentOutputParser from langchain_core.agents import AgentAction from langchain_core.tools import BaseTool +from langchain_core.runnables import Runnable from langchain_core.messages import BaseMessage from typing_extensions import TypedDict @@ -19,10 +21,12 @@ class BaseFunctionCallingParser(ABC): name: str @abstractmethod - def format_inputs(self, inputs: AgentInputs) -> ProcessedAgentInputs: ... + def prepare_llm( + self, llm: BaseChatModel, tools: Sequence[BaseTool] + ) -> Runnable: ... @abstractmethod - def prepare_tool(self, tool: BaseTool) -> Any: ... + def format_inputs(self, inputs: AgentInputs) -> ProcessedAgentInputs: ... class BaseFunctionCallingSingleActionParser(BaseFunctionCallingParser, ABC): diff --git a/planning_library/function_calling_parsers/openai_functions_parser.py b/planning_library/function_calling_parsers/openai_functions_parser.py index 69ce030..5004dbd 100644 --- a/planning_library/function_calling_parsers/openai_functions_parser.py +++ b/planning_library/function_calling_parsers/openai_functions_parser.py @@ -1,8 +1,10 @@ -from typing import Any, List, Tuple +from typing import List, Tuple, Sequence from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser from langchain.agents.format_scratchpad import format_to_openai_function_messages from langchain_core.messages import BaseMessage from langchain_core.agents import AgentAction +from langchain_core.language_models import BaseChatModel +from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_function @@ -36,5 +38,5 @@ def format_inputs( ), } - def prepare_tool(self, tool: BaseTool) -> Any: - return convert_to_openai_function(tool) + def prepare_llm(self, llm: BaseChatModel, tools: Sequence[BaseTool]) -> Runnable: + return llm.bind(functions=[convert_to_openai_function(tool) for tool in tools]) diff --git a/planning_library/function_calling_parsers/openai_tools_parser.py b/planning_library/function_calling_parsers/openai_tools_parser.py index 10c69ff..c62adcd 100644 --- a/planning_library/function_calling_parsers/openai_tools_parser.py +++ b/planning_library/function_calling_parsers/openai_tools_parser.py @@ -1,4 +1,4 @@ -from typing import Any, List, Tuple +from typing import List, Tuple, Sequence from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser from langchain.agents.format_scratchpad.openai_tools import ( format_to_openai_tool_messages, @@ -7,7 +7,8 @@ from langchain_core.agents import AgentAction from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool - +from langchain_core.language_models import BaseChatModel +from langchain_core.runnables import Runnable from planning_library.function_calling_parsers.base_parser import ( BaseFunctionCallingMultiActionParser, AgentInputs, @@ -38,5 +39,5 @@ def format_inputs( ), } - def prepare_tool(self, tool: BaseTool) -> Any: - return convert_to_openai_tool(tool) + def prepare_llm(self, llm: BaseChatModel, tools: Sequence[BaseTool]) -> Runnable: + return llm.bind(tools=[convert_to_openai_tool(tool) for tool in tools])