Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metadata & Applying Filters #103

Merged
merged 5 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions models/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def suffix(self) -> str:
class File(BaseModel):
url: str
name: str | None = None
metadata: dict | None = None

@property
def type(self) -> FileType | None:
Expand Down
8 changes: 6 additions & 2 deletions models/query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import List, Optional

from pydantic import BaseModel
from typing import List, Optional, Union

from models.document import BaseDocumentChunk
from models.ingest import EncoderConfig
from models.vector_database import VectorDatabase
from qdrant_client.http.models import Filter as QdrantFilter


Filter = Union[QdrantFilter, dict]


class RequestPayload(BaseModel):
Expand All @@ -15,6 +18,7 @@ class RequestPayload(BaseModel):
session_id: Optional[str] = None
interpreter_mode: Optional[bool] = False
exclude_fields: List[str] = None
filter: Optional[Filter] = None


class ResponseData(BaseModel):
Expand Down
6 changes: 5 additions & 1 deletion service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ async def generate_chunks(
) -> List[BaseDocumentChunk]:
doc_chunks = []
for file in tqdm(self.files, desc="Generating chunks"):
file_metadata = file.metadata or {}
logger.info(f"Splitting method: {config.splitter.name}")
try:
chunks = []
Expand All @@ -168,7 +169,10 @@ async def generate_chunks(
chunk_data = {
"content": element.get("text"),
"metadata": self._sanitize_metadata(
element.get("metadata")
{
**file_metadata,
**element.get("metadata"),
}
),
}
chunks.append(chunk_data)
Expand Down
4 changes: 3 additions & 1 deletion service/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def create_route_layer() -> RouteLayer:
async def get_documents(
*, vector_service: BaseVectorDatabase, payload: RequestPayload
) -> list[BaseDocumentChunk]:
chunks = await vector_service.query(input=payload.input, top_k=5)
chunks = await vector_service.query(
input=payload.input, filter=payload.filter, top_k=5
)
# filter out documents with empty content
chunks = [chunk for chunk in chunks if chunk.content.strip()]
if not len(chunks):
Expand Down
4 changes: 3 additions & 1 deletion vectordbs/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tqdm import tqdm

from models.document import BaseDocumentChunk
from models.query import Filter
from vectordbs.base import BaseVectorDatabase


Expand Down Expand Up @@ -54,12 +55,13 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
for i in range(0, len(documents), 5):
self.collection.insert_many(documents=documents[i : i + 5])

async def query(self, input: str, top_k: int = 4) -> List:
async def query(self, input: str, filter: Filter = None, top_k: int = 4) -> List:
vectors = await self._generate_vectors(input=input)
results = self.collection.vector_find(
vector=vectors[0],
limit=top_k,
fields={"text", "page_number", "source", "document_id"},
filter=filter,
)
return [
BaseDocumentChunk(
Expand Down
5 changes: 4 additions & 1 deletion vectordbs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from models.query import Filter
from utils.logger import logger


Expand All @@ -24,7 +25,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]):
pass

@abstractmethod
async def query(self, input: str, top_k: int = 25) -> List[BaseDocumentChunk]:
async def query(
self, input: str, filter: Filter, top_k: int = 25
) -> List[BaseDocumentChunk]:
pass

@abstractmethod
Expand Down
6 changes: 5 additions & 1 deletion vectordbs/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from qdrant_client.http import models as rest
from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from models.query import Filter
from vectordbs.base import BaseVectorDatabase

MAX_QUERY_TOP_K = 5
Expand Down Expand Up @@ -58,14 +59,17 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
self.collection.upsert(records)
self.collection.create_index()

async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List:
async def query(
self, input: str, filter: Filter = None, top_k: int = MAX_QUERY_TOP_K
) -> List:
vectors = await self._generate_vectors(input=input)

results = self.collection.query(
data=vectors[0],
limit=top_k,
include_metadata=True,
include_value=False,
filters=filter.model_dump() if filter else {},
)

chunks = []
Expand Down
8 changes: 7 additions & 1 deletion vectordbs/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from models.query import Filter
from utils.logger import logger
from vectordbs.base import BaseVectorDatabase

Expand Down Expand Up @@ -52,7 +53,11 @@ async def upsert(self, chunks: List[BaseDocumentChunk], batch_size: int = 100):
raise

async def query(
self, input: str, top_k: int = 25, include_metadata: bool = True
self,
input: str,
filter: Filter = None,
top_k: int = 25,
include_metadata: bool = True,
) -> list[BaseDocumentChunk]:
if self.index is None:
raise ValueError(f"Pinecone index {self.index_name} is not initialized.")
Expand All @@ -61,6 +66,7 @@ async def query(
vector=query_vectors[0],
top_k=top_k,
include_metadata=include_metadata,
filter=filter,
)
chunks = []
if results.get("matches"):
Expand Down
6 changes: 5 additions & 1 deletion vectordbs/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from models.query import Filter
from vectordbs.base import BaseVectorDatabase

MAX_QUERY_TOP_K = 5
Expand Down Expand Up @@ -69,11 +70,14 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:

self.client.upsert(collection_name=self.index_name, wait=True, points=points)

async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List:
async def query(
self, input: str, filter: Filter, top_k: int = MAX_QUERY_TOP_K
) -> List:
vectors = await self._generate_vectors(input=input)
search_result = self.client.search(
collection_name=self.index_name,
query_vector=("content", vectors[0]),
query_filter=filter,
limit=top_k,
with_payload=True,
)
Expand Down
7 changes: 6 additions & 1 deletion vectordbs/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from utils.logger import logger
from vectordbs.base import BaseVectorDatabase

from models.query import Filter


class WeaviateService(BaseVectorDatabase):
def __init__(
Expand Down Expand Up @@ -72,7 +74,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
batch.add_data_object(**vector_data)
batch.flush()

async def query(self, input: str, top_k: int = 25) -> list[BaseDocumentChunk]:
async def query(
self, input: str, filter: Filter = {}, top_k: int = 25
) -> list[BaseDocumentChunk]:
vectors = await self._generate_vectors(input=input)
vector = {"vector": vectors[0]}

Expand All @@ -84,6 +88,7 @@ async def query(self, input: str, top_k: int = 25) -> list[BaseDocumentChunk]:
)
.with_near_vector(vector)
.with_limit(top_k)
.with_where(filter)
.do()
)
if "data" not in response:
Expand Down