Skip to content

Commit

Permalink
Metadata & Applying Filters (#103)
Browse files Browse the repository at this point in the history
* feat: add ability to add metadata field

* feat: add applying filters on queries for qdrant

* feat: applying filters for other providers

* fix: pydantic error

* chore: fix linting
  • Loading branch information
elisalimli committed Apr 30, 2024
1 parent 32bde47 commit 8116e7b
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 10 deletions.
1 change: 1 addition & 0 deletions models/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def suffix(self) -> str:
class File(BaseModel):
url: str
name: str | None = None
metadata: dict | None = None

@property
def type(self) -> FileType | None:
Expand Down
8 changes: 6 additions & 2 deletions models/query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import List, Optional

from pydantic import BaseModel
from typing import List, Optional, Union

from models.document import BaseDocumentChunk
from models.ingest import EncoderConfig
from models.vector_database import VectorDatabase
from qdrant_client.http.models import Filter as QdrantFilter


Filter = Union[QdrantFilter, dict]


class RequestPayload(BaseModel):
Expand All @@ -15,6 +18,7 @@ class RequestPayload(BaseModel):
session_id: Optional[str] = None
interpreter_mode: Optional[bool] = False
exclude_fields: List[str] = None
filter: Optional[Filter] = None


class ResponseData(BaseModel):
Expand Down
6 changes: 5 additions & 1 deletion service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ async def generate_chunks(
) -> List[BaseDocumentChunk]:
doc_chunks = []
for file in tqdm(self.files, desc="Generating chunks"):
file_metadata = file.metadata or {}
logger.info(f"Splitting method: {config.splitter.name}")
try:
chunks = []
Expand All @@ -168,7 +169,10 @@ async def generate_chunks(
chunk_data = {
"content": element.get("text"),
"metadata": self._sanitize_metadata(
element.get("metadata")
{
**file_metadata,
**element.get("metadata"),
}
),
}
chunks.append(chunk_data)
Expand Down
4 changes: 3 additions & 1 deletion service/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def create_route_layer() -> RouteLayer:
async def get_documents(
*, vector_service: BaseVectorDatabase, payload: RequestPayload
) -> list[BaseDocumentChunk]:
chunks = await vector_service.query(input=payload.input, top_k=5)
chunks = await vector_service.query(
input=payload.input, filter=payload.filter, top_k=5
)
# filter out documents with empty content
chunks = [chunk for chunk in chunks if chunk.content.strip()]
if not len(chunks):
Expand Down
4 changes: 3 additions & 1 deletion vectordbs/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tqdm import tqdm

from models.document import BaseDocumentChunk
from models.query import Filter
from vectordbs.base import BaseVectorDatabase


Expand Down Expand Up @@ -54,12 +55,13 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
for i in range(0, len(documents), 5):
self.collection.insert_many(documents=documents[i : i + 5])

async def query(self, input: str, top_k: int = 4) -> List:
async def query(self, input: str, filter: Filter = None, top_k: int = 4) -> List:
vectors = await self._generate_vectors(input=input)
results = self.collection.vector_find(
vector=vectors[0],
limit=top_k,
fields={"text", "page_number", "source", "document_id"},
filter=filter,
)
return [
BaseDocumentChunk(
Expand Down
5 changes: 4 additions & 1 deletion vectordbs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from models.query import Filter
from utils.logger import logger


Expand All @@ -24,7 +25,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]):
pass

@abstractmethod
async def query(self, input: str, top_k: int = 25) -> List[BaseDocumentChunk]:
async def query(
self, input: str, filter: Filter, top_k: int = 25
) -> List[BaseDocumentChunk]:
pass

@abstractmethod
Expand Down
6 changes: 5 additions & 1 deletion vectordbs/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from qdrant_client.http import models as rest
from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from models.query import Filter
from vectordbs.base import BaseVectorDatabase

MAX_QUERY_TOP_K = 5
Expand Down Expand Up @@ -58,14 +59,17 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
self.collection.upsert(records)
self.collection.create_index()

async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List:
async def query(
self, input: str, filter: Filter = None, top_k: int = MAX_QUERY_TOP_K
) -> List:
vectors = await self._generate_vectors(input=input)

results = self.collection.query(
data=vectors[0],
limit=top_k,
include_metadata=True,
include_value=False,
filters=filter.model_dump() if filter else {},
)

chunks = []
Expand Down
8 changes: 7 additions & 1 deletion vectordbs/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from models.query import Filter
from utils.logger import logger
from vectordbs.base import BaseVectorDatabase

Expand Down Expand Up @@ -52,7 +53,11 @@ async def upsert(self, chunks: List[BaseDocumentChunk], batch_size: int = 100):
raise

async def query(
self, input: str, top_k: int = 25, include_metadata: bool = True
self,
input: str,
filter: Filter = None,
top_k: int = 25,
include_metadata: bool = True,
) -> list[BaseDocumentChunk]:
if self.index is None:
raise ValueError(f"Pinecone index {self.index_name} is not initialized.")
Expand All @@ -61,6 +66,7 @@ async def query(
vector=query_vectors[0],
top_k=top_k,
include_metadata=include_metadata,
filter=filter,
)
chunks = []
if results.get("matches"):
Expand Down
6 changes: 5 additions & 1 deletion vectordbs/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from models.query import Filter
from vectordbs.base import BaseVectorDatabase

MAX_QUERY_TOP_K = 5
Expand Down Expand Up @@ -69,11 +70,14 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:

self.client.upsert(collection_name=self.index_name, wait=True, points=points)

async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List:
async def query(
self, input: str, filter: Filter, top_k: int = MAX_QUERY_TOP_K
) -> List:
vectors = await self._generate_vectors(input=input)
search_result = self.client.search(
collection_name=self.index_name,
query_vector=("content", vectors[0]),
query_filter=filter,
limit=top_k,
with_payload=True,
)
Expand Down
7 changes: 6 additions & 1 deletion vectordbs/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from utils.logger import logger
from vectordbs.base import BaseVectorDatabase

from models.query import Filter


class WeaviateService(BaseVectorDatabase):
def __init__(
Expand Down Expand Up @@ -72,7 +74,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
batch.add_data_object(**vector_data)
batch.flush()

async def query(self, input: str, top_k: int = 25) -> list[BaseDocumentChunk]:
async def query(
self, input: str, filter: Filter = {}, top_k: int = 25
) -> list[BaseDocumentChunk]:
vectors = await self._generate_vectors(input=input)
vector = {"vector": vectors[0]}

Expand All @@ -84,6 +88,7 @@ async def query(self, input: str, top_k: int = 25) -> list[BaseDocumentChunk]:
)
.with_near_vector(vector)
.with_limit(top_k)
.with_where(filter)
.do()
)
if "data" not in response:
Expand Down

0 comments on commit 8116e7b

Please sign in to comment.