Skip to content

Commit

Permalink
Add support for weaviate db (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
homanp authored Jan 17, 2024
1 parent 5b95f92 commit fed2193
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 10 deletions.
2 changes: 1 addition & 1 deletion api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ async def query(payload: RequestPayload):
index_name=payload.index_name, credentials=payload.vector_database
)
chunks = await vector_service.query(input=payload.input, top_k=4)
documents = await vector_service.convert_to_dict(chunks=chunks)
documents = await vector_service.convert_to_rerank_format(chunks=chunks)
results = await vector_service.rerank(query=payload.input, documents=documents)
return {"success": True, "data": results}
88 changes: 79 additions & 9 deletions service/vector_database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import weaviate

from abc import ABC, abstractmethod
from typing import Any, List, Type
from decouple import config
Expand All @@ -24,9 +26,21 @@ async def query():
pass

@abstractmethod
async def convert_to_dict():
async def convert_to_rerank_format():
pass

async def _generate_vectors(sefl, input: str):
vectors = []
embedding_object = embedding(
model="huggingface/intfloat/multilingual-e5-large",
input=input,
api_key=config("HUGGINGFACE_API_KEY"),
)
for vector in embedding_object.data:
if vector["object"] == "embedding":
vectors.append(vector["embedding"])
return vectors

async def rerank(self, query: str, documents: list, top_n: int = 4):
from cohere import Client

Expand Down Expand Up @@ -63,7 +77,7 @@ def __init__(self, index_name: str, dimension: int, credentials: dict):
)
self.index = pinecone.Index(name=self.index_name)

async def convert_to_dict(self, chunks: List):
async def convert_to_rerank_format(self, chunks: List):
docs = [
{
"content": chunk.get("metadata")["content"],
Expand Down Expand Up @@ -117,7 +131,7 @@ def __init__(self, index_name: str, dimension: int, credentials: dict):
),
)

async def convert_to_dict(self, chunks: List[rest.PointStruct]):
async def convert_to_rerank_format(self, chunks: List[rest.PointStruct]):
docs = [
{
"content": chunk.payload.get("content"),
Expand All @@ -128,7 +142,7 @@ async def convert_to_dict(self, chunks: List[rest.PointStruct]):
]
return docs

async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]):
async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]) -> None:
points = []

for _embedding in embeddings:
Expand All @@ -141,12 +155,8 @@ async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]):
)

self.client.upsert(collection_name=self.index_name, wait=True, points=points)
collection_vector_count = self.client.get_collection(
collection_name=self.index_name
).vectors_count
print(f"Vector count in collection: {collection_vector_count}")

async def query(self, input: str, top_k: int):
async def query(self, input: str, top_k: int) -> List:
vectors = []
embedding_object = embedding(
model="huggingface/intfloat/multilingual-e5-large",
Expand All @@ -173,12 +183,72 @@ async def query(self, input: str, top_k: int):
return search_result


class WeaviateService(VectorService):
def __init__(self, index_name: str, dimension: int, credentials: dict):
super().__init__(
index_name=index_name, dimension=dimension, credentials=credentials
)
self.client = weaviate.Client(
url=credentials["host"],
auth_client_secret=weaviate.AuthApiKey(api_key=credentials["api_key"]),
)
schema = {
"class": self.index_name,
"properties": [
{
"name": "text",
"dataType": ["text"],
}
],
}
if not self.client.schema.exists(self.index_name):
self.client.schema.create_class(schema)

async def convert_to_rerank_format(self, chunks: List) -> List:
docs = [
{
"content": chunk.get("text"),
"page_label": chunk.get("page_label"),
"file_url": chunk.get("file_url"),
}
for chunk in chunks
]
return docs

async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]) -> None:
with self.client.batch as batch:
for _embedding in embeddings:
params = {
"uuid": _embedding[0],
"data_object": {"text": _embedding[2]["content"], **_embedding[2]},
"class_name": self.index_name,
"vector": _embedding[1],
}
batch.add_data_object(**params)
batch.flush()

async def query(self, input: str, top_k: int = 4) -> List:
vectors = await self._generate_vectors(input=input)
vector = {"vector": vectors}
result = (
self.client.query.get(
self.index_name.capitalize(),
["text", "file_url", "page_label"],
)
.with_near_vector(vector)
.with_limit(top_k)
.do()
)
return result["data"]["Get"][self.index_name.capitalize()]


def get_vector_service(
index_name: str, credentials: VectorDatabase, dimension: int = 1024
) -> Type[VectorService]:
services = {
"pinecone": PineconeVectorService,
"qdrant": QdrantService,
"weaviate": WeaviateService,
# Add other providers here
# e.g "weaviate": WeaviateVectorService,
}
Expand Down

0 comments on commit fed2193

Please sign in to comment.