From 625118d87c67ac9d594f9e1bd160d0727ca1abe5 Mon Sep 17 00:00:00 2001 From: pedrito87 Date: Wed, 8 Jan 2025 12:08:23 +0100 Subject: [PATCH] Fix integration test with multiple parts --- .../integration_tests/test_chat_models.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/libs/genai/tests/integration_tests/test_chat_models.py b/libs/genai/tests/integration_tests/test_chat_models.py index 3a15b9d4..b93e4411 100644 --- a/libs/genai/tests/integration_tests/test_chat_models.py +++ b/libs/genai/tests/integration_tests/test_chat_models.py @@ -299,7 +299,6 @@ def test_safety_settings_gemini() -> None: assert len(out2.content) > 0 -@pytest.mark.xfail(reason="on the model's side") def test_chat_function_calling_with_multiple_parts() -> None: @tool def search( @@ -342,19 +341,22 @@ def search( assert len(response.tool_calls) > 0 tool_call = response.tool_calls[0] assert tool_call["name"] == "search" + tool_messages = [] + for tool_call in response.tool_calls: + tool_response = search.run(tool_call["args"]) + tool_message = ToolMessage( + name="search", + content=json.dumps(tool_response), + tool_call_id=tool_call["id"], + ) + tool_messages.append(tool_message) + assert len(tool_messages) > 0 + assert len(response.tool_calls) == len(tool_messages) - tool_response = search("sparrow") - tool_message = ToolMessage( - name="search", - content=json.dumps(tool_response), - tool_call_id="0", - ) - - result = llm_with_search.invoke([request, response, tool_message]) + result = llm_with_search.invoke([request, response, *tool_messages]) assert isinstance(result, AIMessage) assert "brown" in result.content - assert len(result.tool_calls) > 0 def _check_tool_calls(response: BaseMessage, expected_name: str) -> None: