From 8e7e17e8606c0d1f07b25964ab8ad337ccf36716 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Mon, 23 Sep 2024 00:31:06 +0100 Subject: [PATCH] refactor: make embedder a dependency --- server/api/debug/redis.py | 12 ++++++---- server/api/v1/chat.py | 16 ++++++++------ server/dependencies/__init__.py | 1 + server/dependencies/embedder.py | 22 +++++++++++++++++++ server/features/chat/model.py | 2 +- server/features/embeddings/__init__.py | 2 +- server/features/embeddings/embedding.py | 8 +++---- server/helpers/__init__.py | 1 - server/helpers/network/__init__.py | 3 --- server/lifespans/chat_model.py | 1 + server/lifespans/create_redis_index.py | 1 + server/lifespans/download_embeddings.py | 5 +++-- server/lifespans/download_nltk.py | 1 + server/utils/__init__.py | 1 + server/utils/network/__init__.py | 3 +++ .../network/has_internet_access.py | 0 .../network/huggingface_download.py | 2 +- tests/test_embedding.py | 20 +++++++++++------ tests/test_has_internet_access.py | 2 +- 19 files changed, 71 insertions(+), 32 deletions(-) create mode 100644 server/dependencies/embedder.py delete mode 100644 server/helpers/__init__.py delete mode 100644 server/helpers/network/__init__.py create mode 100644 server/utils/__init__.py create mode 100644 server/utils/network/__init__.py rename server/{helpers => utils}/network/has_internet_access.py (100%) rename server/{helpers => utils}/network/huggingface_download.py (86%) diff --git a/server/api/debug/redis.py b/server/api/debug/redis.py index 241c0c8..2107d19 100644 --- a/server/api/debug/redis.py +++ b/server/api/debug/redis.py @@ -6,8 +6,8 @@ from server.config import Config from server.databases.redis.wrapper import RedisAsync -from server.dependencies import redis_client -from server.features.embeddings import Embedding +from server.dependencies import embedder, redis_client +from server.features.embeddings import Embedder from server.schemas.v1 import Query @@ -19,7 +19,10 @@ class RedisController(Controller): """ path = '/redis' - dependencies = {'redis': Provide(redis_client)} + dependencies = { + 'redis': Provide(redis_client), + 'embedder': Provide(embedder), + } @delete() async def delete_index(self, redis: Annotated[RedisAsync, Dependency()], recreate: bool = False) -> None: @@ -42,6 +45,7 @@ async def delete_index(self, redis: Annotated[RedisAsync, Dependency()], recreat async def search( self, redis: Annotated[RedisAsync, Dependency()], + embedder: Annotated[Embedder, Dependency()], chat_id: str, data: Query, search_size: Annotated[int, Parameter(gt=0)] = 1, @@ -51,4 +55,4 @@ async def search( ------- an endpoint for searching the Redis vector database """ - return await redis.search(chat_id, Embedding().encode_query(data.query), search_size) + return await redis.search(chat_id, embedder.encode_query(data.query), search_size) diff --git a/server/api/v1/chat.py b/server/api/v1/chat.py index 7b9288d..bd299ba 100644 --- a/server/api/v1/chat.py +++ b/server/api/v1/chat.py @@ -10,9 +10,9 @@ from server.databases.redis.features import store_chunks from server.databases.redis.wrapper import RedisAsync -from server.dependencies.redis import redis_client +from server.dependencies import embedder, redis_client from server.features.chunking import SentenceSplitter, chunk_document -from server.features.embeddings import Embedding +from server.features.embeddings import Embedder from server.features.extraction import extract_documents_from_pdfs from server.features.question_answering import question_answering from server.schemas.v1 import Answer, Chat, Files, Query @@ -27,7 +27,10 @@ class ChatController(Controller): """ path = '/chats' - dependencies = {'redis': Provide(redis_client)} + dependencies = { + 'redis': Provide(redis_client), + 'embedder': Provide(embedder), + } @get() async def create_chat(self) -> Chat: @@ -75,6 +78,7 @@ async def upload_files( self, state: AppState, redis: Annotated[RedisAsync, Dependency()], + embedder: Annotated[Embedder, Dependency()], chat_id: str, data: Annotated[list[UploadFile], Body(media_type=RequestEncodingType.MULTI_PART)], ) -> Files: @@ -83,7 +87,6 @@ async def upload_files( ------- an endpoint for uploading files to a chat """ - embedder = Embedding() text_splitter = SentenceSplitter(state.chat.tokeniser, chunk_size=128, chunk_overlap=0) responses = [] @@ -109,6 +112,7 @@ async def query( self, state: AppState, redis: Annotated[RedisAsync, Dependency()], + embedder: Annotated[Embedder, Dependency()], chat_id: str, data: Query, search_size: Annotated[int, Parameter(ge=0)] = 0, @@ -119,9 +123,7 @@ async def query( ------- the `/query` route provides an endpoint for performning retrieval-augmented generation """ - context = ( - '' if not search_size else await redis.search(chat_id, Embedding().encode_query(data.query), search_size) - ) + context = '' if not search_size else await redis.search(chat_id, embedder.encode_query(data.query), search_size) message_history = await redis.get_messages(chat_id) messages = await question_answering(data.query, context, message_history, state.chat.query) diff --git a/server/dependencies/__init__.py b/server/dependencies/__init__.py index d0351b2..417b725 100644 --- a/server/dependencies/__init__.py +++ b/server/dependencies/__init__.py @@ -1 +1,2 @@ +from server.dependencies.embedder import embedder as embedder from server.dependencies.redis import redis_client as redis_client diff --git a/server/dependencies/embedder.py b/server/dependencies/embedder.py new file mode 100644 index 0000000..1f2c239 --- /dev/null +++ b/server/dependencies/embedder.py @@ -0,0 +1,22 @@ +from typing import Iterator + +from server.features.embeddings import Embedder + + +def embedder() -> Iterator[Embedder]: + """ + Summary + ------- + load the embeddings model + + Returns + ------- + embedding (Embedding): the embeddings model + """ + embedder = Embedder() + + try: + yield embedder + + finally: + del embedder diff --git a/server/features/chat/model.py b/server/features/chat/model.py index c067de8..7fede68 100644 --- a/server/features/chat/model.py +++ b/server/features/chat/model.py @@ -3,7 +3,7 @@ from server.config import Config from server.features.chat.types import Message -from server.helpers import huggingface_download +from server.utils import huggingface_download class ChatModel: diff --git a/server/features/embeddings/__init__.py b/server/features/embeddings/__init__.py index 4b936d4..9c4d990 100644 --- a/server/features/embeddings/__init__.py +++ b/server/features/embeddings/__init__.py @@ -1 +1 @@ -from server.features.embeddings.embedding import Embedding as Embedding +from server.features.embeddings.embedding import Embedder as Embedder diff --git a/server/features/embeddings/embedding.py b/server/features/embeddings/embedding.py index 16849b5..ea4658f 100644 --- a/server/features/embeddings/embedding.py +++ b/server/features/embeddings/embedding.py @@ -1,11 +1,11 @@ -from huggingface_hub import snapshot_download from sentence_transformers import SentenceTransformer from torch import device from server.features.embeddings.flag_embedding import FlagEmbedding +from server.utils import huggingface_download -class Embedding(SentenceTransformer): +class Embedder(SentenceTransformer): """ Summary ------- @@ -20,12 +20,12 @@ class Embedding(SentenceTransformer): encode a sentence for searching relevant passages """ - def __init__(self, *, force_download: bool = False): + def __init__(self): model_name = 'bge-base-en-v1.5' super().__init__(f'BAAI/{model_name}') self.cached_device = super().device # type: ignore - model_path = snapshot_download(f'winstxnhdw/{model_name}-ct2', local_files_only=not force_download) + model_path = huggingface_download(f'winstxnhdw/{model_name}-ct2') self[0] = FlagEmbedding(self[0], model_path, 'auto') @property diff --git a/server/helpers/__init__.py b/server/helpers/__init__.py deleted file mode 100644 index b1fb832..0000000 --- a/server/helpers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from server.helpers.network import huggingface_download as huggingface_download diff --git a/server/helpers/network/__init__.py b/server/helpers/network/__init__.py deleted file mode 100644 index 960f665..0000000 --- a/server/helpers/network/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from server.helpers.network.huggingface_download import ( - huggingface_download as huggingface_download, -) diff --git a/server/lifespans/chat_model.py b/server/lifespans/chat_model.py index 0acd47b..6575cd9 100644 --- a/server/lifespans/chat_model.py +++ b/server/lifespans/chat_model.py @@ -21,5 +21,6 @@ async def chat_model(app: Litestar) -> AsyncIterator[None]: try: yield + finally: del app.state.chat diff --git a/server/lifespans/create_redis_index.py b/server/lifespans/create_redis_index.py index a725df2..ef03b34 100644 --- a/server/lifespans/create_redis_index.py +++ b/server/lifespans/create_redis_index.py @@ -29,5 +29,6 @@ async def create_redis_index(app: Litestar) -> AsyncIterator[None]: try: yield + finally: pass diff --git a/server/lifespans/download_embeddings.py b/server/lifespans/download_embeddings.py index db0aad5..63293f6 100644 --- a/server/lifespans/download_embeddings.py +++ b/server/lifespans/download_embeddings.py @@ -3,11 +3,11 @@ from litestar import Litestar -from server.helpers import huggingface_download +from server.utils import huggingface_download @asynccontextmanager -async def download_embeddings(app: Litestar) -> AsyncIterator[None]: +async def download_embeddings(_: Litestar) -> AsyncIterator[None]: """ Summary ------- @@ -21,5 +21,6 @@ async def download_embeddings(app: Litestar) -> AsyncIterator[None]: try: yield + finally: pass diff --git a/server/lifespans/download_nltk.py b/server/lifespans/download_nltk.py index dd7a004..3e13901 100644 --- a/server/lifespans/download_nltk.py +++ b/server/lifespans/download_nltk.py @@ -40,5 +40,6 @@ async def download_nltk(app: Litestar) -> AsyncIterator[None]: try: yield + finally: pass diff --git a/server/utils/__init__.py b/server/utils/__init__.py new file mode 100644 index 0000000..b0bdaec --- /dev/null +++ b/server/utils/__init__.py @@ -0,0 +1 @@ +from server.utils.network import huggingface_download as huggingface_download diff --git a/server/utils/network/__init__.py b/server/utils/network/__init__.py new file mode 100644 index 0000000..2d5eaaa --- /dev/null +++ b/server/utils/network/__init__.py @@ -0,0 +1,3 @@ +from server.utils.network.huggingface_download import ( + huggingface_download as huggingface_download, +) diff --git a/server/helpers/network/has_internet_access.py b/server/utils/network/has_internet_access.py similarity index 100% rename from server/helpers/network/has_internet_access.py rename to server/utils/network/has_internet_access.py diff --git a/server/helpers/network/huggingface_download.py b/server/utils/network/huggingface_download.py similarity index 86% rename from server/helpers/network/huggingface_download.py rename to server/utils/network/huggingface_download.py index 1616cb6..7b52185 100644 --- a/server/helpers/network/huggingface_download.py +++ b/server/utils/network/huggingface_download.py @@ -1,6 +1,6 @@ from huggingface_hub import snapshot_download -from server.helpers.network.has_internet_access import has_internet_access +from server.utils.network.has_internet_access import has_internet_access def huggingface_download(repository: str) -> str: diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 765a0ef..a2004f2 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -1,18 +1,24 @@ # pylint: disable=missing-function-docstring,redefined-outer-name -from typing import Literal +from typing import Iterable, Literal from numpy import array_equal from pytest import fixture -from server.features.embeddings import Embedding +from server.features.embeddings import Embedder type Text = Literal['Hello world!'] @fixture() -def embedding(): - yield Embedding(force_download=True) +def embedding() -> Iterable[Embedder]: + embedder = Embedder() + + try: + yield embedder + + finally: + del embedder @fixture() @@ -20,13 +26,13 @@ def text(): yield 'Hello world!' -def test_encodings(embedding: Embedding, text: Text): +def test_encodings(embedding: Embedder, text: Text): assert array_equal(embedding.encode_query(text), embedding.encode_normalise(text)) is False -def test_encode_query(embedding: Embedding, text: Text): +def test_encode_query(embedding: Embedder, text: Text): assert len(embedding.encode_query(text)) > 0 -def test_encode_normalise(embedding: Embedding, text: Text): +def test_encode_normalise(embedding: Embedder, text: Text): assert len(embedding.encode_normalise(text)) > 0 diff --git a/tests/test_has_internet_access.py b/tests/test_has_internet_access.py index a7baa28..d6c7604 100644 --- a/tests/test_has_internet_access.py +++ b/tests/test_has_internet_access.py @@ -2,7 +2,7 @@ from pytest import mark -from server.helpers.network.has_internet_access import has_internet_access +from server.utils.network.has_internet_access import has_internet_access @mark.parametrize(