Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
gecBurton committed Jan 9, 2025
1 parent 3755cb4 commit e8736fb
Show file tree
Hide file tree
Showing 15 changed files with 37 additions and 402 deletions.
4 changes: 1 addition & 3 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
RequestMetadata,
metadata_reducer,
)
from redbox.retriever import DjangoFileRetriever
from redbox_app.redbox_core import error_messages
from redbox_app.redbox_core.models import AISettings as AISettingsModel
from redbox_app.redbox_core.models import (
Expand Down Expand Up @@ -132,8 +131,7 @@ async def llm_conversation(
for message in message_history[:-1]
],
ai_settings=ai_settings,
permitted_s3_keys=[f.unique_name async for f in permitted_files],
selected_files_total_tokens=sum(f.metadata.get("token_count", 0) async for f in permitted_files)
selected_files_total_tokens=sum(f.metadata.get("token_count", 0) async for f in permitted_files),
),
)

Expand Down
5 changes: 0 additions & 5 deletions django_app/tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,6 @@ async def test_chat_consumer_redbox_state(

selected_file_uuids: Sequence[str] = [str(f.id) for f in selected_files]
selected_file_keys: Sequence[str] = [f.unique_name for f in selected_files]
permitted_file_keys: Sequence[str] = [
f.unique_name async for f in File.objects.filter(user=alice, status=File.Status.complete)
]
assert selected_file_keys != permitted_file_keys

await communicator.send_json_to(
{
Expand Down Expand Up @@ -482,7 +478,6 @@ async def test_chat_consumer_redbox_state(
{"role": "ai", "text": "A second answer."},
],
ai_settings=ai_settings,
permitted_s3_keys=permitted_file_keys,
)
redbox_state = mock_run.call_args.args[0] # pulls out the args that redbox.run was called with

Expand Down
1 change: 0 additions & 1 deletion redbox-core/redbox/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from logging import getLogger
from typing import Literal

from langchain_core.vectorstores import VectorStoreRetriever

from redbox.graph.root import (
get_chat_with_documents_graph,
Expand Down
12 changes: 3 additions & 9 deletions redbox-core/redbox/graph/nodes/processes.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import logging
import re
from collections.abc import Callable
from typing import Any

from langchain.schema import StrOutputParser
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.tools import StructuredTool
from langchain_core.vectorstores import VectorStoreRetriever

from redbox.chains.components import get_chat_llm
from redbox.chains.runnables import CannedChatLLM, build_llm_chain
from redbox.models import ChatRoute
from redbox.models.chain import DocumentState, PromptSet, RedboxState, RequestMetadata
from redbox.models.graph import ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG
from redbox.transform import flatten_document_state
from redbox.models.chain import DocumentState, PromptSet, RedboxState
from redbox.models.graph import ROUTE_NAME_TAG

log = logging.getLogger(__name__)
re_keyword_pattern = re.compile(r"@(\w+)")
Expand All @@ -26,8 +22,6 @@
## Core patterns




def build_chat_pattern(
prompt_set: PromptSet,
tools: list[StructuredTool] | None = None,
Expand Down
3 changes: 0 additions & 3 deletions redbox-core/redbox/graph/root.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from langchain_core.vectorstores import VectorStoreRetriever
from langgraph.graph import END, START, StateGraph
from langgraph.graph.graph import CompiledGraph

Expand Down Expand Up @@ -89,8 +88,6 @@ def get_chat_with_documents_graph(
return builder.compile(debug=debug)




# Root graph
def get_root_graph(
debug: bool = False,
Expand Down
2 changes: 0 additions & 2 deletions redbox-core/redbox/models/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,9 @@ class RedboxQuery(BaseModel):
user_uuid: UUID = Field(description="User the chain in executing for")
chat_history: list[ChainChatMessage] = Field(description="All previous messages in chat (excluding question)")
ai_settings: AISettings = Field(description="User request AI settings", default_factory=AISettings)
permitted_s3_keys: list[str] = Field(description="List of permitted files for response", default_factory=list)
selected_files_total_tokens: int = 0



class LLMCallMetadata(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
llm_model_name: str
Expand Down
3 changes: 0 additions & 3 deletions redbox-core/redbox/retriever/__init__.py

This file was deleted.

32 changes: 0 additions & 32 deletions redbox-core/redbox/retriever/retrievers.py

This file was deleted.

10 changes: 2 additions & 8 deletions redbox-core/redbox/test/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,10 @@ def __init__(
self.test_id = test_id

def get_docs_matching_query(self) -> list[Document]:
return [
doc
for doc in self.docs
if doc.metadata["uri"] in set(self.query.s3_keys) & set(self.query.permitted_s3_keys)
]
return [doc for doc in self.docs if doc.metadata["uri"] in set(self.query.s3_keys) & set()]

def get_all_permitted_docs(self) -> list[Document]:
return [doc for doc in self.docs if doc.metadata["uri"] in set(self.query.permitted_s3_keys)]
return [doc for doc in self.docs if doc.metadata["uri"] in set()]


def generate_test_cases(query: RedboxQuery, test_data: list[RedboxTestData], test_id: str) -> list[RedboxChatTestCase]:
Expand All @@ -141,8 +137,6 @@ async def _aget_relevant_documents(self, query: str) -> list[Document]:
return self.docs




class GenericFakeChatModelWithTools(GenericFakeChatModel):
"""A thin wrapper to GenericFakeChatModel that allows tool binding."""

Expand Down
60 changes: 0 additions & 60 deletions redbox-core/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from collections.abc import Generator
from typing import TYPE_CHECKING

import pytest
import tiktoken
from _pytest.fixtures import FixtureRequest
from botocore.exceptions import ClientError
from elasticsearch import Elasticsearch
from langchain_core.embeddings.fake import FakeEmbeddings
from langchain_elasticsearch import ElasticsearchStore
from tiktoken.core import Encoding

from redbox.models.settings import Settings
from redbox.retriever import DjangoFileRetriever
from redbox.test.data import RedboxChatTestCase
from tests.retriever.data import ALL_CHUNKS_RETRIEVER_CASES, METADATA_RETRIEVER_CASES, PARAMETERISED_RETRIEVER_CASES

if TYPE_CHECKING:
from mypy_boto3_s3.client import S3Client
Expand Down Expand Up @@ -69,57 +63,3 @@ def es_index(env: Settings) -> str:
@pytest.fixture(scope="session")
def es_client(env: Settings) -> Elasticsearch:
return env.elasticsearch_client()


@pytest.fixture(scope="session")
def es_vector_store(
es_client: Elasticsearch, es_index: str, embedding_model: FakeEmbeddings, env: Settings
) -> ElasticsearchStore:
return ElasticsearchStore(
index_name=es_index,
es_connection=es_client,
query_field="text",
vector_query_field=env.embedding_document_field_name,
embedding=embedding_model,
)


class FakeFile:
def __init__(self, text, metadata):
self.text = text
self.metadata = metadata


class FileManager:
def __init__(self, docs):
self.docs = docs

def filter(self, original_file__in, text__isnull, metadata__isnull):
for doc in self.docs:
yield FakeFile(doc.page_content, doc.metadata)


# -----#
# Data #
# -----#


@pytest.fixture(params=ALL_CHUNKS_RETRIEVER_CASES)
def stored_file_all_chunks(request: FixtureRequest) -> Generator[RedboxChatTestCase, None, None]:
test_case: RedboxChatTestCase = request.param
retriever = DjangoFileRetriever(file_manager=FileManager(test_case.docs))
yield test_case, retriever


@pytest.fixture(params=PARAMETERISED_RETRIEVER_CASES)
def stored_file_parameterised(request: FixtureRequest) -> Generator[RedboxChatTestCase, None, None]:
test_case: RedboxChatTestCase = request.param
retriever = DjangoFileRetriever(file_manager=FileManager(test_case.docs))
yield test_case, retriever


@pytest.fixture(params=METADATA_RETRIEVER_CASES)
def stored_file_metadata(request: FixtureRequest) -> Generator[RedboxChatTestCase, None, None]:
test_case: RedboxChatTestCase = request.param
retriever = DjangoFileRetriever(file_manager=FileManager(test_case.docs))
yield test_case, retriever
11 changes: 2 additions & 9 deletions redbox-core/tests/graph/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def assert_number_of_events(num_of_events: int):
test_case
for generated_cases in [
generate_test_cases(
query=RedboxQuery(
question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[]
),
query=RedboxQuery(question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[]),
test_data=[
RedboxTestData(
number_of_docs=0,
Expand All @@ -61,9 +59,7 @@ def assert_number_of_events(num_of_events: int):
test_id="Basic Chat",
),
generate_test_cases(
query=RedboxQuery(
question="What is AI?", s3_keys=["s3_key"], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[]
),
query=RedboxQuery(question="What is AI?", s3_keys=["s3_key"], user_uuid=uuid4(), chat_history=[]),
test_data=[
RedboxTestData(
number_of_docs=1,
Expand Down Expand Up @@ -127,7 +123,6 @@ def assert_number_of_events(num_of_events: int):
s3_keys=["s3_key"],
user_uuid=uuid4(),
chat_history=[],
permitted_s3_keys=["s3_key"],
),
test_data=[
RedboxTestData(
Expand All @@ -145,7 +140,6 @@ def assert_number_of_events(num_of_events: int):
s3_keys=[],
user_uuid=uuid4(),
chat_history=[],
permitted_s3_keys=[],
),
test_data=[
RedboxTestData(
Expand All @@ -163,7 +157,6 @@ def assert_number_of_events(num_of_events: int):
s3_keys=["s3_key"],
user_uuid=uuid4(),
chat_history=[],
permitted_s3_keys=["s3_key"],
),
test_data=[
RedboxTestData(
Expand Down
Loading

0 comments on commit e8736fb

Please sign in to comment.