diff --git a/api/query.py b/api/query.py index 82402312..6e999e05 100644 --- a/api/query.py +++ b/api/query.py @@ -11,6 +11,6 @@ async def query(payload: RequestPayload): index_name=payload.index_name, credentials=payload.vector_database ) chunks = await vector_service.query(input=payload.input, top_k=4) - documents = await vector_service.convert_to_dict(chunks=chunks) + documents = await vector_service.convert_to_rerank_format(chunks=chunks) results = await vector_service.rerank(query=payload.input, documents=documents) return {"success": True, "data": results} diff --git a/service/vector_database.py b/service/vector_database.py index 642c1b62..4fe81b03 100644 --- a/service/vector_database.py +++ b/service/vector_database.py @@ -1,3 +1,5 @@ +import weaviate + from abc import ABC, abstractmethod from typing import Any, List, Type from decouple import config @@ -24,9 +26,21 @@ async def query(): pass @abstractmethod - async def convert_to_dict(): + async def convert_to_rerank_format(): pass + async def _generate_vectors(sefl, input: str): + vectors = [] + embedding_object = embedding( + model="huggingface/intfloat/multilingual-e5-large", + input=input, + api_key=config("HUGGINGFACE_API_KEY"), + ) + for vector in embedding_object.data: + if vector["object"] == "embedding": + vectors.append(vector["embedding"]) + return vectors + async def rerank(self, query: str, documents: list, top_n: int = 4): from cohere import Client @@ -63,7 +77,7 @@ def __init__(self, index_name: str, dimension: int, credentials: dict): ) self.index = pinecone.Index(name=self.index_name) - async def convert_to_dict(self, chunks: List): + async def convert_to_rerank_format(self, chunks: List): docs = [ { "content": chunk.get("metadata")["content"], @@ -117,7 +131,7 @@ def __init__(self, index_name: str, dimension: int, credentials: dict): ), ) - async def convert_to_dict(self, chunks: List[rest.PointStruct]): + async def convert_to_rerank_format(self, chunks: List[rest.PointStruct]): docs = [ { "content": chunk.payload.get("content"), @@ -128,7 +142,7 @@ async def convert_to_dict(self, chunks: List[rest.PointStruct]): ] return docs - async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]): + async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]) -> None: points = [] for _embedding in embeddings: @@ -141,12 +155,8 @@ async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]): ) self.client.upsert(collection_name=self.index_name, wait=True, points=points) - collection_vector_count = self.client.get_collection( - collection_name=self.index_name - ).vectors_count - print(f"Vector count in collection: {collection_vector_count}") - async def query(self, input: str, top_k: int): + async def query(self, input: str, top_k: int) -> List: vectors = [] embedding_object = embedding( model="huggingface/intfloat/multilingual-e5-large", @@ -173,12 +183,72 @@ async def query(self, input: str, top_k: int): return search_result +class WeaviateService(VectorService): + def __init__(self, index_name: str, dimension: int, credentials: dict): + super().__init__( + index_name=index_name, dimension=dimension, credentials=credentials + ) + self.client = weaviate.Client( + url=credentials["host"], + auth_client_secret=weaviate.AuthApiKey(api_key=credentials["api_key"]), + ) + schema = { + "class": self.index_name, + "properties": [ + { + "name": "text", + "dataType": ["text"], + } + ], + } + if not self.client.schema.exists(self.index_name): + self.client.schema.create_class(schema) + + async def convert_to_rerank_format(self, chunks: List) -> List: + docs = [ + { + "content": chunk.get("text"), + "page_label": chunk.get("page_label"), + "file_url": chunk.get("file_url"), + } + for chunk in chunks + ] + return docs + + async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]) -> None: + with self.client.batch as batch: + for _embedding in embeddings: + params = { + "uuid": _embedding[0], + "data_object": {"text": _embedding[2]["content"], **_embedding[2]}, + "class_name": self.index_name, + "vector": _embedding[1], + } + batch.add_data_object(**params) + batch.flush() + + async def query(self, input: str, top_k: int = 4) -> List: + vectors = await self._generate_vectors(input=input) + vector = {"vector": vectors} + result = ( + self.client.query.get( + self.index_name.capitalize(), + ["text", "file_url", "page_label"], + ) + .with_near_vector(vector) + .with_limit(top_k) + .do() + ) + return result["data"]["Get"][self.index_name.capitalize()] + + def get_vector_service( index_name: str, credentials: VectorDatabase, dimension: int = 1024 ) -> Type[VectorService]: services = { "pinecone": PineconeVectorService, "qdrant": QdrantService, + "weaviate": WeaviateService, # Add other providers here # e.g "weaviate": WeaviateVectorService, }