Skip to content

Commit

Permalink
Add PGVector Support (#94)
Browse files Browse the repository at this point in the history
* feat: add pgvector

* feat: delete summary index too
  • Loading branch information
elisalimli committed Mar 28, 2024
1 parent ea5e041 commit bb1630b
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,5 @@ Super-Rag comes with a built in REST API powered by FastApi.
- Qdrant
- Weaviate
- Astra
- Supabase (coming soon)
- PGVector
- Chroma (coming soon)
8 changes: 8 additions & 0 deletions api/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from models.delete import RequestPayload, ResponsePayload
from vectordbs import get_vector_service
from vectordbs.base import BaseVectorDatabase
from utils.summarise import SUMMARY_SUFFIX

router = APIRouter()

Expand All @@ -16,8 +17,15 @@ async def delete(payload: RequestPayload):
encoder=encoder,
dimensions=payload.encoder.dimensions,
)
summary_vector_service: BaseVectorDatabase = get_vector_service(
index_name=f"{payload.index_name}{SUMMARY_SUFFIX}",
credentials=payload.vector_database,
encoder=encoder,
dimensions=payload.encoder.dimensions,
)

for file in payload.files:
data = await vector_service.delete(file_url=file.url)
await summary_vector_service.delete(file_url=file.url)

return ResponsePayload(success=True, data=data)
1 change: 1 addition & 0 deletions models/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class DatabaseType(Enum):
pinecone = "pinecone"
weaviate = "weaviate"
astra = "astra"
pgvector = "pgvector"


class VectorDatabase(BaseModel):
Expand Down
134 changes: 133 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ gunicorn = "^21.2.0"
unstructured-client = "^0.18.0"
unstructured = {extras = ["google-drive"], version = "^0.12.4"}
tiktoken = "^0.6.0"
vecs = "^0.4.3"

[tool.poetry.group.dev.dependencies]
termcolor = "^2.4.0"
Expand Down
2 changes: 2 additions & 0 deletions vectordbs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vectordbs.pinecone import PineconeService
from vectordbs.qdrant import QdrantService
from vectordbs.weaviate import WeaviateService
from vectordbs.pgvector import PGVectorService

load_dotenv()

Expand All @@ -26,6 +27,7 @@ def get_vector_service(
"qdrant": QdrantService,
"weaviate": WeaviateService,
"astra": AstraService,
"pgvector": PGVectorService,
# Add other providers here
# e.g "weaviate": WeaviateVectorService,
}
Expand Down
95 changes: 95 additions & 0 deletions vectordbs/pgvector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import List

import vecs
from semantic_router.encoders import BaseEncoder
from tqdm import tqdm

from qdrant_client.http import models as rest
from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from vectordbs.base import BaseVectorDatabase

MAX_QUERY_TOP_K = 5


class PGVectorService(BaseVectorDatabase):
def __init__(
self, index_name: str, dimension: int, credentials: dict, encoder: BaseEncoder
):
super().__init__(
index_name=index_name,
dimension=dimension,
credentials=credentials,
encoder=encoder,
)
client = vecs.create_client(connection_string=credentials["database_uri"])
self.collection = client.get_or_create_collection(
name=self.index_name,
dimension=dimension,
)

# TODO: remove this
async def convert_to_rerank_format(self, chunks: List[rest.PointStruct]):
docs = [
{
"content": chunk.payload.get("content"),
"page_label": chunk.payload.get("page_label"),
"file_url": chunk.payload.get("file_url"),
}
for chunk in chunks
]
return docs

async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
records = []
for chunk in tqdm(chunks, desc="Upserting to PGVector"):
records.append(
(
chunk.id,
chunk.dense_embedding,
{
"document_id": chunk.document_id,
"content": chunk.content,
"doc_url": chunk.doc_url,
**(chunk.metadata if chunk.metadata else {}),
},
)
)
self.collection.upsert(records)
self.collection.create_index()

async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List:
vectors = await self._generate_vectors(input=input)

results = self.collection.query(
data=vectors[0],
limit=top_k,
include_metadata=True,
include_value=False,
)

chunks = []

for result in results:
(
id,
metadata,
) = result

chunks.append(
BaseDocumentChunk(
id=id,
source_type=metadata.get("filetype"),
source=metadata.get("doc_url"),
document_id=metadata.get("document_id"),
content=metadata.get("content"),
doc_url=metadata.get("doc_url"),
page_number=metadata.get("page_number"),
metadata={**metadata},
)
)
return chunks

async def delete(self, file_url: str) -> None:
deleted = self.collection.delete(filters={"doc_url": {"$eq": file_url}})
return DeleteResponse(num_of_deleted_chunks=len(deleted))

0 comments on commit bb1630b

Please sign in to comment.