From 06fc5a38ecfc25241285ba43c42715447e684ccc Mon Sep 17 00:00:00 2001 From: Philip Nuzhnyi Date: Sat, 31 Aug 2024 16:06:44 +0100 Subject: [PATCH] See if refusal attribute exists in ChatCompletionMessage before referencing it (#962) Co-authored-by: Ivan Leo --- instructor/function_calls.py | 9 ++- tests/test_function_calls.py | 108 ++++++++++++++++++++++++++++++++++- 2 files changed, 112 insertions(+), 5 deletions(-) diff --git a/instructor/function_calls.py b/instructor/function_calls.py index e0bff5bbc..9e9600b27 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -309,12 +309,15 @@ def parse_tools( strict: Optional[bool] = None, ) -> BaseModel: message = completion.choices[0].message + # this field seems to be missing when using instructor with some other tools (e.g. litellm) + # trying to fix this by adding a check + if hasattr(message, "refusal"): + assert ( + message.refusal is None + ), f"Unable to generate a response due to {message.refusal}" assert ( len(message.tool_calls or []) == 1 ), "Instructor does not support multiple tool calls, use List[Model] instead." - assert ( - message.refusal is None - ), f"Unable to generate a response due to {message.refusal}" tool_call = message.tool_calls[0] # type: ignore assert ( tool_call.function.name == cls.openai_schema["name"] # type: ignore[index] diff --git a/tests/test_function_calls.py b/tests/test_function_calls.py index 842b876c2..9c8299bee 100644 --- a/tests/test_function_calls.py +++ b/tests/test_function_calls.py @@ -2,7 +2,9 @@ import pytest from anthropic.types import Message, Usage -from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function from pydantic import BaseModel, ValidationError import instructor @@ -206,4 +208,106 @@ class Model(BaseModel): def test_mode_functions_deprecation_warning() -> None: from openai import OpenAI with pytest.warns(DeprecationWarning, match="The FUNCTIONS mode is deprecated and will be removed in future versions"): - _ = instructor.from_openai(OpenAI(), mode=instructor.Mode.FUNCTIONS) \ No newline at end of file + _ = instructor.from_openai(OpenAI(), mode=instructor.Mode.FUNCTIONS) + +def test_refusal_attribute(test_model: type[OpenAISchema]): + completion = ChatCompletion( + id="test_id", + created=1234567890, + model="gpt-3.5-turbo", + object="chat.completion", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + content="test_content", + refusal="test_refusal", + role="assistant", + tool_calls=[], + ), + finish_reason="stop", + logprobs=None, + ) + ], + ) + + try: + + test_model.from_response(completion, mode=instructor.Mode.TOOLS) + except Exception as e: + assert "Unable to generate a response due to test_refusal" in str(e) + + +def test_no_refusal_attribute(test_model: type[OpenAISchema]): + completion = ChatCompletion( + id="test_id", + created=1234567890, + model="gpt-3.5-turbo", + object="chat.completion", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + content="test_content", + refusal=None, + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + id="test_id", + function=Function( + name="TestModel", + arguments='{"data": "test_data", "name": "TestModel"}', + ), + type="function", + ) + ], + ), + finish_reason="stop", + logprobs=None, + ) + ], + ) + + resp = test_model.from_response(completion, mode=instructor.Mode.TOOLS) + assert resp.data == "test_data" + assert resp.name == "TestModel" + + +def test_missing_refusal_attribute(test_model: type[OpenAISchema]): + message_without_refusal_attribute = ChatCompletionMessage( + content="test_content", + refusal="test_refusal", + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + id="test_id", + function=Function( + name="TestModel", + arguments='{"data": "test_data", "name": "TestModel"}', + ), + type="function", + ) + ], + ) + + del message_without_refusal_attribute.refusal + assert not hasattr(message_without_refusal_attribute, "refusal") + + completion = ChatCompletion( + id="test_id", + created=1234567890, + model="gpt-3.5-turbo", + object="chat.completion", + choices=[ + Choice( + index=0, + message=message_without_refusal_attribute, + finish_reason="stop", + logprobs=None, + ) + ], + ) + + resp = test_model.from_response(completion, mode=instructor.Mode.TOOLS) + assert resp.data == "test_data" + assert resp.name == "TestModel"