diff --git a/models/file.py b/models/file.py index 201da54b..a70df15b 100644 --- a/models/file.py +++ b/models/file.py @@ -37,6 +37,7 @@ def suffix(self) -> str: class File(BaseModel): url: str name: str | None = None + metadata: dict | None = None @property def type(self) -> FileType | None: diff --git a/models/query.py b/models/query.py index 9289ffbc..f4fab5cf 100644 --- a/models/query.py +++ b/models/query.py @@ -1,10 +1,13 @@ -from typing import List, Optional - from pydantic import BaseModel +from typing import List, Optional, Union from models.document import BaseDocumentChunk from models.ingest import EncoderConfig from models.vector_database import VectorDatabase +from qdrant_client.http.models import Filter as QdrantFilter + + +Filter = Union[QdrantFilter, dict] class RequestPayload(BaseModel): @@ -15,6 +18,7 @@ class RequestPayload(BaseModel): session_id: Optional[str] = None interpreter_mode: Optional[bool] = False exclude_fields: List[str] = None + filter: Optional[Filter] = None class ResponseData(BaseModel): diff --git a/service/embedding.py b/service/embedding.py index 17b38761..50dc6b07 100644 --- a/service/embedding.py +++ b/service/embedding.py @@ -156,6 +156,7 @@ async def generate_chunks( ) -> List[BaseDocumentChunk]: doc_chunks = [] for file in tqdm(self.files, desc="Generating chunks"): + file_metadata = file.metadata or {} logger.info(f"Splitting method: {config.splitter.name}") try: chunks = [] @@ -168,7 +169,10 @@ async def generate_chunks( chunk_data = { "content": element.get("text"), "metadata": self._sanitize_metadata( - element.get("metadata") + { + **file_metadata, + **element.get("metadata"), + } ), } chunks.append(chunk_data) diff --git a/service/router.py b/service/router.py index 9b840fce..dad43acc 100644 --- a/service/router.py +++ b/service/router.py @@ -40,7 +40,9 @@ def create_route_layer() -> RouteLayer: async def get_documents( *, vector_service: BaseVectorDatabase, payload: RequestPayload ) -> list[BaseDocumentChunk]: - chunks = await vector_service.query(input=payload.input, top_k=5) + chunks = await vector_service.query( + input=payload.input, filter=payload.filter, top_k=5 + ) # filter out documents with empty content chunks = [chunk for chunk in chunks if chunk.content.strip()] if not len(chunks): diff --git a/vectordbs/astra.py b/vectordbs/astra.py index 78cb3200..7285619c 100644 --- a/vectordbs/astra.py +++ b/vectordbs/astra.py @@ -5,6 +5,7 @@ from tqdm import tqdm from models.document import BaseDocumentChunk +from models.query import Filter from vectordbs.base import BaseVectorDatabase @@ -54,12 +55,13 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None: for i in range(0, len(documents), 5): self.collection.insert_many(documents=documents[i : i + 5]) - async def query(self, input: str, top_k: int = 4) -> List: + async def query(self, input: str, filter: Filter = None, top_k: int = 4) -> List: vectors = await self._generate_vectors(input=input) results = self.collection.vector_find( vector=vectors[0], limit=top_k, fields={"text", "page_number", "source", "document_id"}, + filter=filter, ) return [ BaseDocumentChunk( diff --git a/vectordbs/base.py b/vectordbs/base.py index 0f3203c2..b4a7b920 100644 --- a/vectordbs/base.py +++ b/vectordbs/base.py @@ -7,6 +7,7 @@ from models.delete import DeleteResponse from models.document import BaseDocumentChunk +from models.query import Filter from utils.logger import logger @@ -24,7 +25,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]): pass @abstractmethod - async def query(self, input: str, top_k: int = 25) -> List[BaseDocumentChunk]: + async def query( + self, input: str, filter: Filter, top_k: int = 25 + ) -> List[BaseDocumentChunk]: pass @abstractmethod diff --git a/vectordbs/pgvector.py b/vectordbs/pgvector.py index e6b185ec..dd94ddbe 100644 --- a/vectordbs/pgvector.py +++ b/vectordbs/pgvector.py @@ -7,6 +7,7 @@ from qdrant_client.http import models as rest from models.delete import DeleteResponse from models.document import BaseDocumentChunk +from models.query import Filter from vectordbs.base import BaseVectorDatabase MAX_QUERY_TOP_K = 5 @@ -58,7 +59,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None: self.collection.upsert(records) self.collection.create_index() - async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List: + async def query( + self, input: str, filter: Filter = None, top_k: int = MAX_QUERY_TOP_K + ) -> List: vectors = await self._generate_vectors(input=input) results = self.collection.query( @@ -66,6 +69,7 @@ async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List: limit=top_k, include_metadata=True, include_value=False, + filters=filter.model_dump() if filter else {}, ) chunks = [] diff --git a/vectordbs/pinecone.py b/vectordbs/pinecone.py index 5fc429dd..7eea1bae 100644 --- a/vectordbs/pinecone.py +++ b/vectordbs/pinecone.py @@ -6,6 +6,7 @@ from models.delete import DeleteResponse from models.document import BaseDocumentChunk +from models.query import Filter from utils.logger import logger from vectordbs.base import BaseVectorDatabase @@ -52,7 +53,11 @@ async def upsert(self, chunks: List[BaseDocumentChunk], batch_size: int = 100): raise async def query( - self, input: str, top_k: int = 25, include_metadata: bool = True + self, + input: str, + filter: Filter = None, + top_k: int = 25, + include_metadata: bool = True, ) -> list[BaseDocumentChunk]: if self.index is None: raise ValueError(f"Pinecone index {self.index_name} is not initialized.") @@ -61,6 +66,7 @@ async def query( vector=query_vectors[0], top_k=top_k, include_metadata=include_metadata, + filter=filter, ) chunks = [] if results.get("matches"): diff --git a/vectordbs/qdrant.py b/vectordbs/qdrant.py index 4c667468..ed4f6f90 100644 --- a/vectordbs/qdrant.py +++ b/vectordbs/qdrant.py @@ -7,6 +7,7 @@ from models.delete import DeleteResponse from models.document import BaseDocumentChunk +from models.query import Filter from vectordbs.base import BaseVectorDatabase MAX_QUERY_TOP_K = 5 @@ -69,11 +70,14 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None: self.client.upsert(collection_name=self.index_name, wait=True, points=points) - async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List: + async def query( + self, input: str, filter: Filter, top_k: int = MAX_QUERY_TOP_K + ) -> List: vectors = await self._generate_vectors(input=input) search_result = self.client.search( collection_name=self.index_name, query_vector=("content", vectors[0]), + query_filter=filter, limit=top_k, with_payload=True, ) diff --git a/vectordbs/weaviate.py b/vectordbs/weaviate.py index 9cda042f..57ce7365 100644 --- a/vectordbs/weaviate.py +++ b/vectordbs/weaviate.py @@ -10,6 +10,8 @@ from utils.logger import logger from vectordbs.base import BaseVectorDatabase +from models.query import Filter + class WeaviateService(BaseVectorDatabase): def __init__( @@ -72,7 +74,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None: batch.add_data_object(**vector_data) batch.flush() - async def query(self, input: str, top_k: int = 25) -> list[BaseDocumentChunk]: + async def query( + self, input: str, filter: Filter = {}, top_k: int = 25 + ) -> list[BaseDocumentChunk]: vectors = await self._generate_vectors(input=input) vector = {"vector": vectors[0]} @@ -84,6 +88,7 @@ async def query(self, input: str, top_k: int = 25) -> list[BaseDocumentChunk]: ) .with_near_vector(vector) .with_limit(top_k) + .with_where(filter) .do() ) if "data" not in response: