Skip to content

Commit

Permalink
LangChain Tools on Chat (#12)
Browse files Browse the repository at this point in the history
* ChatLlamaROS stream fix and demo updated

* ChatLlamaROS stream fix and demo updated

* Fix passing image as data

* Chat formatter from metadata + jinja2

* Langchain structured output and starting tools

* Building model grammar from bind_tools

* chat tools integration

* Adding a default template for tool calling

* ChatLlama loads tools

* Setting args for tools

* Fixing Chat Llama demo

* Black formatter

* Default chat template on code

* Black formatter

* README and Black formatter

* Black Formatter
  • Loading branch information
agonzc34 authored Jan 9, 2025
1 parent b037ad3 commit 2eab6d6
Show file tree
Hide file tree
Showing 9 changed files with 630 additions and 45 deletions.
84 changes: 82 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -734,8 +734,7 @@ rclpy.shutdown()

</details>

#### chat_llama_ros

#### chat_llama_ros (Chat + LVM)
<details>
<summary>Click to expand</summary>

Expand Down Expand Up @@ -778,6 +777,73 @@ rclpy.shutdown()

</details>

#### 🎉 \*\*\*NEW*** chat_llama_ros (Tools) 🎉

<details>
<summary>Click to expand</summary>

The current implementation of Tools allows executing tools without requiring a model trained for that task.

```python

import time

import rclpy
from rclpy.node import Node
from llama_ros.langchain import ChatLlamaROS
from langchain_core.messages import HumanMessage
from langchain.tools import tool
from random import randint

rclpy.init()

@tool
def get_inhabitants(city: str) -> int:
"""Get the current temperature of a city"""
return randint(4_000_000, 8_000_000)


@tool
def get_curr_temperature(city: str) -> int:
"""Get the current temperature of a city"""
return randint(20, 30)

chat = ChatLlamaROS(temp=0.6, penalty_last_n=8, use_llama_template=True)

messages = [
HumanMessage(
"What is the current temperature in Madrid? And its inhabitants?"
)
]

llm_tools = self.chat.bind_tools(
[get_inhabitants, get_curr_temperature], tool_choice='any'
)

all_tools_res = llm_tools.invoke(messages)
messages.append(all_tools_res)

for tool in all_tools_res.tool_calls:
selected_tool = {
"get_inhabitants": get_inhabitants, "get_curr_temperature": get_curr_temperature
}[tool['name']]

tool_msg = selected_tool.invoke(tool)

formatted_output = f"{tool['name']}({''.join(tool['args'].values())}) = {tool_msg.content}"

tool_msg.additional_kwargs = {'args': tool['args']}
messages.append(tool_msg)

res = self.chat.invoke(messages)

print(f"Response: {res.content}")

rclpy.shutdown()
```

</details>

## Demos

### LLM Demo
Expand Down Expand Up @@ -868,6 +934,20 @@ ros2 run llama_demos chatllama_demo_node

[ChatLlamaROS demo](https://github-production-user-asset-6210df.s3.amazonaws.com/55236157/363094669-c6de124a-4e91-4479-99b6-685fecb0ac20.webm?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240830%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240830T081232Z&X-Amz-Expires=300&X-Amz-Signature=f937758f4bcbaec7683e46ddb057fb642dc86a33cc8c736fca3b5ce2bf06ddac&X-Amz-SignedHeaders=host&actor_id=55236157&key_id=0&repo_id=622137360)

### Tools Demo

```shell
ros2 llama launch MiniCPM-2.6.yaml
```

```shell
ros2 run llama_demos chatllama_tools_node
```



[Tools ChatLlama](https://github.com/user-attachments/assets/b912ee29-1466-4d6a-888b-9a2d9c16ae1d)

#### Full Demo (LLM + chat template + RAG + Reranking + Stream)

```shell
Expand Down
6 changes: 6 additions & 0 deletions llama_demos/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,11 @@ install(PROGRAMS
RENAME chatllama_demo_node
)

install(PROGRAMS
llama_demos/chatllama_tools_node.py
DESTINATION lib/${PROJECT_NAME}
RENAME chatllama_tools_node
)

ament_python_install_package(${PROJECT_NAME})
ament_package()
1 change: 1 addition & 0 deletions llama_demos/llama_demos/chatllama_demo_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def send_prompt(self) -> None:
self.chat = ChatLlamaROS(
temp=0.2,
penalty_last_n=8,
use_gguf_template=False,
)

self.prompt = ChatPromptTemplate.from_messages(
Expand Down
117 changes: 117 additions & 0 deletions llama_demos/llama_demos/chatllama_tools_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/env python3

# MIT License

# Copyright (c) 2024 Alejandro González Cantón
# Copyright (c) 2024 Miguel Ángel González Santamarta

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


import time

import rclpy
from rclpy.node import Node
from llama_ros.langchain import ChatLlamaROS
from langchain_core.messages import HumanMessage
from langchain.tools import tool
from random import randint


@tool
def get_inhabitants(city: str) -> int:
"""Get the current temperature of a city"""
return randint(4_000_000, 8_000_000)


@tool
def get_curr_temperature(city: str) -> int:
"""Get the current temperature of a city"""
return randint(20, 30)


class ChatLlamaToolsDemoNode(Node):

def __init__(self) -> None:
super().__init__("chat_tools_demo_node")

self.initial_time = -1
self.tools_time = -1
self.eval_time = -1

def send_prompt(self) -> None:
self.chat = ChatLlamaROS(temp=0.6, penalty_last_n=8, use_llama_template=True)

messages = [
HumanMessage(
"What is the current temperature in Madrid? And its inhabitants?"
)
]

self.get_logger().info(f"\nPrompt: {messages[0].content}")

llm_tools = self.chat.bind_tools(
[get_inhabitants, get_curr_temperature], tool_choice="any"
)

self.initial_time = time.time()
all_tools_res = llm_tools.invoke(messages)
self.tools_time = time.time()

messages.append(all_tools_res)

for tool in all_tools_res.tool_calls:
selected_tool = {
"get_inhabitants": get_inhabitants,
"get_curr_temperature": get_curr_temperature,
}[tool["name"]]

tool_msg = selected_tool.invoke(tool)

formatted_output = (
f"{tool['name']}({''.join(tool['args'].values())}) = {tool_msg.content}"
)
self.get_logger().info(f"Calling tool: {formatted_output}")

tool_msg.additional_kwargs = {"args": tool["args"]}
messages.append(tool_msg)

res = self.chat.invoke(messages)

self.eval_time = time.time()

self.get_logger().info(f"\nResponse: {res.content}")

time_generate_tools = self.tools_time - self.initial_time
time_last_response = self.eval_time - self.tools_time
self.get_logger().info(f"Time to generate tools: {time_generate_tools:.2} s")
self.get_logger().info(
f"Time to generate last response: {time_last_response:.2} s"
)


def main():
rclpy.init()
node = ChatLlamaToolsDemoNode()
node.send_prompt()
rclpy.shutdown()


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions llama_ros/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,9 @@ install(TARGETS
DESTINATION lib/${PROJECT_NAME}
)

install(DIRECTORY
DESTINATION share/${PROJECT_NAME}
)

ament_python_install_package(${PROJECT_NAME})
ament_package()
Loading

0 comments on commit 2eab6d6

Please sign in to comment.