Skip to content

Commit

Permalink
Update function calling parsing
Browse files Browse the repository at this point in the history
Make binging configurable
  • Loading branch information
saridormi committed Apr 8, 2024
1 parent 1f328da commit aa47b2d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
3 changes: 2 additions & 1 deletion planning_library/components/agent_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def create_agent(
)
parser = ParserRegistry.get_parser(parser_name)

llm_with_tools = llm.bind(tools=[parser.prepare_tool(tool) for tool in tools])
llm_with_tools = parser.prepare_llm(llm=llm, tools=tools)

runnable: Runnable = (
RunnableLambda(parser.format_inputs)
| prompt
Expand Down
10 changes: 7 additions & 3 deletions planning_library/function_calling_parsers/base_parser.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod
from typing import Any, List, Tuple
from typing import List, Tuple, Sequence
from langchain_core.language_models import BaseChatModel
from langchain.agents.agent import AgentOutputParser, MultiActionAgentOutputParser
from langchain_core.agents import AgentAction
from langchain_core.tools import BaseTool
from langchain_core.runnables import Runnable
from langchain_core.messages import BaseMessage
from typing_extensions import TypedDict

Expand All @@ -19,10 +21,12 @@ class BaseFunctionCallingParser(ABC):
name: str

@abstractmethod
def format_inputs(self, inputs: AgentInputs) -> ProcessedAgentInputs: ...
def prepare_llm(
self, llm: BaseChatModel, tools: Sequence[BaseTool]
) -> Runnable: ...

@abstractmethod
def prepare_tool(self, tool: BaseTool) -> Any: ...
def format_inputs(self, inputs: AgentInputs) -> ProcessedAgentInputs: ...


class BaseFunctionCallingSingleActionParser(BaseFunctionCallingParser, ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, List, Tuple
from typing import List, Tuple, Sequence
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain_core.messages import BaseMessage
from langchain_core.agents import AgentAction
from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function

Expand Down Expand Up @@ -36,5 +38,5 @@ def format_inputs(
),
}

def prepare_tool(self, tool: BaseTool) -> Any:
return convert_to_openai_function(tool)
def prepare_llm(self, llm: BaseChatModel, tools: Sequence[BaseTool]) -> Runnable:
return llm.bind(functions=[convert_to_openai_function(tool) for tool in tools])
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Tuple
from typing import List, Tuple, Sequence
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
Expand All @@ -7,7 +7,8 @@
from langchain_core.agents import AgentAction
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool

from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import Runnable
from planning_library.function_calling_parsers.base_parser import (
BaseFunctionCallingMultiActionParser,
AgentInputs,
Expand Down Expand Up @@ -38,5 +39,5 @@ def format_inputs(
),
}

def prepare_tool(self, tool: BaseTool) -> Any:
return convert_to_openai_tool(tool)
def prepare_llm(self, llm: BaseChatModel, tools: Sequence[BaseTool]) -> Runnable:
return llm.bind(tools=[convert_to_openai_tool(tool) for tool in tools])

0 comments on commit aa47b2d

Please sign in to comment.