From 6083cb746dc753b91d9508e4902d0aa6b04771bc Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Fri, 26 Apr 2024 17:49:55 +0400 Subject: [PATCH 1/5] feat: add ability to add metadata field --- models/file.py | 1 + service/embedding.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/models/file.py b/models/file.py index 201da54b..a70df15b 100644 --- a/models/file.py +++ b/models/file.py @@ -37,6 +37,7 @@ def suffix(self) -> str: class File(BaseModel): url: str name: str | None = None + metadata: dict | None = None @property def type(self) -> FileType | None: diff --git a/service/embedding.py b/service/embedding.py index 17b38761..50dc6b07 100644 --- a/service/embedding.py +++ b/service/embedding.py @@ -156,6 +156,7 @@ async def generate_chunks( ) -> List[BaseDocumentChunk]: doc_chunks = [] for file in tqdm(self.files, desc="Generating chunks"): + file_metadata = file.metadata or {} logger.info(f"Splitting method: {config.splitter.name}") try: chunks = [] @@ -168,7 +169,10 @@ async def generate_chunks( chunk_data = { "content": element.get("text"), "metadata": self._sanitize_metadata( - element.get("metadata") + { + **file_metadata, + **element.get("metadata"), + } ), } chunks.append(chunk_data) From 865563388b9e6515dc38d462a35669b09fda1925 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Fri, 26 Apr 2024 17:50:13 +0400 Subject: [PATCH 2/5] feat: add applying filters on queries for qdrant --- models/query.py | 6 ++++-- service/router.py | 4 +++- vectordbs/base.py | 5 ++++- vectordbs/qdrant.py | 6 +++++- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/models/query.py b/models/query.py index 9289ffbc..b935e66c 100644 --- a/models/query.py +++ b/models/query.py @@ -1,10 +1,10 @@ -from typing import List, Optional - from pydantic import BaseModel +from typing import List, Optional from models.document import BaseDocumentChunk from models.ingest import EncoderConfig from models.vector_database import VectorDatabase +from qdrant_client.http.models import Filter class RequestPayload(BaseModel): @@ -15,6 +15,8 @@ class RequestPayload(BaseModel): session_id: Optional[str] = None interpreter_mode: Optional[bool] = False exclude_fields: List[str] = None + # TODO: use our own Filter model + filter: Optional[Filter] = None class ResponseData(BaseModel): diff --git a/service/router.py b/service/router.py index 9b840fce..dad43acc 100644 --- a/service/router.py +++ b/service/router.py @@ -40,7 +40,9 @@ def create_route_layer() -> RouteLayer: async def get_documents( *, vector_service: BaseVectorDatabase, payload: RequestPayload ) -> list[BaseDocumentChunk]: - chunks = await vector_service.query(input=payload.input, top_k=5) + chunks = await vector_service.query( + input=payload.input, filter=payload.filter, top_k=5 + ) # filter out documents with empty content chunks = [chunk for chunk in chunks if chunk.content.strip()] if not len(chunks): diff --git a/vectordbs/base.py b/vectordbs/base.py index 0f3203c2..b4a7b920 100644 --- a/vectordbs/base.py +++ b/vectordbs/base.py @@ -7,6 +7,7 @@ from models.delete import DeleteResponse from models.document import BaseDocumentChunk +from models.query import Filter from utils.logger import logger @@ -24,7 +25,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]): pass @abstractmethod - async def query(self, input: str, top_k: int = 25) -> List[BaseDocumentChunk]: + async def query( + self, input: str, filter: Filter, top_k: int = 25 + ) -> List[BaseDocumentChunk]: pass @abstractmethod diff --git a/vectordbs/qdrant.py b/vectordbs/qdrant.py index 4c667468..ed4f6f90 100644 --- a/vectordbs/qdrant.py +++ b/vectordbs/qdrant.py @@ -7,6 +7,7 @@ from models.delete import DeleteResponse from models.document import BaseDocumentChunk +from models.query import Filter from vectordbs.base import BaseVectorDatabase MAX_QUERY_TOP_K = 5 @@ -69,11 +70,14 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None: self.client.upsert(collection_name=self.index_name, wait=True, points=points) - async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List: + async def query( + self, input: str, filter: Filter, top_k: int = MAX_QUERY_TOP_K + ) -> List: vectors = await self._generate_vectors(input=input) search_result = self.client.search( collection_name=self.index_name, query_vector=("content", vectors[0]), + query_filter=filter, limit=top_k, with_payload=True, ) From 371f17105d6fb2686ff89cf3f26d75742f176970 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Fri, 26 Apr 2024 21:08:24 +0400 Subject: [PATCH 3/5] feat: applying filters for other providers --- models/query.py | 27 ++++++++++++++++++++++++--- vectordbs/astra.py | 4 +++- vectordbs/pgvector.py | 6 +++++- vectordbs/pinecone.py | 8 +++++++- vectordbs/weaviate.py | 7 ++++++- 5 files changed, 45 insertions(+), 7 deletions(-) diff --git a/models/query.py b/models/query.py index b935e66c..f2e7a0fe 100644 --- a/models/query.py +++ b/models/query.py @@ -1,10 +1,32 @@ from pydantic import BaseModel -from typing import List, Optional +from typing import List, Optional, Union, Any from models.document import BaseDocumentChunk from models.ingest import EncoderConfig from models.vector_database import VectorDatabase -from qdrant_client.http.models import Filter +from qdrant_client.http.models import Filter as QdrantFilter + + +class PineconeFilter(BaseModel): + __root__: dict[str, Union[str, float, int, bool, List, dict]] + + +class AstraFilter(BaseModel): + __root__: dict[str, Any] + + +class WeaviateFilter(BaseModel): + __root__: dict + + +class PgVectorFilter(BaseModel): + __root__: dict + + +class Filter(BaseModel): + __root__: Union[ + PineconeFilter, QdrantFilter, WeaviateFilter, AstraFilter, PgVectorFilter + ] class RequestPayload(BaseModel): @@ -15,7 +37,6 @@ class RequestPayload(BaseModel): session_id: Optional[str] = None interpreter_mode: Optional[bool] = False exclude_fields: List[str] = None - # TODO: use our own Filter model filter: Optional[Filter] = None diff --git a/vectordbs/astra.py b/vectordbs/astra.py index 78cb3200..7285619c 100644 --- a/vectordbs/astra.py +++ b/vectordbs/astra.py @@ -5,6 +5,7 @@ from tqdm import tqdm from models.document import BaseDocumentChunk +from models.query import Filter from vectordbs.base import BaseVectorDatabase @@ -54,12 +55,13 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None: for i in range(0, len(documents), 5): self.collection.insert_many(documents=documents[i : i + 5]) - async def query(self, input: str, top_k: int = 4) -> List: + async def query(self, input: str, filter: Filter = None, top_k: int = 4) -> List: vectors = await self._generate_vectors(input=input) results = self.collection.vector_find( vector=vectors[0], limit=top_k, fields={"text", "page_number", "source", "document_id"}, + filter=filter, ) return [ BaseDocumentChunk( diff --git a/vectordbs/pgvector.py b/vectordbs/pgvector.py index e6b185ec..dd94ddbe 100644 --- a/vectordbs/pgvector.py +++ b/vectordbs/pgvector.py @@ -7,6 +7,7 @@ from qdrant_client.http import models as rest from models.delete import DeleteResponse from models.document import BaseDocumentChunk +from models.query import Filter from vectordbs.base import BaseVectorDatabase MAX_QUERY_TOP_K = 5 @@ -58,7 +59,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None: self.collection.upsert(records) self.collection.create_index() - async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List: + async def query( + self, input: str, filter: Filter = None, top_k: int = MAX_QUERY_TOP_K + ) -> List: vectors = await self._generate_vectors(input=input) results = self.collection.query( @@ -66,6 +69,7 @@ async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List: limit=top_k, include_metadata=True, include_value=False, + filters=filter.model_dump() if filter else {}, ) chunks = [] diff --git a/vectordbs/pinecone.py b/vectordbs/pinecone.py index 5fc429dd..7eea1bae 100644 --- a/vectordbs/pinecone.py +++ b/vectordbs/pinecone.py @@ -6,6 +6,7 @@ from models.delete import DeleteResponse from models.document import BaseDocumentChunk +from models.query import Filter from utils.logger import logger from vectordbs.base import BaseVectorDatabase @@ -52,7 +53,11 @@ async def upsert(self, chunks: List[BaseDocumentChunk], batch_size: int = 100): raise async def query( - self, input: str, top_k: int = 25, include_metadata: bool = True + self, + input: str, + filter: Filter = None, + top_k: int = 25, + include_metadata: bool = True, ) -> list[BaseDocumentChunk]: if self.index is None: raise ValueError(f"Pinecone index {self.index_name} is not initialized.") @@ -61,6 +66,7 @@ async def query( vector=query_vectors[0], top_k=top_k, include_metadata=include_metadata, + filter=filter, ) chunks = [] if results.get("matches"): diff --git a/vectordbs/weaviate.py b/vectordbs/weaviate.py index 9cda042f..57ce7365 100644 --- a/vectordbs/weaviate.py +++ b/vectordbs/weaviate.py @@ -10,6 +10,8 @@ from utils.logger import logger from vectordbs.base import BaseVectorDatabase +from models.query import Filter + class WeaviateService(BaseVectorDatabase): def __init__( @@ -72,7 +74,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None: batch.add_data_object(**vector_data) batch.flush() - async def query(self, input: str, top_k: int = 25) -> list[BaseDocumentChunk]: + async def query( + self, input: str, filter: Filter = {}, top_k: int = 25 + ) -> list[BaseDocumentChunk]: vectors = await self._generate_vectors(input=input) vector = {"vector": vectors[0]} @@ -84,6 +88,7 @@ async def query(self, input: str, top_k: int = 25) -> list[BaseDocumentChunk]: ) .with_near_vector(vector) .with_limit(top_k) + .with_where(filter) .do() ) if "data" not in response: From 12d72f25fe1e62c96af28076c707b91d4b4f61f0 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Fri, 26 Apr 2024 21:21:12 +0400 Subject: [PATCH 4/5] fix: pydantic error --- models/query.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/models/query.py b/models/query.py index f2e7a0fe..a5335221 100644 --- a/models/query.py +++ b/models/query.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, RootModel from typing import List, Optional, Union, Any from models.document import BaseDocumentChunk @@ -7,26 +7,7 @@ from qdrant_client.http.models import Filter as QdrantFilter -class PineconeFilter(BaseModel): - __root__: dict[str, Union[str, float, int, bool, List, dict]] - - -class AstraFilter(BaseModel): - __root__: dict[str, Any] - - -class WeaviateFilter(BaseModel): - __root__: dict - - -class PgVectorFilter(BaseModel): - __root__: dict - - -class Filter(BaseModel): - __root__: Union[ - PineconeFilter, QdrantFilter, WeaviateFilter, AstraFilter, PgVectorFilter - ] +Filter = Union[QdrantFilter, dict] class RequestPayload(BaseModel): From 730453a03f1c490d061129f87e68f93387abf230 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Mon, 29 Apr 2024 11:55:39 +0400 Subject: [PATCH 5/5] chore: fix linting --- models/query.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/query.py b/models/query.py index a5335221..f4fab5cf 100644 --- a/models/query.py +++ b/models/query.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel, RootModel -from typing import List, Optional, Union, Any +from pydantic import BaseModel +from typing import List, Optional, Union from models.document import BaseDocumentChunk from models.ingest import EncoderConfig