Skip to content

Commit

Permalink
reformatting (#1081)
Browse files Browse the repository at this point in the history
  • Loading branch information
gecBurton authored Oct 7, 2024
1 parent 808c46b commit a7990cc
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 113 deletions.
12 changes: 3 additions & 9 deletions redbox-core/redbox/chains/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,11 @@ def log_chunks(chunks: list[Document]):
return chunks


def document_loader(
document_loader: UnstructuredChunkLoader, s3_client: S3Client, env: Settings
) -> Runnable:
def document_loader(document_loader: UnstructuredChunkLoader, s3_client: S3Client, env: Settings) -> Runnable:
@chain
def wrapped(file_name: str) -> Iterator[Document]:
file_bytes = s3_client.get_object(Bucket=env.bucket_name, Key=file_name)[
"Body"
].read()
return document_loader.lazy_load(
file_name=file_name, file_bytes=BytesIO(file_bytes)
)
file_bytes = s3_client.get_object(Bucket=env.bucket_name, Key=file_name)["Body"].read()
return document_loader.lazy_load(file_name=file_name, file_bytes=BytesIO(file_bytes))

return wrapped

Expand Down
6 changes: 4 additions & 2 deletions redbox-core/redbox/graph/nodes/sends.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def build_document_group_send(target: str) -> Callable[[RedboxState], list[Send]

def _group_send(state: RedboxState) -> list[Send]:
group_send_states: list[RedboxState] = [
_copy_state(state,
_copy_state(
state,
documents={document_group_key: document_group},
)
for document_group_key, document_group in state["documents"].items()
Expand All @@ -30,7 +31,8 @@ def build_document_chunk_send(target: str) -> Callable[[RedboxState], list[Send]

def _chunk_send(state: RedboxState) -> list[Send]:
chunk_send_states: list[RedboxState] = [
_copy_state(state,
_copy_state(
state,
documents={document_group_key: {document_key: document}},
)
for document_group_key, document_group in state["documents"].items()
Expand Down
6 changes: 3 additions & 3 deletions redbox-core/redbox/loader/ingester.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def ingest_file(file_name: str) -> str | None:
)

try:
new_ids = RunnableParallel(
{"normal": chunk_ingest_chain, "largest": large_chunk_ingest_chain}
).invoke(file_name)
new_ids = RunnableParallel({"normal": chunk_ingest_chain, "largest": large_chunk_ingest_chain}).invoke(
file_name
)
logging.info(
"File: %s %s chunks ingested",
file_name,
Expand Down
37 changes: 9 additions & 28 deletions redbox-core/redbox/loader/loaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
from ast import Pass
from collections.abc import Iterator
from datetime import UTC, datetime
from io import BytesIO
Expand All @@ -9,10 +8,8 @@
import requests
import tiktoken
from langchain_core.documents import Document
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from requests.exceptions import HTTPError

from redbox.chains.components import get_chat_llm
from redbox.models.chain import AISettings
Expand All @@ -31,9 +28,7 @@


class MetadataLoader:
def __init__(
self, env: Settings, s3_client: S3Client, file_name: str, metadata: dict = None
):
def __init__(self, env: Settings, s3_client: S3Client, file_name: str, metadata: dict = None):
self.env = env
self.s3_client = s3_client
self.llm = get_chat_llm(self.env, AISettings())
Expand All @@ -53,9 +48,7 @@ def get_first_n_tokens(self, chunks: list[dict], n: int) -> str:
tokens += chunk["text"]
return tokens

def get_doc_metadata(
self, chunks: list[dict], n: int, ignore: list[str] = None
) -> dict[str, Any]:
def get_doc_metadata(self, chunks: list[dict], n: int, ignore: list[str] = None) -> dict[str, Any]:
"""
Use the first n chunks to get metadata using unstructured.
Metadata keys in the ignore list will be excluded from the result.
Expand All @@ -64,9 +57,7 @@ def get_doc_metadata(
for i, chunk in enumerate(chunks):
if i > n:
return metadata
metadata = self.merge_unstructured_metadata(
metadata, chunk["metadata"], ignore
)
metadata = self.merge_unstructured_metadata(metadata, chunk["metadata"], ignore)
return metadata

@staticmethod
Expand All @@ -87,9 +78,7 @@ def merge_unstructured_metadata(x: dict, y: dict, ignore: set[str] = None) -> di

if key in x and key in y:
if isinstance(x[key], list) or isinstance(y[key], list):
combined[key] = list(
set(x[key] + (y[key] if isinstance(y[key], list) else [y[key]]))
)
combined[key] = list(set(x[key] + (y[key] if isinstance(y[key], list) else [y[key]])))
else:
combined[key] = [x[key], y[key]]
elif key in x:
Expand All @@ -100,17 +89,13 @@ def merge_unstructured_metadata(x: dict, y: dict, ignore: set[str] = None) -> di
return combined

def _get_file_bytes(self, s3_client: S3Client, file_name: str) -> BytesIO:
return s3_client.get_object(Bucket=self.env.bucket_name, Key=file_name)[
"Body"
].read()
return s3_client.get_object(Bucket=self.env.bucket_name, Key=file_name)["Body"].read()

def _chunking(self) -> Any:
"""
Chunking data using local unstructured
"""
file_bytes = self._get_file_bytes(
s3_client=self.s3_client, file_name=self.file_name
)
file_bytes = self._get_file_bytes(s3_client=self.s3_client, file_name=self.file_name)
url = f"http://{self.env.unstructured_host}:8000/general/v0/general"
files = {
"files": (self.file_name, file_bytes),
Expand Down Expand Up @@ -161,9 +146,7 @@ def extract_metadata(self):
# missing keys
self.metadata = self.default_metadata

def create_file_metadata(
self, page_content: str, metadata: dict[str, Any]
) -> dict[str, Any]:
def create_file_metadata(self, page_content: str, metadata: dict[str, Any]) -> dict[str, Any]:
"""Uses a sample of the document and any extracted metadata to generate further metadata."""
metadata_chain = (
ChatPromptTemplate.from_messages(
Expand All @@ -187,12 +170,10 @@ def create_file_metadata(
)

try:
return metadata_chain.invoke(
{"page_content": page_content, "metadata": metadata}
)
return metadata_chain.invoke({"page_content": page_content, "metadata": metadata})
except ConnectionError as e:
logger.warning(f"Retrying due to HTTPError {e}")
except json.JSONDecodeError as e:
except json.JSONDecodeError:
# replace with fail safe metadata
return None
except Exception as e:
Expand Down
16 changes: 4 additions & 12 deletions redbox-core/redbox/models/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,8 @@ class Settings(BaseSettings):
worker_ingest_largest_chunk_size: int = 300_000
worker_ingest_largest_chunk_overlap: int = 0

response_no_doc_available: str = (
"No available data for selected files. They may need to be removed and added again"
)
response_max_content_exceeded: str = (
"Max content exceeded. Try smaller or fewer documents"
)
response_no_doc_available: str = "No available data for selected files. They may need to be removed and added again"
response_max_content_exceeded: str = "Max content exceeded. Try smaller or fewer documents"

object_store: str = "minio"

Expand All @@ -124,9 +120,7 @@ class Settings(BaseSettings):

unstructured_host: str = "unstructured"

model_config = SettingsConfigDict(
env_file=".env", env_nested_delimiter="__", extra="allow", frozen=True
)
model_config = SettingsConfigDict(env_file=".env", env_nested_delimiter="__", extra="allow", frozen=True)

## Prompts
metadata_prompt: tuple = (
Expand Down Expand Up @@ -155,9 +149,7 @@ def elasticsearch_client(self) -> Elasticsearch:
basic_auth=(self.elastic.user, self.elastic.password),
)
else:
client = Elasticsearch(
cloud_id=self.elastic.cloud_id, api_key=self.elastic.api_key
)
client = Elasticsearch(cloud_id=self.elastic.cloud_id, api_key=self.elastic.api_key)

return client.options(request_timeout=30, retry_on_timeout=True, max_retries=3)

Expand Down
21 changes: 5 additions & 16 deletions redbox-core/redbox/test/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,24 +118,16 @@ def get_docs_matching_query(self) -> list[Document]:
return [
doc
for doc in self.docs
if doc.metadata["file_name"]
in set(self.query.s3_keys) & set(self.query.permitted_s3_keys)
if doc.metadata["file_name"] in set(self.query.s3_keys) & set(self.query.permitted_s3_keys)
]

def get_all_permitted_docs(self) -> list[Document]:
return [
doc
for doc in self.docs
if doc.metadata["file_name"] in set(self.query.permitted_s3_keys)
]
return [doc for doc in self.docs if doc.metadata["file_name"] in set(self.query.permitted_s3_keys)]


def generate_test_cases(
query: RedboxQuery, test_data: list[RedboxTestData], test_id: str
) -> list[RedboxChatTestCase]:
def generate_test_cases(query: RedboxQuery, test_data: list[RedboxTestData], test_id: str) -> list[RedboxChatTestCase]:
return [
RedboxChatTestCase(test_id=f"{test_id}-{i}", query=query, test_data=data)
for i, data in enumerate(test_data)
RedboxChatTestCase(test_id=f"{test_id}-{i}", query=query, test_data=data) for i, data in enumerate(test_data)
]


Expand All @@ -158,8 +150,5 @@ def mock_parameterised_retriever(docs: list[Document]) -> FakeRetriever:


def mock_metadata_retriever(docs: list[Document]) -> FakeRetriever:
metadata_only_docs = [
Document(page_content="", metadata={**doc.metadata, "embedding": None})
for doc in docs
]
metadata_only_docs = [Document(page_content="", metadata={**doc.metadata, "embedding": None}) for doc in docs]
return FakeRetriever(docs=metadata_only_docs)
56 changes: 13 additions & 43 deletions redbox-core/tests/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_elasticsearch import ElasticsearchStore

import redbox
import redbox.loader
import redbox.loader.loaders
from redbox.chains.ingest import document_loader, ingest_from_loader
from redbox.loader import ingester
from redbox.loader.ingester import ingest_file
Expand Down Expand Up @@ -48,9 +45,7 @@ def file_to_s3(filename: str, s3_client: S3Client, env: Settings) -> str:
return file_name


def make_file_query(
file_name: str, resolution: ChunkResolution | None = None
) -> dict[str, Any]:
def make_file_query(file_name: str, resolution: ChunkResolution | None = None) -> dict[str, Any]:
query_filter = build_query_filter(
selected_files=[file_name],
permitted_files=[file_name],
Expand Down Expand Up @@ -109,12 +104,11 @@ def test_extract_metadata(
metadata.extract_metadata()

if test_case in ["missing_key"]:

assert metadata.metadata.get("name") == ""
assert metadata.metadata.get("description") == ""
assert metadata.metadata.get("keywords") == ""
else:
assert not metadata.metadata is None
assert metadata.metadata is not None
llm_json = json.loads(llm_response[0])
assert metadata.metadata.get("name") == llm_json.get("name")
assert metadata.metadata.get("description") == llm_json.get("description")
Expand Down Expand Up @@ -153,9 +147,7 @@ def test_document_loader(

mock_llm_response = mock_llm.return_value
mock_llm_response.status_code = 200
mock_llm_response.return_value = GenericFakeChatModel(
messages=iter([json.dumps(fake_llm_response())])
)
mock_llm_response.return_value = GenericFakeChatModel(messages=iter([json.dumps(fake_llm_response())]))

# Upload file
file = file_to_s3("html/example.html", s3_client, env)
Expand All @@ -182,9 +174,7 @@ def test_document_loader(
llm_response = fake_llm_response()
assert chuck.metadata["name"] == llm_response["name"]
assert chuck.metadata["description"] == llm_response["description"]
assert chuck.metadata["keywords"] == coerce_to_string_list(
llm_response["keywords"]
)
assert chuck.metadata["keywords"] == coerce_to_string_list(llm_response["keywords"])


@patch("redbox.loader.loaders.get_chat_llm")
Expand Down Expand Up @@ -231,9 +221,7 @@ def test_ingest_from_loader(

mock_llm_response = mock_llm.return_value
mock_llm_response.status_code = 200
mock_llm_response.return_value = GenericFakeChatModel(
messages=iter([json.dumps(fake_llm_response())])
)
mock_llm_response.return_value = GenericFakeChatModel(messages=iter([json.dumps(fake_llm_response())]))

# Upload file and call
file_name = file_to_s3(filename="html/example.html", s3_client=s3_client, env=env)
Expand All @@ -253,20 +241,14 @@ def test_ingest_from_loader(
# Mock embeddings
monkeypatch.setattr(ingester, "get_embeddings", lambda _: FakeEmbeddings(size=3072))

ingest_chain = ingest_from_loader(
loader=loader, s3_client=s3_client, vectorstore=es_vector_store, env=env
)
ingest_chain = ingest_from_loader(loader=loader, s3_client=s3_client, vectorstore=es_vector_store, env=env)

_ = ingest_chain.invoke(file_name)

# Test it's written to Elastic
file_query = make_file_query(file_name=file_name, resolution=resolution)

chunks = list(
scan(
client=es_client, index=f"{env.elastic_root_index}-chunk", query=file_query
)
)
chunks = list(scan(client=es_client, index=f"{env.elastic_root_index}-chunk", query=file_query))
assert len(chunks) > 0

def get_metadata(chunk: dict) -> dict:
Expand All @@ -278,9 +260,7 @@ def get_metadata(chunk: dict) -> dict:
metadata = get_metadata(chunk)
assert metadata["name"] == fake_llm_response()["name"]
assert metadata["description"] == fake_llm_response()["description"]
assert metadata["keywords"] == coerce_to_string_list(
fake_llm_response()["keywords"]
)
assert metadata["keywords"] == coerce_to_string_list(fake_llm_response()["keywords"])

if has_embeddings:
embeddings = chunks[0]["_source"].get("embedding")
Expand Down Expand Up @@ -348,9 +328,7 @@ def test_ingest_file(
# Mock llm
mock_llm_response = mock_llm.return_value
mock_llm_response.status_code = 200
mock_llm_response.return_value = GenericFakeChatModel(
messages=iter([json.dumps(fake_llm_response())])
)
mock_llm_response.return_value = GenericFakeChatModel(messages=iter([json.dumps(fake_llm_response())]))

res = ingest_file(filename)

Expand Down Expand Up @@ -380,24 +358,16 @@ def get_metadata(chunk: dict) -> dict:
llm_response = fake_llm_response()
assert metadata["name"] == llm_response["name"]
assert metadata["description"] == llm_response["description"]
assert metadata["keywords"] == coerce_to_string_list(
llm_response["keywords"]
)
assert metadata["keywords"] == coerce_to_string_list(llm_response["keywords"])

def get_chunk_resolution(chunk: dict) -> str:
return chunk["_source"]["metadata"]["chunk_resolution"]

normal_resolution = [
chunk for chunk in chunks if get_chunk_resolution(chunk) == "normal"
]
largest_resolution = [
chunk for chunk in chunks if get_chunk_resolution(chunk) == "largest"
]
normal_resolution = [chunk for chunk in chunks if get_chunk_resolution(chunk) == "normal"]
largest_resolution = [chunk for chunk in chunks if get_chunk_resolution(chunk) == "largest"]

assert len(normal_resolution) > 0
assert len(largest_resolution) > 0

# Teardown
es_client.delete_by_query(
index=f"{env.elastic_root_index}-chunk", body=file_query
)
es_client.delete_by_query(index=f"{env.elastic_root_index}-chunk", body=file_query)

0 comments on commit a7990cc

Please sign in to comment.