Skip to content

Commit

Permalink
genai: Fix multiple tool calls in a single AIMessage (#671)
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrito87 authored Jan 7, 2025
1 parent 1f6d2d7 commit 79047b2
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 66 deletions.
116 changes: 59 additions & 57 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,49 @@ def _convert_to_parts(
return parts


def _convert_tool_message_to_part(message: ToolMessage | FunctionMessage) -> Part:
"""Converts a tool or function message to a google part."""
name = message.name
response: Any
if not isinstance(message.content, str):
response = message.content
else:
try:
response = json.loads(message.content)
except json.JSONDecodeError:
response = message.content # leave as str representation
part = Part(
function_response=FunctionResponse(
name=name,
response=(
{"output": response} if not isinstance(response, dict) else response
),
)
)
return part


def _get_ai_message_tool_messages_parts(
tool_messages: Sequence[ToolMessage], ai_message: AIMessage
) -> list[Part]:
"""
Finds relevant tool messages for the AI message and converts them to a single
list of Parts.
"""
# We are interested only in the tool messages that are part of the AI message
tool_calls_ids = [tool_call["id"] for tool_call in ai_message.tool_calls]
parts = []
for i, message in enumerate(tool_messages):
if not tool_calls_ids:
break
if message.tool_call_id in tool_calls_ids:
# remove the id from the list, so that we do not iterate over it again
tool_calls_ids.remove(message.tool_call_id)
part = _convert_tool_message_to_part(message)
parts.append(part)
return parts


def _parse_chat_history(
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
) -> Tuple[Optional[Content], List[Content]]:
Expand All @@ -310,22 +353,34 @@ def _parse_chat_history(
warnings.warn("Convert_system_message_to_human will be deprecated!")

system_instruction: Optional[Content] = None
for i, message in enumerate(input_messages):
messages_without_tool_messages = [
message for message in input_messages if not isinstance(message, ToolMessage)
]
tool_messages = [
message for message in input_messages if isinstance(message, ToolMessage)
]
for i, message in enumerate(messages_without_tool_messages):
if i == 0 and isinstance(message, SystemMessage):
system_instruction = Content(parts=_convert_to_parts(message.content))
continue
elif isinstance(message, AIMessage):
role = "model"
if message.tool_calls:
parts = []
ai_message_parts = []
for tool_call in message.tool_calls:
function_call = FunctionCall(
{
"name": tool_call["name"],
"args": tool_call["args"],
}
)
parts.append(Part(function_call=function_call))
ai_message_parts.append(Part(function_call=function_call))
tool_messages_parts = _get_ai_message_tool_messages_parts(
tool_messages=tool_messages, ai_message=message
)
messages.append(Content(role=role, parts=ai_message_parts))
messages.append(Content(role="user", parts=tool_messages_parts))
continue
elif raw_function_call := message.additional_kwargs.get("function_call"):
function_call = FunctionCall(
{
Expand All @@ -344,60 +399,7 @@ def _parse_chat_history(
system_instruction = None
elif isinstance(message, FunctionMessage):
role = "user"
response: Any
if not isinstance(message.content, str):
response = message.content
else:
try:
response = json.loads(message.content)
except json.JSONDecodeError:
response = message.content # leave as str representation
parts = [
Part(
function_response=FunctionResponse(
name=message.name,
response=(
{"output": response}
if not isinstance(response, dict)
else response
),
)
)
]
elif isinstance(message, ToolMessage):
role = "user"
prev_message: Optional[BaseMessage] = (
input_messages[i - 1] if i > 0 else None
)
if (
prev_message
and isinstance(prev_message, AIMessage)
and prev_message.tool_calls
):
# message.name can be null for ToolMessage
name: str = prev_message.tool_calls[0]["name"]
else:
name = message.name # type: ignore
tool_response: Any
if not isinstance(message.content, str):
tool_response = message.content
else:
try:
tool_response = json.loads(message.content)
except json.JSONDecodeError:
tool_response = message.content # leave as str representation
parts = [
Part(
function_response=FunctionResponse(
name=name,
response=(
{"output": tool_response}
if not isinstance(tool_response, dict)
else tool_response
),
)
)
]
parts = [_convert_tool_message_to_part(message)]
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
Expand Down
81 changes: 72 additions & 9 deletions libs/genai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,26 +133,37 @@ def test_parse_history(convert_system_message_to_human: bool) -> None:
function_name = "calculator"
function_call_1 = {
"name": function_name,
"arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "+"}),
"args": {"arg1": "2", "arg2": "2", "op": "+"},
"id": "0",
}
function_answer1 = json.dumps({"result": 4})
function_call_2 = {
"name": function_name,
"arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "*"}),
}
function_answer2 = json.dumps({"result": 4})
function_call_3 = {
"name": function_name,
"args": {"arg1": "2", "arg2": "2", "op": "*"},
"id": "1",
}
function_answer_3 = json.dumps({"result": 4})
function_call_4 = {
"name": function_name,
"args": {"arg1": "2", "arg2": "3", "op": "*"},
"id": "2",
}
function_answer_4 = json.dumps({"result": 6})
text_answer1 = "They are same"

system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(
content="",
additional_kwargs={
"function_call": function_call_1,
},
tool_calls=[function_call_1],
)
message3 = ToolMessage(
name="calculator", content=function_answer1, tool_call_id="1"
name="calculator", content=function_answer1, tool_call_id="0"
)
message4 = AIMessage(
content="",
Expand All @@ -161,7 +172,14 @@ def test_parse_history(convert_system_message_to_human: bool) -> None:
},
)
message5 = FunctionMessage(name="calculator", content=function_answer2)
message6 = AIMessage(content=text_answer1)
message6 = AIMessage(content="", tool_calls=[function_call_3, function_call_4])
message7 = ToolMessage(
name="calculator", content=function_answer_3, tool_call_id="1"
)
message8 = ToolMessage(
name="calculator", content=function_answer_4, tool_call_id="2"
)
message9 = AIMessage(content=text_answer1)
messages = [
system_message,
message1,
Expand All @@ -170,11 +188,14 @@ def test_parse_history(convert_system_message_to_human: bool) -> None:
message4,
message5,
message6,
message7,
message8,
message9,
]
system_instruction, history = _parse_chat_history(
messages, convert_system_message_to_human=convert_system_message_to_human
)
assert len(history) == 6
assert len(history) == 8
if convert_system_message_to_human:
assert history[0] == glm.Content(
role="user",
Expand All @@ -191,7 +212,7 @@ def test_parse_history(convert_system_message_to_human: bool) -> None:
function_call=glm.FunctionCall(
{
"name": "calculator",
"args": json.loads(function_call_1["arguments"]),
"args": function_call_1["args"],
}
)
)
Expand Down Expand Up @@ -236,7 +257,49 @@ def test_parse_history(convert_system_message_to_human: bool) -> None:
)
],
)
assert history[5] == glm.Content(role="model", parts=[glm.Part(text=text_answer1)])
assert history[5] == glm.Content(
role="model",
parts=[
glm.Part(
function_call=glm.FunctionCall(
{
"name": "calculator",
"args": function_call_3["args"],
}
)
),
glm.Part(
function_call=glm.FunctionCall(
{
"name": "calculator",
"args": function_call_4["args"],
}
)
),
],
)
assert history[6] == glm.Content(
role="user",
parts=[
glm.Part(
function_response=glm.FunctionResponse(
{
"name": "calculator",
"response": {"result": 4},
}
)
),
glm.Part(
function_response=glm.FunctionResponse(
{
"name": "calculator",
"response": {"result": 6},
}
)
),
],
)
assert history[7] == glm.Content(role="model", parts=[glm.Part(text=text_answer1)])
if convert_system_message_to_human:
assert system_instruction is None
else:
Expand Down

0 comments on commit 79047b2

Please sign in to comment.