From e8736fb988460b7d070885b6aea793313c2bb539 Mon Sep 17 00:00:00 2001 From: gecBurton Date: Thu, 9 Jan 2025 16:24:21 +0000 Subject: [PATCH] wip --- .../redbox_app/redbox_core/consumers.py | 4 +- django_app/tests/test_consumers.py | 5 - redbox-core/redbox/app.py | 1 - redbox-core/redbox/graph/nodes/processes.py | 12 +- redbox-core/redbox/graph/root.py | 3 - redbox-core/redbox/models/chain.py | 2 - redbox-core/redbox/retriever/__init__.py | 3 - redbox-core/redbox/retriever/retrievers.py | 32 ---- redbox-core/redbox/test/data.py | 10 +- redbox-core/tests/conftest.py | 60 ------- redbox-core/tests/graph/test_app.py | 11 +- redbox-core/tests/graph/test_patterns.py | 43 +++-- redbox-core/tests/retriever/__init__.py | 0 redbox-core/tests/retriever/data.py | 158 ------------------ redbox-core/tests/retriever/test_retriever.py | 95 ----------- 15 files changed, 37 insertions(+), 402 deletions(-) delete mode 100644 redbox-core/redbox/retriever/__init__.py delete mode 100644 redbox-core/redbox/retriever/retrievers.py delete mode 100644 redbox-core/tests/retriever/__init__.py delete mode 100644 redbox-core/tests/retriever/data.py delete mode 100644 redbox-core/tests/retriever/test_retriever.py diff --git a/django_app/redbox_app/redbox_core/consumers.py b/django_app/redbox_app/redbox_core/consumers.py index 0a3f03896..9aadd1152 100644 --- a/django_app/redbox_app/redbox_core/consumers.py +++ b/django_app/redbox_app/redbox_core/consumers.py @@ -22,7 +22,6 @@ RequestMetadata, metadata_reducer, ) -from redbox.retriever import DjangoFileRetriever from redbox_app.redbox_core import error_messages from redbox_app.redbox_core.models import AISettings as AISettingsModel from redbox_app.redbox_core.models import ( @@ -132,8 +131,7 @@ async def llm_conversation( for message in message_history[:-1] ], ai_settings=ai_settings, - permitted_s3_keys=[f.unique_name async for f in permitted_files], - selected_files_total_tokens=sum(f.metadata.get("token_count", 0) async for f in permitted_files) + selected_files_total_tokens=sum(f.metadata.get("token_count", 0) async for f in permitted_files), ), ) diff --git a/django_app/tests/test_consumers.py b/django_app/tests/test_consumers.py index d87c48b5a..f952e265a 100644 --- a/django_app/tests/test_consumers.py +++ b/django_app/tests/test_consumers.py @@ -449,10 +449,6 @@ async def test_chat_consumer_redbox_state( selected_file_uuids: Sequence[str] = [str(f.id) for f in selected_files] selected_file_keys: Sequence[str] = [f.unique_name for f in selected_files] - permitted_file_keys: Sequence[str] = [ - f.unique_name async for f in File.objects.filter(user=alice, status=File.Status.complete) - ] - assert selected_file_keys != permitted_file_keys await communicator.send_json_to( { @@ -482,7 +478,6 @@ async def test_chat_consumer_redbox_state( {"role": "ai", "text": "A second answer."}, ], ai_settings=ai_settings, - permitted_s3_keys=permitted_file_keys, ) redbox_state = mock_run.call_args.args[0] # pulls out the args that redbox.run was called with diff --git a/redbox-core/redbox/app.py b/redbox-core/redbox/app.py index e7e512b50..006e4fac9 100644 --- a/redbox-core/redbox/app.py +++ b/redbox-core/redbox/app.py @@ -1,7 +1,6 @@ from logging import getLogger from typing import Literal -from langchain_core.vectorstores import VectorStoreRetriever from redbox.graph.root import ( get_chat_with_documents_graph, diff --git a/redbox-core/redbox/graph/nodes/processes.py b/redbox-core/redbox/graph/nodes/processes.py index 1fd9e44df..b3107e469 100644 --- a/redbox-core/redbox/graph/nodes/processes.py +++ b/redbox-core/redbox/graph/nodes/processes.py @@ -1,21 +1,17 @@ import logging import re -from collections.abc import Callable from typing import Any from langchain.schema import StrOutputParser -from langchain_core.documents import Document from langchain_core.messages import AIMessage, HumanMessage -from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel +from langchain_core.runnables import Runnable, RunnableLambda from langchain_core.tools import StructuredTool -from langchain_core.vectorstores import VectorStoreRetriever from redbox.chains.components import get_chat_llm from redbox.chains.runnables import CannedChatLLM, build_llm_chain from redbox.models import ChatRoute -from redbox.models.chain import DocumentState, PromptSet, RedboxState, RequestMetadata -from redbox.models.graph import ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG -from redbox.transform import flatten_document_state +from redbox.models.chain import DocumentState, PromptSet, RedboxState +from redbox.models.graph import ROUTE_NAME_TAG log = logging.getLogger(__name__) re_keyword_pattern = re.compile(r"@(\w+)") @@ -26,8 +22,6 @@ ## Core patterns - - def build_chat_pattern( prompt_set: PromptSet, tools: list[StructuredTool] | None = None, diff --git a/redbox-core/redbox/graph/root.py b/redbox-core/redbox/graph/root.py index 31f255148..4774cc847 100644 --- a/redbox-core/redbox/graph/root.py +++ b/redbox-core/redbox/graph/root.py @@ -1,4 +1,3 @@ -from langchain_core.vectorstores import VectorStoreRetriever from langgraph.graph import END, START, StateGraph from langgraph.graph.graph import CompiledGraph @@ -89,8 +88,6 @@ def get_chat_with_documents_graph( return builder.compile(debug=debug) - - # Root graph def get_root_graph( debug: bool = False, diff --git a/redbox-core/redbox/models/chain.py b/redbox-core/redbox/models/chain.py index 6b66c922d..b1fd7736a 100644 --- a/redbox-core/redbox/models/chain.py +++ b/redbox-core/redbox/models/chain.py @@ -142,11 +142,9 @@ class RedboxQuery(BaseModel): user_uuid: UUID = Field(description="User the chain in executing for") chat_history: list[ChainChatMessage] = Field(description="All previous messages in chat (excluding question)") ai_settings: AISettings = Field(description="User request AI settings", default_factory=AISettings) - permitted_s3_keys: list[str] = Field(description="List of permitted files for response", default_factory=list) selected_files_total_tokens: int = 0 - class LLMCallMetadata(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) llm_model_name: str diff --git a/redbox-core/redbox/retriever/__init__.py b/redbox-core/redbox/retriever/__init__.py deleted file mode 100644 index e0b8042d9..000000000 --- a/redbox-core/redbox/retriever/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .retrievers import DjangoFileRetriever - -__all__ = ["DjangoFileRetriever"] diff --git a/redbox-core/redbox/retriever/retrievers.py b/redbox-core/redbox/retriever/retrievers.py deleted file mode 100644 index 6259d5830..000000000 --- a/redbox-core/redbox/retriever/retrievers.py +++ /dev/null @@ -1,32 +0,0 @@ -from logging import getLogger - -from langchain_core.callbacks import CallbackManagerForRetrieverRun -from langchain_core.documents import Document -from langchain_core.retrievers import BaseRetriever - -from redbox.models.chain import RedboxState -from typing import Any - -logger = getLogger(__name__) - - -class DjangoFileRetriever(BaseRetriever): - file_manager: Any = None - - def _get_relevant_documents( - self, query: RedboxState, *, run_manager: CallbackManagerForRetrieverRun - ) -> list[Document]: - selected_files = set(query.request.s3_keys) - permitted_files = set(query.request.permitted_s3_keys) - - if not selected_files <= permitted_files: - logger.warning( - "User has selected files they aren't permitted to access: \n" - f"{", ".join(selected_files - permitted_files)}" - ) - - file_names = list(selected_files & permitted_files) - - files = self.file_manager.filter(original_file__in=file_names, text__isnull=False, metadata__isnull=False) - - return [Document(page_content=file.text, metadata=file.metadata) for file in files] diff --git a/redbox-core/redbox/test/data.py b/redbox-core/redbox/test/data.py index 10937472b..8747457ef 100644 --- a/redbox-core/redbox/test/data.py +++ b/redbox-core/redbox/test/data.py @@ -115,14 +115,10 @@ def __init__( self.test_id = test_id def get_docs_matching_query(self) -> list[Document]: - return [ - doc - for doc in self.docs - if doc.metadata["uri"] in set(self.query.s3_keys) & set(self.query.permitted_s3_keys) - ] + return [doc for doc in self.docs if doc.metadata["uri"] in set(self.query.s3_keys) & set()] def get_all_permitted_docs(self) -> list[Document]: - return [doc for doc in self.docs if doc.metadata["uri"] in set(self.query.permitted_s3_keys)] + return [doc for doc in self.docs if doc.metadata["uri"] in set()] def generate_test_cases(query: RedboxQuery, test_data: list[RedboxTestData], test_id: str) -> list[RedboxChatTestCase]: @@ -141,8 +137,6 @@ async def _aget_relevant_documents(self, query: str) -> list[Document]: return self.docs - - class GenericFakeChatModelWithTools(GenericFakeChatModel): """A thin wrapper to GenericFakeChatModel that allows tool binding.""" diff --git a/redbox-core/tests/conftest.py b/redbox-core/tests/conftest.py index fc35540d1..db8dac604 100644 --- a/redbox-core/tests/conftest.py +++ b/redbox-core/tests/conftest.py @@ -1,19 +1,13 @@ -from collections.abc import Generator from typing import TYPE_CHECKING import pytest import tiktoken -from _pytest.fixtures import FixtureRequest from botocore.exceptions import ClientError from elasticsearch import Elasticsearch from langchain_core.embeddings.fake import FakeEmbeddings -from langchain_elasticsearch import ElasticsearchStore from tiktoken.core import Encoding from redbox.models.settings import Settings -from redbox.retriever import DjangoFileRetriever -from redbox.test.data import RedboxChatTestCase -from tests.retriever.data import ALL_CHUNKS_RETRIEVER_CASES, METADATA_RETRIEVER_CASES, PARAMETERISED_RETRIEVER_CASES if TYPE_CHECKING: from mypy_boto3_s3.client import S3Client @@ -69,57 +63,3 @@ def es_index(env: Settings) -> str: @pytest.fixture(scope="session") def es_client(env: Settings) -> Elasticsearch: return env.elasticsearch_client() - - -@pytest.fixture(scope="session") -def es_vector_store( - es_client: Elasticsearch, es_index: str, embedding_model: FakeEmbeddings, env: Settings -) -> ElasticsearchStore: - return ElasticsearchStore( - index_name=es_index, - es_connection=es_client, - query_field="text", - vector_query_field=env.embedding_document_field_name, - embedding=embedding_model, - ) - - -class FakeFile: - def __init__(self, text, metadata): - self.text = text - self.metadata = metadata - - -class FileManager: - def __init__(self, docs): - self.docs = docs - - def filter(self, original_file__in, text__isnull, metadata__isnull): - for doc in self.docs: - yield FakeFile(doc.page_content, doc.metadata) - - -# -----# -# Data # -# -----# - - -@pytest.fixture(params=ALL_CHUNKS_RETRIEVER_CASES) -def stored_file_all_chunks(request: FixtureRequest) -> Generator[RedboxChatTestCase, None, None]: - test_case: RedboxChatTestCase = request.param - retriever = DjangoFileRetriever(file_manager=FileManager(test_case.docs)) - yield test_case, retriever - - -@pytest.fixture(params=PARAMETERISED_RETRIEVER_CASES) -def stored_file_parameterised(request: FixtureRequest) -> Generator[RedboxChatTestCase, None, None]: - test_case: RedboxChatTestCase = request.param - retriever = DjangoFileRetriever(file_manager=FileManager(test_case.docs)) - yield test_case, retriever - - -@pytest.fixture(params=METADATA_RETRIEVER_CASES) -def stored_file_metadata(request: FixtureRequest) -> Generator[RedboxChatTestCase, None, None]: - test_case: RedboxChatTestCase = request.param - retriever = DjangoFileRetriever(file_manager=FileManager(test_case.docs)) - yield test_case, retriever diff --git a/redbox-core/tests/graph/test_app.py b/redbox-core/tests/graph/test_app.py index 03b5b3f59..791a46acc 100644 --- a/redbox-core/tests/graph/test_app.py +++ b/redbox-core/tests/graph/test_app.py @@ -35,9 +35,7 @@ def assert_number_of_events(num_of_events: int): test_case for generated_cases in [ generate_test_cases( - query=RedboxQuery( - question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[] - ), + query=RedboxQuery(question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[]), test_data=[ RedboxTestData( number_of_docs=0, @@ -61,9 +59,7 @@ def assert_number_of_events(num_of_events: int): test_id="Basic Chat", ), generate_test_cases( - query=RedboxQuery( - question="What is AI?", s3_keys=["s3_key"], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[] - ), + query=RedboxQuery(question="What is AI?", s3_keys=["s3_key"], user_uuid=uuid4(), chat_history=[]), test_data=[ RedboxTestData( number_of_docs=1, @@ -127,7 +123,6 @@ def assert_number_of_events(num_of_events: int): s3_keys=["s3_key"], user_uuid=uuid4(), chat_history=[], - permitted_s3_keys=["s3_key"], ), test_data=[ RedboxTestData( @@ -145,7 +140,6 @@ def assert_number_of_events(num_of_events: int): s3_keys=[], user_uuid=uuid4(), chat_history=[], - permitted_s3_keys=[], ), test_data=[ RedboxTestData( @@ -163,7 +157,6 @@ def assert_number_of_events(num_of_events: int): s3_keys=["s3_key"], user_uuid=uuid4(), chat_history=[], - permitted_s3_keys=["s3_key"], ), test_data=[ RedboxTestData( diff --git a/redbox-core/tests/graph/test_patterns.py b/redbox-core/tests/graph/test_patterns.py index 5ebc1e2b5..07ac47547 100644 --- a/redbox-core/tests/graph/test_patterns.py +++ b/redbox-core/tests/graph/test_patterns.py @@ -30,7 +30,7 @@ LANGGRAPH_DEBUG = True CHAT_PROMPT_TEST_CASES = generate_test_cases( - query=RedboxQuery(question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[]), + query=RedboxQuery(question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[]), test_data=[ RedboxTestData( number_of_docs=0, @@ -62,7 +62,7 @@ def test_build_chat_prompt_from_messages_runnable(test_case: RedboxChatTestCase, BUILD_LLM_TEST_CASES = generate_test_cases( - query=RedboxQuery(question="What is AI?", file_uuids=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[]), + query=RedboxQuery(question="What is AI?", file_uuids=[], user_uuid=uuid4(), chat_history=[]), test_data=[ RedboxTestData( number_of_docs=2, @@ -107,7 +107,7 @@ def test_build_llm_chain(test_case: RedboxChatTestCase): CHAT_TEST_CASES = generate_test_cases( - query=RedboxQuery(question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[]), + query=RedboxQuery(question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[]), test_data=[ RedboxTestData( number_of_docs=0, @@ -139,7 +139,12 @@ def test_build_chat_pattern(test_case: RedboxChatTestCase, mocker: MockerFixture SET_ROUTE_TEST_CASES = generate_test_cases( - query=RedboxQuery(question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[]), + query=RedboxQuery( + question="What is AI?", + s3_keys=[], + user_uuid=uuid4(), + chat_history=[], + ), test_data=[ RedboxTestData( number_of_docs=0, @@ -178,7 +183,6 @@ def test_build_set_route_pattern(test_case: RedboxChatTestCase): s3_keys=["s3_key_1", "s3_key_2"], user_uuid=uuid4(), chat_history=[], - permitted_s3_keys=["s3_key_1", "s3_key_2"], ), test_data=[ RedboxTestData( @@ -204,15 +208,12 @@ def test_build_set_route_pattern(test_case: RedboxChatTestCase): ) - - STUFF_TEST_CASES = generate_test_cases( query=RedboxQuery( question="What is AI?", s3_keys=["s3_key_1", "s3_key_2"], user_uuid=uuid4(), chat_history=[], - permitted_s3_keys=["s3_key_1", "s3_key_2"], ), test_data=[ RedboxTestData( @@ -257,7 +258,6 @@ def test_build_stuff_pattern(test_case: RedboxChatTestCase, mocker: MockerFixtur s3_keys=["s3_key_1", "s3_key_2"], user_uuid=uuid4(), chat_history=[], - permitted_s3_keys=["s3_key_1", "s3_key_2"], ), test_data=[ RedboxTestData( @@ -276,7 +276,10 @@ def test_build_passthrough_pattern(): passthrough = build_passthrough_pattern() state = RedboxState( request=RedboxQuery( - question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[] + question="What is AI?", + s3_keys=[], + user_uuid=uuid4(), + chat_history=[], ), ) @@ -291,7 +294,10 @@ def test_build_set_text_pattern(): set_text = build_set_text_pattern(text="An hendy hap ychabbe ychent.") state = RedboxState( request=RedboxQuery( - question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[] + question="What is AI?", + s3_keys=[], + user_uuid=uuid4(), + chat_history=[], ), ) @@ -305,7 +311,10 @@ def test_empty_process(): """Tests the empty process doesn't touch the state whatsoever.""" state = RedboxState( request=RedboxQuery( - question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[] + question="What is AI?", + s3_keys=[], + user_uuid=uuid4(), + chat_history=[], ), documents=structure_documents_by_file_name([doc for doc in generate_docs(s3_key="s3_key")]), messages=[HumanMessage(content="Foo")], @@ -327,7 +336,10 @@ def test_empty_process(): CLEAR_DOC_TEST_CASES = [ RedboxState( request=RedboxQuery( - question="What is AI?", file_uuids=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[] + question="What is AI?", + file_uuids=[], + user_uuid=uuid4(), + chat_history=[], ), documents=structure_documents_by_file_name([doc for doc in generate_docs(s3_key="s3_key")]), messages=[HumanMessage(content="Foo")], @@ -335,7 +347,10 @@ def test_empty_process(): ), RedboxState( request=RedboxQuery( - question="What is AI?", file_uuids=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[] + question="What is AI?", + file_uuids=[], + user_uuid=uuid4(), + chat_history=[], ), documents={}, messages=[HumanMessage(content="Foo")], diff --git a/redbox-core/tests/retriever/__init__.py b/redbox-core/tests/retriever/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/redbox-core/tests/retriever/data.py b/redbox-core/tests/retriever/data.py deleted file mode 100644 index 7ecaa7252..000000000 --- a/redbox-core/tests/retriever/data.py +++ /dev/null @@ -1,158 +0,0 @@ -from uuid import uuid4 - -from redbox.models.chain import RedboxQuery -from redbox.models.file import ChunkResolution -from redbox.test.data import RedboxTestData, generate_test_cases - -ALL_CHUNKS_RETRIEVER_CASES = [ - test_case - for generator in [ - generate_test_cases( - query=RedboxQuery( - question="Irrelevant Question", - s3_keys=["s3_key"], - user_uuid=uuid4(), - chat_history=[], - permitted_s3_keys=["s3_key"], - ), - test_data=[ - RedboxTestData(number_of_docs=8, tokens_in_all_docs=8000, chunk_resolution=ChunkResolution.largest) - ], - test_id="Successful Path", - ), - generate_test_cases( - query=RedboxQuery( - question="Irrelevant Question", - s3_keys=["s3_key"], - user_uuid=uuid4(), - chat_history=[], - permitted_s3_keys=[], - ), - test_data=[ - RedboxTestData(number_of_docs=8, tokens_in_all_docs=8000, chunk_resolution=ChunkResolution.largest) - ], - test_id="No permitted S3 keys", - ), - generate_test_cases( - query=RedboxQuery( - question="Irrelevant Question", - s3_keys=[], - user_uuid=uuid4(), - chat_history=[], - permitted_s3_keys=["s3_key"], - ), - test_data=[ - RedboxTestData( - number_of_docs=8, - tokens_in_all_docs=8_000, - chunk_resolution=ChunkResolution.largest, - s3_keys=["s3_key"], - ) - ], - test_id="Empty keys but permitted", - ), - ] - for test_case in generator -] - -PARAMETERISED_RETRIEVER_CASES = [ - test_case - for generator in [ - generate_test_cases( - query=RedboxQuery( - question="Irrelevant Question", - s3_keys=["s3_key"], - user_uuid=uuid4(), - chat_history=[], - permitted_s3_keys=["s3_key"], - ), - test_data=[ - RedboxTestData(number_of_docs=8, tokens_in_all_docs=8000, chunk_resolution=ChunkResolution.normal) - ], - test_id="Successful Path", - ), - generate_test_cases( - query=RedboxQuery( - question="Irrelevant Question", - s3_keys=["s3_key"], - user_uuid=uuid4(), - chat_history=[], - permitted_s3_keys=[], - ), - test_data=[ - RedboxTestData(number_of_docs=8, tokens_in_all_docs=8000, chunk_resolution=ChunkResolution.normal) - ], - test_id="No permitted S3 keys", - ), - generate_test_cases( - query=RedboxQuery( - question="Irrelevant Question", - s3_keys=[], - user_uuid=uuid4(), - chat_history=[], - permitted_s3_keys=["s3_key"], - ), - test_data=[ - RedboxTestData( - number_of_docs=8, - tokens_in_all_docs=8_000, - chunk_resolution=ChunkResolution.normal, - s3_keys=["s3_key"], - ) - ], - test_id="Empty keys but permitted", - ), - ] - for test_case in generator -] - -METADATA_RETRIEVER_CASES = [ - test_case - for generator in [ - generate_test_cases( - query=RedboxQuery( - question="Irrelevant Question", - s3_keys=["s3_key"], - user_uuid=uuid4(), - chat_history=[], - permitted_s3_keys=["s3_key"], - ), - test_data=[ - RedboxTestData(number_of_docs=8, tokens_in_all_docs=8000, chunk_resolution=ChunkResolution.largest) - ], - test_id="Successful Path", - ), - generate_test_cases( - query=RedboxQuery( - question="Irrelevant Question", - s3_keys=["s3_key"], - user_uuid=uuid4(), - chat_history=[], - permitted_s3_keys=[], - ), - test_data=[ - RedboxTestData(number_of_docs=8, tokens_in_all_docs=8000, chunk_resolution=ChunkResolution.largest) - ], - test_id="No permitted S3 keys", - ), - generate_test_cases( - query=RedboxQuery( - question="Irrelevant Question", - s3_keys=[], - user_uuid=uuid4(), - chat_history=[], - permitted_s3_keys=["s3_key"], - ), - test_data=[ - RedboxTestData( - number_of_docs=8, - tokens_in_all_docs=8_000, - chunk_resolution=ChunkResolution.largest, - s3_keys=["s3_key"], - ) - ], - test_id="Empty keys but permitted", - ), - ] - for test_case in generator -] diff --git a/redbox-core/tests/retriever/test_retriever.py b/redbox-core/tests/retriever/test_retriever.py deleted file mode 100644 index fb2d2065b..000000000 --- a/redbox-core/tests/retriever/test_retriever.py +++ /dev/null @@ -1,95 +0,0 @@ -from redbox.models.chain import RedboxState -from redbox.test.data import RedboxChatTestCase - -TEST_CHAIN_PARAMETERS = ( - { - "rag_k": 1, - "rag_num_candidates": 100, - "match_boost": 1, - "knn_boost": 2, - "similarity_threshold": 0, - "elbow_filter_enabled": True, - "rag_gauss_scale_size": 3, - "rag_gauss_scale_decay": 0.5, - "rag_gauss_scale_min": 1.1, - "rag_gauss_scale_max": 2.0, - }, - { - "rag_k": 2, - "rag_num_candidates": 100, - "match_boost": 1, - "knn_boost": 2, - "similarity_threshold": 0, - "elbow_filter_enabled": False, - "rag_gauss_scale_size": 1, - "rag_gauss_scale_decay": 0.1, - "rag_gauss_scale_min": 1.0, - "rag_gauss_scale_max": 1.0, - }, -) - - -def test_retriever(stored_file_all_chunks: RedboxChatTestCase): - """ - Given a RedboxState, asserts: - - * If documents are selected and there's permission to get them - * The length of the result is equal to the total stored chunks - * The result page content is identical to all possible correct - page content - * The result contains exactly file_names the user selected - * The result contains a subset of file_names from permitted S3 keys - * If documents are selected and there's no permission to get them - * The length of the result is zero - * If documents aren't selected and there's permission to get them - * The length of the result is zero - * If documents aren't selected and there's no permission to get them - * The length of the result is zero - """ - stored_file_all_chunks, retriever = stored_file_all_chunks - - result = retriever.invoke(RedboxState(request=stored_file_all_chunks.query)) - correct = stored_file_all_chunks.get_docs_matching_query() - - selected = bool(stored_file_all_chunks.query.s3_keys) - permission = bool(stored_file_all_chunks.query.permitted_s3_keys) - - if selected and permission: - assert len(result) == len(correct) - assert {c.page_content for c in result} == {c.page_content for c in correct} - assert {c.metadata["uri"] for c in result} == set(stored_file_all_chunks.query.s3_keys) - assert {c.metadata["uri"] for c in result} <= set(stored_file_all_chunks.query.permitted_s3_keys) - else: - len(result) == 0 - - -def test_metadata_retriever(stored_file_metadata: RedboxChatTestCase): - """ - Given a RedboxState, asserts: - - * If documents are selected and there's permission to get them - * The length of the result is equal to the total stored chunks - * The result contains exactly file_names the user selected - * The result contains a subset of file_names from permitted S3 keys - * If documents are selected and there's no permission to get them - * The length of the result is zero - * If documents aren't selected and there's permission to get them - * The length of the result is zero - * If documents aren't selected and there's no permission to get them - * The length of the result is zero - """ - - stored_file_metadata, retriever = stored_file_metadata - - result = retriever.invoke(RedboxState(request=stored_file_metadata.query)) - correct = stored_file_metadata.get_docs_matching_query() - - selected = bool(stored_file_metadata.query.s3_keys) - permission = bool(stored_file_metadata.query.permitted_s3_keys) - - if selected and permission: - assert len(result) == len(correct) - assert {c.metadata["uri"] for c in result} == set(stored_file_metadata.query.s3_keys) - assert {c.metadata["uri"] for c in result} <= set(stored_file_metadata.query.permitted_s3_keys) - else: - len(result) == 0