Skip to content

Commit

Permalink
Support hybrid vector retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
trducng committed May 26, 2024
1 parent ebf1315 commit 646c116
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 12 deletions.
50 changes: 42 additions & 8 deletions libs/kotaemon/kotaemon/indices/vectorindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class VectorRetrieval(BaseRetrieval):
doc_store: Optional[BaseDocumentStore] = None
embedding: BaseEmbeddings
rerankers: Sequence[BaseReranking] = []
top_k: int = 1
top_k: int = 5
retrieval_mode: str = "hybrid" # vector, text, hybrid

def run(
self, text: str | Document, top_k: Optional[int] = None, **kwargs
Expand All @@ -101,13 +102,46 @@ def run(
"retrieve the documents"
)

emb: list[float] = self.embedding(text)[0].embedding
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k, **kwargs)
docs = self.doc_store.get(ids)
result = [
RetrievedDocument(**doc.to_dict(), score=score)
for doc, score in zip(docs, scores)
]
result: list[RetrievedDocument] = []
# TODO: should declare scope directly in the run params
scope = kwargs.pop("scope", None)
emb: list[float]

if self.retrieval_mode == "vector":
emb = self.embedding(text)[0].embedding
_, scores, ids = self.vector_store.query(
embedding=emb, top_k=top_k, **kwargs
)
docs = self.doc_store.get(ids)
result = [
RetrievedDocument(**doc.to_dict(), score=score)
for doc, score in zip(docs, scores)
]
elif self.retrieval_mode == "text":
query = text.text if isinstance(text, Document) else text
docs = self.doc_store.query(query, top_k=top_k, doc_ids=scope)
result = [RetrievedDocument(**doc.to_dict(), score=-1.0) for doc in docs]
elif self.retrieval_mode == "hybrid":
# similartiy search section
emb = self.embedding(text)[0].embedding
_, vs_scores, vs_ids = self.vector_store.query(
embedding=emb, top_k=top_k, **kwargs
)
vs_docs = self.doc_store.get(vs_ids)

# full-text search section
query = text.text if isinstance(text, Document) else text
docs = self.doc_store.query(query, top_k=top_k, doc_ids=scope)
result = [
RetrievedDocument(**doc.to_dict(), score=-1.0)
for doc in docs
if doc not in vs_ids
]
result += [
RetrievedDocument(**doc.to_dict(), score=score)
for doc, score in zip(vs_docs, vs_scores)
]

# use additional reranker to re-order the document list
if self.rerankers:
for reranker in self.rerankers:
Expand Down
7 changes: 7 additions & 0 deletions libs/kotaemon/kotaemon/storages/docstores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def count(self) -> int:
"""Count number of documents"""
...

@abstractmethod
def query(
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
) -> List[Document]:
"""Search document store using search query"""
...

@abstractmethod
def delete(self, ids: Union[List[str], str]):
"""Delete document by id"""
Expand Down
5 changes: 4 additions & 1 deletion libs/kotaemon/kotaemon/storages/docstores/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def add(
"_id": doc_id,
}
requests.append(request)
self.es_bulk(self.client, requests)

success, failed = self.es_bulk(self.client, requests)
print("Added/Updated documents to index", success)
print("Failed documents to index", failed)

if refresh_indices:
self.client.indices.refresh(index=self.index_name)
Expand Down
6 changes: 6 additions & 0 deletions libs/kotaemon/kotaemon/storages/docstores/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def load(self, path: Union[str, Path]):
# Also, for portability, use SQLAlchemy for document store.
self._store = {key: Document.from_dict(value) for key, value in store.items()}

def query(
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
) -> List[Document]:
"""Perform full-text search on document store"""
return []

def __persist_flow__(self):
return {}

Expand Down
8 changes: 6 additions & 2 deletions libs/ktem/ktem/index/file/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
get_extra_table: bool = False
mmr: bool = False
top_k: int = 5
retrieval_mode: str = "hybrid"

@Node.auto(depends_on=["embedding", "VS", "DS"])
def vector_retrieval(self) -> VectorRetrieval:
return VectorRetrieval(
embedding=self.embedding,
vector_store=self.VS,
doc_store=self.DS,
retrieval_mode=self.retrieval_mode, # type: ignore
)

def run(
Expand All @@ -105,7 +107,7 @@ def run(
logger.info(f"Skip retrieval because of no selected files: {self}")
return []

retrieval_kwargs = {}
retrieval_kwargs: dict = {}
with Session(engine) as session:
stmt = select(self.Index).where(
self.Index.relation_type == "vector",
Expand All @@ -114,6 +116,7 @@ def run(
results = session.execute(stmt)
vs_ids = [r[0].target_id for r in results.all()]

retrieval_kwargs["scope"] = vs_ids
retrieval_kwargs["filters"] = MetadataFilters(
filters=[
MetadataFilter(
Expand Down Expand Up @@ -200,7 +203,7 @@ def get_user_settings(cls) -> dict:
},
"retrieval_mode": {
"name": "Retrieval mode",
"value": "vector",
"value": "hybrid",
"choices": ["vector", "text", "hybrid"],
"component": "dropdown",
},
Expand Down Expand Up @@ -241,6 +244,7 @@ def get_pipeline(cls, user_settings, index_settings, selected):
"embedding", embedding_models_manager.get_default_name()
)
],
retrieval_mode=user_settings["retrieval_mode"],
)
if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore
Expand Down
4 changes: 3 additions & 1 deletion libs/ktem/ktem/reasoning/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,9 @@ def retrieve(
query = message
print(f"Rewritten query: {query}")
if not query:
return [], []
# TODO: previously return [], [] because we think this message as something
# like "Hello", "I need help"...
query = message

docs, doc_ids = [], []
for retriever in self.retrievers:
Expand Down

0 comments on commit 646c116

Please sign in to comment.