Skip to content

Commit

Permalink
Adjust test
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrito87 committed Jan 3, 2025
1 parent 5dc8939 commit e40ac90
Showing 1 changed file with 72 additions and 9 deletions.
81 changes: 72 additions & 9 deletions libs/genai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,26 +133,37 @@ def test_parse_history(convert_system_message_to_human: bool) -> None:
function_name = "calculator"
function_call_1 = {
"name": function_name,
"arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "+"}),
"args": {"arg1": "2", "arg2": "2", "op": "+"},
"id": "0",
}
function_answer1 = json.dumps({"result": 4})
function_call_2 = {
"name": function_name,
"arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "*"}),
}
function_answer2 = json.dumps({"result": 4})
function_call_3 = {
"name": function_name,
"args": {"arg1": "2", "arg2": "2", "op": "*"},
"id": "1",
}
function_answer_3 = json.dumps({"result": 4})
function_call_4 = {
"name": function_name,
"args": {"arg1": "2", "arg2": "3", "op": "*"},
"id": "2",
}
function_answer_4 = json.dumps({"result": 6})
text_answer1 = "They are same"

system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(
content="",
additional_kwargs={
"function_call": function_call_1,
},
tool_calls=[function_call_1],
)
message3 = ToolMessage(
name="calculator", content=function_answer1, tool_call_id="1"
name="calculator", content=function_answer1, tool_call_id="0"
)
message4 = AIMessage(
content="",
Expand All @@ -161,7 +172,14 @@ def test_parse_history(convert_system_message_to_human: bool) -> None:
},
)
message5 = FunctionMessage(name="calculator", content=function_answer2)
message6 = AIMessage(content=text_answer1)
message6 = AIMessage(content="", tool_calls=[function_call_3, function_call_4])
message7 = ToolMessage(
name="calculator", content=function_answer_3, tool_call_id="1"
)
message8 = ToolMessage(
name="calculator", content=function_answer_4, tool_call_id="2"
)
message9 = AIMessage(content=text_answer1)
messages = [
system_message,
message1,
Expand All @@ -170,11 +188,14 @@ def test_parse_history(convert_system_message_to_human: bool) -> None:
message4,
message5,
message6,
message7,
message8,
message9,
]
system_instruction, history = _parse_chat_history(
messages, convert_system_message_to_human=convert_system_message_to_human
)
assert len(history) == 6
assert len(history) == 8
if convert_system_message_to_human:
assert history[0] == glm.Content(
role="user",
Expand All @@ -191,7 +212,7 @@ def test_parse_history(convert_system_message_to_human: bool) -> None:
function_call=glm.FunctionCall(
{
"name": "calculator",
"args": json.loads(function_call_1["arguments"]),
"args": function_call_1["args"],
}
)
)
Expand Down Expand Up @@ -236,7 +257,49 @@ def test_parse_history(convert_system_message_to_human: bool) -> None:
)
],
)
assert history[5] == glm.Content(role="model", parts=[glm.Part(text=text_answer1)])
assert history[5] == glm.Content(
role="model",
parts=[
glm.Part(
function_call=glm.FunctionCall(
{
"name": "calculator",
"args": function_call_3["args"],
}
)
),
glm.Part(
function_call=glm.FunctionCall(
{
"name": "calculator",
"args": function_call_4["args"],
}
)
),
],
)
assert history[6] == glm.Content(
role="user",
parts=[
glm.Part(
function_response=glm.FunctionResponse(
{
"name": "calculator",
"response": {"result": 4},
}
)
),
glm.Part(
function_response=glm.FunctionResponse(
{
"name": "calculator",
"response": {"result": 6},
}
)
),
],
)
assert history[7] == glm.Content(role="model", parts=[glm.Part(text=text_answer1)])
if convert_system_message_to_human:
assert system_instruction is None
else:
Expand Down

0 comments on commit e40ac90

Please sign in to comment.