Skip to content

Commit

Permalink
auto tool_choice
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Jan 13, 2025
1 parent 02d459b commit d5f07be
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
4 changes: 3 additions & 1 deletion llama_demos/llama_demos/chatllama_tools_demo_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def send_prompt(self) -> None:
]

self.get_logger().info(f"\nPrompt: {messages[0].content}")
llm_tools = self.chat.bind_tools([get_inhabitants, get_curr_temperature])
llm_tools = self.chat.bind_tools(
[get_inhabitants, get_curr_temperature], tool_choice="any"
)

self.initial_time = time.time()
all_tools_res = llm_tools.invoke(messages)
Expand Down
11 changes: 6 additions & 5 deletions llama_ros/llama_ros/langchain/chat_llama_ros.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@

DEFAULT_TEMPLATE = """{% if tools_grammar %}
{{- '<|im_start|>assistant\n' }}
{{- 'You are an assistant. Output in JSON format. The key "tool_calls" is a list of tools in the format: {name, arguments}. Available tools are:' }}
{{- 'You are an assistant. Output in JSON format. The key "tool_calls" is a list of tools in the format {name, arguments}. Available tools are:' }}
{% for tool in tools_grammar %}
{% if not loop.last %}
{{- tool }}
Expand Down Expand Up @@ -401,9 +401,8 @@ def bind_tools(
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["all", "one", "any"], bool]
] = "any",
only_tool_calling: bool = True,
Union[dict, str, Literal["auto", "all", "one", "any"], bool]
] = "auto",
method: Literal[
"function_calling", "json_schema", "json_mode"
] = "function_calling",
Expand All @@ -412,7 +411,7 @@ def bind_tools(

formatted_tools = [convert_to_openai_tool(tool)["function"] for tool in tools]
tool_names = [ft["name"] for ft in formatted_tools]
valid_choices = ["all", "one", "any"]
valid_choices = ["auto", "all", "one", "any"]

is_valid_choice = tool_choice in valid_choices
chosen_tool = [f for f in formatted_tools if f["name"] == tool_choice]
Expand Down Expand Up @@ -471,6 +470,8 @@ def bind_tools(
]

else:
only_tool_calling = tool_choice != "auto"
tool_choice = "any" if not only_tool_calling else tool_choice
tool_calls["properties"]["tool_calls"]["items"][f"{tool_choice}Of"] = []

for tool in formatted_tools:
Expand Down

0 comments on commit d5f07be

Please sign in to comment.