From a66b8769ef4ce28d503301c0c45c779b0846297c Mon Sep 17 00:00:00 2001 From: George Burton Date: Thu, 5 Sep 2024 19:03:55 +0100 Subject: [PATCH] tests passing --- .../redbox_app/redbox_core/consumers.py | 40 +-- django_app/tests/test_consumers.py | 241 +++++++++++------- 2 files changed, 168 insertions(+), 113 deletions(-) diff --git a/django_app/redbox_app/redbox_core/consumers.py b/django_app/redbox_app/redbox_core/consumers.py index efe9f6350..39216f2b3 100644 --- a/django_app/redbox_app/redbox_core/consumers.py +++ b/django_app/redbox_app/redbox_core/consumers.py @@ -87,15 +87,15 @@ async def llm_conversation(self, selected_files: Sequence[File], session: Chat, ), ) - await self.redbox.run( - state, - response_tokens_callback=self.handle_text, - route_name_callback=self.handle_route, - documents_callback=self.handle_documents, - metadata_tokens_callback=self.handle_metadata, - ) - try: + await self.redbox.run( + state, + response_tokens_callback=self.handle_text, + route_name_callback=self.handle_route, + documents_callback=self.handle_documents, + metadata_tokens_callback=self.handle_metadata, + ) + message = await self.save_message( session, "".join(self.full_reply), @@ -193,25 +193,25 @@ def get_ai_settings(user: User) -> AISettings: fields=[field.name for field in user.ai_settings._meta.fields if field.name != "label"], # noqa: SLF001 ) - async def handle_text(self, response: ClientResponse) -> str: - await self.send_to_client("text", response.data) - self.full_reply.append(response.data) + async def handle_text(self, response: str) -> str: + await self.send_to_client("text", response) + self.full_reply.append(response) - async def handle_route(self, response: ClientResponse) -> str: - await self.send_to_client("route", response.data) - self.routes.append(response.data) + async def handle_route(self, response: str) -> str: + await self.send_to_client("route", response) + self.route = response - async def handle_metadata(self, response: ClientResponse): - for model, token_count in response.data.input_tokens.items(): + async def handle_metadata(self, response: MetadataDetail): + for model, token_count in response.input_tokens.items(): self.metadata.input_tokens[model] = self.metadata.input_tokens.get(model, 0) + token_count - for model, token_count in response.data.output_tokens.items(): + for model, token_count in response.output_tokens.items(): self.metadata.output_tokens[model] = self.metadata.output_tokens.get(model, 0) + token_count - async def handle_documents(self, response: ClientResponse) -> Sequence[tuple[File, SourceDocument]]: - s3_keys = [doc.s3_key for doc in response.data] + async def handle_documents(self, response: list[SourceDocument]) -> Sequence[tuple[File, SourceDocument]]: + s3_keys = [doc.s3_key for doc in response] files = File.objects.filter(original_file__in=s3_keys) async for file in files: await self.send_to_client("source", {"url": str(file.url), "original_file_name": file.original_file_name}) for file in files: - self.citations.append((file, [doc for doc in response.data if doc.s3_key == file.unique_name])) + self.citations.append((file, [doc for doc in response if doc.s3_key == file.unique_name])) diff --git a/django_app/tests/test_consumers.py b/django_app/tests/test_consumers.py index b2c3c4e09..06b006031 100644 --- a/django_app/tests/test_consumers.py +++ b/django_app/tests/test_consumers.py @@ -10,9 +10,13 @@ from channels.db import database_sync_to_async from channels.testing import WebsocketCommunicator from django.db.models import Model +from langchain_core.language_models import BaseChatModel +from pydantic import BaseModel from websockets import WebSocketClientProtocol from websockets.legacy.client import Connect +from redbox.graph.root import FINAL_RESPONSE_TAG, ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG +from redbox.models.chat import MetadataDetail, SourceDocument from redbox_app.redbox_core import error_messages from redbox_app.redbox_core.consumers import ChatConsumer from redbox_app.redbox_core.models import Chat, ChatMessage, ChatMessageTokenUse, ChatRoleEnum, File, User @@ -38,7 +42,7 @@ async def test_chat_consumer_with_new_session(alice: User, uploaded_file: File, # Given # When - with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect): + with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_connect): communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice connected, _ = await communicator.connect() @@ -85,7 +89,7 @@ async def test_chat_consumer_staff_user(staff_user: User, mocked_connect: Connec # Given # When - with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect): + with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_connect): communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = staff_user connected, _ = await communicator.connect() @@ -118,7 +122,7 @@ async def test_chat_consumer_with_existing_session(alice: User, chat: Chat, mock # Given # When - with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect): + with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_connect): communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice connected, _ = await communicator.connect() @@ -144,7 +148,7 @@ async def test_chat_consumer_with_naughty_question(alice: User, uploaded_file: F # Given # When - with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect): + with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_connect): communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice connected, _ = await communicator.connect() @@ -185,7 +189,7 @@ async def test_chat_consumer_with_naughty_citation( # Given # When - with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_naughty_citation): + with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_connect_with_naughty_citation): communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice connected, _ = await communicator.connect() @@ -248,7 +252,7 @@ async def test_chat_consumer_with_selected_files( selected_files: Sequence[File] = several_files[2:] # When - with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_several_files): + with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_connect_with_several_files): communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice connected, _ = await communicator.connect() @@ -306,7 +310,7 @@ async def test_chat_consumer_with_connection_error(alice: User, mocked_breaking_ # Given # When - with patch("redbox_app.redbox_core.consumers.connect", new=mocked_breaking_connect): + with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_breaking_connect): communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice connected, _ = await communicator.connect() @@ -328,7 +332,9 @@ async def test_chat_consumer_with_explicit_unhandled_error( # Given # When - with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_explicit_unhandled_error): + with patch( + "redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_connect_with_explicit_unhandled_error + ): communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice connected, _ = await communicator.connect() @@ -355,7 +361,9 @@ async def test_chat_consumer_with_rate_limited_error(alice: User, mocked_connect # Given # When - with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_rate_limited_error): + with patch( + "redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_connect_with_rate_limited_error + ): communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice connected, _ = await communicator.connect() @@ -384,7 +392,10 @@ async def test_chat_consumer_with_explicit_no_document_selected_error( # Given # When - with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_explicit_no_document_selected_error): + with patch( + "redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", + new=mocked_connect_with_explicit_no_document_selected_error, + ): communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice connected, _ = await communicator.connect() @@ -407,7 +418,10 @@ async def test_chat_consumer_with_explicit_no_document_selected_error( async def test_chat_consumer_get_ai_settings( alice: User, mocked_connect_with_explicit_no_document_selected_error: Connect ): - with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_explicit_no_document_selected_error): + with patch( + "redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", + new=mocked_connect_with_explicit_no_document_selected_error, + ): communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice connected, _ = await communicator.connect() @@ -434,111 +448,152 @@ def get_chat_messages(user: User) -> Sequence[ChatMessage]: ) +class Token(BaseModel): + content: str + + +class CannedGraphLLM(BaseChatModel): + responses: list[dict] + + def _generate(self, *_args, **_kwargs): + for _ in self.responses: + yield + + def _llm_type(self): + return "go away" + + def _convert_input(self, prompt): + if isinstance(prompt, dict): + prompt = prompt["request"].question + return super()._convert_input(prompt) + + async def astream_events(self, *_args, **_kwargs): + for response in self.responses: + yield response + + @pytest.fixture() def mocked_connect(uploaded_file: File) -> Connect: - mocked_websocket = AsyncMock(spec=WebSocketClientProtocol, name="mocked_websocket") - mocked_connect = MagicMock(spec=Connect, name="mocked_connect") - mocked_connect.return_value.__aenter__.return_value = mocked_websocket - mocked_websocket.__aiter__.return_value = [ - json.dumps({"resource_type": "text", "data": "Good afternoon, "}), - json.dumps({"resource_type": "text", "data": "Mr. Amor."}), - json.dumps({"resource_type": "route_name", "data": "gratitude"}), - json.dumps( - { - "resource_type": "documents", - "data": [{"s3_key": uploaded_file.unique_name, "page_content": "Good afternoon Mr Amor"}], - } - ), - json.dumps( - { - "resource_type": "documents", - "data": [ - {"s3_key": uploaded_file.unique_name, "page_content": "Good afternoon Mr Amor"}, - { - "s3_key": uploaded_file.unique_name, - "page_content": "Good afternoon Mr Amor", - "page_numbers": [34, 35], - }, - ], - } - ), - json.dumps( - { - "resource_type": "metadata", - "data": {"input_tokens": {"gpt-4o": 123}, "output_tokens": {"gpt-4o": 1000}}, - } - ), - json.dumps({"resource_type": "end"}), + responses = [ + { + "event": "on_chat_model_stream", + "tags": [FINAL_RESPONSE_TAG], + "data": {"chunk": Token(content="Good afternoon, ")}, + }, + {"event": "on_chat_model_stream", "tags": [FINAL_RESPONSE_TAG], "data": {"chunk": Token(content="Mr. Amor.")}}, + {"event": "on_chain_end", "tags": [ROUTE_NAME_TAG], "data": {"output": "gratitude"}}, + { + "event": "on_retriever_end", + "tags": [SOURCE_DOCUMENTS_TAG], + "data": { + "output": [SourceDocument(s3_key=uploaded_file.unique_name, page_content="Good afternoon Mr Amor")] + }, + }, + { + "event": "on_retriever_end", + "tags": [SOURCE_DOCUMENTS_TAG], + "data": { + "output": [ + SourceDocument(s3_key=uploaded_file.unique_name, page_content="Good afternoon Mr Amor"), + SourceDocument( + s3_key=uploaded_file.unique_name, page_content="Good afternoon Mr Amor", page_numbers=[34, 35] + ), + ] + }, + }, + { + "event": "on_custom_event", + "name": "on_metadata_generation", + "data": MetadataDetail(input_tokens={"gpt-4o": 123}, output_tokens={"gpt-4o": 1000}), + }, ] - return mocked_connect + + return CannedGraphLLM(responses=responses) @pytest.fixture() -def mocked_connect_with_naughty_citation(uploaded_file: File) -> Connect: - mocked_websocket = AsyncMock(spec=WebSocketClientProtocol, name="mocked_websocket") - mocked_connect = MagicMock(spec=Connect, name="mocked_connect") - mocked_connect.return_value.__aenter__.return_value = mocked_websocket - mocked_websocket.__aiter__.return_value = [ - json.dumps({"resource_type": "text", "data": "Good afternoon, Mr. Amor."}), - json.dumps({"resource_type": "route_name", "data": "gratitude"}), - json.dumps( - { - "resource_type": "documents", - "data": [ - {"s3_key": uploaded_file.unique_name, "page_content": "Good afternoon Mr Amor"}, - {"s3_key": uploaded_file.unique_name, "page_content": "I shouldn't send a \x00"}, - ], - } - ), - json.dumps({"resource_type": "end"}), +def mocked_connect_with_naughty_citation(uploaded_file: File) -> CannedGraphLLM: + responses = [ + { + "event": "on_chat_model_stream", + "tags": [FINAL_RESPONSE_TAG], + "data": {"chunk": Token(content="Good afternoon, Mr. Amor.")}, + }, + {"event": "on_chain_end", "tags": [ROUTE_NAME_TAG], "data": {"output": "gratitude"}}, + { + "event": "on_retriever_end", + "tags": [SOURCE_DOCUMENTS_TAG], + "data": { + "output": [ + SourceDocument(s3_key=uploaded_file.unique_name, page_content="Good afternoon Mr Amor"), + SourceDocument(s3_key=uploaded_file.unique_name, page_content="I shouldn't send a \x00"), + ] + }, + }, ] - return mocked_connect + + return CannedGraphLLM(responses=responses) @pytest.fixture() def mocked_breaking_connect() -> Connect: - mocked_websocket = AsyncMock(spec=WebSocketClientProtocol, name="mocked_websocket") - mocked_connect = MagicMock(spec=Connect, name="mocked_connect") - mocked_connect.return_value.__aenter__.return_value = mocked_websocket - mocked_websocket.__aiter__.side_effect = CancelledError() - return mocked_connect + mocked_graph = MagicMock(name="mocked_graph") + mocked_graph.astream_events.side_effect = CancelledError() + return mocked_graph @pytest.fixture() -def mocked_connect_with_explicit_unhandled_error() -> Connect: - mocked_websocket = AsyncMock(spec=WebSocketClientProtocol, name="mocked_websocket") - mocked_connect = MagicMock(spec=Connect, name="mocked_connect") - mocked_connect.return_value.__aenter__.return_value = mocked_websocket - mocked_websocket.__aiter__.return_value = [ - json.dumps({"resource_type": "text", "data": "Good afternoon, "}), - json.dumps({"resource_type": "error", "data": {"code": "unknown", "message": "Oh dear."}}), +def mocked_connect_with_explicit_unhandled_error() -> CannedGraphLLM: + responses = [ + { + "event": "on_chat_model_stream", + "tags": [FINAL_RESPONSE_TAG], + "data": {"chunk": Token(content="Good afternoon, ")}, + }, + { + "event": "on_chat_model_stream", + "tags": [FINAL_RESPONSE_TAG], + "data": {"chunk": Token(content=error_messages.CORE_ERROR_MESSAGE)}, + }, ] - return mocked_connect + + return CannedGraphLLM(responses=responses) @pytest.fixture() -def mocked_connect_with_rate_limited_error() -> Connect: - mocked_websocket = AsyncMock(spec=WebSocketClientProtocol, name="mocked_websocket") - mocked_connect = MagicMock(spec=Connect, name="mocked_connect") - mocked_connect.return_value.__aenter__.return_value = mocked_websocket - mocked_websocket.__aiter__.return_value = [ - json.dumps({"resource_type": "text", "data": "Good afternoon, "}), - json.dumps( - {"resource_type": "error", "data": {"code": "rate-limit", "message": "HTTP/1.1 429 Too Many Requests"}} - ), +def mocked_connect_with_rate_limited_error() -> CannedGraphLLM: + responses = [ + { + "event": "on_chat_model_stream", + "tags": [FINAL_RESPONSE_TAG], + "data": {"chunk": Token(content="Good afternoon, ")}, + }, + { + "event": "on_chat_model_stream", + "tags": [FINAL_RESPONSE_TAG], + "data": {"chunk": Token(content=error_messages.RATE_LIMITED)}, + }, ] - return mocked_connect + + return CannedGraphLLM(responses=responses) @pytest.fixture() -def mocked_connect_with_explicit_no_document_selected_error() -> Connect: - mocked_websocket = AsyncMock(spec=WebSocketClientProtocol, name="mocked_websocket") - mocked_connect = MagicMock(spec=Connect, name="mocked_connect") - mocked_connect.return_value.__aenter__.return_value = mocked_websocket - mocked_websocket.__aiter__.return_value = [ - json.dumps({"resource_type": "error", "data": {"code": "no-document-selected", "message": "whatever"}}), +def mocked_connect_with_explicit_no_document_selected_error() -> CannedGraphLLM: + responses = [ + { + "event": "on_chat_model_stream", + "tags": [FINAL_RESPONSE_TAG], + "data": {"chunk": Token(content="Good afternoon, ")}, + }, + { + "event": "on_chat_model_stream", + "tags": [FINAL_RESPONSE_TAG], + "data": {"chunk": Token(content=error_messages.SELECT_DOCUMENT)}, + }, ] - return mocked_connect + + return CannedGraphLLM(responses=responses) @pytest.fixture()