Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow file selector to be disabled #36

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 37 additions & 33 deletions libs/ktem/ktem/index/file/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,58 +67,64 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
documents
get_extra_table: if True, for each retrieved document, the pipeline will look
for surrounding tables (e.g. within the page)
top_k: number of documents to retrieve
mmr: whether to use mmr to re-rank the documents
disabled: whether the pipeline is disabled
"""

vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
reranker: BaseReranking
get_extra_table: bool = False
mmr: bool = False
top_k: int = 5

def run(
self,
text: str,
top_k: int = 5,
mmr: bool = False,
doc_ids: Optional[list[str]] = None,
*args,
**kwargs,
) -> list[RetrievedDocument]:
"""Retrieve document excerpts similar to the text

Args:
text: the text to retrieve similar documents
top_k: number of documents to retrieve
mmr: whether to use mmr to re-rank the documents
doc_ids: list of document ids to constraint the retrieval
"""
if not doc_ids:
logger.info(f"Skip retrieval because of no selected files: {self}")
return []

Index = self._Index

kwargs = {}
if doc_ids:
with Session(engine) as session:
stmt = select(Index).where(
Index.relation_type == "vector",
Index.source_id.in_(doc_ids), # type: ignore
)
results = session.execute(stmt)
vs_ids = [r[0].target_id for r in results.all()]

kwargs["filters"] = MetadataFilters(
filters=[
MetadataFilter(
key="doc_id",
value=vs_id,
operator=FilterOperator.EQ,
)
for vs_id in vs_ids
],
condition=FilterCondition.OR,
vrkwargs = {}
lone17 marked this conversation as resolved.
Show resolved Hide resolved
with Session(engine) as session:
stmt = select(Index).where(
Index.relation_type == "vector",
Index.source_id.in_(doc_ids), # type: ignore
)
results = session.execute(stmt)
vs_ids = [r[0].target_id for r in results.all()]

vrkwargs["filters"] = MetadataFilters(
filters=[
MetadataFilter(
key="doc_id",
value=vs_id,
operator=FilterOperator.EQ,
)
for vs_id in vs_ids
],
condition=FilterCondition.OR,
)

if mmr:
if self.mmr:
# TODO: double check that llama-index MMR works correctly
kwargs["mode"] = VectorStoreQueryMode.MMR
kwargs["mmr_threshold"] = 0.5
vrkwargs["mode"] = VectorStoreQueryMode.MMR
vrkwargs["mmr_threshold"] = 0.5

# rerank
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
docs = self.vector_retrieval(text=text, top_k=self.top_k, **vrkwargs)
if docs and self.get_from_path("reranker"):
docs = self.reranker(docs, query=text)

Expand Down Expand Up @@ -221,18 +227,16 @@ def get_pipeline(cls, user_settings, index_settings, selected):
retriever = cls(
get_extra_table=user_settings["prioritize_table"],
reranker=user_settings["reranking_llm"],
top_k=user_settings["num_retrieval"],
mmr=user_settings["mmr"],
)
if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore

retriever.vector_retrieval.embedding = embedding_models_manager[
index_settings.get("embedding", embedding_models_manager.get_default_name())
]
kwargs = {
".top_k": int(user_settings["num_retrieval"]),
".mmr": user_settings["mmr"],
".doc_ids": selected,
}
kwargs = {".doc_ids": selected}
retriever.set_run(kwargs, temp=True)
return retriever

Expand Down
43 changes: 39 additions & 4 deletions libs/ktem/ktem/index/file/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,20 +512,55 @@ def __init__(self, app, index):
self._index = index
self.on_building_ui()

def default(self):
return "disabled", []

def on_building_ui(self):
default_mode, default_selector = self.default()

self.mode = gr.Radio(
value=default_mode,
choices=[
("Disabled", "disabled"),
("Search All", "all"),
("Select", "select"),
],
container=False,
)
self.selector = gr.Dropdown(
label="Files",
choices=[],
choices=default_selector,
multiselect=True,
container=False,
interactive=True,
visible=False,
)

def on_register_events(self):
self.mode.change(
fn=lambda mode: gr.update(visible=mode == "select"),
inputs=[self.mode],
outputs=[self.selector],
)

def as_gradio_component(self):
return self.selector
return [self.mode, self.selector]

def get_selected_ids(self, components):
mode, selected = components[0], components[1]
if mode == "disabled":
return []
elif mode == "select":
return selected

file_ids = []
with Session(engine) as session:
statement = select(self._index._resources["Source"].id)
results = session.execute(statement).all()
for (id,) in results:
file_ids.append(id)

def get_selected_ids(self, selected):
return selected
return file_ids

def load_files(self, selected_files):
options = []
Expand Down
2 changes: 2 additions & 0 deletions libs/ktem/ktem/pages/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def on_building_ui(self):
len(self._indices_input) + len(gr_index),
)
)
index.default_selector = index_ui.default()
self._indices_input.extend(gr_index)
else:
index.selector = len(self._indices_input)
index.default_selector = index_ui.default()
self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui)

Expand Down
4 changes: 2 additions & 2 deletions libs/ktem/ktem/pages/chat/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ def select_conv(self, conversation_id):
if index.selector is None:
continue
if isinstance(index.selector, int):
indices.append(selected.get(str(index.id), []))
indices.append(selected.get(str(index.id), index.default_selector))
if isinstance(index.selector, tuple):
indices.extend(selected.get(str(index.id), [[]] * len(index.selector)))
indices.extend(selected.get(str(index.id), index.default_selector))

return id_, id_, name, chats, info_panel, state, *indices

Expand Down
Loading