Skip to content

Commit

Permalink
refactor: make embedder a dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Sep 22, 2024
1 parent 771815c commit 8e7e17e
Show file tree
Hide file tree
Showing 19 changed files with 71 additions and 32 deletions.
12 changes: 8 additions & 4 deletions server/api/debug/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from server.config import Config
from server.databases.redis.wrapper import RedisAsync
from server.dependencies import redis_client
from server.features.embeddings import Embedding
from server.dependencies import embedder, redis_client
from server.features.embeddings import Embedder
from server.schemas.v1 import Query


Expand All @@ -19,7 +19,10 @@ class RedisController(Controller):
"""

path = '/redis'
dependencies = {'redis': Provide(redis_client)}
dependencies = {
'redis': Provide(redis_client),
'embedder': Provide(embedder),
}

@delete()
async def delete_index(self, redis: Annotated[RedisAsync, Dependency()], recreate: bool = False) -> None:
Expand All @@ -42,6 +45,7 @@ async def delete_index(self, redis: Annotated[RedisAsync, Dependency()], recreat
async def search(
self,
redis: Annotated[RedisAsync, Dependency()],
embedder: Annotated[Embedder, Dependency()],
chat_id: str,
data: Query,
search_size: Annotated[int, Parameter(gt=0)] = 1,
Expand All @@ -51,4 +55,4 @@ async def search(
-------
an endpoint for searching the Redis vector database
"""
return await redis.search(chat_id, Embedding().encode_query(data.query), search_size)
return await redis.search(chat_id, embedder.encode_query(data.query), search_size)
16 changes: 9 additions & 7 deletions server/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

from server.databases.redis.features import store_chunks
from server.databases.redis.wrapper import RedisAsync
from server.dependencies.redis import redis_client
from server.dependencies import embedder, redis_client
from server.features.chunking import SentenceSplitter, chunk_document
from server.features.embeddings import Embedding
from server.features.embeddings import Embedder
from server.features.extraction import extract_documents_from_pdfs
from server.features.question_answering import question_answering
from server.schemas.v1 import Answer, Chat, Files, Query
Expand All @@ -27,7 +27,10 @@ class ChatController(Controller):
"""

path = '/chats'
dependencies = {'redis': Provide(redis_client)}
dependencies = {
'redis': Provide(redis_client),
'embedder': Provide(embedder),
}

@get()
async def create_chat(self) -> Chat:
Expand Down Expand Up @@ -75,6 +78,7 @@ async def upload_files(
self,
state: AppState,
redis: Annotated[RedisAsync, Dependency()],
embedder: Annotated[Embedder, Dependency()],
chat_id: str,
data: Annotated[list[UploadFile], Body(media_type=RequestEncodingType.MULTI_PART)],
) -> Files:
Expand All @@ -83,7 +87,6 @@ async def upload_files(
-------
an endpoint for uploading files to a chat
"""
embedder = Embedding()
text_splitter = SentenceSplitter(state.chat.tokeniser, chunk_size=128, chunk_overlap=0)
responses = []

Expand All @@ -109,6 +112,7 @@ async def query(
self,
state: AppState,
redis: Annotated[RedisAsync, Dependency()],
embedder: Annotated[Embedder, Dependency()],
chat_id: str,
data: Query,
search_size: Annotated[int, Parameter(ge=0)] = 0,
Expand All @@ -119,9 +123,7 @@ async def query(
-------
the `/query` route provides an endpoint for performning retrieval-augmented generation
"""
context = (
'' if not search_size else await redis.search(chat_id, Embedding().encode_query(data.query), search_size)
)
context = '' if not search_size else await redis.search(chat_id, embedder.encode_query(data.query), search_size)

message_history = await redis.get_messages(chat_id)
messages = await question_answering(data.query, context, message_history, state.chat.query)
Expand Down
1 change: 1 addition & 0 deletions server/dependencies/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from server.dependencies.embedder import embedder as embedder
from server.dependencies.redis import redis_client as redis_client
22 changes: 22 additions & 0 deletions server/dependencies/embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Iterator

from server.features.embeddings import Embedder


def embedder() -> Iterator[Embedder]:
"""
Summary
-------
load the embeddings model
Returns
-------
embedding (Embedding): the embeddings model
"""
embedder = Embedder()

try:
yield embedder

finally:
del embedder
2 changes: 1 addition & 1 deletion server/features/chat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from server.config import Config
from server.features.chat.types import Message
from server.helpers import huggingface_download
from server.utils import huggingface_download


class ChatModel:
Expand Down
2 changes: 1 addition & 1 deletion server/features/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from server.features.embeddings.embedding import Embedding as Embedding
from server.features.embeddings.embedding import Embedder as Embedder
8 changes: 4 additions & 4 deletions server/features/embeddings/embedding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from huggingface_hub import snapshot_download
from sentence_transformers import SentenceTransformer
from torch import device

from server.features.embeddings.flag_embedding import FlagEmbedding
from server.utils import huggingface_download


class Embedding(SentenceTransformer):
class Embedder(SentenceTransformer):
"""
Summary
-------
Expand All @@ -20,12 +20,12 @@ class Embedding(SentenceTransformer):
encode a sentence for searching relevant passages
"""

def __init__(self, *, force_download: bool = False):
def __init__(self):
model_name = 'bge-base-en-v1.5'
super().__init__(f'BAAI/{model_name}')
self.cached_device = super().device # type: ignore

model_path = snapshot_download(f'winstxnhdw/{model_name}-ct2', local_files_only=not force_download)
model_path = huggingface_download(f'winstxnhdw/{model_name}-ct2')
self[0] = FlagEmbedding(self[0], model_path, 'auto')

@property
Expand Down
1 change: 0 additions & 1 deletion server/helpers/__init__.py

This file was deleted.

3 changes: 0 additions & 3 deletions server/helpers/network/__init__.py

This file was deleted.

1 change: 1 addition & 0 deletions server/lifespans/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ async def chat_model(app: Litestar) -> AsyncIterator[None]:

try:
yield

finally:
del app.state.chat
1 change: 1 addition & 0 deletions server/lifespans/create_redis_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ async def create_redis_index(app: Litestar) -> AsyncIterator[None]:

try:
yield

finally:
pass
5 changes: 3 additions & 2 deletions server/lifespans/download_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from litestar import Litestar

from server.helpers import huggingface_download
from server.utils import huggingface_download


@asynccontextmanager
async def download_embeddings(app: Litestar) -> AsyncIterator[None]:
async def download_embeddings(_: Litestar) -> AsyncIterator[None]:
"""
Summary
-------
Expand All @@ -21,5 +21,6 @@ async def download_embeddings(app: Litestar) -> AsyncIterator[None]:

try:
yield

finally:
pass
1 change: 1 addition & 0 deletions server/lifespans/download_nltk.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ async def download_nltk(app: Litestar) -> AsyncIterator[None]:

try:
yield

finally:
pass
1 change: 1 addition & 0 deletions server/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from server.utils.network import huggingface_download as huggingface_download
3 changes: 3 additions & 0 deletions server/utils/network/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from server.utils.network.huggingface_download import (
huggingface_download as huggingface_download,
)
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from huggingface_hub import snapshot_download

from server.helpers.network.has_internet_access import has_internet_access
from server.utils.network.has_internet_access import has_internet_access


def huggingface_download(repository: str) -> str:
Expand Down
20 changes: 13 additions & 7 deletions tests/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
# pylint: disable=missing-function-docstring,redefined-outer-name

from typing import Literal
from typing import Iterable, Literal

from numpy import array_equal
from pytest import fixture

from server.features.embeddings import Embedding
from server.features.embeddings import Embedder

type Text = Literal['Hello world!']


@fixture()
def embedding():
yield Embedding(force_download=True)
def embedding() -> Iterable[Embedder]:
embedder = Embedder()

try:
yield embedder

finally:
del embedder


@fixture()
def text():
yield 'Hello world!'


def test_encodings(embedding: Embedding, text: Text):
def test_encodings(embedding: Embedder, text: Text):
assert array_equal(embedding.encode_query(text), embedding.encode_normalise(text)) is False


def test_encode_query(embedding: Embedding, text: Text):
def test_encode_query(embedding: Embedder, text: Text):
assert len(embedding.encode_query(text)) > 0


def test_encode_normalise(embedding: Embedding, text: Text):
def test_encode_normalise(embedding: Embedder, text: Text):
assert len(embedding.encode_normalise(text)) > 0
2 changes: 1 addition & 1 deletion tests/test_has_internet_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pytest import mark

from server.helpers.network.has_internet_access import has_internet_access
from server.utils.network.has_internet_access import has_internet_access


@mark.parametrize(
Expand Down

0 comments on commit 8e7e17e

Please sign in to comment.