From 4766d90ffd7f96eb919d7c651a7ec9e92eb56444 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 25 Aug 2025 14:54:54 +0200 Subject: [PATCH 01/28] Update LLMInterface to restore LC compatibility --- examples/customize/llms/openai_llm.py | 22 ++++++- src/neo4j_graphrag/llm/base.py | 39 +++++++++---- src/neo4j_graphrag/llm/openai_llm.py | 83 ++++++++++----------------- src/neo4j_graphrag/llm/utils.py | 49 ++++++++++++++++ src/neo4j_graphrag/message_history.py | 3 + 5 files changed, 130 insertions(+), 66 deletions(-) create mode 100644 src/neo4j_graphrag/llm/utils.py diff --git a/examples/customize/llms/openai_llm.py b/examples/customize/llms/openai_llm.py index d4b38244e..501ccdb53 100644 --- a/examples/customize/llms/openai_llm.py +++ b/examples/customize/llms/openai_llm.py @@ -1,8 +1,28 @@ from neo4j_graphrag.llm import LLMResponse, OpenAILLM +from neo4j_graphrag.message_history import InMemoryMessageHistory +from neo4j_graphrag.types import LLMMessage # set api key here on in the OPENAI_API_KEY env var api_key = None +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + + llm = OpenAILLM(model_name="gpt-4o", api_key=api_key) -res: LLMResponse = llm.invoke("say something") +res: LLMResponse = llm.invoke( + # "say something", + # messages, + InMemoryMessageHistory( + messages=messages, + ) +) print(res.content) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index ff7af1c70..28e284e95 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -27,7 +27,8 @@ from neo4j_graphrag.tool import Tool -from neo4j_graphrag.utils.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler +from .utils import legacy_inputs_to_message_history class LLMInterface(ABC): @@ -55,20 +56,27 @@ def __init__( else: self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER - @abstractmethod + @rate_limit_handler def invoke( self, - input: str, + input: Union[str, List[LLMMessage], MessageHistory], message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, + ) -> LLMResponse: + message_history = legacy_inputs_to_message_history( + input, message_history, system_instruction + ) + return self._invoke(message_history.messages) + + @abstractmethod + def _invoke( + self, + input: list[LLMMessage], ) -> LLMResponse: """Sends a text input to the LLM and retrieves a response. Args: - input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. + input (MessageHistory): Text sent to the LLM. Returns: LLMResponse: The response from the LLM. @@ -77,20 +85,27 @@ def invoke( LLMGenerationError: If anything goes wrong. """ - @abstractmethod + @async_rate_limit_handler async def ainvoke( self, - input: str, + input: Union[str, List[LLMMessage], MessageHistory], message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, + ) -> LLMResponse: + message_history = legacy_inputs_to_message_history( + input, message_history, system_instruction + ) + return await self._ainvoke(message_history.messages) + + @abstractmethod + async def _ainvoke( + self, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends a text input to the LLM and retrieves a response. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index afdf0234d..afe2fa3c4 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -28,26 +28,15 @@ cast, ) -from pydantic import ValidationError - from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage from ..exceptions import LLMGenerationError from .base import LLMInterface -from neo4j_graphrag.utils.rate_limit import ( - RateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, -) from .types import ( - BaseMessage, LLMResponse, - MessageList, ToolCall, ToolCallResponse, - SystemMessage, - UserMessage, ) from neo4j_graphrag.tool import Tool @@ -55,14 +44,17 @@ if TYPE_CHECKING: from openai.types.chat import ( ChatCompletionMessageParam, - ChatCompletionToolParam, - ) + ChatCompletionToolParam, ChatCompletionUserMessageParam, + ChatCompletionSystemMessageParam, ChatCompletionAssistantMessageParam, +) from openai import OpenAI, AsyncOpenAI + from neo4j_graphrag.utiles.rate_limit import RateLimitHandler else: ChatCompletionMessageParam = Any ChatCompletionToolParam = Any OpenAI = Any AsyncOpenAI = Any + RateLimitHandler = Any class BaseOpenAILLM(LLMInterface, abc.ABC): @@ -97,23 +89,26 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + messages: list[LLMMessage], ) -> Iterable[ChatCompletionMessageParam]: - messages = [] - if system_instruction: - messages.append(SystemMessage(content=system_instruction).model_dump()) - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - messages.extend(cast(Iterable[dict[str, Any]], message_history)) - messages.append(UserMessage(content=input).model_dump()) - return messages # type: ignore + chat_messages = [] + for m in messages: + message_type: ChatCompletionMessageParam + if m["role"] == "system": + message_type = ChatCompletionSystemMessageParam + elif m["role"] == "user": + message_type = ChatCompletionUserMessageParam + elif m["role"] == "assistant": + message_type = ChatCompletionAssistantMessageParam + else: + raise ValueError(f"Unknown message type: {m['role']}") + chat_messages.append( + message_type( + role=m["role"], + content=m["content"], + ) + ) + return chat_messages def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]: """Convert a Tool object to OpenAI's expected format. @@ -136,21 +131,15 @@ def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]: except AttributeError: raise LLMGenerationError(f"Tool {tool} is not a valid Tool object") - @rate_limit_handler - def invoke( + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends a text input to the OpenAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. @@ -159,10 +148,8 @@ def invoke( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages response = self.client.chat.completions.create( - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), model=self.model_name, **self.model_params, ) @@ -171,7 +158,6 @@ def invoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - @rate_limit_handler def invoke_with_tools( self, input: str, @@ -246,21 +232,15 @@ def invoke_with_tools( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. @@ -269,10 +249,8 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages response = await self.async_client.chat.completions.create( - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), model=self.model_name, **self.model_params, ) @@ -281,7 +259,6 @@ async def ainvoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - @async_rate_limit_handler async def ainvoke_with_tools( self, input: str, diff --git a/src/neo4j_graphrag/llm/utils.py b/src/neo4j_graphrag/llm/utils.py new file mode 100644 index 000000000..42126c939 --- /dev/null +++ b/src/neo4j_graphrag/llm/utils.py @@ -0,0 +1,49 @@ +import warnings +from typing import Union, Optional + +from neo4j_graphrag.message_history import MessageHistory, InMemoryMessageHistory +from neo4j_graphrag.types import LLMMessage + + +def legacy_inputs_to_message_history( + input: Union[str, list[LLMMessage], MessageHistory], + message_history: Optional[Union[list[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, +) -> MessageHistory: + if message_history: + warnings.warn( + "Using message_history parameter is deprecated and will be removed in 2.0. Use a list of inputs or a MessageHistory instead.", + DeprecationWarning, + ) + if isinstance(message_history, MessageHistory): + history = message_history + else: # list[LLMMessage] + history = InMemoryMessageHistory(message_history) + else: + history = InMemoryMessageHistory() + if system_instruction is not None: + warnings.warn( + "Using system_instruction parameter is deprecated and will be removed in 2.0. Use a list of inputs or a MessageHistory instead.", + DeprecationWarning, + ) + if history.is_empty(): + history.add_message( + LLMMessage( + role="system", + content=system_instruction, + ), + ) + else: + warnings.warn( + "system_instruction provided but ignored as the message history is not empty", + RuntimeWarning, + ) + if isinstance(input, str): + history.add_message(LLMMessage(role="user", content=input)) + return history + if isinstance(input, list): + history.add_messages(input) + return history + # input is a MessageHistory instance + history.add_messages(input.messages) + return history diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index 59ba033d9..f4df4576f 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -74,6 +74,9 @@ class MessageHistory(ABC): @abstractmethod def messages(self) -> List[LLMMessage]: ... + def is_empty(self) -> bool: + return len(self.messages) == 0 + @abstractmethod def add_message(self, message: LLMMessage) -> None: ... From 014af4ef78f38ccc83bf04e4479cd4155a21e643 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 25 Aug 2025 16:11:24 +0200 Subject: [PATCH 02/28] Update AnthropicLLM --- examples/customize/llms/anthropic_llm.py | 18 +++++- src/neo4j_graphrag/llm/anthropic_llm.py | 72 ++++++++---------------- 2 files changed, 42 insertions(+), 48 deletions(-) diff --git a/examples/customize/llms/anthropic_llm.py b/examples/customize/llms/anthropic_llm.py index 85c4ad03a..dbd3f56fd 100644 --- a/examples/customize/llms/anthropic_llm.py +++ b/examples/customize/llms/anthropic_llm.py @@ -1,12 +1,28 @@ from neo4j_graphrag.llm import AnthropicLLM, LLMResponse +from neo4j_graphrag.types import LLMMessage # set api key here on in the ANTHROPIC_API_KEY env var api_key = None +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + + llm = AnthropicLLM( model_name="claude-3-opus-20240229", model_params={"max_tokens": 1000}, # max_tokens must be specified api_key=api_key, ) -res: LLMResponse = llm.invoke("say something") +res: LLMResponse = llm.invoke( + # "say something", + messages, +) print(res.content) diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 21560d3f2..bcd2f0e59 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -13,24 +13,17 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Iterable, Optional -from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( - BaseMessage, LLMResponse, - MessageList, - UserMessage, ) -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage if TYPE_CHECKING: @@ -84,46 +77,39 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - ) -> Iterable[MessageParam]: - messages: list[dict[str, str]] = [] - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - messages.extend(cast(Iterable[dict[str, Any]], message_history)) - messages.append(UserMessage(content=input).model_dump()) - return messages # type: ignore - - @rate_limit_handler - def invoke( + input: list[LLMMessage], + ) -> tuple[str, Iterable[MessageParam]]: + messages: list[MessageParam] = [] + system_instruction = self.anthropic.NOT_GIVEN + for i in input: + if i["role"] == "system": + system_instruction = i["content"] + else: + messages.append( + self.anthropic.types.MessageParam( + role=i["role"], # type: ignore + content=i["content"], + ) + ) + return system_instruction, messages + + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history) + system_instruction, messages = self.get_messages(input) response = self.client.messages.create( model=self.model_name, - system=system_instruction or self.anthropic.NOT_GIVEN, + system=system_instruction, messages=messages, **self.model_params, ) @@ -136,31 +122,23 @@ def invoke( except self.anthropic.APIError as e: raise LLMGenerationError(e) - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history) + system_instruction, messages = self.get_messages(input) response = await self.async_client.messages.create( model=self.model_name, - system=system_instruction or self.anthropic.NOT_GIVEN, + system=system_instruction, messages=messages, **self.model_params, ) From eb9c91cd73648be78299508e75ce8ce7781eb3dd Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 26 Aug 2025 15:13:38 +0200 Subject: [PATCH 03/28] Update MistralAILLM --- examples/README.md | 2 +- examples/customize/llms/mistalai_llm.py | 10 --- examples/customize/llms/mistralai_llm.py | 32 ++++++++++ src/neo4j_graphrag/llm/mistralai_llm.py | 78 +++++++++--------------- 4 files changed, 62 insertions(+), 60 deletions(-) delete mode 100644 examples/customize/llms/mistalai_llm.py create mode 100644 examples/customize/llms/mistralai_llm.py diff --git a/examples/README.md b/examples/README.md index 774739b32..b8033358d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -69,7 +69,7 @@ are listed in [the last section of this file](#customize). - [OpenAI (GPT)](./customize/llms/openai_llm.py) - [Azure OpenAI]() - [VertexAI (Gemini)](./customize/llms/vertexai_llm.py) -- [MistralAI](./customize/llms/mistalai_llm.py) +- [MistralAI](customize/llms/mistralai_llm.py) - [Cohere](./customize/llms/cohere_llm.py) - [Anthropic (Claude)](./customize/llms/anthropic_llm.py) - [Ollama](./customize/llms/ollama_llm.py) diff --git a/examples/customize/llms/mistalai_llm.py b/examples/customize/llms/mistalai_llm.py deleted file mode 100644 index b829baad4..000000000 --- a/examples/customize/llms/mistalai_llm.py +++ /dev/null @@ -1,10 +0,0 @@ -from neo4j_graphrag.llm import MistralAILLM - -# set api key here on in the MISTRAL_API_KEY env var -api_key = None - -llm = MistralAILLM( - model_name="mistral-small-latest", - api_key=api_key, -) -llm.invoke("say something") diff --git a/examples/customize/llms/mistralai_llm.py b/examples/customize/llms/mistralai_llm.py new file mode 100644 index 000000000..66db280b1 --- /dev/null +++ b/examples/customize/llms/mistralai_llm.py @@ -0,0 +1,32 @@ +from neo4j_graphrag.llm import MistralAILLM, LLMResponse +from neo4j_graphrag.message_history import InMemoryMessageHistory +from neo4j_graphrag.types import LLMMessage + +# set api key here on in the MISTRAL_API_KEY env var +api_key = None + + +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + + +llm = MistralAILLM( + model_name="mistral-small-latest", + api_key=api_key, +) +res: LLMResponse = llm.invoke( + # "say something", + # messages, + InMemoryMessageHistory( + messages=messages, + ) +) +print(res.content) diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index 3fa8663ae..ba49ab23d 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -15,33 +15,31 @@ from __future__ import annotations import os -from typing import Any, Iterable, List, Optional, Union, cast - -from pydantic import ValidationError +from typing import Any, Optional from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( - BaseMessage, LLMResponse, - MessageList, - SystemMessage, - UserMessage, ) -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage try: - from mistralai import Messages, Mistral + from mistralai import ( + Messages, + UserMessage, + AssistantMessage, + SystemMessage, + Mistral, + ) from mistralai.models.sdkerror import SDKError except ImportError: Mistral = None # type: ignore SDKError = None # type: ignore + Messages = Any class MistralAILLM(LLMInterface): @@ -75,38 +73,30 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> list[Messages]: - messages = [] - if system_instruction: - messages.append(SystemMessage(content=system_instruction).model_dump()) - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - messages.extend(cast(Iterable[dict[str, Any]], message_history)) - messages.append(UserMessage(content=input).model_dump()) - return cast(list[Messages], messages) - - @rate_limit_handler - def invoke( + messages: list[Messages] = [] + for m in input: + if m["role"] == "system": + messages.append(SystemMessage(content=m["content"])) + continue + if m["role"] == "user": + messages.append(UserMessage(content=m["content"])) + continue + if m["role"] == "assistant": + messages.append(AssistantMessage(content=m["content"])) + continue + return messages + + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends a text input to the Mistral chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from MistralAI. @@ -115,9 +105,7 @@ def invoke( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history, system_instruction) + messages = self.get_messages(input) response = self.client.chat.complete( model=self.model_name, messages=messages, @@ -132,21 +120,15 @@ def invoke( except SDKError as e: raise LLMGenerationError(e) - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends a text input to the MistralAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from MistralAI. @@ -155,9 +137,7 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history, system_instruction) + messages = self.get_messages(input) response = await self.client.chat.complete_async( model=self.model_name, messages=messages, From dbc20900e0b7c32b3144ee7cff9bb9aeeef3aeb1 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 26 Aug 2025 20:11:10 +0200 Subject: [PATCH 04/28] Update OllamaLLM --- examples/customize/llms/ollama_llm.py | 20 ++++++++- src/neo4j_graphrag/llm/ollama_llm.py | 64 ++++++--------------------- 2 files changed, 31 insertions(+), 53 deletions(-) diff --git a/examples/customize/llms/ollama_llm.py b/examples/customize/llms/ollama_llm.py index dc42f7466..42d0ddaab 100644 --- a/examples/customize/llms/ollama_llm.py +++ b/examples/customize/llms/ollama_llm.py @@ -3,11 +3,27 @@ """ from neo4j_graphrag.llm import LLMResponse, OllamaLLM +from neo4j_graphrag.types import LLMMessage + +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + + llm = OllamaLLM( - model_name="", + model_name="orca-mini:latest", # model_params={"options": {"temperature": 0}, "format": "json"}, # host="...", # if using a remote server ) -res: LLMResponse = llm.invoke("What is the additive color model?") +res: LLMResponse = llm.invoke( + messages, +) print(res.content) diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 94541e033..a0189c64b 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -15,26 +15,15 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast - -from pydantic import ValidationError +from typing import TYPE_CHECKING, Any, Optional, Sequence from neo4j_graphrag.exceptions import LLMGenerationError -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage from .base import LLMInterface -from neo4j_graphrag.utils.rate_limit import ( - RateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, -) +from neo4j_graphrag.utils.rate_limit import RateLimitHandler from .types import ( - BaseMessage, LLMResponse, - MessageList, - SystemMessage, - UserMessage, ) if TYPE_CHECKING: @@ -80,48 +69,29 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> Sequence[Message]: - messages = [] - if system_instruction: - messages.append(SystemMessage(content=system_instruction).model_dump()) - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - messages.extend(cast(Iterable[dict[str, Any]], message_history)) - messages.append(UserMessage(content=input).model_dump()) - return messages # type: ignore + return [ + self.ollama.Message(**i) + for i in input + ] - @rate_limit_handler - def invoke( + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages response = self.client.chat( model=self.model_name, - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), **self.model_params, ) content = response.message.content or "" @@ -129,21 +99,15 @@ def invoke( except self.ollama.ResponseError as e: raise LLMGenerationError(e) - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. @@ -152,11 +116,9 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages response = await self.async_client.chat( model=self.model_name, - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), options=self.model_params, ) content = response.message.content or "" From dcae75d41a28bb826e340c6b1faedc8c49a8aa3f Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 27 Aug 2025 09:56:52 +0200 Subject: [PATCH 05/28] Update CohereLLM --- examples/customize/llms/cohere_llm.py | 14 +++++- src/neo4j_graphrag/llm/cohere_llm.py | 69 ++++++++------------------- 2 files changed, 33 insertions(+), 50 deletions(-) diff --git a/examples/customize/llms/cohere_llm.py b/examples/customize/llms/cohere_llm.py index d631d3e41..daa3926ef 100644 --- a/examples/customize/llms/cohere_llm.py +++ b/examples/customize/llms/cohere_llm.py @@ -1,11 +1,23 @@ from neo4j_graphrag.llm import CohereLLM, LLMResponse +from neo4j_graphrag.types import LLMMessage # set api key here on in the CO_API_KEY env var api_key = None +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + llm = CohereLLM( model_name="command-r", api_key=api_key, ) -res: LLMResponse = llm.invoke("say something") +res: LLMResponse = llm.invoke(input=messages) print(res.content) diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 2e3ca0cea..fa6de2d7c 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -14,25 +14,16 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast - -from pydantic import ValidationError +from typing import TYPE_CHECKING, Any, Optional from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( - BaseMessage, LLMResponse, - MessageList, - SystemMessage, - UserMessage, ) -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage if TYPE_CHECKING: @@ -84,46 +75,34 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> ChatMessages: - messages = [] - if system_instruction: - messages.append(SystemMessage(content=system_instruction).model_dump()) - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - messages.extend(cast(Iterable[dict[str, Any]], message_history)) - messages.append(UserMessage(content=input).model_dump()) - return messages # type: ignore - - @rate_limit_handler - def invoke( + messages: ChatMessages = [] + for i in input: + if i["role"] == "system": + messages.append(self.cohere.SystemChatMessageV2(content=i["content"])) + if i["role"] == "user": + messages.append(self.cohere.UserChatMessageV2(content=i["content"])) + if i["role"] == "assistant": + messages.append( + self.cohere.AssistantChatMessageV2(content=i["content"]) + ) + return messages + + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history, system_instruction) + messages = self.get_messages(input) res = self.client.chat( messages=messages, model=self.model_name, @@ -134,28 +113,20 @@ def invoke( content=res.message.content[0].text if res.message.content else "", ) - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history, system_instruction) + messages = self.get_messages(input) res = await self.async_client.chat( messages=messages, model=self.model_name, From 569523774c4a6326b4d927bdedd22ac6d53cdb1d Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 27 Aug 2025 13:34:12 +0200 Subject: [PATCH 06/28] Mypy / ruff --- examples/customize/llms/custom_llm.py | 25 ++++++------------------- examples/customize/llms/ollama_llm.py | 1 - src/neo4j_graphrag/llm/anthropic_llm.py | 10 +++++----- src/neo4j_graphrag/llm/mistralai_llm.py | 2 +- src/neo4j_graphrag/llm/ollama_llm.py | 5 +---- src/neo4j_graphrag/llm/openai_llm.py | 14 ++++++++------ 6 files changed, 21 insertions(+), 36 deletions(-) diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index 86b3cb993..554629d4a 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -1,6 +1,6 @@ import random import string -from typing import Any, Awaitable, Callable, List, Optional, TypeVar, Union +from typing import Any, Awaitable, Callable, Optional, TypeVar from neo4j_graphrag.llm import LLMInterface, LLMResponse from neo4j_graphrag.utils.rate_limit import ( @@ -8,7 +8,6 @@ # rate_limit_handler, # async_rate_limit_handler, ) -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage @@ -18,38 +17,26 @@ def __init__( ): super().__init__(model_name, **kwargs) - # Optional: Apply rate limit handling to synchronous invoke method - # @rate_limit_handler - def invoke( + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: content: str = ( self.model_name + ": " + "".join(random.choices(string.ascii_letters, k=30)) ) return LLMResponse(content=content) - # Optional: Apply rate limit handling to asynchronous ainvoke method - # @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: raise NotImplementedError() -llm = CustomLLM( - "" -) # if rate_limit_handler and async_rate_limit_handler decorators are used, the default rate limit handler will be applied automatically (retry with exponential backoff) +llm = CustomLLM("") res: LLMResponse = llm.invoke("text") print(res.content) -# If rate_limit_handler and async_rate_limit_handler decorators are used and you want to use a custom rate limit handler -# Type variables for function signatures used in rate limit handlers F = TypeVar("F", bound=Callable[..., Any]) AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) diff --git a/examples/customize/llms/ollama_llm.py b/examples/customize/llms/ollama_llm.py index 42d0ddaab..37dd1dbec 100644 --- a/examples/customize/llms/ollama_llm.py +++ b/examples/customize/llms/ollama_llm.py @@ -17,7 +17,6 @@ ] - llm = OllamaLLM( model_name="orca-mini:latest", # model_params={"options": {"temperature": 0}, "format": "json"}, diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index bcd2f0e59..8df1b464d 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -13,8 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, Optional - +from typing import TYPE_CHECKING, Any, Iterable, Optional, Union from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface @@ -28,6 +27,7 @@ if TYPE_CHECKING: from anthropic.types.message_param import MessageParam + from anthropic import NotGiven class AnthropicLLM(LLMInterface): @@ -78,16 +78,16 @@ def __init__( def get_messages( self, input: list[LLMMessage], - ) -> tuple[str, Iterable[MessageParam]]: + ) -> tuple[Union[str, NotGiven], Iterable[MessageParam]]: messages: list[MessageParam] = [] - system_instruction = self.anthropic.NOT_GIVEN + system_instruction: Union[str, NotGiven] = self.anthropic.NOT_GIVEN for i in input: if i["role"] == "system": system_instruction = i["content"] else: messages.append( self.anthropic.types.MessageParam( - role=i["role"], # type: ignore + role=i["role"], content=i["content"], ) ) diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index ba49ab23d..591863aa6 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -39,7 +39,7 @@ except ImportError: Mistral = None # type: ignore SDKError = None # type: ignore - Messages = Any + Messages = None # type: ignore class MistralAILLM(LLMInterface): diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index a0189c64b..512db928d 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -71,10 +71,7 @@ def get_messages( self, input: list[LLMMessage], ) -> Sequence[Message]: - return [ - self.ollama.Message(**i) - for i in input - ] + return [self.ollama.Message(**i) for i in input] def _invoke( self, diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index afe2fa3c4..03b101779 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -25,7 +25,7 @@ Iterable, Sequence, Union, - cast, + cast, Type, ) from neo4j_graphrag.message_history import MessageHistory @@ -44,9 +44,11 @@ if TYPE_CHECKING: from openai.types.chat import ( ChatCompletionMessageParam, - ChatCompletionToolParam, ChatCompletionUserMessageParam, - ChatCompletionSystemMessageParam, ChatCompletionAssistantMessageParam, -) + ChatCompletionToolParam, + ChatCompletionUserMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionAssistantMessageParam, + ) from openai import OpenAI, AsyncOpenAI from neo4j_graphrag.utiles.rate_limit import RateLimitHandler else: @@ -93,7 +95,7 @@ def get_messages( ) -> Iterable[ChatCompletionMessageParam]: chat_messages = [] for m in messages: - message_type: ChatCompletionMessageParam + message_type: Type[ChatCompletionMessageParam] if m["role"] == "system": message_type = ChatCompletionSystemMessageParam elif m["role"] == "user": @@ -104,7 +106,7 @@ def get_messages( raise ValueError(f"Unknown message type: {m['role']}") chat_messages.append( message_type( - role=m["role"], + role=m["role"], # type: ignore content=m["content"], ) ) From 422fc117d3402f841538e3b350cf58831a02a389 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 28 Aug 2025 18:50:40 +0200 Subject: [PATCH 07/28] Update VertexAILLM --- examples/customize/llms/vertexai_llm.py | 17 ++++- src/neo4j_graphrag/llm/vertexai_llm.py | 90 +++++++++---------------- 2 files changed, 47 insertions(+), 60 deletions(-) diff --git a/examples/customize/llms/vertexai_llm.py b/examples/customize/llms/vertexai_llm.py index f43864935..34fc179ae 100644 --- a/examples/customize/llms/vertexai_llm.py +++ b/examples/customize/llms/vertexai_llm.py @@ -1,6 +1,20 @@ from neo4j_graphrag.llm import LLMResponse, VertexAILLM from vertexai.generative_models import GenerationConfig +from neo4j_graphrag.types import LLMMessage + +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + + generation_config = GenerationConfig(temperature=1.0) llm = VertexAILLM( model_name="gemini-2.0-flash-001", @@ -9,7 +23,6 @@ # vertexai.generative_models.GenerativeModel client ) res: LLMResponse = llm.invoke( - "say something", - system_instruction="You are living in 3000 where AI rules the world", + input=messages, ) print(res.content) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index b9f1e40e8..ff3f2c70e 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -13,21 +13,16 @@ # limitations under the License. from __future__ import annotations -from typing import Any, List, Optional, Union, cast, Sequence +from typing import Any, List, Optional, Union, Sequence -from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( - BaseMessage, LLMResponse, - MessageList, ToolCall, ToolCallResponse, ) @@ -98,92 +93,73 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - ) -> list[Content]: + input: list[LLMMessage], + ) -> tuple[str | None, list[Content]]: messages = [] - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - - for message in message_history: - if message.get("role") == "user": - messages.append( - Content( - role="user", - parts=[Part.from_text(message.get("content", ""))], - ) + system_instruction = self.system_instruction + for message in input: + if message.get("role") == "system": + system_instruction = message.get("content") + continue + if message.get("role") == "user": + messages.append( + Content( + role="user", + parts=[Part.from_text(message.get("content", ""))], ) - elif message.get("role") == "assistant": - messages.append( - Content( - role="model", - parts=[Part.from_text(message.get("content", ""))], - ) + ) + continue + if message.get("role") == "assistant": + messages.append( + Content( + role="model", + parts=[Part.from_text(message.get("content", ""))], ) + ) + continue + return system_instruction, messages - messages.append(Content(role="user", parts=[Part.from_text(input)])) - return messages - - @rate_limit_handler - def invoke( + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ + system_instruction, messages = self.get_messages(input) model = self._get_model( system_instruction=system_instruction, ) try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - options = self._get_call_params(input, message_history, tools=None) + options = self._get_call_params(messages, tools=None) response = model.generate_content(**options) return self._parse_content_response(response) except ResponseValidationError as e: raise LLMGenerationError("Error calling VertexAILLM") from e - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages + system_instruction, messages = self.get_messages(input) model = self._get_model( system_instruction=system_instruction, ) - options = self._get_call_params(input, message_history, tools=None) + options = self._get_call_params(messages, tools=None) response = await model.generate_content_async(**options) return self._parse_content_response(response) except ResponseValidationError as e: @@ -222,8 +198,7 @@ def _get_model( def _get_call_params( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]], + messages: list[Content], tools: Optional[Sequence[Tool]], ) -> dict[str, Any]: options = dict(self.options) @@ -241,7 +216,6 @@ def _get_call_params( # no tools, remove tool_config if defined options.pop("tool_config", None) - messages = self.get_messages(input, message_history) options["contents"] = messages return options From b524b1ae2297f3360abd4974a5b8c171193476b1 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 28 Aug 2025 19:03:21 +0200 Subject: [PATCH 08/28] Update (a)invoke_with_tools methods in the same way --- src/neo4j_graphrag/llm/base.py | 17 ++++++++++ src/neo4j_graphrag/llm/openai_llm.py | 46 ++++++++------------------ src/neo4j_graphrag/llm/vertexai_llm.py | 43 +++++++++--------------- 3 files changed, 46 insertions(+), 60 deletions(-) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 28e284e95..c0c36de51 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -114,6 +114,7 @@ async def _ainvoke( LLMGenerationError: If anything goes wrong. """ + @rate_limit_handler def invoke_with_tools( self, input: str, @@ -139,6 +140,14 @@ def invoke_with_tools( LLMGenerationError: If anything goes wrong. NotImplementedError: If the LLM provider does not support tool calling. """ + history = legacy_inputs_to_message_history( + input, message_history, system_instruction + ) + return self._invoke_with_tools(history.messages, tools) + + def _invoke_with_tools( + self, inputs: list[LLMMessage], tools: Sequence[Tool] + ) -> ToolCallResponse: raise NotImplementedError("This LLM provider does not support tool calling.") async def ainvoke_with_tools( @@ -166,4 +175,12 @@ async def ainvoke_with_tools( LLMGenerationError: If anything goes wrong. NotImplementedError: If the LLM provider does not support tool calling. """ + history = legacy_inputs_to_message_history( + input, message_history, system_instruction + ) + return await self._ainvoke_with_tools(history.messages, tools) + + async def _ainvoke_with_tools( + self, inputs: list[LLMMessage], tools: Sequence[Tool] + ) -> ToolCallResponse: raise NotImplementedError("This LLM provider does not support tool calling.") diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 03b101779..7d670d6a8 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -24,11 +24,10 @@ Optional, Iterable, Sequence, - Union, - cast, Type, + cast, + Type, ) -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage from ..exceptions import LLMGenerationError @@ -45,9 +44,6 @@ from openai.types.chat import ( ChatCompletionMessageParam, ChatCompletionToolParam, - ChatCompletionUserMessageParam, - ChatCompletionSystemMessageParam, - ChatCompletionAssistantMessageParam, ) from openai import OpenAI, AsyncOpenAI from neo4j_graphrag.utiles.rate_limit import RateLimitHandler @@ -97,11 +93,13 @@ def get_messages( for m in messages: message_type: Type[ChatCompletionMessageParam] if m["role"] == "system": - message_type = ChatCompletionSystemMessageParam + message_type = self.openai.types.chat.ChatCompletionSystemMessageParam elif m["role"] == "user": - message_type = ChatCompletionUserMessageParam + message_type = self.openai.types.chat.ChatCompletionUserMessageParam elif m["role"] == "assistant": - message_type = ChatCompletionAssistantMessageParam + message_type = ( + self.openai.types.chat.ChatCompletionAssistantMessageParam + ) else: raise ValueError(f"Unknown message type: {m['role']}") chat_messages.append( @@ -160,12 +158,10 @@ def _invoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - def invoke_with_tools( + def _invoke_with_tools( self, - input: str, - tools: Sequence[Tool], # Tools definition as a sequence of Tool objects - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], + tools: Sequence[Tool], ) -> ToolCallResponse: """Sends a text input to the OpenAI chat completion model with tool definitions and retrieves a tool call response. @@ -173,9 +169,6 @@ def invoke_with_tools( Args: input (str): Text sent to the LLM. tools (List[Tool]): List of Tools for the LLM to choose from. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: ToolCallResponse: The response from the LLM containing a tool call. @@ -184,9 +177,6 @@ def invoke_with_tools( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - params = self.model_params.copy() if self.model_params else {} if "temperature" not in params: params["temperature"] = 0.0 @@ -198,7 +188,7 @@ def invoke_with_tools( openai_tools.append(cast(ChatCompletionToolParam, openai_format_tool)) response = self.client.chat.completions.create( - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), model=self.model_name, tools=openai_tools, tool_choice="auto", @@ -261,12 +251,10 @@ async def _ainvoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - async def ainvoke_with_tools( + async def _ainvoke_with_tools( self, - input: str, + input: list[LLMMessage], tools: Sequence[Tool], # Tools definition as a sequence of Tool objects - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, ) -> ToolCallResponse: """Asynchronously sends a text input to the OpenAI chat completion model with tool definitions and retrieves a tool call response. @@ -274,9 +262,6 @@ async def ainvoke_with_tools( Args: input (str): Text sent to the LLM. tools (List[Tool]): List of Tools for the LLM to choose from. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: ToolCallResponse: The response from the LLM containing a tool call. @@ -285,9 +270,6 @@ async def ainvoke_with_tools( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - params = self.model_params.copy() if "temperature" not in params: params["temperature"] = 0.0 @@ -299,7 +281,7 @@ async def ainvoke_with_tools( openai_tools.append(cast(ChatCompletionToolParam, openai_format_tool)) response = await self.async_client.chat.completions.create( - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), model=self.model_name, tools=openai_tools, tool_choice="auto", diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index ff3f2c70e..3b2e9ca87 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, List, Optional, Union, Sequence +from typing import Any, Optional, Sequence from neo4j_graphrag.exceptions import LLMGenerationError @@ -26,7 +26,6 @@ ToolCall, ToolCallResponse, ) -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.tool import Tool from neo4j_graphrag.types import LLMMessage @@ -189,7 +188,6 @@ def _get_model( self, system_instruction: Optional[str] = None, ) -> GenerativeModel: - # system_message = [system_instruction] if system_instruction is not None else [] model = GenerativeModel( model_name=self.model_name, system_instruction=system_instruction, @@ -198,7 +196,7 @@ def _get_model( def _get_call_params( self, - messages: list[Content], + contents: list[Content], tools: Optional[Sequence[Tool]], ) -> dict[str, Any]: options = dict(self.options) @@ -215,31 +213,28 @@ def _get_call_params( else: # no tools, remove tool_config if defined options.pop("tool_config", None) - - options["contents"] = messages + options["contents"] = contents return options async def _acall_llm( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - model = self._get_model(system_instruction=system_instruction) - options = self._get_call_params(input, message_history, tools) + system_instruction, contents = self.get_messages(input) + model = self._get_model(system_instruction) + options = self._get_call_params(contents, tools) response = await model.generate_content_async(**options) return response # type: ignore[no-any-return] def _call_llm( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - model = self._get_model(system_instruction=system_instruction) - options = self._get_call_params(input, message_history, tools) + system_instruction, contents = self.get_messages(input) + model = self._get_model(system_instruction) + options = self._get_call_params(contents, tools) response = model.generate_content(**options) return response # type: ignore[no-any-return] @@ -261,32 +256,24 @@ def _parse_content_response(self, response: GenerationResponse) -> LLMResponse: content=response.text, ) - async def ainvoke_with_tools( + async def _ainvoke_with_tools( self, - input: str, + input: list[LLMMessage], tools: Sequence[Tool], - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, ) -> ToolCallResponse: response = await self._acall_llm( input, - message_history=message_history, - system_instruction=system_instruction, tools=tools, ) return self._parse_tool_response(response) - def invoke_with_tools( + def _invoke_with_tools( self, - input: str, + input: list[LLMMessage], tools: Sequence[Tool], - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, ) -> ToolCallResponse: response = self._call_llm( input, - message_history=message_history, - system_instruction=system_instruction, tools=tools, ) return self._parse_tool_response(response) From 0aa7ceef1e4d7e62d9835b1c34322053a28ff120 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 28 Aug 2025 19:21:53 +0200 Subject: [PATCH 09/28] Rename method and return directly list[LLMMessage] --- src/neo4j_graphrag/llm/base.py | 29 ++++++++------------- src/neo4j_graphrag/llm/utils.py | 46 +++++++++++++++++++-------------- 2 files changed, 38 insertions(+), 37 deletions(-) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index c0c36de51..f9993afc6 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -23,12 +23,13 @@ from .types import LLMResponse, ToolCallResponse from neo4j_graphrag.utils.rate_limit import ( DEFAULT_RATE_LIMIT_HANDLER, + rate_limit_handler, + async_rate_limit_handler, ) from neo4j_graphrag.tool import Tool -from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler -from .utils import legacy_inputs_to_message_history +from .utils import legacy_inputs_to_messages class LLMInterface(ABC): @@ -63,10 +64,8 @@ def invoke( message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: - message_history = legacy_inputs_to_message_history( - input, message_history, system_instruction - ) - return self._invoke(message_history.messages) + messages = legacy_inputs_to_messages(input, message_history, system_instruction) + return self._invoke(messages) @abstractmethod def _invoke( @@ -92,10 +91,8 @@ async def ainvoke( message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: - message_history = legacy_inputs_to_message_history( - input, message_history, system_instruction - ) - return await self._ainvoke(message_history.messages) + messages = legacy_inputs_to_messages(input, message_history, system_instruction) + return await self._ainvoke(messages) @abstractmethod async def _ainvoke( @@ -140,10 +137,8 @@ def invoke_with_tools( LLMGenerationError: If anything goes wrong. NotImplementedError: If the LLM provider does not support tool calling. """ - history = legacy_inputs_to_message_history( - input, message_history, system_instruction - ) - return self._invoke_with_tools(history.messages, tools) + messages = legacy_inputs_to_messages(input, message_history, system_instruction) + return self._invoke_with_tools(messages, tools) def _invoke_with_tools( self, inputs: list[LLMMessage], tools: Sequence[Tool] @@ -175,10 +170,8 @@ async def ainvoke_with_tools( LLMGenerationError: If anything goes wrong. NotImplementedError: If the LLM provider does not support tool calling. """ - history = legacy_inputs_to_message_history( - input, message_history, system_instruction - ) - return await self._ainvoke_with_tools(history.messages, tools) + messages = legacy_inputs_to_messages(input, message_history, system_instruction) + return await self._ainvoke_with_tools(messages, tools) async def _ainvoke_with_tools( self, inputs: list[LLMMessage], tools: Sequence[Tool] diff --git a/src/neo4j_graphrag/llm/utils.py b/src/neo4j_graphrag/llm/utils.py index 42126c939..ec98d2644 100644 --- a/src/neo4j_graphrag/llm/utils.py +++ b/src/neo4j_graphrag/llm/utils.py @@ -1,49 +1,57 @@ import warnings from typing import Union, Optional -from neo4j_graphrag.message_history import MessageHistory, InMemoryMessageHistory +from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage -def legacy_inputs_to_message_history( +def system_instruction_from_messages(messages: list[LLMMessage]) -> str | None: + for message in messages: + if message["role"] == "system": + return message["content"] + return None + + +def legacy_inputs_to_messages( input: Union[str, list[LLMMessage], MessageHistory], message_history: Optional[Union[list[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, -) -> MessageHistory: +) -> list[LLMMessage]: if message_history: warnings.warn( "Using message_history parameter is deprecated and will be removed in 2.0. Use a list of inputs or a MessageHistory instead.", DeprecationWarning, ) if isinstance(message_history, MessageHistory): - history = message_history + messages = message_history.messages else: # list[LLMMessage] - history = InMemoryMessageHistory(message_history) + messages = [] else: - history = InMemoryMessageHistory() + messages = [] if system_instruction is not None: warnings.warn( "Using system_instruction parameter is deprecated and will be removed in 2.0. Use a list of inputs or a MessageHistory instead.", DeprecationWarning, ) - if history.is_empty(): - history.add_message( + if system_instruction_from_messages(messages) is not None: + warnings.warn( + "system_instruction provided but ignored as the message history already contains a system message", + RuntimeWarning, + ) + else: + messages.append( LLMMessage( role="system", content=system_instruction, ), ) - else: - warnings.warn( - "system_instruction provided but ignored as the message history is not empty", - RuntimeWarning, - ) + if isinstance(input, str): - history.add_message(LLMMessage(role="user", content=input)) - return history + messages.append(LLMMessage(role="user", content=input)) + return messages if isinstance(input, list): - history.add_messages(input) - return history + messages.extend(input) + return messages # input is a MessageHistory instance - history.add_messages(input.messages) - return history + messages.extend(input.messages) + return messages From 7537643d208ef36b8d74b38328496049de4ff184 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 28 Aug 2025 19:22:15 +0200 Subject: [PATCH 10/28] Update GraphRAG to restore full LC compatibility --- examples/README.md | 2 +- ...atiblity.py => langchain_compatibility.py} | 0 src/neo4j_graphrag/generation/graphrag.py | 19 ++++++++++++++----- 3 files changed, 15 insertions(+), 6 deletions(-) rename examples/customize/answer/{langchain_compatiblity.py => langchain_compatibility.py} (100%) diff --git a/examples/README.md b/examples/README.md index b8033358d..6cd0e758b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -142,7 +142,7 @@ are listed in [the last section of this file](#customize). ### Answer: GraphRAG -- [LangChain compatibility](./customize/answer/langchain_compatiblity.py) +- [LangChain compatibility](customize/answer/langchain_compatibility.py) - [Use a custom prompt](./customize/answer/custom_prompt.py) diff --git a/examples/customize/answer/langchain_compatiblity.py b/examples/customize/answer/langchain_compatibility.py similarity index 100% rename from examples/customize/answer/langchain_compatiblity.py rename to examples/customize/answer/langchain_compatibility.py diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 08f08a368..e79622dc3 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -27,6 +27,7 @@ from neo4j_graphrag.generation.prompts import RagTemplate from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.llm.utils import legacy_inputs_to_messages from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.retrievers.base import Retriever from neo4j_graphrag.types import LLMMessage, RetrieverResult @@ -145,12 +146,17 @@ def search( prompt = self.prompt_template.format( query_text=query_text, context=context, examples=validated_data.examples ) + + messages = legacy_inputs_to_messages( + prompt, + message_history=message_history, + system_instruction=self.prompt_template.system_instructions, + ) + logger.debug(f"RAG: retriever_result={prettify(retriever_result)}") logger.debug(f"RAG: prompt={prompt}") llm_response = self.llm.invoke( - prompt, - message_history, - system_instruction=self.prompt_template.system_instructions, + messages, ) answer = llm_response.content result: dict[str, Any] = {"answer": answer} @@ -168,9 +174,12 @@ def _build_query( summarization_prompt = self._chat_summary_prompt( message_history=message_history ) - summary = self.llm.invoke( - input=summarization_prompt, + messages = legacy_inputs_to_messages( + summarization_prompt, system_instruction=summary_system_message, + ) + summary = self.llm.invoke( + messages, ).content return self.conversation_prompt(summary=summary, current_query=query_text) return query_text From eeefa8a63fac1f0d0026eeab3cb978c0a27ba833 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 5 Sep 2025 15:40:52 +0200 Subject: [PATCH 11/28] Test for the utils functions --- src/neo4j_graphrag/llm/utils.py | 15 ++--- tests/unit/llm/test_utils.py | 107 ++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 11 deletions(-) create mode 100644 tests/unit/llm/test_utils.py diff --git a/src/neo4j_graphrag/llm/utils.py b/src/neo4j_graphrag/llm/utils.py index ec98d2644..bbeb13f68 100644 --- a/src/neo4j_graphrag/llm/utils.py +++ b/src/neo4j_graphrag/llm/utils.py @@ -18,28 +18,21 @@ def legacy_inputs_to_messages( system_instruction: Optional[str] = None, ) -> list[LLMMessage]: if message_history: - warnings.warn( - "Using message_history parameter is deprecated and will be removed in 2.0. Use a list of inputs or a MessageHistory instead.", - DeprecationWarning, - ) if isinstance(message_history, MessageHistory): messages = message_history.messages else: # list[LLMMessage] - messages = [] + messages = message_history else: messages = [] if system_instruction is not None: - warnings.warn( - "Using system_instruction parameter is deprecated and will be removed in 2.0. Use a list of inputs or a MessageHistory instead.", - DeprecationWarning, - ) if system_instruction_from_messages(messages) is not None: warnings.warn( "system_instruction provided but ignored as the message history already contains a system message", - RuntimeWarning, + UserWarning, ) else: - messages.append( + messages.insert( + 0, LLMMessage( role="system", content=system_instruction, diff --git a/tests/unit/llm/test_utils.py b/tests/unit/llm/test_utils.py new file mode 100644 index 000000000..d67404cea --- /dev/null +++ b/tests/unit/llm/test_utils.py @@ -0,0 +1,107 @@ +import pytest + +from neo4j_graphrag.llm.utils import system_instruction_from_messages, \ + legacy_inputs_to_messages +from neo4j_graphrag.message_history import InMemoryMessageHistory +from neo4j_graphrag.types import LLMMessage + + +def test_system_instruction_from_messages(): + messages = [ + LLMMessage(role="system", content="text"), + ] + assert system_instruction_from_messages(messages) == "text" + + messages = [] + assert system_instruction_from_messages(messages) is None + + messages = [ + LLMMessage(role="assistant", content="text"), + ] + assert system_instruction_from_messages(messages) is None + + +def test_legacy_inputs_to_messages_only_input_as_llm_message_list(): + messages = legacy_inputs_to_messages(input=[ + LLMMessage(role="user", content="text"), + ]) + assert messages == [ + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_only_input_as_message_history(): + messages = legacy_inputs_to_messages(input=InMemoryMessageHistory( + messages=[ + LLMMessage(role="user", content="text"), + ] + )) + assert messages == [ + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_only_input_as_str(): + messages = legacy_inputs_to_messages(input="text") + assert messages == [ + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_llm_message_list(): + messages = legacy_inputs_to_messages( + input="text", + message_history=[ + LLMMessage(role="assistant", content="How can I assist you today?"), + ] + ) + assert messages == [ + LLMMessage(role="assistant", content="How can I assist you today?"), + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_message_history(): + messages = legacy_inputs_to_messages( + input="text", + message_history=InMemoryMessageHistory(messages=[ + LLMMessage(role="assistant", content="How can I assist you today?"), + ]) + ) + assert messages == [ + LLMMessage(role="assistant", content="How can I assist you today?"), + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_with_explicit_system_instruction(): + messages = legacy_inputs_to_messages( + input="text", + message_history=[ + LLMMessage(role="assistant", content="How can I assist you today?"), + ], + system_instruction="You are a genius." + ) + assert messages == [ + LLMMessage(role="system", content="You are a genius."), + LLMMessage(role="assistant", content="How can I assist you today?"), + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_do_not_duplicate_system_instruction(): + with pytest.warns( + UserWarning, + match="system_instruction provided but ignored as the message history already contains a system message" + ): + messages = legacy_inputs_to_messages( + input="text", + message_history=[ + LLMMessage(role="system", content="You are super smart."), + ], + system_instruction="You are a genius." + ) + assert messages == [ + LLMMessage(role="system", content="You are super smart."), + LLMMessage(role="user", content="text"), + ] From 13bb7b75377c2cfb0bd142d9aac74a3dabc92a3c Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 11 Sep 2025 09:26:34 +0200 Subject: [PATCH 12/28] WIP: update tests --- src/neo4j_graphrag/llm/utils.py | 2 +- tests/unit/llm/test_base.py | 43 +++++++++++++++++++++++++++++++ tests/unit/llm/test_openai_llm.py | 24 +++++++---------- 3 files changed, 54 insertions(+), 15 deletions(-) create mode 100644 tests/unit/llm/test_base.py diff --git a/src/neo4j_graphrag/llm/utils.py b/src/neo4j_graphrag/llm/utils.py index bbeb13f68..5f5310505 100644 --- a/src/neo4j_graphrag/llm/utils.py +++ b/src/neo4j_graphrag/llm/utils.py @@ -21,7 +21,7 @@ def legacy_inputs_to_messages( if isinstance(message_history, MessageHistory): messages = message_history.messages else: # list[LLMMessage] - messages = message_history + messages = [LLMMessage(**m) for m in message_history] else: messages = [] if system_instruction is not None: diff --git a/tests/unit/llm/test_base.py b/tests/unit/llm/test_base.py new file mode 100644 index 000000000..6c540f1a3 --- /dev/null +++ b/tests/unit/llm/test_base.py @@ -0,0 +1,43 @@ +from typing import Type, Generator, Optional, Any +from unittest.mock import patch, Mock + +from joblib.testing import fixture + +from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.types import LLMMessage + + +@fixture(scope="module") # type: ignore[misc] +def llm_interface() -> Generator[Type[LLMInterface], None, None]: + real_abstract_methods = LLMInterface.__abstractmethods__ + LLMInterface.__abstractmethods__ = frozenset() + + class CustomLLMInterface(LLMInterface): + pass + + yield CustomLLMInterface + + LLMInterface.__abstractmethods__ = real_abstract_methods + + +@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") +def test_base_llm_interface_invoke_with_input_as_str(mock_inputs: Mock, llm_interface: Type[LLMInterface]) -> None: + mock_inputs.return_value = [LLMMessage(role="user", content="return value of the legacy_inputs_to_messages function")] + llm = llm_interface(model_name="test") + message_history = [ + LLMMessage(**{"role": "user", "content": "When does the sun come up in the summer?"}), + LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), + ] + question = "What about next season?" + system_instruction = "You are a genius." + + with patch.object(llm, "_invoke") as mock_invoke: + llm.invoke(question, message_history, system_instruction) + mock_invoke.assert_called_once_with( + [LLMMessage(role="user", content="return value of the legacy_inputs_to_messages function")] + ) + mock_inputs.assert_called_once_with( + question, + message_history, + system_instruction, + ) diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 3c5ee1b9e..60b8c1e01 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -21,6 +21,7 @@ from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM from neo4j_graphrag.llm.types import ToolCallResponse from neo4j_graphrag.tool import Tool +from neo4j_graphrag.types import LLMMessage def get_mock_openai() -> MagicMock: @@ -50,7 +51,9 @@ def test_openai_llm_happy_path(mock_import: Mock) -> None: @patch("builtins.__import__") -def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: +@patch("neo4j_graphrag.llm.openai_llm.legacy_inputs_to_messages") +def test_openai_llm_with_message_history_happy_path(mock_inputs: Mock, mock_import: Mock) -> None: + mock_inputs.return_value = [LLMMessage(role="user", content="text")] mock_openai = get_mock_openai() mock_import.return_value = mock_openai mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( @@ -63,18 +66,10 @@ def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: ] question = "What about next season?" - res = llm.invoke(question, message_history) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "openai chat response" - message_history.append({"role": "user", "content": question}) - # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions - llm.client.chat.completions.create.assert_called_once() # type: ignore - # Check call arguments individually - call_args = llm.client.chat.completions.create.call_args[ # type: ignore - 1 - ] # Get the keyword arguments - assert call_args["messages"] == message_history - assert call_args["model"] == "gpt" + with patch.object(llm, "_invoke") as mock_invoke: + llm.invoke(question, message_history) # type: ignore + mock_invoke.assert_called_once_with([LLMMessage(role="user", content="text")]) + mock_inputs.assert_called_once_with(input=question, message_history=message_history) @patch("builtins.__import__") @@ -404,5 +399,6 @@ def test_azure_openai_llm_with_message_history_validation_error( question = "What about next season?" with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) # type: ignore + r = llm.invoke(question, message_history) # type: ignore + print(r) assert "Input should be a valid string" in str(exc_info.value) From 5ea2f3796b23b2382ca0f12b1e2c402e80bff0be Mon Sep 17 00:00:00 2001 From: estelle Date: Sun, 14 Sep 2025 11:49:42 +0200 Subject: [PATCH 13/28] Improve test coverage for utils and base modules --- src/neo4j_graphrag/llm/base.py | 17 +++++- src/neo4j_graphrag/llm/utils.py | 21 ++++++- tests/unit/llm/test_base.py | 103 ++++++++++++++++++++++++++++++-- tests/unit/llm/test_utils.py | 87 +++++++++++++++++++-------- 4 files changed, 195 insertions(+), 33 deletions(-) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index f9993afc6..5534ff4ca 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -17,6 +17,8 @@ from abc import ABC, abstractmethod from typing import Any, List, Optional, Sequence, Union +from pydantic import ValidationError + from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage @@ -30,6 +32,7 @@ from neo4j_graphrag.tool import Tool from .utils import legacy_inputs_to_messages +from ..exceptions import LLMGenerationError class LLMInterface(ABC): @@ -64,7 +67,12 @@ def invoke( message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: - messages = legacy_inputs_to_messages(input, message_history, system_instruction) + try: + messages = legacy_inputs_to_messages( + input, message_history, system_instruction + ) + except ValidationError as e: + raise LLMGenerationError("Input validation failed") from e return self._invoke(messages) @abstractmethod @@ -137,7 +145,12 @@ def invoke_with_tools( LLMGenerationError: If anything goes wrong. NotImplementedError: If the LLM provider does not support tool calling. """ - messages = legacy_inputs_to_messages(input, message_history, system_instruction) + try: + messages = legacy_inputs_to_messages( + input, message_history, system_instruction + ) + except ValidationError as e: + raise LLMGenerationError("Input validation failed") from e return self._invoke_with_tools(messages, tools) def _invoke_with_tools( diff --git a/src/neo4j_graphrag/llm/utils.py b/src/neo4j_graphrag/llm/utils.py index 5f5310505..b61a880f4 100644 --- a/src/neo4j_graphrag/llm/utils.py +++ b/src/neo4j_graphrag/llm/utils.py @@ -1,6 +1,22 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import warnings from typing import Union, Optional +from pydantic import TypeAdapter + from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage @@ -12,6 +28,9 @@ def system_instruction_from_messages(messages: list[LLMMessage]) -> str | None: return None +llm_messages_adapter = TypeAdapter(list[LLMMessage]) + + def legacy_inputs_to_messages( input: Union[str, list[LLMMessage], MessageHistory], message_history: Optional[Union[list[LLMMessage], MessageHistory]] = None, @@ -21,7 +40,7 @@ def legacy_inputs_to_messages( if isinstance(message_history, MessageHistory): messages = message_history.messages else: # list[LLMMessage] - messages = [LLMMessage(**m) for m in message_history] + messages = llm_messages_adapter.validate_python(message_history) else: messages = [] if system_instruction is not None: diff --git a/tests/unit/llm/test_base.py b/tests/unit/llm/test_base.py index 6c540f1a3..9a927b193 100644 --- a/tests/unit/llm/test_base.py +++ b/tests/unit/llm/test_base.py @@ -1,8 +1,11 @@ -from typing import Type, Generator, Optional, Any +from typing import Type, Generator from unittest.mock import patch, Mock +import pytest from joblib.testing import fixture +from pydantic import ValidationError +from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm import LLMInterface from neo4j_graphrag.types import LLMMessage @@ -21,11 +24,20 @@ class CustomLLMInterface(LLMInterface): @patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") -def test_base_llm_interface_invoke_with_input_as_str(mock_inputs: Mock, llm_interface: Type[LLMInterface]) -> None: - mock_inputs.return_value = [LLMMessage(role="user", content="return value of the legacy_inputs_to_messages function")] +def test_base_llm_interface_invoke_with_input_as_str( + mock_inputs: Mock, llm_interface: Type[LLMInterface] +) -> None: + mock_inputs.return_value = [ + LLMMessage( + role="user", + content="return value of the legacy_inputs_to_messages function", + ) + ] llm = llm_interface(model_name="test") message_history = [ - LLMMessage(**{"role": "user", "content": "When does the sun come up in the summer?"}), + LLMMessage( + **{"role": "user", "content": "When does the sun come up in the summer?"} + ), LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), ] question = "What about next season?" @@ -34,10 +46,91 @@ def test_base_llm_interface_invoke_with_input_as_str(mock_inputs: Mock, llm_inte with patch.object(llm, "_invoke") as mock_invoke: llm.invoke(question, message_history, system_instruction) mock_invoke.assert_called_once_with( - [LLMMessage(role="user", content="return value of the legacy_inputs_to_messages function")] + [ + LLMMessage( + role="user", + content="return value of the legacy_inputs_to_messages function", + ) + ] ) mock_inputs.assert_called_once_with( question, message_history, system_instruction, ) + + +@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") +def test_base_llm_interface_invoke_with_invalid_inputs( + mock_inputs: Mock, llm_interface: Type[LLMInterface] +) -> None: + mock_inputs.side_effect = [ + ValidationError.from_exception_data("Invalid data", line_errors=[]) + ] + llm = llm_interface(model_name="test") + question = "What about next season?" + + with pytest.raises(LLMGenerationError, match="Input validation failed"): + llm.invoke(question) + mock_inputs.assert_called_once_with( + question, + None, + None, + ) + + +@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") +def test_base_llm_interface_invoke_with_tools_with_input_as_str( + mock_inputs: Mock, llm_interface: Type[LLMInterface] +) -> None: + mock_inputs.return_value = [ + LLMMessage( + role="user", + content="return value of the legacy_inputs_to_messages function", + ) + ] + llm = llm_interface(model_name="test") + message_history = [ + LLMMessage( + **{"role": "user", "content": "When does the sun come up in the summer?"} + ), + LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), + ] + question = "What about next season?" + system_instruction = "You are a genius." + + with patch.object(llm, "_invoke_with_tools") as mock_invoke: + llm.invoke_with_tools(question, [], message_history, system_instruction) + mock_invoke.assert_called_once_with( + [ + LLMMessage( + role="user", + content="return value of the legacy_inputs_to_messages function", + ) + ], + [], # tools + ) + mock_inputs.assert_called_once_with( + question, + message_history, + system_instruction, + ) + + +@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") +def test_base_llm_interface_invoke_with_tools_with_invalid_inputs( + mock_inputs: Mock, llm_interface: Type[LLMInterface] +) -> None: + mock_inputs.side_effect = [ + ValidationError.from_exception_data("Invalid data", line_errors=[]) + ] + llm = llm_interface(model_name="test") + question = "What about next season?" + + with pytest.raises(LLMGenerationError, match="Input validation failed"): + llm.invoke_with_tools(question, []) + mock_inputs.assert_called_once_with( + question, + None, + None, + ) diff --git a/tests/unit/llm/test_utils.py b/tests/unit/llm/test_utils.py index d67404cea..6a969864d 100644 --- a/tests/unit/llm/test_utils.py +++ b/tests/unit/llm/test_utils.py @@ -1,12 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest +from pydantic import ValidationError -from neo4j_graphrag.llm.utils import system_instruction_from_messages, \ - legacy_inputs_to_messages +from neo4j_graphrag.llm.utils import ( + system_instruction_from_messages, + legacy_inputs_to_messages, +) from neo4j_graphrag.message_history import InMemoryMessageHistory from neo4j_graphrag.types import LLMMessage -def test_system_instruction_from_messages(): +def test_system_instruction_from_messages() -> None: messages = [ LLMMessage(role="system", content="text"), ] @@ -21,39 +38,45 @@ def test_system_instruction_from_messages(): assert system_instruction_from_messages(messages) is None -def test_legacy_inputs_to_messages_only_input_as_llm_message_list(): - messages = legacy_inputs_to_messages(input=[ - LLMMessage(role="user", content="text"), - ]) +def test_legacy_inputs_to_messages_only_input_as_llm_message_list() -> None: + messages = legacy_inputs_to_messages( + input=[ + LLMMessage(role="user", content="text"), + ] + ) assert messages == [ LLMMessage(role="user", content="text"), ] -def test_legacy_inputs_to_messages_only_input_as_message_history(): - messages = legacy_inputs_to_messages(input=InMemoryMessageHistory( - messages=[ - LLMMessage(role="user", content="text"), - ] - )) +def test_legacy_inputs_to_messages_only_input_as_message_history() -> None: + messages = legacy_inputs_to_messages( + input=InMemoryMessageHistory( + messages=[ + LLMMessage(role="user", content="text"), + ] + ) + ) assert messages == [ LLMMessage(role="user", content="text"), ] -def test_legacy_inputs_to_messages_only_input_as_str(): +def test_legacy_inputs_to_messages_only_input_as_str() -> None: messages = legacy_inputs_to_messages(input="text") assert messages == [ LLMMessage(role="user", content="text"), ] -def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_llm_message_list(): +def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_llm_message_list() -> ( + None +): messages = legacy_inputs_to_messages( input="text", message_history=[ LLMMessage(role="assistant", content="How can I assist you today?"), - ] + ], ) assert messages == [ LLMMessage(role="assistant", content="How can I assist you today?"), @@ -61,12 +84,16 @@ def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_llm_messa ] -def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_message_history(): +def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_message_history() -> ( + None +): messages = legacy_inputs_to_messages( input="text", - message_history=InMemoryMessageHistory(messages=[ - LLMMessage(role="assistant", content="How can I assist you today?"), - ]) + message_history=InMemoryMessageHistory( + messages=[ + LLMMessage(role="assistant", content="How can I assist you today?"), + ] + ), ) assert messages == [ LLMMessage(role="assistant", content="How can I assist you today?"), @@ -74,13 +101,13 @@ def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_message_h ] -def test_legacy_inputs_to_messages_with_explicit_system_instruction(): +def test_legacy_inputs_to_messages_with_explicit_system_instruction() -> None: messages = legacy_inputs_to_messages( input="text", message_history=[ LLMMessage(role="assistant", content="How can I assist you today?"), ], - system_instruction="You are a genius." + system_instruction="You are a genius.", ) assert messages == [ LLMMessage(role="system", content="You are a genius."), @@ -89,19 +116,29 @@ def test_legacy_inputs_to_messages_with_explicit_system_instruction(): ] -def test_legacy_inputs_to_messages_do_not_duplicate_system_instruction(): +def test_legacy_inputs_to_messages_do_not_duplicate_system_instruction() -> None: with pytest.warns( UserWarning, - match="system_instruction provided but ignored as the message history already contains a system message" + match="system_instruction provided but ignored as the message history already contains a system message", ): messages = legacy_inputs_to_messages( input="text", message_history=[ LLMMessage(role="system", content="You are super smart."), ], - system_instruction="You are a genius." + system_instruction="You are a genius.", ) assert messages == [ LLMMessage(role="system", content="You are super smart."), LLMMessage(role="user", content="text"), ] + + +def test_legacy_inputs_to_messages_wrong_type_in_message_list() -> None: + with pytest.raises(ValidationError, match="Input should be a valid string"): + legacy_inputs_to_messages( + input="text", + message_history=[ + {"role": "system", "content": 10}, # type: ignore + ], + ) From 9196764974ec3a354427eed5458290f74c3053fb Mon Sep 17 00:00:00 2001 From: estelle Date: Sun, 14 Sep 2025 13:36:25 +0200 Subject: [PATCH 14/28] Fix UT OpenAILLM --- src/neo4j_graphrag/llm/openai_llm.py | 2 +- tests/unit/llm/test_openai_llm.py | 275 ++------------------------- 2 files changed, 18 insertions(+), 259 deletions(-) diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 7d670d6a8..588352009 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -101,7 +101,7 @@ def get_messages( self.openai.types.chat.ChatCompletionAssistantMessageParam ) else: - raise ValueError(f"Unknown message type: {m['role']}") + raise ValueError(f"Unknown role: {m['role']}") chat_messages.append( message_type( role=m["role"], # type: ignore diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 60b8c1e01..2b2cb29f2 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -37,7 +37,7 @@ def test_openai_llm_missing_dependency(mock_import: Mock) -> None: @patch("builtins.__import__") -def test_openai_llm_happy_path(mock_import: Mock) -> None: +def test_openai_llm_happy_path_e2e(mock_import: Mock) -> None: mock_openai = get_mock_openai() mock_import.return_value = mock_openai mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( @@ -50,83 +50,31 @@ def test_openai_llm_happy_path(mock_import: Mock) -> None: assert res.content == "openai chat response" -@patch("builtins.__import__") -@patch("neo4j_graphrag.llm.openai_llm.legacy_inputs_to_messages") -def test_openai_llm_with_message_history_happy_path(mock_inputs: Mock, mock_import: Mock) -> None: - mock_inputs.return_value = [LLMMessage(role="user", content="text")] - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content="openai chat response"))], - ) +def test_openai_llm_get_messages() -> None: llm = OpenAILLM(api_key="my key", model_name="gpt") message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, + LLMMessage(**{"role": "system", "content": "do something"}), + LLMMessage( + **{"role": "user", "content": "When does the sun come up in the summer?"} + ), + LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), ] - question = "What about next season?" - - with patch.object(llm, "_invoke") as mock_invoke: - llm.invoke(question, message_history) # type: ignore - mock_invoke.assert_called_once_with([LLMMessage(role="user", content="text")]) - mock_inputs.assert_called_once_with(input=question, message_history=message_history) + messages = llm.get_messages(message_history) + assert isinstance(messages, list) + for actual, expected in zip(messages, message_history): + assert isinstance(actual, dict) + assert actual["role"] == expected["role"] + assert actual["content"] == expected["content"] -@patch("builtins.__import__") -def test_openai_llm_with_message_history_and_system_instruction( - mock_import: Mock, -) -> None: - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content="openai chat response"))], - ) - system_instruction = "You are a helpful assistent." - llm = OpenAILLM( - api_key="my key", - model_name="gpt", - ) - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "openai chat response" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions - llm.client.chat.completions.create.assert_called_once() # type: ignore - # Check call arguments individually - call_args = llm.client.chat.completions.create.call_args[ # type: ignore - 1 - ] # Get the keyword arguments - assert call_args["messages"] == messages - assert call_args["model"] == "gpt" - - assert llm.client.chat.completions.create.call_count == 1 # type: ignore - - -@patch("builtins.__import__") -def test_openai_llm_with_message_history_validation_error(mock_import: Mock) -> None: - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content="openai chat response"))], - ) +def test_openai_llm_get_messages_unknown_role() -> None: llm = OpenAILLM(api_key="my key", model_name="gpt") message_history = [ - {"role": "human", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, + LLMMessage(**{"role": "unknown role", "content": "Usually around 6am."}), ] - question = "What about next season?" - - with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) # type: ignore - assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) + with pytest.raises(ValueError, match="Unknown role"): + llm.get_messages(message_history) @patch("builtins.__import__") @@ -171,130 +119,6 @@ def test_openai_llm_invoke_with_tools_happy_path( assert res.content == "openai tool response" -@patch("builtins.__import__") -@patch("json.loads") -def test_openai_llm_invoke_with_tools_with_message_history( - mock_json_loads: Mock, - mock_import: Mock, - test_tool: Tool, -) -> None: - # Set up json.loads to return a dictionary - mock_json_loads.return_value = {"param1": "value1"} - - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - - # Mock the tool call response - mock_function = MagicMock() - mock_function.name = "test_tool" - mock_function.arguments = '{"param1": "value1"}' - - mock_tool_call = MagicMock() - mock_tool_call.function = mock_function - - mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( - choices=[ - MagicMock( - message=MagicMock( - content="openai tool response", tool_calls=[mock_tool_call] - ) - ) - ], - ) - - llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [test_tool] - - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - res = llm.invoke_with_tools(question, tools, message_history) # type: ignore - assert isinstance(res, ToolCallResponse) - assert len(res.tool_calls) == 1 - assert res.tool_calls[0].name == "test_tool" - assert res.tool_calls[0].arguments == {"param1": "value1"} - - # Verify the correct messages were passed - message_history.append({"role": "user", "content": question}) - # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions - llm.client.chat.completions.create.assert_called_once() # type: ignore - # Check call arguments individually - call_args = llm.client.chat.completions.create.call_args[ # type: ignore - 1 - ] # Get the keyword arguments - assert call_args["messages"] == message_history - assert call_args["model"] == "gpt" - # Check tools content rather than direct equality - assert len(call_args["tools"]) == 1 - assert call_args["tools"][0]["type"] == "function" - assert call_args["tools"][0]["function"]["name"] == "test_tool" - assert call_args["tools"][0]["function"]["description"] == "A test tool" - assert call_args["tool_choice"] == "auto" - assert call_args["temperature"] == 0.0 - - -@patch("builtins.__import__") -@patch("json.loads") -def test_openai_llm_invoke_with_tools_with_system_instruction( - mock_json_loads: Mock, - mock_import: Mock, - test_tool: Mock, -) -> None: - # Set up json.loads to return a dictionary - mock_json_loads.return_value = {"param1": "value1"} - - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - - # Mock the tool call response - mock_function = MagicMock() - mock_function.name = "test_tool" - mock_function.arguments = '{"param1": "value1"}' - - mock_tool_call = MagicMock() - mock_tool_call.function = mock_function - - mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( - choices=[ - MagicMock( - message=MagicMock( - content="openai tool response", tool_calls=[mock_tool_call] - ) - ) - ], - ) - - llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [test_tool] - - system_instruction = "You are a helpful assistant." - - res = llm.invoke_with_tools("my text", tools, system_instruction=system_instruction) - assert isinstance(res, ToolCallResponse) - - # Verify system instruction was included - messages = [{"role": "system", "content": system_instruction}] - messages.append({"role": "user", "content": "my text"}) - # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions - llm.client.chat.completions.create.assert_called_once() # type: ignore - # Check call arguments individually - call_args = llm.client.chat.completions.create.call_args[ # type: ignore - 1 - ] # Get the keyword arguments - assert call_args["messages"] == messages - assert call_args["model"] == "gpt" - # Check tools content rather than direct equality - assert len(call_args["tools"]) == 1 - assert call_args["tools"][0]["type"] == "function" - assert call_args["tools"][0]["function"]["name"] == "test_tool" - assert call_args["tools"][0]["function"]["description"] == "A test tool" - assert call_args["tool_choice"] == "auto" - assert call_args["temperature"] == 0.0 - - @patch("builtins.__import__") def test_openai_llm_invoke_with_tools_error(mock_import: Mock, test_tool: Tool) -> None: mock_openai = get_mock_openai() @@ -337,68 +161,3 @@ def test_azure_openai_llm_happy_path(mock_import: Mock) -> None: res = llm.invoke("my text") assert isinstance(res, LLMResponse) assert res.content == "openai chat response" - - -@patch("builtins.__import__") -def test_azure_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = ( - MagicMock( - choices=[MagicMock(message=MagicMock(content="openai chat response"))], - ) - ) - llm = AzureOpenAILLM( - model_name="gpt", - azure_endpoint="https://test.openai.azure.com/", - api_key="my key", - api_version="version", - ) - - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - res = llm.invoke(question, message_history) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "openai chat response" - message_history.append({"role": "user", "content": question}) - # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions - llm.client.chat.completions.create.assert_called_once() # type: ignore - # Check call arguments individually - call_args = llm.client.chat.completions.create.call_args[ # type: ignore - 1 - ] # Get the keyword arguments - assert call_args["messages"] == message_history - assert call_args["model"] == "gpt" - - -@patch("builtins.__import__") -def test_azure_openai_llm_with_message_history_validation_error( - mock_import: Mock, -) -> None: - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = ( - MagicMock( - choices=[MagicMock(message=MagicMock(content="openai chat response"))], - ) - ) - llm = AzureOpenAILLM( - model_name="gpt", - azure_endpoint="https://test.openai.azure.com/", - api_key="my key", - api_version="version", - ) - - message_history = [ - {"role": "user", "content": 33}, - ] - question = "What about next season?" - - with pytest.raises(LLMGenerationError) as exc_info: - r = llm.invoke(question, message_history) # type: ignore - print(r) - assert "Input should be a valid string" in str(exc_info.value) From 8a94f1aeeade643a488195975e9c38b0ffd41d18 Mon Sep 17 00:00:00 2001 From: estelle Date: Sun, 14 Sep 2025 18:21:45 +0200 Subject: [PATCH 15/28] Update Ollama tests --- tests/unit/llm/test_ollama_llm.py | 112 +++--------------------------- 1 file changed, 9 insertions(+), 103 deletions(-) diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py index c1d3f9fdc..3621680a1 100644 --- a/tests/unit/llm/test_ollama_llm.py +++ b/tests/unit/llm/test_ollama_llm.py @@ -41,6 +41,7 @@ def test_ollama_llm_happy_path_deprecated_options(mock_import: Mock) -> None: mock_ollama.Client.return_value.chat.return_value = MagicMock( message=MagicMock(content="ollama chat response"), ) + mock_ollama.Message.return_value = {"role": "user", "content": "test"} model = "gpt" model_params = {"temperature": 0.3} with pytest.warns(DeprecationWarning) as record: @@ -59,11 +60,12 @@ def test_ollama_llm_happy_path_deprecated_options(mock_import: Mock) -> None: res = llm.invoke(question) assert isinstance(res, LLMResponse) assert res.content == "ollama chat response" - messages = [ - {"role": "user", "content": question}, - ] llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] - model=model, messages=messages, options={"temperature": 0.3} + model=model, + messages=[ + {"role": "user", "content": "test"} + ], + options={"temperature": 0.3} ) @@ -90,6 +92,7 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: mock_ollama.Client.return_value.chat.return_value = MagicMock( message=MagicMock(content="ollama chat response"), ) + mock_ollama.Message.return_value = {"role": "user", "content": "test"} model = "gpt" options = {"temperature": 0.3} model_params = {"options": options, "format": "json"} @@ -102,7 +105,7 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: assert isinstance(res, LLMResponse) assert res.content == "ollama chat response" messages = [ - {"role": "user", "content": question}, + {"role": "user", "content": "test"}, ] llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] model=model, @@ -112,102 +115,6 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: ) -@patch("builtins.__import__") -def test_ollama_invoke_with_system_instruction_happy_path(mock_import: Mock) -> None: - mock_ollama = get_mock_ollama() - mock_import.return_value = mock_ollama - mock_ollama.Client.return_value.chat.return_value = MagicMock( - message=MagicMock(content="ollama chat response"), - ) - model = "gpt" - options = {"temperature": 0.3} - model_params = {"options": options, "format": "json"} - llm = OllamaLLM( - model, - model_params=model_params, - ) - system_instruction = "You are a helpful assistant." - question = "What about next season?" - - response = llm.invoke(question, system_instruction=system_instruction) - assert response.content == "ollama chat response" - messages = [{"role": "system", "content": system_instruction}] - messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] - model=model, - messages=messages, - options=options, - format="json", - ) - - -@patch("builtins.__import__") -def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> None: - mock_ollama = get_mock_ollama() - mock_import.return_value = mock_ollama - mock_ollama.Client.return_value.chat.return_value = MagicMock( - message=MagicMock(content="ollama chat response"), - ) - model = "gpt" - options = {"temperature": 0.3} - model_params = {"options": options} - llm = OllamaLLM( - model, - model_params=model_params, - ) - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - response = llm.invoke(question, message_history) # type: ignore - assert response.content == "ollama chat response" - messages = [m for m in message_history] - messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] - model=model, messages=messages, options=options - ) - - -@patch("builtins.__import__") -def test_ollama_invoke_with_message_history_and_system_instruction( - mock_import: Mock, -) -> None: - mock_ollama = get_mock_ollama() - mock_import.return_value = mock_ollama - mock_ollama.Client.return_value.chat.return_value = MagicMock( - message=MagicMock(content="ollama chat response"), - ) - model = "gpt" - options = {"temperature": 0.3} - model_params = {"options": options} - system_instruction = "You are a helpful assistant." - llm = OllamaLLM( - model, - model_params=model_params, - ) - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - response = llm.invoke( - question, - message_history, # type: ignore - system_instruction=system_instruction, - ) - assert response.content == "ollama chat response" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] - model=model, messages=messages, options=options - ) - assert llm.client.chat.call_count == 1 # type: ignore - - @patch("builtins.__import__") def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock) -> None: mock_ollama = get_mock_ollama() @@ -228,9 +135,8 @@ def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock) ] question = "What about next season?" - with pytest.raises(LLMGenerationError) as exc_info: + with pytest.raises(LLMGenerationError, match="Input validation failed"): llm.invoke(question, message_history) # type: ignore - assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) @pytest.mark.asyncio From 7c22e732c9cfc3a01aafbae79c137b01846893c2 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 18 Sep 2025 08:26:44 +0200 Subject: [PATCH 16/28] Update Ollama/Anthropic --- tests/unit/llm/test_anthropic_llm.py | 159 ++++++++------------------- tests/unit/llm/test_ollama_llm.py | 6 +- 2 files changed, 50 insertions(+), 115 deletions(-) diff --git a/tests/unit/llm/test_anthropic_llm.py b/tests/unit/llm/test_anthropic_llm.py index 029d75778..326f027d1 100644 --- a/tests/unit/llm/test_anthropic_llm.py +++ b/tests/unit/llm/test_anthropic_llm.py @@ -19,9 +19,11 @@ import anthropic import pytest -from neo4j_graphrag.exceptions import LLMGenerationError +from anthropic import NOT_GIVEN, NotGiven + +from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.anthropic_llm import AnthropicLLM -from neo4j_graphrag.llm.types import LLMResponse +from neo4j_graphrag.types import LLMMessage @pytest.fixture @@ -40,132 +42,65 @@ def test_anthropic_llm_missing_dependency(mock_import: Mock) -> None: AnthropicLLM(model_name="claude-3-opus-20240229") -def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None: - mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( - content=[MagicMock(text="generated text")] - ) - model_params = {"temperature": 0.3} - llm = AnthropicLLM("claude-3-opus-20240229", model_params=model_params) - input_text = "may thy knife chip and shatter" - response = llm.invoke(input_text) - assert response.content == "generated text" - llm.client.messages.create.assert_called_once_with( # type: ignore - messages=[{"role": "user", "content": input_text}], - model="claude-3-opus-20240229", - system=anthropic.NOT_GIVEN, - **model_params, - ) - - -def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock) -> None: - mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( - content=[MagicMock(text="generated text")] - ) - model_params = {"temperature": 0.3} - llm = AnthropicLLM( - "claude-3-opus-20240229", - model_params=model_params, - ) +def test_anthropic_llm_get_messages_with_system_instructions() -> None: + llm = AnthropicLLM(api_key="my key", model_name="claude") message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, + LLMMessage(**{"role": "system", "content": "do something"}), + LLMMessage( + **{"role": "user", "content": "When does the sun come up in the summer?"} + ), + LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), ] - question = "What about next season?" - - response = llm.invoke(question, message_history) # type: ignore - assert response.content == "generated text" - message_history.append({"role": "user", "content": question}) - llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined] - messages=message_history, - model="claude-3-opus-20240229", - system=anthropic.NOT_GIVEN, - **model_params, - ) + system_instruction, messages = llm.get_messages(message_history) + assert isinstance(system_instruction, str) + assert system_instruction == "do something" + assert isinstance(messages, list) + assert len(messages) == 2 # exclude system instruction + for actual, expected in zip(messages, message_history[1:]): + assert isinstance(actual, dict) + assert actual["role"] == expected["role"] + assert actual["content"] == expected["content"] -def test_anthropic_invoke_with_system_instruction( - mock_anthropic: Mock, -) -> None: - mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( - content=[MagicMock(text="generated text")] - ) - model_params = {"temperature": 0.3} - system_instruction = "You are a helpful assistant." - llm = AnthropicLLM( - "claude-3-opus-20240229", - model_params=model_params, - ) - question = "When does it come up in the winter?" - response = llm.invoke(question, system_instruction=system_instruction) - assert isinstance(response, LLMResponse) - assert response.content == "generated text" - messages = [{"role": "user", "content": question}] - llm.client.messages.create.assert_called_with( # type: ignore[attr-defined] - model="claude-3-opus-20240229", - system=system_instruction, - messages=messages, - **model_params, - ) +def test_anthropic_llm_get_messages_without_system_instructions() -> None: + llm = AnthropicLLM(api_key="my key", model_name="claude") + message_history = [ + LLMMessage( + **{"role": "user", "content": "When does the sun come up in the summer?"} + ), + LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), + ] - assert llm.client.messages.create.call_count == 1 # type: ignore + system_instruction, messages = llm.get_messages(message_history) + assert isinstance(system_instruction, NotGiven) + assert system_instruction == NOT_GIVEN + assert isinstance(messages, list) + assert len(messages) == 2 + for actual, expected in zip(messages, message_history): + assert isinstance(actual, dict) + assert actual["role"] == expected["role"] + assert actual["content"] == expected["content"] -def test_anthropic_invoke_with_message_history_and_system_instruction( - mock_anthropic: Mock, -) -> None: +def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None: mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( content=[MagicMock(text="generated text")] ) + mock_anthropic.types.MessageParam.return_value = {"role": "user", "content": "hi"} model_params = {"temperature": 0.3} - system_instruction = "You are a helpful assistant." - llm = AnthropicLLM( - "claude-3-opus-20240229", - model_params=model_params, - ) - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - - question = "When does it come up in the winter?" - response = llm.invoke(question, message_history, system_instruction) # type: ignore + llm = AnthropicLLM("claude-3-opus-20240229", model_params=model_params) + input_text = "may thy knife chip and shatter" + response = llm.invoke(input_text) assert isinstance(response, LLMResponse) assert response.content == "generated text" - message_history.append({"role": "user", "content": question}) - llm.client.messages.create.assert_called_with( # type: ignore[attr-defined] + llm.client.messages.create.assert_called_once_with( # type: ignore + messages=[{"role": "user", "content": "hi"}], model="claude-3-opus-20240229", - system=system_instruction, - messages=message_history, + system=anthropic.NOT_GIVEN, **model_params, ) - assert llm.client.messages.create.call_count == 1 # type: ignore - - -def test_anthropic_invoke_with_message_history_validation_error( - mock_anthropic: Mock, -) -> None: - mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( - content=[MagicMock(text="generated text")] - ) - model_params = {"temperature": 0.3} - system_instruction = "You are a helpful assistant." - llm = AnthropicLLM( - "claude-3-opus-20240229", - model_params=model_params, - system_instruction=system_instruction, - ) - message_history = [ - {"role": "human", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) # type: ignore - assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) - @pytest.mark.asyncio async def test_anthropic_ainvoke_happy_path(mock_anthropic: Mock) -> None: @@ -173,14 +108,16 @@ async def test_anthropic_ainvoke_happy_path(mock_anthropic: Mock) -> None: mock_response.content = [MagicMock(text="Return text")] mock_model = mock_anthropic.AsyncAnthropic.return_value mock_model.messages.create = AsyncMock(return_value=mock_response) + mock_anthropic.types.MessageParam.return_value = {"role": "user", "content": "hi"} model_params = {"temperature": 0.3} llm = AnthropicLLM("claude-3-opus-20240229", model_params) input_text = "may thy knife chip and shatter" response = await llm.ainvoke(input_text) + assert isinstance(response, LLMResponse) assert response.content == "Return text" llm.async_client.messages.create.assert_awaited_once_with( # type: ignore model="claude-3-opus-20240229", system=anthropic.NOT_GIVEN, - messages=[{"role": "user", "content": input_text}], + messages=[{"role": "user", "content": "hi"}], **model_params, ) diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py index 3621680a1..6ecb9fb13 100644 --- a/tests/unit/llm/test_ollama_llm.py +++ b/tests/unit/llm/test_ollama_llm.py @@ -62,10 +62,8 @@ def test_ollama_llm_happy_path_deprecated_options(mock_import: Mock) -> None: assert res.content == "ollama chat response" llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] model=model, - messages=[ - {"role": "user", "content": "test"} - ], - options={"temperature": 0.3} + messages=[{"role": "user", "content": "test"}], + options={"temperature": 0.3}, ) From 8e723353fc200772a23b8a52eca63135429b450c Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 18 Sep 2025 08:26:58 +0200 Subject: [PATCH 17/28] WIP update cohere --- tests/unit/llm/test_cohere_llm.py | 81 ++----------------------------- 1 file changed, 5 insertions(+), 76 deletions(-) diff --git a/tests/unit/llm/test_cohere_llm.py b/tests/unit/llm/test_cohere_llm.py index 10a02ec86..c3b43dbc4 100644 --- a/tests/unit/llm/test_cohere_llm.py +++ b/tests/unit/llm/test_cohere_llm.py @@ -41,86 +41,17 @@ def test_cohere_llm_happy_path(mock_cohere: Mock) -> None: chat_response_mock = MagicMock() chat_response_mock.message.content = [MagicMock(text="cohere response text")] mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock + mock_cohere.UserChatMessageV2.return_value = {"role": "user", "content": "test"} llm = CohereLLM(model_name="something") res = llm.invoke("my text") assert isinstance(res, LLMResponse) assert res.content == "cohere response text" - - -def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) -> None: - chat_response_mock = MagicMock() - chat_response_mock.message.content = [MagicMock(text="cohere response text")] - mock_cohere_client_chat = mock_cohere.ClientV2.return_value.chat - mock_cohere_client_chat.return_value = chat_response_mock - - system_instruction = "You are a helpful assistant." - llm = CohereLLM(model_name="something") - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "cohere response text" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - mock_cohere_client_chat.assert_called_once_with( - messages=messages, + mock_cohere.ClientV2.return_value.chat.assert_called_once_with( + messages=[{"role": "user", "content": "test"}], model="something", ) -def test_cohere_llm_invoke_with_message_history_and_system_instruction( - mock_cohere: Mock, -) -> None: - chat_response_mock = MagicMock() - chat_response_mock.message.content = [MagicMock(text="cohere response text")] - mock_cohere_client_chat = mock_cohere.ClientV2.return_value.chat - mock_cohere_client_chat.return_value = chat_response_mock - - system_instruction = "You are a helpful assistant." - llm = CohereLLM(model_name="gpt") - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "cohere response text" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - mock_cohere_client_chat.assert_called_once_with( - messages=messages, - model="gpt", - ) - - -def test_cohere_llm_invoke_with_message_history_validation_error( - mock_cohere: Mock, -) -> None: - chat_response_mock = MagicMock() - chat_response_mock.message.content = [MagicMock(text="cohere response text")] - mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock - - system_instruction = "You are a helpful assistant." - llm = CohereLLM(model_name="something", system_instruction=system_instruction) - message_history = [ - {"role": "robot", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) # type: ignore - assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) - - @pytest.mark.asyncio async def test_cohere_llm_happy_path_async(mock_cohere: Mock) -> None: chat_response_mock = MagicMock( @@ -139,9 +70,8 @@ async def test_cohere_llm_happy_path_async(mock_cohere: Mock) -> None: def test_cohere_llm_failed(mock_cohere: Mock) -> None: mock_cohere.ClientV2.return_value.chat.side_effect = cohere.core.ApiError llm = CohereLLM(model_name="something") - with pytest.raises(LLMGenerationError) as excinfo: + with pytest.raises(LLMGenerationError, match="ApiError"): llm.invoke("my text") - assert "ApiError" in str(excinfo) @pytest.mark.asyncio @@ -149,6 +79,5 @@ async def test_cohere_llm_failed_async(mock_cohere: Mock) -> None: mock_cohere.AsyncClientV2.return_value.chat.side_effect = cohere.core.ApiError llm = CohereLLM(model_name="something") - with pytest.raises(LLMGenerationError) as excinfo: + with pytest.raises(LLMGenerationError, match="ApiError"): await llm.ainvoke("my text") - assert "ApiError" in str(excinfo) From a34c7606d8d6c9d2d8855f8dff911f43bab29850 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 18 Sep 2025 13:24:40 +0200 Subject: [PATCH 18/28] CHANGELOG.md --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c0b6c9b5..19f1f3999 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,8 +4,15 @@ ### Added +- Document node is now always created when running SimpleKGPipeline, even if `from_pdf=False`. +- Document metadata is exposed in SimpleKGPipeline run method. - Added automatic rate limiting with retry logic and exponential backoff for all Embedding providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely. +### Fixed + +- LangChain Chat models compatibility is now working again. + + ## 1.10.0 ### Added From 7a4d4a06149ed056b6821d6f63b8ff6c4cba9ca8 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 30 Sep 2025 16:35:47 +0200 Subject: [PATCH 19/28] Ruff after rebase --- src/neo4j_graphrag/llm/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 5534ff4ca..f42f71b6f 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -27,6 +27,7 @@ DEFAULT_RATE_LIMIT_HANDLER, rate_limit_handler, async_rate_limit_handler, + RateLimitHandler, ) from neo4j_graphrag.tool import Tool From 78bc7c3990cac415ffbdb916af996b33d83f607b Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 1 Oct 2025 09:51:17 +0200 Subject: [PATCH 20/28] More fixes on cohere tests --- src/neo4j_graphrag/llm/cohere_llm.py | 4 ++-- tests/unit/llm/test_cohere_llm.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index fa6de2d7c..1bfb7d77f 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -108,7 +108,7 @@ def _invoke( model=self.model_name, ) except self.cohere_api_error as e: - raise LLMGenerationError(e) + raise LLMGenerationError("Error calling cohere") from e return LLMResponse( content=res.message.content[0].text if res.message.content else "", ) @@ -132,7 +132,7 @@ async def _ainvoke( model=self.model_name, ) except self.cohere_api_error as e: - raise LLMGenerationError(e) + raise LLMGenerationError("Error calling cohere") from e return LLMResponse( content=res.message.content[0].text if res.message.content else "", ) diff --git a/tests/unit/llm/test_cohere_llm.py b/tests/unit/llm/test_cohere_llm.py index c3b43dbc4..fb968d0c9 100644 --- a/tests/unit/llm/test_cohere_llm.py +++ b/tests/unit/llm/test_cohere_llm.py @@ -70,7 +70,7 @@ async def test_cohere_llm_happy_path_async(mock_cohere: Mock) -> None: def test_cohere_llm_failed(mock_cohere: Mock) -> None: mock_cohere.ClientV2.return_value.chat.side_effect = cohere.core.ApiError llm = CohereLLM(model_name="something") - with pytest.raises(LLMGenerationError, match="ApiError"): + with pytest.raises(LLMGenerationError, match="Error calling cohere"): llm.invoke("my text") @@ -79,5 +79,5 @@ async def test_cohere_llm_failed_async(mock_cohere: Mock) -> None: mock_cohere.AsyncClientV2.return_value.chat.side_effect = cohere.core.ApiError llm = CohereLLM(model_name="something") - with pytest.raises(LLMGenerationError, match="ApiError"): + with pytest.raises(LLMGenerationError, match="Error calling cohere"): await llm.ainvoke("my text") From dd7ab908706c3194d1f650e333d879a9e9c00eaa Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 1 Oct 2025 17:07:32 +0200 Subject: [PATCH 21/28] Add tests for retry behavior --- tests/unit/llm/test_base.py | 84 ++++++++++++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 2 deletions(-) diff --git a/tests/unit/llm/test_base.py b/tests/unit/llm/test_base.py index 9a927b193..dd5b5c86b 100644 --- a/tests/unit/llm/test_base.py +++ b/tests/unit/llm/test_base.py @@ -1,12 +1,20 @@ +"""The base LLMInterface is responsible +for formatting the inputs as a list of LLMMessage objects +and handling the rate limits. This is what is being tested +in this file. +""" + from typing import Type, Generator -from unittest.mock import patch, Mock +from unittest import mock +from unittest.mock import patch, Mock, call import pytest +import tenacity from joblib.testing import fixture from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError -from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.llm import LLMInterface, LLMResponse from neo4j_graphrag.types import LLMMessage @@ -134,3 +142,75 @@ def test_base_llm_interface_invoke_with_tools_with_invalid_inputs( None, None, ) + + +@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") +def test_base_llm_interface_invoke_retry_ok( + mock_inputs: Mock, llm_interface: Type[LLMInterface] +): + mock_inputs.return_value = [ + LLMMessage( + role="user", + content="return value of the legacy_inputs_to_messages function", + ) + ] + llm = llm_interface(model_name="test") + question = "What about next season?" + + with mock.patch.object(llm, "_invoke") as mock_invoke_core: + mock_invoke_core.side_effect = [ + LLMGenerationError("rate limit"), + LLMResponse(content="all good"), + ] + res = llm.invoke(question, []) + assert res.content == "all good" + call_args = [ + { + "role": "user", + "content": "return value of the legacy_inputs_to_messages function", + } + ] + assert mock_invoke_core.call_count == 2 + mock_invoke_core.assert_has_calls( + [ + call(call_args), + call(call_args), + ] + ) + + +@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") +def test_base_llm_interface_invoke_retry_fail( + mock_inputs: Mock, llm_interface: Type[LLMInterface] +): + mock_inputs.return_value = [ + LLMMessage( + role="user", + content="return value of the legacy_inputs_to_messages function", + ) + ] + llm = llm_interface(model_name="test") + question = "What about next season?" + + with mock.patch.object(llm, "_invoke") as mock_invoke_core: + mock_invoke_core.side_effect = [ + LLMGenerationError("rate limit"), + LLMGenerationError("rate limit"), + LLMGenerationError("rate limit"), + ] + with pytest.raises(tenacity.RetryError): + llm.invoke(question, []) + call_args = [ + { + "role": "user", + "content": "return value of the legacy_inputs_to_messages function", + } + ] + assert mock_invoke_core.call_count == 3 + mock_invoke_core.assert_has_calls( + [ + call(call_args), + call(call_args), + call(call_args), + ] + ) From 526e3f319a9c65f0d5998d70c72eb4f39626a4f1 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 1 Oct 2025 17:10:39 +0200 Subject: [PATCH 22/28] Fix MistralAILLM tests --- tests/unit/llm/test_mistralai_llm.py | 91 ---------------------------- 1 file changed, 91 deletions(-) diff --git a/tests/unit/llm/test_mistralai_llm.py b/tests/unit/llm/test_mistralai_llm.py index 324798f2f..f22cfb3e6 100644 --- a/tests/unit/llm/test_mistralai_llm.py +++ b/tests/unit/llm/test_mistralai_llm.py @@ -46,97 +46,6 @@ def test_mistralai_llm_invoke(mock_mistral: Mock) -> None: assert res.content == "mistral response" -@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") -def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None: - mock_mistral_instance = mock_mistral.return_value - chat_response_mock = MagicMock() - chat_response_mock.choices = [ - MagicMock(message=MagicMock(content="mistral response")) - ] - mock_mistral_instance.chat.complete.return_value = chat_response_mock - model = "mistral-model" - system_instruction = "You are a helpful assistant." - - llm = MistralAILLM(model_name=model) - - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore - - assert isinstance(res, LLMResponse) - assert res.content == "mistral response" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined] - messages=messages, - model=model, - ) - - -@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") -def test_mistralai_llm_invoke_with_message_history_and_system_instruction( - mock_mistral: Mock, -) -> None: - mock_mistral_instance = mock_mistral.return_value - chat_response_mock = MagicMock() - chat_response_mock.choices = [ - MagicMock(message=MagicMock(content="mistral response")) - ] - mock_mistral_instance.chat.complete.return_value = chat_response_mock - model = "mistral-model" - system_instruction = "You are a helpful assistant." - llm = MistralAILLM(model_name=model) - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - # first invocation - initial instructions - res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "mistral response" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined] - messages=messages, - model=model, - ) - - assert llm.client.chat.complete.call_count == 1 # type: ignore - - -@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") -def test_mistralai_llm_invoke_with_message_history_validation_error( - mock_mistral: Mock, -) -> None: - mock_mistral_instance = mock_mistral.return_value - chat_response_mock = MagicMock() - chat_response_mock.choices = [ - MagicMock(message=MagicMock(content="mistral response")) - ] - mock_mistral_instance.chat.complete.return_value = chat_response_mock - model = "mistral-model" - system_instruction = "You are a helpful assistant." - - llm = MistralAILLM(model_name=model, system_instruction=system_instruction) - - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "monkey", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) # type: ignore - assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) - - @pytest.mark.asyncio @patch("neo4j_graphrag.llm.mistralai_llm.Mistral") async def test_mistralai_llm_ainvoke(mock_mistral: Mock) -> None: From d0e189c6f8c8bee6d77a836acce9f697b4937c4e Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 1 Oct 2025 17:25:04 +0200 Subject: [PATCH 23/28] Fix VertexAILLM tests --- tests/unit/llm/test_vertexai_llm.py | 100 ++++------------------------ 1 file changed, 14 insertions(+), 86 deletions(-) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 5d0e9b959..466a77b49 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -from typing import cast from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -59,79 +58,14 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: assert content[0].parts[0].text == input_text -@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") -@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM.get_messages") -def test_vertexai_invoke_with_system_instruction( - mock_get_messages: MagicMock, - GenerativeModelMock: MagicMock, -) -> None: - system_instruction = "You are a helpful assistant." - model_name = "gemini-1.5-flash-001" - input_text = "may thy knife chip and shatter" - mock_response = Mock() - mock_response.text = "Return text" - mock_model = GenerativeModelMock.return_value - mock_model.generate_content.return_value = mock_response - - mock_get_messages.return_value = [{"text": "some text"}] - - model_params = {"temperature": 0.5} - llm = VertexAILLM(model_name, model_params) - - response = llm.invoke(input_text, system_instruction=system_instruction) - assert response.content == "Return text" - GenerativeModelMock.assert_called_once_with( - model_name=model_name, - system_instruction=system_instruction, - ) - mock_model.generate_content.assert_called_once_with( - contents=[{"text": "some text"}] - ) - - -@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") -def test_vertexai_invoke_with_message_history_and_system_instruction( - GenerativeModelMock: MagicMock, -) -> None: - system_instruction = "You are a helpful assistant." - model_name = "gemini-1.5-flash-001" - mock_response = Mock() - mock_response.text = "Return text" - mock_model = GenerativeModelMock.return_value - mock_model.generate_content.return_value = mock_response - model_params = {"temperature": 0.5} - llm = VertexAILLM(model_name, model_params) - - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - response = llm.invoke( - question, - message_history, # type: ignore - system_instruction=system_instruction, - ) - assert response.content == "Return text" - GenerativeModelMock.assert_called_once_with( - model_name=model_name, - system_instruction=system_instruction, - ) - last_call = mock_model.generate_content.call_args_list[0] - content = last_call.kwargs["contents"] - assert len(content) == 3 # question + 2 messages in history - - @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: model_name = "gemini-1.5-flash-001" - question = "When does it set?" message_history: list[LLMMessage] = [ + {"role": "system", "content": "Answer to a 3yo kid"}, {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, {"role": "user", "content": "What about next season?"}, - {"role": "assistant", "content": "Around 8am."}, ] expected_response = [ Content( @@ -140,16 +74,15 @@ def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: ), Content(role="model", parts=[Part.from_text("Usually around 6am.")]), Content(role="user", parts=[Part.from_text("What about next season?")]), - Content(role="model", parts=[Part.from_text("Around 8am.")]), - Content(role="user", parts=[Part.from_text("When does it set?")]), ] llm = VertexAILLM(model_name=model_name) - response = llm.get_messages(question, message_history) + system_instructions, messages = llm.get_messages(message_history) GenerativeModelMock.assert_not_called() - assert len(response) == len(expected_response) - for actual, expected in zip(response, expected_response): + assert system_instructions == "Answer to a 3yo kid" + assert len(messages) == len(expected_response) + for actual, expected in zip(messages, expected_response): assert actual.role == expected.role assert actual.parts[0].text == expected.parts[0].text @@ -164,9 +97,8 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) ] llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) - with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, cast(list[LLMMessage], message_history)) - assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) + with pytest.raises(LLMGenerationError, match="Input validation failed"): + llm.invoke(question, message_history) @pytest.mark.asyncio @@ -179,7 +111,7 @@ async def test_vertexai_ainvoke_happy_path( mock_response.text = "Return text" mock_model = GenerativeModelMock.return_value mock_model.generate_content_async = AsyncMock(return_value=mock_response) - mock_get_messages.return_value = [{"text": "Return text"}] + mock_get_messages.return_value = None, [{"text": "Return text"}] model_params = {"temperature": 0.5} llm = VertexAILLM("gemini-1.5-flash-001", model_params) input_text = "may thy knife chip and shatter" @@ -223,9 +155,7 @@ def test_vertexai_invoke_with_tools( res = llm.invoke_with_tools("my text", tools) mock_call_llm.assert_called_once_with( - "my text", - message_history=None, - system_instruction=None, + [{"role": "user", "content": "my text"}], tools=tools, ) mock_parse_tool.assert_called_once() @@ -244,11 +174,11 @@ def test_vertexai_call_llm_with_tools(mock_model: Mock, test_tool: Tool) -> None tools = [test_tool] with patch.object(llm, "_get_llm_tools", return_value=["my tools"]): - res = llm._call_llm("my text", tools=tools) + res = llm._call_llm([{"role": "user", "content": "my text"}], tools=tools) assert isinstance(res, GenerationResponse) mock_model.assert_called_once_with( - system_instruction=None, + None, ) calls = mock_generate_content.call_args_list assert len(calls) == 1 @@ -277,9 +207,7 @@ def test_vertexai_ainvoke_with_tools( res = llm.invoke_with_tools("my text", tools) mock_call_llm.assert_called_once_with( - "my text", - message_history=None, - system_instruction=None, + [{"role": "user", "content": "my text"}], tools=tools, ) mock_parse_tool.assert_called_once() @@ -301,8 +229,8 @@ async def test_vertexai_acall_llm_with_tools(mock_model: Mock, test_tool: Tool) llm = VertexAILLM(model_name="gemini") tools = [test_tool] - res = await llm._acall_llm("my text", tools=tools) + res = await llm._acall_llm([{"role": "user", "content": "my text"}], tools=tools) mock_model.assert_called_once_with( - system_instruction=None, + None, ) assert isinstance(res, GenerationResponse) From 8db779acbb5f6a0343068f17efd97109a31d11d2 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 1 Oct 2025 17:26:27 +0200 Subject: [PATCH 24/28] mypy --- tests/unit/llm/test_base.py | 4 ++-- tests/unit/llm/test_openai_llm.py | 2 +- tests/unit/llm/test_vertexai_llm.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/llm/test_base.py b/tests/unit/llm/test_base.py index dd5b5c86b..4eff7cb94 100644 --- a/tests/unit/llm/test_base.py +++ b/tests/unit/llm/test_base.py @@ -147,7 +147,7 @@ def test_base_llm_interface_invoke_with_tools_with_invalid_inputs( @patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") def test_base_llm_interface_invoke_retry_ok( mock_inputs: Mock, llm_interface: Type[LLMInterface] -): +) -> None: mock_inputs.return_value = [ LLMMessage( role="user", @@ -182,7 +182,7 @@ def test_base_llm_interface_invoke_retry_ok( @patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") def test_base_llm_interface_invoke_retry_fail( mock_inputs: Mock, llm_interface: Type[LLMInterface] -): +) -> None: mock_inputs.return_value = [ LLMMessage( role="user", diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 2b2cb29f2..39571989f 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -71,7 +71,7 @@ def test_openai_llm_get_messages() -> None: def test_openai_llm_get_messages_unknown_role() -> None: llm = OpenAILLM(api_key="my key", model_name="gpt") message_history = [ - LLMMessage(**{"role": "unknown role", "content": "Usually around 6am."}), + LLMMessage(**{"role": "unknown role", "content": "Usually around 6am."}), # type: ignore[typeddict-item] ] with pytest.raises(ValueError, match="Unknown role"): llm.get_messages(message_history) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 466a77b49..96c66f138 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -93,7 +93,7 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) model_name = "gemini-1.5-flash-001" question = "hi!" message_history = [ - {"role": "model", "content": "hello!"}, + LLMMessage(**{"role": "model", "content": "hello!"}), ] llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) From 59a55801f1a4946ac25a5b30aad8bec4e9766ef5 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 1 Oct 2025 17:28:34 +0200 Subject: [PATCH 25/28] Address comments --- examples/customize/llms/custom_llm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index 554629d4a..eccd7beec 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -37,6 +37,8 @@ async def _ainvoke( res: LLMResponse = llm.invoke("text") print(res.content) +# If you want to use a custom rate limit handler +# Type variables for function signatures used in rate limit handlers F = TypeVar("F", bound=Callable[..., Any]) AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) From 9f6fff2d73a4ffe96b1f0952decfae2779d907bf Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 1 Oct 2025 17:42:49 +0200 Subject: [PATCH 26/28] Fix mypy again --- tests/unit/llm/test_vertexai_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 96c66f138..effbfadf8 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -93,7 +93,7 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) model_name = "gemini-1.5-flash-001" question = "hi!" message_history = [ - LLMMessage(**{"role": "model", "content": "hello!"}), + LLMMessage(**{"role": "model", "content": "hello!"}), # type: ignore[typeddict-item] ] llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) From 394f9ebc98581e4a14ebd13f3c0cf3a5e3a15ad1 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 1 Oct 2025 17:54:27 +0200 Subject: [PATCH 27/28] Fix e2e --- tests/e2e/test_graphrag_e2e.py | 57 +++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/tests/e2e/test_graphrag_e2e.py b/tests/e2e/test_graphrag_e2e.py index 895a9adb0..d747aec08 100644 --- a/tests/e2e/test_graphrag_e2e.py +++ b/tests/e2e/test_graphrag_e2e.py @@ -60,7 +60,14 @@ def test_graphrag_happy_path( ) llm.invoke.assert_called_once_with( - """Context: + [ + { + "role": "system", + "content": "Answer the user question using the provided context.", + }, + { + "role": "user", + "content": """Context: @@ -72,8 +79,8 @@ def test_graphrag_happy_path( Answer: """, - None, - system_instruction="Answer the user question using the provided context.", + }, + ] ) assert isinstance(result, RagResultModel) assert result.answer == "some text" @@ -148,13 +155,21 @@ def test_graphrag_happy_path_with_neo4j_message_history( llm.invoke.assert_has_calls( [ call( - input=first_invocation_input, - system_instruction=first_invocation_system_instruction, + [ + {"role": "system", "content": first_invocation_system_instruction}, + {"role": "user", "content": first_invocation_input}, + ] ), call( - second_invocation, - message_history.messages, - system_instruction="Answer the user question using the provided context.", + [ + { + "role": "system", + "content": "Answer the user question using the provided context.", + }, + {"role": "user", "content": "initial question"}, + {"role": "assistant", "content": "answer to initial question"}, + {"role": "user", "content": second_invocation}, + ] ), ] ) @@ -190,7 +205,14 @@ def test_graphrag_happy_path_return_context( ) llm.invoke.assert_called_once_with( - """Context: + [ + { + "role": "system", + "content": "Answer the user question using the provided context.", + }, + { + "role": "user", + "content": """Context: @@ -202,8 +224,8 @@ def test_graphrag_happy_path_return_context( Answer: """, - None, - system_instruction="Answer the user question using the provided context.", + }, + ], ) assert isinstance(result, RagResultModel) assert result.answer == "some text" @@ -236,7 +258,14 @@ def test_graphrag_happy_path_examples( ) llm.invoke.assert_called_once_with( - """Context: + [ + { + "role": "system", + "content": "Answer the user question using the provided context.", + }, + { + "role": "user", + "content": """Context: @@ -248,8 +277,8 @@ def test_graphrag_happy_path_examples( Answer: """, - None, - system_instruction="Answer the user question using the provided context.", + }, + ] ) assert result.answer == "some text" From 1027c9a65e4ac3590d2b84e731f4354e341443a0 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 1 Oct 2025 18:35:54 +0200 Subject: [PATCH 28/28] Fix CI --- src/neo4j_graphrag/llm/utils.py | 1 + tests/unit/test_graphrag.py | 54 +++++++++++++++++++++++---------- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/src/neo4j_graphrag/llm/utils.py b/src/neo4j_graphrag/llm/utils.py index b61a880f4..5746ca91c 100644 --- a/src/neo4j_graphrag/llm/utils.py +++ b/src/neo4j_graphrag/llm/utils.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import warnings from typing import Union, Optional diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index 925b48b78..b58d8a9e8 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -63,7 +63,14 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None: retriever_mock.search.assert_called_once_with(query_text="question", top_k=111) llm.invoke.assert_called_once_with( - """Context: + [ + { + "role": "system", + "content": "Answer the user question using the provided context.", + }, + { + "role": "user", + "content": """Context: item content 1 item content 2 @@ -75,8 +82,8 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None: Answer: """, - None, # message history - system_instruction="Answer the user question using the provided context.", + }, + ] ) assert isinstance(res, RagResultModel) @@ -142,13 +149,20 @@ def test_graphrag_happy_path_with_message_history( llm.invoke.assert_has_calls( [ call( - input=first_invocation_input, - system_instruction=first_invocation_system_instruction, + [ + {"role": "system", "content": first_invocation_system_instruction}, + {"role": "user", "content": first_invocation_input}, + ] ), call( - second_invocation, - message_history, - system_instruction="Answer the user question using the provided context.", + [ + { + "role": "system", + "content": "Answer the user question using the provided context.", + }, + *message_history, + {"role": "user", "content": second_invocation}, + ] ), ] ) @@ -218,13 +232,20 @@ def test_graphrag_happy_path_with_in_memory_message_history( llm.invoke.assert_has_calls( [ call( - input=first_invocation_input, - system_instruction=first_invocation_system_instruction, + [ + {"role": "system", "content": first_invocation_system_instruction}, + {"role": "user", "content": first_invocation_input}, + ] ), call( - second_invocation, - message_history.messages, - system_instruction="Answer the user question using the provided context.", + [ + { + "role": "system", + "content": "Answer the user question using the provided context.", + }, + *message_history.messages, + {"role": "user", "content": second_invocation}, + ] ), ] ) @@ -253,9 +274,10 @@ def test_graphrag_happy_path_custom_system_instruction( llm.invoke.assert_has_calls( [ call( - mock.ANY, - None, # no message history - system_instruction="Custom instruction", + [ + {"role": "system", "content": "Custom instruction"}, + {"role": "user", "content": mock.ANY}, + ] ), ] )