From be2c8338b6046c768aa274be7a9bafdcddedb5ab Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Tue, 16 Jan 2024 17:11:02 -0800 Subject: [PATCH] Add support for astra (#11) --- requirements.txt | 6 ++++- service/vector_database.py | 50 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 781db2ef..59741684 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,11 +2,14 @@ aiohttp==3.9.1 aiosignal==1.3.1 annotated-types==0.6.0 anyio==4.2.0 +astrapy==0.7.0 attrs==23.2.0 Authlib==1.3.0 backoff==2.2.1 beautifulsoup4==4.12.2 black==23.12.1 +cassandra-driver==3.29.0 +cassio==0.1.4 certifi==2023.11.17 cffi==1.16.0 charset-normalizer==3.3.2 @@ -22,6 +25,7 @@ fastavro==1.9.3 filelock==3.13.1 frozenlist==1.4.1 fsspec==2023.12.2 +geomet==0.2.1.post1 greenlet==3.0.3 grpcio==1.60.0 grpcio-tools==1.60.0 @@ -30,7 +34,7 @@ h2==4.1.0 hpack==4.0.0 httpcore==1.0.2 httptools==0.6.1 -httpx==0.26.0 +httpx==0.25.2 huggingface-hub==0.20.2 hyperframe==6.0.1 idna==3.6 diff --git a/service/vector_database.py b/service/vector_database.py index 4fe81b03..74f17a94 100644 --- a/service/vector_database.py +++ b/service/vector_database.py @@ -7,6 +7,7 @@ from qdrant_client import QdrantClient from qdrant_client.http import models as rest from pinecone import Pinecone, ServerlessSpec +from astrapy.db import AstraDB from models.vector_database import VectorDatabase @@ -242,6 +243,54 @@ async def query(self, input: str, top_k: int = 4) -> List: return result["data"]["Get"][self.index_name.capitalize()] +class AstraService(VectorService): + def __init__(self, index_name: str, dimension: int, credentials: dict): + super().__init__( + index_name=index_name, dimension=dimension, credentials=credentials + ) + self.client = AstraDB( + token=credentials["api_key"], + api_endpoint=credentials["host"], + ) + collections = self.client.get_collections() + if self.index_name not in collections["status"]["collections"]: + self.collection = self.client.create_collection( + dimension=dimension, collection_name=index_name + ) + self.collection = self.client.collection(collection_name=self.index_name) + + async def convert_to_rerank_format(self, chunks: List) -> List: + docs = [ + { + "content": chunk.get("text"), + "page_label": chunk.get("page_label"), + "file_url": chunk.get("file_url"), + } + for chunk in chunks + ] + return docs + + async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]) -> None: + documents = [ + { + "_id": _embedding[0], + "text": _embedding[2]["content"], + "$vector": _embedding[1], + **_embedding[2], + } + for _embedding in embeddings + ] + for i in range(0, len(documents), 5): + self.collection.insert_many(documents=documents[i : i + 5]) + + async def query(self, input: str, top_k: int = 4) -> List: + vectors = await self._generate_vectors(input=input) + results = self.collection.vector_find( + vector=vectors, limit=top_k, fields={"text", "page_label", "file_url"} + ) + return results + + def get_vector_service( index_name: str, credentials: VectorDatabase, dimension: int = 1024 ) -> Type[VectorService]: @@ -249,6 +298,7 @@ def get_vector_service( "pinecone": PineconeVectorService, "qdrant": QdrantService, "weaviate": WeaviateService, + "astra": AstraService, # Add other providers here # e.g "weaviate": WeaviateVectorService, }