Skip to content

Commit

Permalink
Fix cohere multimodal (#9)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update
  • Loading branch information
aniketmaurya authored May 9, 2024
1 parent de10a1a commit 37bde72
Show file tree
Hide file tree
Showing 10 changed files with 337 additions and 266 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ output = llm.chat_completion(messages)

if need_tool_use(output):
print("Using weather tool")
tool_results = llm.run_tool(output)
tool_results = llm.run_tools(output)
tool_results[0]["role"] = "assistant"

updated_messages = messages + tool_results
updated_messages = updated_messages + [
{"role": "user", "content": "Think step by step and answer my question based on the above context."}
]
]
output = llm.chat_completion(updated_messages)

print(output.choices[0].message.content)
Expand Down Expand Up @@ -99,7 +99,7 @@ messages = [
"content": f"Check this image {image_url} and suggest me a location where I can go in London which looks similar"}
]
output = llm.chat_completion(messages)
tool_results = llm.run_tool(output)
tool_results = llm.run_tools(output)

updated_messages = messages + tool_results
messages = updated_messages + [{"role": "user", "content": "please answer me, based on the tool results."}]
Expand Down
2 changes: 1 addition & 1 deletion examples/chatbot/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def llamacpp_chat(message, history):

output = llm.chat_completion(messages)
if need_tool_use(output):
tool_response = llm.run_tool(output)
tool_response = llm.run_tools(output)
updated_messages = messages + tool_response
messages = updated_messages + [
{"role": "user", "content": "please answer me, based on the tool results."}
Expand Down
222 changes: 113 additions & 109 deletions examples/cohere.ipynb

Large diffs are not rendered by default.

248 changes: 123 additions & 125 deletions examples/experiments.ipynb

Large diffs are not rendered by default.

37 changes: 31 additions & 6 deletions src/agents/llms/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@

from langchain_cohere import ChatCohere


from typing import List, Optional, Any, Dict

from langchain_core.messages import AIMessage
from loguru import logger

import logging
from agents.specs import ChatCompletion, Choice, Usage, Message, ToolCall
from agents.tool_executor import ToolRegistry
from langchain_core.tools import StructuredTool
from llama_cpp import ChatCompletionRequestMessage

logger = logging.getLogger(__name__)


class CohereChatCompletion:
def __init__(self, **kwargs):
self.llm = ChatCohere()

self.tool_registry = ToolRegistry()

def bind_tools(self, tools: Optional[List[StructuredTool]] = None):
Expand Down Expand Up @@ -52,12 +51,38 @@ def _format_cohere_to_openai(self, output: AIMessage):
}
return ChatCompletion(**response)

# def _format_cohere_to_openai(self, output: NonStreamedChatResponse):
# _tool_calls = output.tool_calls
# tool_calls = []
# for tool in _tool_calls:
# tool = ToolCall(id=tool["id"], type=tool["type"], function=tool["function"])
# tool_calls.append(tool)
#
# message = Message(
# role="assistant", content=output.text, tool_calls=tool_calls
# )
# choices = Choice(index=0, logprobs=None, message=message, finish_reason=output.finish_reason)
# usage = Usage(
# prompt_tokens=output.meta.tokens.input_tokens,
# completion_tokens=output.meta.tokens.output_tokens,
# total_tokens=(output.meta.tokens.input_tokens + output.meta.tokens.output_tokens),
# )
# response = {
# "id": output.generation_id,
# "object": "",
# "created": int(time.time()),
# "model": "Cohere",
# "choices": [choices],
# "usage": usage,
# }
# return ChatCompletion(**response)

def chat_completion(
self, messages: List[ChatCompletionRequestMessage], **kwargs
) -> ChatCompletion:
output = self.llm.invoke(messages, **kwargs)
logger.debug(output)
return self._format_cohere_to_openai(output)

def run_tool(self, chat_completion: ChatCompletion) -> List[Dict[str, Any]]:
return self.tool_registry.call_tool(chat_completion)
def run_tools(self, chat_completion: ChatCompletion) -> List[Dict[str, Any]]:
return self.tool_registry.call_tools(chat_completion)
5 changes: 3 additions & 2 deletions src/agents/llms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def create_image_inspector(**kwargs):
chat_handler = MoondreamChatHandler.from_pretrained(
repo_id="vikhyatk/moondream2",
filename="*mmproj*",
verbose=False,
)

llm = Llama.from_pretrained(
Expand Down Expand Up @@ -78,5 +79,5 @@ def chat_completion(
logger.debug(output)
return ChatCompletion(**output)

def run_tool(self, chat_completion: ChatCompletion) -> List[Dict[str, Any]]:
return self.tool_registry.call_tool(chat_completion)
def run_tools(self, chat_completion: ChatCompletion) -> List[Dict[str, Any]]:
return self.tool_registry.call_tools(chat_completion)
39 changes: 23 additions & 16 deletions src/agents/tool_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@


class ToolRegistry:
def __init__(self):
def __init__(self, tool_format="openai"):
self.tool_format = tool_format
self._tools: Dict[str, StructuredTool] = {}
self._formatted_tools: Dict[str, Any] = {}

Expand All @@ -67,7 +68,23 @@ def openai_tools(self) -> List[Dict[str, Any]]:

return result

def call_tool(self, output: Union[ChatCompletion, Dict]) -> List[Dict[str, str]]:
def call_tool(self, tool: ToolCall) -> Any:
"""Call a single tool and return the result."""
function_name = tool.function.name
function_to_call = self.get(function_name)

if not function_to_call:
raise ValueError(f"No function was found for {function_name}")

function_args = json.loads(tool.function.arguments)
logger.debug(f"Function {function_name} invoked with {function_args}")
function_response = function_to_call.invoke(function_args)
logger.debug(f"Function {function_name}, responded with {function_response}")
return function_response

def call_tools(self, output: Union[ChatCompletion, Dict]) -> List[Dict[str, str]]:
"""Call all tools from the ChatCompletion output and return the
result."""
if isinstance(output, dict):
output = ChatCompletion(**output)

Expand All @@ -77,21 +94,11 @@ def call_tool(self, output: Union[ChatCompletion, Dict]) -> List[Dict[str, str]]
messages = []
# https://platform.openai.com/docs/guides/function-calling
tool_calls = output.choices[0].message.tool_calls
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = self.get(function_name)

if not function_to_call:
raise ValueError(f"No function was found for {function_name}")

function_args = json.loads(tool_call.function.arguments)
logger.debug(f"Function {function_name} invoked with {function_args}")
function_response = function_to_call.invoke(function_args)
logger.debug(
f"Function {function_name}, responded with {function_response}"
)
for tool in tool_calls:
function_name = tool.function.name
function_response = self.call_tool(tool)
messages.append({
"tool_call_id": tool_call.id,
"tool_call_id": tool.id,
"role": "tool",
"name": function_name,
"content": function_response,
Expand Down
22 changes: 19 additions & 3 deletions src/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from langchain_community.utilities import WikipediaAPIWrapper
from loguru import logger

from agents.utils import llama_cpp_image_handler

wikipedia_api_wrapper = None
_image_inspector = None

Expand Down Expand Up @@ -80,17 +82,31 @@ def image_inspector(image_url_or_path: str) -> str:
"""
global _image_inspector
if _image_inspector is None:
logger.info(
"Loading image inspector for first time. This might take a while..."
)
_image_inspector = create_image_inspector()

response = _image_inspector.create_chat_completion(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": image_url_or_path}},
{"type": "text", "text": "Describe this image in detail please."},
{
"type": "image_url",
"image_url": {
"url": llama_cpp_image_handler(image_url_or_path)
},
},
],
}
]
)
return response["choices"][0]["message"]["content"]
return json.dumps(
{
"image_url_or_path": image_url_or_path,
"image description": response["choices"][0]["message"]["content"],
},
indent=2,
)
20 changes: 20 additions & 0 deletions src/agents/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import base64
import os
import logging

logger = logging.getLogger(__name__)


def image_to_base64_data_uri(file_path):
with open(file_path, "rb") as img_file:
base64_data = base64.b64encode(img_file.read()).decode("utf-8")
return f"data:image/png;base64,{base64_data}"


def llama_cpp_image_handler(image_url_or_path: str):
if image_url_or_path.startswith("http"):
return image_url_or_path
else:
if not os.path.exists(image_url_or_path):
raise ValueError(f"Path {image_url_or_path} does not exist")
return image_to_base64_data_uri(image_url_or_path)
2 changes: 1 addition & 1 deletion tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_registry():
tool_registry.register_tool(get_current_weather)
assert tool_registry.get("get_current_weather")

messages = tool_registry.call_tool(completion_data)
messages = tool_registry.call_tools(completion_data)
assert "FeelsLikeC" in messages[0]["content"]


Expand Down

0 comments on commit 37bde72

Please sign in to comment.