diff --git a/.gitignore b/.gitignore
index 2e3fb01ca..711a741e7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -452,6 +452,7 @@ $RECYCLE.BIN/
.theflow/
# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm
+*.py[coid]
logs/
.gitsecret/keys/random_seed
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 21356cef6..3f68b56f3 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -52,7 +52,12 @@ repos:
hooks:
- id: mypy
additional_dependencies:
- [types-PyYAML==6.0.12.11, "types-requests", "sqlmodel"]
+ [
+ types-PyYAML==6.0.12.11,
+ "types-requests",
+ "sqlmodel",
+ "types-Markdown",
+ ]
args: ["--check-untyped-defs", "--ignore-missing-imports"]
exclude: "^templates/"
- repo: https://github.com/codespell-project/codespell
diff --git a/libs/kotaemon/kotaemon/indices/ingests/files.py b/libs/kotaemon/kotaemon/indices/ingests/files.py
index ed00e5cb7..75f944e4d 100644
--- a/libs/kotaemon/kotaemon/indices/ingests/files.py
+++ b/libs/kotaemon/kotaemon/indices/ingests/files.py
@@ -7,6 +7,7 @@
from kotaemon.indices.extractors import BaseDocParser
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
from kotaemon.loaders import (
+ AdobeReader,
DirectoryReader,
MathpixPDFReader,
OCRReader,
@@ -41,7 +42,7 @@ class DocumentIngestor(BaseComponent):
The default file extractors are stored in `KH_DEFAULT_FILE_EXTRACTORS`
"""
- pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
+ pdf_mode: str = "normal" # "normal", "mathpix", "ocr", "multimodal"
doc_parsers: list[BaseDocParser] = Param(default_callback=lambda _: [])
text_splitter: BaseSplitter = TokenSplitter.withx(
chunk_size=1024,
@@ -61,6 +62,8 @@ def _get_reader(self, input_files: list[str | Path]):
pass # use default loader of llama-index which is pypdf
elif self.pdf_mode == "ocr":
file_extractors[".pdf"] = OCRReader()
+ elif self.pdf_mode == "multimodal":
+ file_extractors[".pdf"] = AdobeReader()
else:
file_extractors[".pdf"] = MathpixPDFReader()
diff --git a/libs/kotaemon/kotaemon/indices/qa/citation.py b/libs/kotaemon/kotaemon/indices/qa/citation.py
index f1a53c797..3192a07fa 100644
--- a/libs/kotaemon/kotaemon/indices/qa/citation.py
+++ b/libs/kotaemon/kotaemon/indices/qa/citation.py
@@ -104,18 +104,16 @@ def invoke(self, context: str, question: str):
print("CitationPipeline: invoking LLM")
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
print("CitationPipeline: finish invoking LLM")
+ if not llm_output.messages:
+ return None
+ function_output = llm_output.messages[0].additional_kwargs["function_call"][
+ "arguments"
+ ]
+ output = QuestionAnswer.parse_raw(function_output)
except Exception as e:
print(e)
return None
- if not llm_output.messages:
- return None
-
- function_output = llm_output.messages[0].additional_kwargs["function_call"][
- "arguments"
- ]
- output = QuestionAnswer.parse_raw(function_output)
-
return output
async def ainvoke(self, context: str, question: str):
diff --git a/libs/kotaemon/kotaemon/loaders/__init__.py b/libs/kotaemon/kotaemon/loaders/__init__.py
index d742b52f2..a59d71315 100644
--- a/libs/kotaemon/kotaemon/loaders/__init__.py
+++ b/libs/kotaemon/kotaemon/loaders/__init__.py
@@ -1,10 +1,11 @@
+from .adobe_loader import AdobeReader
from .base import AutoReader, BaseReader
from .composite_loader import DirectoryReader
from .docx_loader import DocxReader
from .excel_loader import PandasExcelReader
from .html_loader import HtmlReader
from .mathpix_loader import MathpixPDFReader
-from .ocr_loader import OCRReader
+from .ocr_loader import ImageReader, OCRReader
from .unstructured_loader import UnstructuredReader
__all__ = [
@@ -12,9 +13,11 @@
"BaseReader",
"PandasExcelReader",
"MathpixPDFReader",
+ "ImageReader",
"OCRReader",
"DirectoryReader",
"UnstructuredReader",
"DocxReader",
"HtmlReader",
+ "AdobeReader",
]
diff --git a/libs/kotaemon/kotaemon/loaders/adobe_loader.py b/libs/kotaemon/kotaemon/loaders/adobe_loader.py
new file mode 100644
index 000000000..09a802c37
--- /dev/null
+++ b/libs/kotaemon/kotaemon/loaders/adobe_loader.py
@@ -0,0 +1,186 @@
+import logging
+import os
+import re
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from decouple import config
+from llama_index.readers.base import BaseReader
+
+from kotaemon.base import Document
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_VLM_ENDPOINT = (
+ "{0}openai/deployments/{1}/chat/completions?api-version={2}".format(
+ config("AZURE_OPENAI_ENDPOINT", default=""),
+ "gpt-4-vision",
+ config("OPENAI_API_VERSION", default=""),
+ )
+)
+
+
+class AdobeReader(BaseReader):
+ """Read PDF using the Adobe's PDF Services.
+ Be able to extract text, table, and figure with high accuracy
+
+ Example:
+ ```python
+ >> from kotaemon.loaders import AdobeReader
+ >> reader = AdobeReader()
+ >> documents = reader.load_data("path/to/pdf")
+ ```
+ Args:
+ endpoint: URL to the Vision Language Model endpoint. If not provided,
+ will use the default `kotaemon.loaders.adobe_loader.DEFAULT_VLM_ENDPOINT`
+
+ max_figures_to_caption: an int decides how many figured will be captioned.
+ The rest will be ignored (are indexed without captions).
+ """
+
+ def __init__(
+ self,
+ vlm_endpoint: Optional[str] = None,
+ max_figures_to_caption: int = 100,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+ """Init params"""
+ super().__init__(*args)
+ self.table_regex = r"/Table(\[\d+\])?$"
+ self.figure_regex = r"/Figure(\[\d+\])?$"
+ self.vlm_endpoint = vlm_endpoint or DEFAULT_VLM_ENDPOINT
+ self.max_figures_to_caption = max_figures_to_caption
+
+ def load_data(
+ self, file: Path, extra_info: Optional[Dict] = None, **kwargs
+ ) -> List[Document]:
+ """Load data by calling to the Adobe's API
+
+ Args:
+ file (Path): Path to the PDF file
+
+ Returns:
+ List[Document]: list of documents extracted from the PDF file,
+ includes 3 types: text, table, and image
+
+ """
+ from .utils.adobe import (
+ generate_figure_captions,
+ load_json,
+ parse_figure_paths,
+ parse_table_paths,
+ request_adobe_service,
+ )
+
+ filename = file.name
+ filepath = str(Path(file).resolve())
+ output_path = request_adobe_service(file_path=str(file), output_path="")
+ results_path = os.path.join(output_path, "structuredData.json")
+
+ if not os.path.exists(results_path):
+ logger.exception("Fail to parse the document.")
+ return []
+
+ data = load_json(results_path)
+
+ texts = defaultdict(list)
+ tables = []
+ figures = []
+
+ elements = data["elements"]
+ for item_id, item in enumerate(elements):
+ page_number = item.get("Page", -1) + 1
+ item_path = item["Path"]
+ item_text = item.get("Text", "")
+
+ file_paths = [
+ Path(output_path) / path for path in item.get("filePaths", [])
+ ]
+ prev_item = elements[item_id - 1]
+ title = prev_item.get("Text", "")
+
+ if re.search(self.table_regex, item_path):
+ table_content = parse_table_paths(file_paths)
+ if not table_content:
+ continue
+ table_caption = (
+ table_content.replace("|", "").replace("---", "")
+ + f"\n(Table in Page {page_number}. {title})"
+ )
+ tables.append((page_number, table_content, table_caption))
+
+ elif re.search(self.figure_regex, item_path):
+ figure_caption = (
+ item_text + f"\n(Figure in Page {page_number}. {title})"
+ )
+ figure_content = parse_figure_paths(file_paths)
+ if not figure_content:
+ continue
+ figures.append([page_number, figure_content, figure_caption])
+
+ else:
+ if item_text and "Table" not in item_path and "Figure" not in item_path:
+ texts[page_number].append(item_text)
+
+ # get figure caption using GPT-4V
+ figure_captions = generate_figure_captions(
+ self.vlm_endpoint,
+ [item[1] for item in figures],
+ self.max_figures_to_caption,
+ )
+ for item, caption in zip(figures, figure_captions):
+ # update figure caption
+ item[2] += " " + caption
+
+ # Wrap elements with Document
+ documents = []
+
+ # join plain text elements
+ for page_number, txts in texts.items():
+ documents.append(
+ Document(
+ text="\n".join(txts),
+ metadata={
+ "page_label": page_number,
+ "file_name": filename,
+ "file_path": filepath,
+ },
+ )
+ )
+
+ # table elements
+ for page_number, table_content, table_caption in tables:
+ documents.append(
+ Document(
+ text=table_caption,
+ metadata={
+ "table_origin": table_content,
+ "type": "table",
+ "page_label": page_number,
+ "file_name": filename,
+ "file_path": filepath,
+ },
+ metadata_template="",
+ metadata_seperator="",
+ )
+ )
+
+ # figure elements
+ for page_number, figure_content, figure_caption in figures:
+ documents.append(
+ Document(
+ text=figure_caption,
+ metadata={
+ "image_origin": figure_content,
+ "type": "image",
+ "page_label": page_number,
+ "file_name": filename,
+ "file_path": filepath,
+ },
+ metadata_template="",
+ metadata_seperator="",
+ )
+ )
+ return documents
diff --git a/libs/kotaemon/kotaemon/loaders/ocr_loader.py b/libs/kotaemon/kotaemon/loaders/ocr_loader.py
index e68971768..bb1ac5dca 100644
--- a/libs/kotaemon/kotaemon/loaders/ocr_loader.py
+++ b/libs/kotaemon/kotaemon/loaders/ocr_loader.py
@@ -125,3 +125,70 @@ def load_data(
)
return documents
+
+
+class ImageReader(BaseReader):
+ """Read PDF using OCR, with high focus on table extraction
+
+ Example:
+ ```python
+ >> from knowledgehub.loaders import OCRReader
+ >> reader = OCRReader()
+ >> documents = reader.load_data("path/to/pdf")
+ ```
+
+ Args:
+ endpoint: URL to FullOCR endpoint. If not provided, will look for
+ environment variable `OCR_READER_ENDPOINT` or use the default
+ `knowledgehub.loaders.ocr_loader.DEFAULT_OCR_ENDPOINT`
+ (http://127.0.0.1:8000/v2/ai/infer/)
+ use_ocr: whether to use OCR to read text (e.g: from images, tables) in the PDF
+ If False, only the table and text within table cells will be extracted.
+ """
+
+ def __init__(self, endpoint: Optional[str] = None):
+ """Init the OCR reader with OCR endpoint (FullOCR pipeline)"""
+ super().__init__()
+ self.ocr_endpoint = endpoint or os.getenv(
+ "OCR_READER_ENDPOINT", DEFAULT_OCR_ENDPOINT
+ )
+
+ def load_data(
+ self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
+ ) -> List[Document]:
+ """Load data using OCR reader
+
+ Args:
+ file_path (Path): Path to PDF file
+ debug_path (Path): Path to store debug image output
+ artifact_path (Path): Path to OCR endpoints artifacts directory
+
+ Returns:
+ List[Document]: list of documents extracted from the PDF file
+ """
+ file_path = Path(file_path).resolve()
+
+ with file_path.open("rb") as content:
+ files = {"input": content}
+ data = {"job_id": uuid4(), "table_only": False}
+
+ # call the API from FullOCR endpoint
+ if "response_content" in kwargs:
+ # overriding response content if specified
+ ocr_results = kwargs["response_content"]
+ else:
+ # call original API
+ resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
+ ocr_results = resp.json()["result"]
+
+ extra_info = extra_info or {}
+ result = []
+ for ocr_result in ocr_results:
+ result.append(
+ Document(
+ content=ocr_result["csv_string"],
+ metadata=extra_info,
+ )
+ )
+
+ return result
diff --git a/libs/kotaemon/kotaemon/loaders/utils/adobe.py b/libs/kotaemon/kotaemon/loaders/utils/adobe.py
new file mode 100644
index 000000000..f1adcd5a7
--- /dev/null
+++ b/libs/kotaemon/kotaemon/loaders/utils/adobe.py
@@ -0,0 +1,246 @@
+# need pip install pdfservices-sdk==2.3.0
+
+import base64
+import json
+import logging
+import os
+import tempfile
+import zipfile
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import List, Union
+
+import pandas as pd
+from decouple import config
+
+from kotaemon.loaders.utils.gpt4v import generate_gpt4v
+
+
+def request_adobe_service(file_path: str, output_path: str = "") -> str:
+ """Main function to call the adobe service, and unzip the results.
+ Args:
+ file_path (str): path to the pdf file
+ output_path (str): path to store the results
+
+ Returns:
+ output_path (str): path to the results
+
+ """
+ try:
+ from adobe.pdfservices.operation.auth.credentials import Credentials
+ from adobe.pdfservices.operation.exception.exceptions import (
+ SdkException,
+ ServiceApiException,
+ ServiceUsageException,
+ )
+ from adobe.pdfservices.operation.execution_context import ExecutionContext
+ from adobe.pdfservices.operation.io.file_ref import FileRef
+ from adobe.pdfservices.operation.pdfops.extract_pdf_operation import (
+ ExtractPDFOperation,
+ )
+ from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_element_type import ( # noqa: E501
+ ExtractElementType,
+ )
+ from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_pdf_options import ( # noqa: E501
+ ExtractPDFOptions,
+ )
+ from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_renditions_element_type import ( # noqa: E501
+ ExtractRenditionsElementType,
+ )
+ except ImportError:
+ raise ImportError(
+ "pdfservices-sdk is not installed. "
+ "Please install it by running `pip install pdfservices-sdk"
+ "@git+https://github.com/niallcm/pdfservices-python-sdk.git"
+ "@bump-and-unfreeze-requirements`"
+ )
+
+ if not output_path:
+ output_path = tempfile.mkdtemp()
+
+ try:
+ # Initial setup, create credentials instance.
+ credentials = (
+ Credentials.service_principal_credentials_builder()
+ .with_client_id(config("PDF_SERVICES_CLIENT_ID", default=""))
+ .with_client_secret(config("PDF_SERVICES_CLIENT_SECRET", default=""))
+ .build()
+ )
+
+ # Create an ExecutionContext using credentials
+ # and create a new operation instance.
+ execution_context = ExecutionContext.create(credentials)
+ extract_pdf_operation = ExtractPDFOperation.create_new()
+
+ # Set operation input from a source file.
+ source = FileRef.create_from_local_file(file_path)
+ extract_pdf_operation.set_input(source)
+
+ # Build ExtractPDF options and set them into the operation
+ extract_pdf_options: ExtractPDFOptions = (
+ ExtractPDFOptions.builder()
+ .with_elements_to_extract(
+ [ExtractElementType.TEXT, ExtractElementType.TABLES]
+ )
+ .with_elements_to_extract_renditions(
+ [
+ ExtractRenditionsElementType.TABLES,
+ ExtractRenditionsElementType.FIGURES,
+ ]
+ )
+ .build()
+ )
+ extract_pdf_operation.set_options(extract_pdf_options)
+
+ # Execute the operation.
+ result: FileRef = extract_pdf_operation.execute(execution_context)
+
+ # Save the result to the specified location.
+ zip_file_path = os.path.join(
+ output_path, "ExtractTextTableWithFigureTableRendition.zip"
+ )
+ result.save_as(zip_file_path)
+ # Open the ZIP file
+ with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
+ # Extract all contents to the destination folder
+ zip_ref.extractall(output_path)
+ except (ServiceApiException, ServiceUsageException, SdkException):
+ logging.exception("Exception encountered while executing operation")
+
+ return output_path
+
+
+def make_markdown_table(table_as_list: List[str]) -> str:
+ """
+ Convert table from python list representation to markdown format.
+ The input list consists of rows of tables, the first row is the header.
+
+ Args:
+ table_as_list: list of table rows
+ Example: [["Name", "Age", "Height"],
+ ["Jake", 20, 5'10],
+ ["Mary", 21, 5'7]]
+ Returns:
+ markdown representation of the table
+ """
+ markdown = "\n" + str("| ")
+
+ for e in table_as_list[0]:
+ to_add = " " + str(e) + str(" |")
+ markdown += to_add
+ markdown += "\n"
+
+ markdown += "| "
+ for i in range(len(table_as_list[0])):
+ markdown += str("--- | ")
+ markdown += "\n"
+
+ for entry in table_as_list[1:]:
+ markdown += str("| ")
+ for e in entry:
+ to_add = str(e) + str(" | ")
+ markdown += to_add
+ markdown += "\n"
+
+ return markdown + "\n"
+
+
+def load_json(input_path: Union[str | Path]) -> dict:
+ """Load json file"""
+ with open(input_path, "r") as fi:
+ data = json.load(fi)
+
+ return data
+
+
+def load_excel(input_path: Union[str | Path]) -> str:
+ """Load excel file and convert to markdown"""
+
+ df = pd.read_excel(input_path).fillna("")
+ # Convert dataframe to a list of rows
+ row_list = [df.columns.values.tolist()] + df.values.tolist()
+
+ for item_id, item in enumerate(row_list[0]):
+ if "Unnamed" in item:
+ row_list[0][item_id] = ""
+
+ for row in row_list:
+ for item_id, item in enumerate(row):
+ row[item_id] = str(item).replace("_x000D_", " ").replace("\n", " ").strip()
+
+ markdown_str = make_markdown_table(row_list)
+ return markdown_str
+
+
+def encode_image_base64(image_path: Union[str | Path]) -> Union[bytes, str]:
+ """Convert image to base64"""
+
+ with open(image_path, "rb") as image_file:
+ return base64.b64encode(image_file.read()).decode("utf-8")
+
+
+def parse_table_paths(file_paths: List[Path]) -> str:
+ """Read the table stored in an excel file given the file path"""
+
+ content = ""
+ for path in file_paths:
+ if path.suffix == ".xlsx":
+ content = load_excel(path)
+ break
+ return content
+
+
+def parse_figure_paths(file_paths: List[Path]) -> Union[bytes, str]:
+ """Read and convert an image to base64 given the image path"""
+
+ content = ""
+ for path in file_paths:
+ if path.suffix == ".png":
+ base64_image = encode_image_base64(path)
+ content = f"data:image/png;base64,{base64_image}" # type: ignore
+ break
+ return content
+
+
+def generate_single_figure_caption(vlm_endpoint: str, figure: str) -> str:
+ """Summarize a single figure using GPT-4V"""
+ if figure:
+ output = generate_gpt4v(
+ endpoint=vlm_endpoint,
+ prompt="Provide a short 2 sentence summary of this image?",
+ images=figure,
+ )
+ if "sorry" in output.lower():
+ output = ""
+ else:
+ output = ""
+ return output
+
+
+def generate_figure_captions(
+ vlm_endpoint: str, figures: List, max_figures_to_process: int
+) -> List:
+ """Summarize several figures using GPT-4V.
+ Args:
+ vlm_endpoint (str): endpoint to the vision language model service
+ figures (List): list of base64 images
+ max_figures_to_process (int): the maximum number of figures will be summarized,
+ the rest are ignored.
+
+ Returns:
+ results (List[str]): list of all figure captions and empty strings for
+ ignored figures.
+ """
+ to_gen_figures = figures[:max_figures_to_process]
+ other_figures = figures[max_figures_to_process:]
+
+ with ThreadPoolExecutor() as executor:
+ futures = [
+ executor.submit(
+ lambda: generate_single_figure_caption(vlm_endpoint, figure)
+ )
+ for figure in to_gen_figures
+ ]
+
+ results = [future.result() for future in futures]
+ return results + [""] * len(other_figures)
diff --git a/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py b/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py
new file mode 100644
index 000000000..1e219d660
--- /dev/null
+++ b/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py
@@ -0,0 +1,96 @@
+import json
+from typing import Any, List
+
+import requests
+from decouple import config
+
+
+def generate_gpt4v(
+ endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
+) -> str:
+ # OpenAI API Key
+ api_key = config("AZURE_OPENAI_API_KEY", default="")
+ headers = {"Content-Type": "application/json", "api-key": api_key}
+
+ if isinstance(images, str):
+ images = [images]
+
+ payload = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": prompt},
+ ]
+ + [
+ {
+ "type": "image_url",
+ "image_url": {"url": image},
+ }
+ for image in images
+ ],
+ }
+ ],
+ "max_tokens": max_tokens,
+ }
+
+ try:
+ response = requests.post(endpoint, headers=headers, json=payload)
+ output = response.json()
+ output = output["choices"][0]["message"]["content"]
+ except Exception:
+ output = ""
+ return output
+
+
+def stream_gpt4v(
+ endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
+) -> Any:
+ # OpenAI API Key
+ api_key = config("AZURE_OPENAI_API_KEY", default="")
+ headers = {"Content-Type": "application/json", "api-key": api_key}
+
+ if isinstance(images, str):
+ images = [images]
+
+ payload = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": prompt},
+ ]
+ + [
+ {
+ "type": "image_url",
+ "image_url": {"url": image},
+ }
+ for image in images
+ ],
+ }
+ ],
+ "max_tokens": max_tokens,
+ "stream": True,
+ }
+ try:
+ response = requests.post(endpoint, headers=headers, json=payload, stream=True)
+ assert response.status_code == 200, str(response.content)
+ output = ""
+ for line in response.iter_lines():
+ if line:
+ if line.startswith(b"\xef\xbb\xbf"):
+ line = line[9:]
+ else:
+ line = line[6:]
+ try:
+ if line == "[DONE]":
+ break
+ line = json.loads(line.decode("utf-8"))
+ except Exception:
+ break
+ if len(line["choices"]):
+ output += line["choices"][0]["delta"].get("content", "")
+ yield line["choices"][0]["delta"].get("content", "")
+ except Exception:
+ output = ""
+ return output
diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml
index e1e30280d..73c3e8ab7 100644
--- a/libs/kotaemon/pyproject.toml
+++ b/libs/kotaemon/pyproject.toml
@@ -60,6 +60,7 @@ adv = [
"cohere",
"elasticsearch",
"llama-cpp-python",
+ "pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements",
]
dev = [
"ipython",
@@ -69,6 +70,7 @@ dev = [
"flake8",
"sphinx",
"coverage",
+ "python-decouple"
]
all = ["kotaemon[adv,dev]"]
diff --git a/libs/kotaemon/tests/_test_multimodal_reader.py b/libs/kotaemon/tests/_test_multimodal_reader.py
new file mode 100644
index 000000000..b07786f03
--- /dev/null
+++ b/libs/kotaemon/tests/_test_multimodal_reader.py
@@ -0,0 +1,21 @@
+# TODO: This test is broken and should be rewritten
+from pathlib import Path
+
+from kotaemon.loaders import AdobeReader
+
+# from dotenv import load_dotenv
+
+
+input_file = Path(__file__).parent / "resources" / "multimodal.pdf"
+
+# load_dotenv()
+
+
+def test_adobe_reader():
+ reader = AdobeReader()
+ documents = reader.load_data(input_file)
+ table_docs = [doc for doc in documents if doc.metadata.get("type", "") == "table"]
+ assert len(table_docs) == 2
+
+ figure_docs = [doc for doc in documents if doc.metadata.get("type", "") == "image"]
+ assert len(figure_docs) == 2
diff --git a/libs/kotaemon/tests/resources/multimodal.pdf b/libs/kotaemon/tests/resources/multimodal.pdf
new file mode 100644
index 000000000..29c2bdc96
Binary files /dev/null and b/libs/kotaemon/tests/resources/multimodal.pdf differ
diff --git a/libs/ktem/flowsettings.py b/libs/ktem/flowsettings.py
index 268d63fab..2e26cff1a 100644
--- a/libs/ktem/flowsettings.py
+++ b/libs/ktem/flowsettings.py
@@ -123,6 +123,11 @@
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
+KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
+ config("AZURE_OPENAI_ENDPOINT", default=""),
+ config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4-vision"),
+ config("OPENAI_API_VERSION", default=""),
+)
SETTINGS_APP = {
diff --git a/libs/ktem/ktem/app.py b/libs/ktem/ktem/app.py
index 0a39fa60b..64e8a9da0 100644
--- a/libs/ktem/ktem/app.py
+++ b/libs/ktem/ktem/app.py
@@ -17,6 +17,7 @@ class BaseApp:
The main application contains app-level information:
- setting state
+ - dynamic conversation state
- user id
Also contains registering methods for:
@@ -228,7 +229,9 @@ def on_register_events(self):
def _on_app_created(self):
"""Called when the app is created"""
- def as_gradio_component(self) -> Optional[gr.components.Component]:
+ def as_gradio_component(
+ self,
+ ) -> Optional[gr.components.Component | list[gr.components.Component]]:
"""Return the gradio components responsible for events
Note: in ideal scenario, this method shouldn't be necessary.
diff --git a/libs/ktem/ktem/index/base.py b/libs/ktem/ktem/index/base.py
index 50bdd9e44..518376264 100644
--- a/libs/ktem/ktem/index/base.py
+++ b/libs/ktem/ktem/index/base.py
@@ -1,6 +1,6 @@
import abc
import logging
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from ktem.app import BasePage
@@ -57,7 +57,7 @@ def __init__(self, app, id, name, config):
self._app = app
self.id = id
self.name = name
- self._config = config # admin settings
+ self.config = config # admin settings
def on_create(self):
"""Create the index for the first time"""
@@ -121,7 +121,7 @@ def get_indexing_pipeline(self, settings: dict) -> "BaseComponent":
...
def get_retriever_pipelines(
- self, settings: dict, selected: Optional[list]
+ self, settings: dict, selected: Any = None
) -> list["BaseComponent"]:
"""Return the retriever pipelines to retrieve the entity from the index"""
return []
diff --git a/libs/ktem/ktem/index/file/base.py b/libs/ktem/ktem/index/file/base.py
index 5f8e6f4fa..4f28f51ac 100644
--- a/libs/ktem/ktem/index/file/base.py
+++ b/libs/ktem/ktem/index/file/base.py
@@ -127,3 +127,11 @@ def get_filestorage_path(self, rel_paths: str | list[str]) -> list[str]:
the absolute file storage path to the file
"""
raise NotImplementedError
+
+ def warning(self, msg):
+ """Log a warning message
+
+ Args:
+ msg: the message to log
+ """
+ print(msg)
diff --git a/libs/ktem/ktem/index/file/index.py b/libs/ktem/ktem/index/file/index.py
index ab1f35a3a..5fe395596 100644
--- a/libs/ktem/ktem/index/file/index.py
+++ b/libs/ktem/ktem/index/file/index.py
@@ -13,7 +13,6 @@
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
-from .ui import FileIndexPage, FileSelector
class FileIndex(BaseIndex):
@@ -77,9 +76,15 @@ def __init__(self, app, id: int, name: str, config: dict):
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
+ self._selector_ui_cls: Type
+ self._selector_ui: Any = None
+ self._index_ui_cls: Type
+ self._index_ui: Any = None
self._setup_indexing_cls()
self._setup_retriever_cls()
+ self._setup_file_index_ui_cls()
+ self._setup_file_selector_ui_cls()
self._default_settings: dict[str, dict] = {}
self._setting_mappings: dict[str, dict] = {}
@@ -91,14 +96,14 @@ def _setup_indexing_cls(self):
The indexing class will is retrieved from the following order. Stop at the
first order found:
- - `FILE_INDEX_PIPELINE` in self._config
+ - `FILE_INDEX_PIPELINE` in self.config
- `FILE_INDEX_{id}_PIPELINE` in the flowsettings
- `FILE_INDEX_PIPELINE` in the flowsettings
- The default .pipelines.IndexDocumentPipeline
"""
- if "FILE_INDEX_PIPELINE" in self._config:
+ if "FILE_INDEX_PIPELINE" in self.config:
self._indexing_pipeline_cls = import_dotted_string(
- self._config["FILE_INDEX_PIPELINE"], safe=False
+ self.config["FILE_INDEX_PIPELINE"], safe=False
)
return
@@ -125,15 +130,15 @@ def _setup_retriever_cls(self):
The retriever classes will is retrieved from the following order. Stop at the
first order found:
- - `FILE_INDEX_RETRIEVER_PIPELINES` in self._config
+ - `FILE_INDEX_RETRIEVER_PIPELINES` in self.config
- `FILE_INDEX_{id}_RETRIEVER_PIPELINES` in the flowsettings
- `FILE_INDEX_RETRIEVER_PIPELINES` in the flowsettings
- The default .pipelines.DocumentRetrievalPipeline
"""
- if "FILE_INDEX_RETRIEVER_PIPELINES" in self._config:
+ if "FILE_INDEX_RETRIEVER_PIPELINES" in self.config:
self._retriever_pipeline_cls = [
import_dotted_string(each, safe=False)
- for each in self._config["FILE_INDEX_RETRIEVER_PIPELINES"]
+ for each in self.config["FILE_INDEX_RETRIEVER_PIPELINES"]
]
return
@@ -157,6 +162,76 @@ def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [DocumentRetrievalPipeline]
+ def _setup_file_selector_ui_cls(self):
+ """Retrieve the file selector UI for the file index
+
+ There can be multiple retriever classes.
+
+ The retriever classes will is retrieved from the following order. Stop at the
+ first order found:
+ - `FILE_INDEX_SELECTOR_UI` in self.config
+ - `FILE_INDEX_{id}_SELECTOR_UI` in the flowsettings
+ - `FILE_INDEX_SELECTOR_UI` in the flowsettings
+ - The default .ui.FileSelector
+ """
+ if "FILE_INDEX_SELECTOR_UI" in self.config:
+ self._selector_ui_cls = import_dotted_string(
+ self.config["FILE_INDEX_SELECTOR_UI"], safe=False
+ )
+ return
+
+ if hasattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"):
+ self._selector_ui_cls = import_dotted_string(
+ getattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"),
+ safe=False,
+ )
+ return
+
+ if hasattr(flowsettings, "FILE_INDEX_SELECTOR_UI"):
+ self._selector_ui_cls = import_dotted_string(
+ getattr(flowsettings, "FILE_INDEX_SELECTOR_UI"), safe=False
+ )
+ return
+
+ from .ui import FileSelector
+
+ self._selector_ui_cls = FileSelector
+
+ def _setup_file_index_ui_cls(self):
+ """Retrieve the Index UI class
+
+ There can be multiple retriever classes.
+
+ The retriever classes will is retrieved from the following order. Stop at the
+ first order found:
+ - `FILE_INDEX_UI` in self.config
+ - `FILE_INDEX_{id}_UI` in the flowsettings
+ - `FILE_INDEX_UI` in the flowsettings
+ - The default .ui.FileIndexPage
+ """
+ if "FILE_INDEX_UI" in self.config:
+ self._index_ui_cls = import_dotted_string(
+ self.config["FILE_INDEX_UI"], safe=False
+ )
+ return
+
+ if hasattr(flowsettings, f"FILE_INDEX_{self.id}_UI"):
+ self._index_ui_cls = import_dotted_string(
+ getattr(flowsettings, f"FILE_INDEX_{self.id}_UI"),
+ safe=False,
+ )
+ return
+
+ if hasattr(flowsettings, "FILE_INDEX_UI"):
+ self._index_ui_cls = import_dotted_string(
+ getattr(flowsettings, "FILE_INDEX_UI"), safe=False
+ )
+ return
+
+ from .ui import FileIndexPage
+
+ self._index_ui_cls = FileIndexPage
+
def on_create(self):
"""Create the index for the first time
@@ -165,6 +240,13 @@ def on_create(self):
2. Create the vectorstore
3. Create the docstore
"""
+ file_types_str = self.config.get(
+ "supported_file_types",
+ self.get_admin_settings()["supported_file_types"]["value"],
+ )
+ file_types = [each.strip() for each in file_types_str.split(",")]
+ self.config["supported_file_types"] = file_types
+
self._resources["Source"].metadata.create_all(engine) # type: ignore
self._resources["Index"].metadata.create_all(engine) # type: ignore
self._fs_path.mkdir(parents=True, exist_ok=True)
@@ -180,10 +262,14 @@ def on_delete(self):
shutil.rmtree(self._fs_path)
def get_selector_component_ui(self):
- return FileSelector(self._app, self)
+ if self._selector_ui is None:
+ self._selector_ui = self._selector_ui_cls(self._app, self)
+ return self._selector_ui
def get_index_page_ui(self):
- return FileIndexPage(self._app, self)
+ if self._index_ui is None:
+ self._index_ui = self._index_ui_cls(self._app, self)
+ return self._index_ui
def get_user_settings(self):
if self._default_settings:
@@ -210,7 +296,31 @@ def get_admin_settings(cls):
"value": embedding_default,
"component": "dropdown",
"choices": embedding_choices,
- }
+ },
+ "supported_file_types": {
+ "name": "Supported file types",
+ "value": (
+ "image, .pdf, .txt, .csv, .xlsx, .doc, .docx, .pptx, .html, .zip"
+ ),
+ "component": "text",
+ },
+ "max_file_size": {
+ "name": "Max file size (MB) - set 0 to disable",
+ "value": 1000,
+ "component": "number",
+ },
+ "max_number_of_files": {
+ "name": "Max number of files that can be indexed - set 0 to disable",
+ "value": 0,
+ "component": "number",
+ },
+ "max_number_of_text_length": {
+ "name": (
+ "Max amount of characters that can be indexed - set 0 to disable"
+ ),
+ "value": 0,
+ "component": "number",
+ },
}
def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing:
@@ -224,14 +334,15 @@ def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing:
else:
stripped_settings[key] = value
- obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self._config)
+ obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
obj.set_resources(resources=self._resources)
return obj
def get_retriever_pipelines(
- self, settings: dict, selected: Optional[list] = None
+ self, settings: dict, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
+ # retrieval settings
prefix = f"index.options.{self.id}."
stripped_settings = {}
for key, value in settings.items():
@@ -240,9 +351,12 @@ def get_retriever_pipelines(
else:
stripped_settings[key] = value
+ # transform selected id
+ selected_ids: Optional[list[str]] = self._selector_ui.get_selected_ids(selected)
+
retrievers = []
for cls in self._retriever_pipeline_cls:
- obj = cls.get_pipeline(stripped_settings, self._config, selected)
+ obj = cls.get_pipeline(stripped_settings, self.config, selected_ids)
if obj is None:
continue
obj.set_resources(self._resources)
diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py
index c6c7accd5..13036f33f 100644
--- a/libs/ktem/ktem/index/file/pipelines.py
+++ b/libs/ktem/ktem/index/file/pipelines.py
@@ -9,6 +9,7 @@
from pathlib import Path
from typing import Optional
+import gradio as gr
from ktem.components import embeddings, filestorage_path
from ktem.db.models import engine
from llama_index.vector_stores import (
@@ -18,7 +19,7 @@
MetadataFilters,
)
from llama_index.vector_stores.types import VectorStoreQueryMode
-from sqlalchemy import select
+from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from theflow.settings import settings
from theflow.utils.modules import import_dotted_string
@@ -279,6 +280,7 @@ def run(
to_index: list[str] = []
file_to_hash: dict[str, str] = {}
errors = []
+ to_update = []
for file_path in file_paths:
abs_path = str(Path(file_path).resolve())
@@ -291,16 +293,26 @@ def run(
statement = select(Source).where(Source.name == Path(abs_path).name)
item = session.execute(statement).first()
- if item and not reindex:
- errors.append(Path(abs_path).name)
- continue
+ if item:
+ if not reindex:
+ errors.append(Path(abs_path).name)
+ continue
+ else:
+ to_update.append(Path(abs_path).name)
to_index.append(abs_path)
if errors:
+ error_files = ", ".join(errors)
+ if len(error_files) > 100:
+ error_files = error_files[:80] + "..."
print(
- "Files already exist. Please rename/remove them or enable reindex.\n"
- f"{errors}"
+ "Skip these files already exist. Please rename/remove them or "
+ f"enable reindex:\n{errors}"
+ )
+ self.warning(
+ "Skip these files already exist. Please rename/remove them or "
+ f"enable reindex:\n{error_files}"
)
if not to_index:
@@ -310,9 +322,19 @@ def run(
for path in to_index:
shutil.copy(path, filestorage_path / file_to_hash[path])
- # prepare record info
+ # extract the file & prepare record info
file_to_source: dict = {}
+ extraction_errors = []
+ nodes = []
for file_path, file_hash in file_to_hash.items():
+ if str(Path(file_path).resolve()) not in to_index:
+ continue
+
+ extraction_result = self.file_ingestor(file_path)
+ if not extraction_result:
+ extraction_errors.append(Path(file_path).name)
+ continue
+ nodes.extend(extraction_result)
source = Source(
name=Path(file_path).name,
path=file_hash,
@@ -320,9 +342,23 @@ def run(
)
file_to_source[file_path] = source
- # extract the files
- nodes = self.file_ingestor(to_index)
- print("Extracted", len(to_index), "files into", len(nodes), "nodes")
+ if extraction_errors:
+ msg = "Failed to extract these files: {}".format(
+ ", ".join(extraction_errors)
+ )
+ print(msg)
+ self.warning(msg)
+
+ if not nodes:
+ return [], []
+
+ print(
+ "Extracted",
+ len(to_index) - len(extraction_errors),
+ "files into",
+ len(nodes),
+ "nodes",
+ )
# index the files
print("Indexing the files into vector store")
@@ -332,7 +368,11 @@ def run(
# persist to the index
print("Persisting the vector and the document into index")
file_ids = []
+ to_update = list(set(to_update))
with Session(engine) as session:
+ if to_update:
+ session.execute(delete(Source).where(Source.name.in_(to_update)))
+
for source in file_to_source.values():
session.add(source)
session.commit()
@@ -378,6 +418,7 @@ def get_user_settings(cls) -> dict:
("PDF text parser", "normal"),
("Mathpix", "mathpix"),
("Advanced ocr", "ocr"),
+ ("Multimodal parser", "multimodal"),
],
"component": "dropdown",
},
@@ -403,3 +444,6 @@ def set_resources(self, resources: dict):
super().set_resources(resources)
self.indexing_vector_pipeline.vector_store = self._VS
self.indexing_vector_pipeline.doc_store = self._DS
+
+ def warning(self, msg):
+ gr.Warning(msg)
diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py
index 9da2b4a14..11d491f03 100644
--- a/libs/ktem/ktem/index/file/ui.py
+++ b/libs/ktem/ktem/index/file/ui.py
@@ -1,29 +1,48 @@
import os
import tempfile
+from pathlib import Path
import gradio as gr
import pandas as pd
+from gradio.data_classes import FileData
+from gradio.utils import NamedString
from ktem.app import BasePage
from ktem.db.engine import engine
from sqlalchemy import select
from sqlalchemy.orm import Session
+class File(gr.File):
+ """Subclass from gr.File to maintain the original filename
+
+ The issue happens when user uploads file with name like: !@#$%%^&*().pdf
+ """
+
+ def _process_single_file(self, f: FileData) -> NamedString | bytes:
+ file_name = f.path
+ if self.type == "filepath":
+ if f.orig_name and Path(file_name).name != f.orig_name:
+ file_name = str(Path(file_name).parent / f.orig_name)
+ os.rename(f.path, file_name)
+ file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
+ file.name = file_name
+ return NamedString(file_name)
+ elif self.type == "binary":
+ with open(file_name, "rb") as file_data:
+ return file_data.read()
+ else:
+ raise ValueError(
+ "Unknown type: "
+ + str(type)
+ + ". Please choose from: 'filepath', 'binary'."
+ )
+
+
class DirectoryUpload(BasePage):
- def __init__(self, app):
- self._app = app
- self._supported_file_types = [
- "image",
- ".pdf",
- ".txt",
- ".csv",
- ".xlsx",
- ".doc",
- ".docx",
- ".pptx",
- ".html",
- ".zip",
- ]
+ def __init__(self, app, index):
+ super().__init__(app)
+ self._index = index
+ self._supported_file_types = self._index.config.get("supported_file_types", [])
self.on_building_ui()
def on_building_ui(self):
@@ -50,18 +69,7 @@ class FileIndexPage(BasePage):
def __init__(self, app, index):
super().__init__(app)
self._index = index
- self._supported_file_types = [
- "image",
- ".pdf",
- ".txt",
- ".csv",
- ".xlsx",
- ".doc",
- ".docx",
- ".pptx",
- ".html",
- ".zip",
- ]
+ self._supported_file_types = self._index.config.get("supported_file_types", [])
self.selected_panel_false = "Selected file: (please select above)"
self.selected_panel_true = "Selected file: {name}"
# TODO: on_building_ui is not correctly named if it's always called in
@@ -69,13 +77,32 @@ def __init__(self, app, index):
self.public_events = [f"onFileIndex{index.id}Changed"]
self.on_building_ui()
+ def upload_instruction(self) -> str:
+ msgs = []
+ if self._supported_file_types:
+ msgs.append(
+ f"- Supported file types: {', '.join(self._supported_file_types)}"
+ )
+
+ if max_file_size := self._index.config.get("max_file_size", 0):
+ msgs.append(f"- Maximum file size: {max_file_size} MB")
+
+ if max_number_of_files := self._index.config.get("max_number_of_files", 0):
+ msgs.append(f"- The index can have maximum {max_number_of_files} files")
+
+ if msgs:
+ return "\n".join(msgs)
+
+ return ""
+
def on_building_ui(self):
"""Build the UI of the app"""
- with gr.Accordion(label="File upload", open=False):
- gr.Markdown(
- f"Supported file types: {', '.join(self._supported_file_types)}",
- )
- self.files = gr.File(
+ with gr.Accordion(label="File upload", open=True) as self.upload:
+ msg = self.upload_instruction()
+ if msg:
+ gr.Markdown(msg)
+
+ self.files = File(
file_types=self._supported_file_types,
file_count="multiple",
container=False,
@@ -98,18 +125,20 @@ def on_building_ui(self):
interactive=False,
)
- with gr.Row():
+ with gr.Row() as self.selection_info:
self.selected_file_id = gr.State(value=None)
self.selected_panel = gr.Markdown(self.selected_panel_false)
self.deselect_button = gr.Button("Deselect", visible=False)
- with gr.Row():
+ with gr.Row() as self.tools:
with gr.Column():
self.view_button = gr.Button("View Text (WIP)")
with gr.Column():
self.delete_button = gr.Button("Delete")
with gr.Row():
- self.delete_yes = gr.Button("Confirm Delete", visible=False)
+ self.delete_yes = gr.Button(
+ "Confirm Delete", variant="primary", visible=False
+ )
self.delete_no = gr.Button("Cancel", visible=False)
def on_subscribe_public_events(self):
@@ -242,10 +271,12 @@ def on_register_events(self):
self._app.settings_state,
],
outputs=[self.file_output],
+ concurrency_limit=20,
).then(
fn=self.list_file,
inputs=None,
outputs=[self.file_list_state, self.file_list],
+ concurrency_limit=20,
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
onUploaded = onUploaded.then(**event)
@@ -274,6 +305,15 @@ def index_fn(self, files, reindex: bool, settings):
selected_files: the list of files already selected
settings: the settings of the app
"""
+ if not files:
+ gr.Info("No uploaded file")
+ return gr.update()
+
+ errors = self.validate(files)
+ if errors:
+ gr.Warning(", ".join(errors))
+ return gr.update()
+
gr.Info(f"Start indexing {len(files)} files...")
# get the pipeline
@@ -409,6 +449,35 @@ def interact_file_list(self, list_files, ev: gr.SelectData):
name=list_files["name"][ev.index[0]]
)
+ def validate(self, files: list[str]):
+ """Validate if the files are valid"""
+ paths = [Path(file) for file in files]
+ errors = []
+ if max_file_size := self._index.config.get("max_file_size", 0):
+ errors_max_size = []
+ for path in paths:
+ if path.stat().st_size > max_file_size * 1e6:
+ errors_max_size.append(path.name)
+ if errors_max_size:
+ str_errors = ", ".join(errors_max_size)
+ if len(str_errors) > 60:
+ str_errors = str_errors[:55] + "..."
+ errors.append(
+ f"Maximum file size ({max_file_size} MB) exceeded: {str_errors}"
+ )
+
+ if max_number_of_files := self._index.config.get("max_number_of_files", 0):
+ with Session(engine) as session:
+ current_num_files = session.query(
+ self._index._db_tables["Source"].id
+ ).count()
+ if len(paths) + current_num_files > max_number_of_files:
+ errors.append(
+ f"Maximum number of files ({max_number_of_files}) will be exceeded"
+ )
+
+ return errors
+
class FileSelector(BasePage):
"""File selector UI in the Chat page"""
@@ -430,6 +499,9 @@ def on_building_ui(self):
def as_gradio_component(self):
return self.selector
+ def get_selected_ids(self, selected):
+ return selected
+
def load_files(self, selected_files):
options = []
available_ids = []
diff --git a/libs/ktem/ktem/index/manager.py b/libs/ktem/ktem/index/manager.py
index 72c4f9948..af1c8d484 100644
--- a/libs/ktem/ktem/index/manager.py
+++ b/libs/ktem/ktem/index/manager.py
@@ -1,4 +1,4 @@
-from typing import Type
+from typing import Optional, Type
from ktem.db.models import engine
from sqlmodel import Session, select
@@ -49,15 +49,19 @@ def build_index(self, name: str, config: dict, index_type: str, id=None):
Returns:
BaseIndex: the index object
"""
+ index_cls = import_dotted_string(index_type, safe=False)
+ index = index_cls(app=self._app, id=id, name=name, config=config)
+ index.on_create()
+
with Session(engine) as session:
- index_entry = Index(id=id, name=name, config=config, index_type=index_type)
+ index_entry = Index(
+ id=index.id, name=index.name, config=index.config, index_type=index_type
+ )
session.add(index_entry)
session.commit()
session.refresh(index_entry)
- index_cls = import_dotted_string(index_type, safe=False)
- index = index_cls(app=self._app, id=id, name=name, config=config)
- index.on_create()
+ index.id = index_entry.id
return index
@@ -77,7 +81,7 @@ def start_index(self, id: int, name: str, config: dict, index_type: str):
self._indices.append(index)
return index
- def exists(self, id: int) -> bool:
+ def exists(self, id: Optional[int] = None, name: Optional[str] = None) -> bool:
"""Check if the index exists
Args:
@@ -86,9 +90,19 @@ def exists(self, id: int) -> bool:
Returns:
bool: True if the index exists, False otherwise
"""
- with Session(engine) as session:
- index = session.get(Index, id)
- return index is not None
+ if id:
+ with Session(engine) as session:
+ index = session.get(Index, id)
+ return index is not None
+
+ if name:
+ with Session(engine) as session:
+ index = session.exec(
+ select(Index).where(Index.name == name)
+ ).one_or_none()
+ return index is not None
+
+ return False
def on_application_startup(self):
"""This method is called by the base application when the application starts
diff --git a/libs/ktem/ktem/main.py b/libs/ktem/ktem/main.py
index c375ed7ed..1d76d0498 100644
--- a/libs/ktem/ktem/main.py
+++ b/libs/ktem/ktem/main.py
@@ -27,7 +27,7 @@ def ui(self):
if self.f_user_management:
from ktem.pages.login import LoginPage
- with gr.Tab("Login", elem_id="login-tab") as self._tabs["login-tab"]:
+ with gr.Tab("Welcome", elem_id="login-tab") as self._tabs["login-tab"]:
self.login_page = LoginPage(self)
with gr.Tab(
@@ -62,6 +62,9 @@ def ui(self):
def on_subscribe_public_events(self):
if self.f_user_management:
+ from ktem.db.engine import engine
+ from ktem.db.models import User
+ from sqlmodel import Session, select
def signed_in_out(user_id):
if not user_id:
@@ -73,14 +76,31 @@ def signed_in_out(user_id):
)
for k in self._tabs.keys()
)
- return list(
- (
- gr.update(visible=True)
- if k != "login-tab"
- else gr.update(visible=False)
- )
- for k in self._tabs.keys()
- )
+
+ with Session(engine) as session:
+ user = session.exec(select(User).where(User.id == user_id)).first()
+ if user is None:
+ return list(
+ (
+ gr.update(visible=True)
+ if k == "login-tab"
+ else gr.update(visible=False)
+ )
+ for k in self._tabs.keys()
+ )
+
+ is_admin = user.admin
+
+ tabs_update = []
+ for k in self._tabs.keys():
+ if k == "login-tab":
+ tabs_update.append(gr.update(visible=False))
+ elif k == "admin-tab":
+ tabs_update.append(gr.update(visible=is_admin))
+ else:
+ tabs_update.append(gr.update(visible=True))
+
+ return tabs_update
self.subscribe_event(
name="onSignIn",
diff --git a/libs/ktem/ktem/pages/admin/user.py b/libs/ktem/ktem/pages/admin/user.py
index 519fb0fd7..6411b30d7 100644
--- a/libs/ktem/ktem/pages/admin/user.py
+++ b/libs/ktem/ktem/pages/admin/user.py
@@ -40,7 +40,7 @@ def validate_username(usn):
if len(usn) > 32:
errors.append("Username must be at most 32 characters long")
- if not usn.strip("_").isalnum():
+ if not usn.replace("_", "").isalnum():
errors.append(
"Username must contain only alphanumeric characters and underscores"
)
@@ -97,8 +97,6 @@ def validate_password(pwd, pwd_cnf):
class UserManagement(BasePage):
def __init__(self, app):
self._app = app
- self.selected_panel_false = "Selected user: (please select above)"
- self.selected_panel_true = "Selected user: {name}"
self.on_building_ui()
if hasattr(flowsettings, "KH_FEATURE_USER_MANAGEMENT_ADMIN") and hasattr(
@@ -126,7 +124,38 @@ def __init__(self, app):
gr.Info(f'User "{usn}" created successfully')
def on_building_ui(self):
- with gr.Accordion(label="Create user", open=False):
+ with gr.Tab(label="User list"):
+ self.state_user_list = gr.State(value=None)
+ self.user_list = gr.DataFrame(
+ headers=["id", "name", "admin"],
+ interactive=False,
+ )
+
+ with gr.Group(visible=False) as self._selected_panel:
+ self.selected_user_id = gr.Number(value=-1, visible=False)
+ self.usn_edit = gr.Textbox(label="Username")
+ with gr.Row():
+ self.pwd_edit = gr.Textbox(label="Change password", type="password")
+ self.pwd_cnf_edit = gr.Textbox(
+ label="Confirm change password",
+ type="password",
+ )
+ self.admin_edit = gr.Checkbox(label="Admin")
+
+ with gr.Row() as self._selected_panel_btn:
+ with gr.Column():
+ self.btn_edit_save = gr.Button("Save")
+ with gr.Column():
+ self.btn_delete = gr.Button("Delete")
+ with gr.Row():
+ self.btn_delete_yes = gr.Button(
+ "Confirm delete", variant="primary", visible=False
+ )
+ self.btn_delete_no = gr.Button("Cancel", visible=False)
+ with gr.Column():
+ self.btn_close = gr.Button("Close")
+
+ with gr.Tab(label="Create user"):
self.usn_new = gr.Textbox(label="Username", interactive=True)
self.pwd_new = gr.Textbox(
label="Password", type="password", interactive=True
@@ -139,52 +168,28 @@ def on_building_ui(self):
gr.Markdown(PASSWORD_RULE)
self.btn_new = gr.Button("Create user")
- gr.Markdown("## User list")
- self.btn_list_user = gr.Button("Refresh user list")
- self.state_user_list = gr.State(value=None)
- self.user_list = gr.DataFrame(
- headers=["id", "name", "admin"],
- interactive=False,
- )
-
- with gr.Row():
- self.selected_user_id = gr.State(value=None)
- self.selected_panel = gr.Markdown(self.selected_panel_false)
- self.deselect_button = gr.Button("Deselect", visible=False)
-
- with gr.Group():
- self.btn_delete = gr.Button("Delete user")
- with gr.Row():
- self.btn_delete_yes = gr.Button("Confirm", visible=False)
- self.btn_delete_no = gr.Button("Cancel", visible=False)
-
- gr.Markdown("## User details")
- self.usn_edit = gr.Textbox(label="Username")
- self.pwd_edit = gr.Textbox(label="Password", type="password")
- self.pwd_cnf_edit = gr.Textbox(label="Confirm password", type="password")
- self.admin_edit = gr.Checkbox(label="Admin")
- self.btn_edit_save = gr.Button("Save")
-
def on_register_events(self):
self.btn_new.click(
self.create_user,
inputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new],
- outputs=None,
- )
- self.btn_list_user.click(
- self.list_users, inputs=None, outputs=[self.state_user_list, self.user_list]
+ outputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new],
+ ).then(
+ self.list_users,
+ inputs=self._app.user_id,
+ outputs=[self.state_user_list, self.user_list],
)
self.user_list.select(
self.select_user,
inputs=self.user_list,
- outputs=[self.selected_user_id, self.selected_panel],
+ outputs=[self.selected_user_id],
show_progress="hidden",
)
- self.selected_panel.change(
+ self.selected_user_id.change(
self.on_selected_user_change,
inputs=[self.selected_user_id],
outputs=[
- self.deselect_button,
+ self._selected_panel,
+ self._selected_panel_btn,
# delete section
self.btn_delete,
self.btn_delete_yes,
@@ -197,12 +202,6 @@ def on_register_events(self):
],
show_progress="hidden",
)
- self.deselect_button.click(
- lambda: (None, self.selected_panel_false),
- inputs=None,
- outputs=[self.selected_user_id, self.selected_panel],
- show_progress="hidden",
- )
self.btn_delete.click(
self.on_btn_delete_click,
inputs=[self.selected_user_id],
@@ -211,9 +210,13 @@ def on_register_events(self):
)
self.btn_delete_yes.click(
self.delete_user,
- inputs=[self.selected_user_id],
- outputs=[self.selected_user_id, self.selected_panel],
+ inputs=[self._app.user_id, self.selected_user_id],
+ outputs=[self.selected_user_id],
show_progress="hidden",
+ ).then(
+ self.list_users,
+ inputs=self._app.user_id,
+ outputs=[self.state_user_list, self.user_list],
)
self.btn_delete_no.click(
lambda: (
@@ -234,21 +237,53 @@ def on_register_events(self):
self.pwd_cnf_edit,
self.admin_edit,
],
- outputs=None,
+ outputs=[self.pwd_edit, self.pwd_cnf_edit],
show_progress="hidden",
+ ).then(
+ self.list_users,
+ inputs=self._app.user_id,
+ outputs=[self.state_user_list, self.user_list],
+ )
+ self.btn_close.click(
+ lambda: -1,
+ outputs=[self.selected_user_id],
+ )
+
+ def on_subscribe_public_events(self):
+ self._app.subscribe_event(
+ name="onSignIn",
+ definition={
+ "fn": self.list_users,
+ "inputs": [self._app.user_id],
+ "outputs": [self.state_user_list, self.user_list],
+ },
+ )
+ self._app.subscribe_event(
+ name="onSignOut",
+ definition={
+ "fn": lambda: ("", "", "", None, None, -1),
+ "outputs": [
+ self.usn_new,
+ self.pwd_new,
+ self.pwd_cnf_new,
+ self.state_user_list,
+ self.user_list,
+ self.selected_user_id,
+ ],
+ },
)
def create_user(self, usn, pwd, pwd_cnf):
errors = validate_username(usn)
if errors:
gr.Warning(errors)
- return
+ return usn, pwd, pwd_cnf
errors = validate_password(pwd, pwd_cnf)
print(errors)
if errors:
gr.Warning(errors)
- return
+ return usn, pwd, pwd_cnf
with Session(engine) as session:
statement = select(User).where(User.username_lower == usn.lower())
@@ -265,8 +300,22 @@ def create_user(self, usn, pwd, pwd_cnf):
session.commit()
gr.Info(f'User "{usn}" created successfully')
- def list_users(self):
+ return "", "", ""
+
+ def list_users(self, user_id):
+ if user_id is None:
+ return [], pd.DataFrame.from_records(
+ [{"id": "-", "username": "-", "admin": "-"}]
+ )
+
with Session(engine) as session:
+ statement = select(User).where(User.id == user_id)
+ user = session.exec(statement).one()
+ if not user.admin:
+ return [], pd.DataFrame.from_records(
+ [{"id": "-", "username": "-", "admin": "-"}]
+ )
+
statement = select(User)
results = [
{"id": user.id, "username": user.username, "admin": user.admin}
@@ -284,18 +333,17 @@ def list_users(self):
def select_user(self, user_list, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0:
gr.Info("No user is loaded. Please refresh the user list")
- return None, self.selected_panel_false
+ return -1
if not ev.selected:
- return None, self.selected_panel_false
+ return -1
- return user_list["id"][ev.index[0]], self.selected_panel_true.format(
- name=user_list["username"][ev.index[0]]
- )
+ return user_list["id"][ev.index[0]]
def on_selected_user_change(self, selected_user_id):
- if selected_user_id is None:
- deselect_button = gr.update(visible=False)
+ if selected_user_id == -1:
+ _selected_panel = gr.update(visible=False)
+ _selected_panel_btn = gr.update(visible=False)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
@@ -304,7 +352,8 @@ def on_selected_user_change(self, selected_user_id):
pwd_cnf_edit = gr.update(value="")
admin_edit = gr.update(value=False)
else:
- deselect_button = gr.update(visible=True)
+ _selected_panel = gr.update(visible=True)
+ _selected_panel_btn = gr.update(visible=True)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
@@ -319,7 +368,8 @@ def on_selected_user_change(self, selected_user_id):
admin_edit = gr.update(value=user.admin)
return (
- deselect_button,
+ _selected_panel,
+ _selected_panel_btn,
btn_delete,
btn_delete_yes,
btn_delete_no,
@@ -344,17 +394,16 @@ def on_btn_delete_click(self, selected_user_id):
return btn_delete, btn_delete_yes, btn_delete_no
def save_user(self, selected_user_id, usn, pwd, pwd_cnf, admin):
- if usn:
- errors = validate_username(usn)
- if errors:
- gr.Warning(errors)
- return
+ errors = validate_username(usn)
+ if errors:
+ gr.Warning(errors)
+ return pwd, pwd_cnf
if pwd:
errors = validate_password(pwd, pwd_cnf)
if errors:
gr.Warning(errors)
- return
+ return pwd, pwd_cnf
with Session(engine) as session:
statement = select(User).where(User.id == int(selected_user_id))
@@ -367,11 +416,17 @@ def save_user(self, selected_user_id, usn, pwd, pwd_cnf, admin):
session.commit()
gr.Info(f'User "{usn}" updated successfully')
- def delete_user(self, selected_user_id):
+ return "", ""
+
+ def delete_user(self, current_user, selected_user_id):
+ if current_user == selected_user_id:
+ gr.Warning("You cannot delete yourself")
+ return selected_user_id
+
with Session(engine) as session:
statement = select(User).where(User.id == int(selected_user_id))
user = session.exec(statement).one()
session.delete(user)
session.commit()
gr.Info(f'User "{user.username}" deleted successfully')
- return None, self.selected_panel_false
+ return -1
diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py
index 6648c2fbc..a83bd837d 100644
--- a/libs/ktem/ktem/pages/chat/__init__.py
+++ b/libs/ktem/ktem/pages/chat/__init__.py
@@ -7,8 +7,11 @@
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
from sqlmodel import Session, select
+from theflow.settings import settings as flowsettings
from .chat_panel import ChatPanel
+from .chat_suggestion import ChatSuggestion
+from .common import STATE
from .control import ConversationControl
from .report import ReportIssue
@@ -21,27 +24,43 @@ def __init__(self, app):
def on_building_ui(self):
with gr.Row():
+ self.chat_state = gr.State(STATE)
with gr.Column(scale=1):
self.chat_control = ConversationControl(self._app)
+ if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
+ self.chat_suggestion = ChatSuggestion(self._app)
+
for index in self._app.index_manager.indices:
- index.selector = -1
+ index.selector = None
index_ui = index.get_selector_component_ui()
if not index_ui:
+ # the index doesn't have a selector UI component
continue
- index_ui.unrender()
+ index_ui.unrender() # need to rerender later within Accordion
with gr.Accordion(label=f"{index.name} Index", open=False):
index_ui.render()
gr_index = index_ui.as_gradio_component()
if gr_index:
- index.selector = len(self._indices_input)
- self._indices_input.append(gr_index)
+ if isinstance(gr_index, list):
+ index.selector = tuple(
+ range(
+ len(self._indices_input),
+ len(self._indices_input) + len(gr_index),
+ )
+ )
+ self._indices_input.extend(gr_index)
+ else:
+ index.selector = len(self._indices_input)
+ self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui)
self.report_issue = ReportIssue(self._app)
+
with gr.Column(scale=6):
self.chat_panel = ChatPanel(self._app)
+
with gr.Column(scale=3):
with gr.Accordion(label="Information panel", open=True):
self.info_panel = gr.HTML(elem_id="chat-info-panel")
@@ -52,32 +71,77 @@ def on_register_events(self):
self.chat_panel.text_input.submit,
self.chat_panel.submit_btn.click,
],
- fn=self.chat_panel.submit_msg,
- inputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
- outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
+ fn=self.submit_msg,
+ inputs=[
+ self.chat_panel.text_input,
+ self.chat_panel.chatbot,
+ self._app.user_id,
+ self.chat_control.conversation_id,
+ self.chat_control.conversation_rn,
+ ],
+ outputs=[
+ self.chat_panel.text_input,
+ self.chat_panel.chatbot,
+ self.chat_control.conversation_id,
+ self.chat_control.conversation,
+ self.chat_control.conversation_rn,
+ ],
+ concurrency_limit=20,
show_progress="hidden",
- ).then(
+ ).success(
fn=self.chat_fn,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
+ self.chat_state,
]
+ self._indices_input,
outputs=[
- self.chat_panel.text_input,
self.chat_panel.chatbot,
self.info_panel,
+ self.chat_state,
],
+ concurrency_limit=20,
show_progress="minimal",
).then(
fn=self.update_data_source,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
+ self.chat_state,
]
+ self._indices_input,
outputs=None,
+ concurrency_limit=20,
+ )
+
+ self.chat_panel.regen_btn.click(
+ fn=self.regen_fn,
+ inputs=[
+ self.chat_control.conversation_id,
+ self.chat_panel.chatbot,
+ self._app.settings_state,
+ self.chat_state,
+ ]
+ + self._indices_input,
+ outputs=[
+ self.chat_panel.chatbot,
+ self.info_panel,
+ self.chat_state,
+ ],
+ concurrency_limit=20,
+ show_progress="minimal",
+ ).then(
+ fn=self.update_data_source,
+ inputs=[
+ self.chat_control.conversation_id,
+ self.chat_panel.chatbot,
+ self.chat_state,
+ ]
+ + self._indices_input,
+ outputs=None,
+ concurrency_limit=20,
)
self.chat_panel.chatbot.like(
@@ -86,7 +150,12 @@ def on_register_events(self):
outputs=None,
)
- self.chat_control.conversation.change(
+ self.chat_control.btn_new.click(
+ self.chat_control.new_conv,
+ inputs=self._app.user_id,
+ outputs=[self.chat_control.conversation_id, self.chat_control.conversation],
+ show_progress="hidden",
+ ).then(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
outputs=[
@@ -94,11 +163,71 @@ def on_register_events(self):
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
+ self.info_panel,
+ self.chat_state,
]
+ self._indices_input,
show_progress="hidden",
)
+ self.chat_control.btn_del.click(
+ lambda id: self.toggle_delete(id),
+ inputs=[self.chat_control.conversation_id],
+ outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
+ )
+ self.chat_control.btn_del_conf.click(
+ self.chat_control.delete_conv,
+ inputs=[self.chat_control.conversation_id, self._app.user_id],
+ outputs=[self.chat_control.conversation_id, self.chat_control.conversation],
+ show_progress="hidden",
+ ).then(
+ self.chat_control.select_conv,
+ inputs=[self.chat_control.conversation],
+ outputs=[
+ self.chat_control.conversation_id,
+ self.chat_control.conversation,
+ self.chat_control.conversation_rn,
+ self.chat_panel.chatbot,
+ self.info_panel,
+ ]
+ + self._indices_input,
+ show_progress="hidden",
+ ).then(
+ lambda: self.toggle_delete(""),
+ outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
+ )
+ self.chat_control.btn_del_cnl.click(
+ lambda: self.toggle_delete(""),
+ outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
+ )
+ self.chat_control.conversation_rn_btn.click(
+ self.chat_control.rename_conv,
+ inputs=[
+ self.chat_control.conversation_id,
+ self.chat_control.conversation_rn,
+ self._app.user_id,
+ ],
+ outputs=[self.chat_control.conversation, self.chat_control.conversation],
+ show_progress="hidden",
+ )
+
+ self.chat_control.conversation.select(
+ self.chat_control.select_conv,
+ inputs=[self.chat_control.conversation],
+ outputs=[
+ self.chat_control.conversation_id,
+ self.chat_control.conversation,
+ self.chat_control.conversation_rn,
+ self.chat_panel.chatbot,
+ self.info_panel,
+ ]
+ + self._indices_input,
+ show_progress="hidden",
+ ).then(
+ lambda: self.toggle_delete(""),
+ outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
+ )
+
self.report_issue.report_btn.click(
self.report_issue.report,
inputs=[
@@ -109,12 +238,79 @@ def on_register_events(self):
self.chat_panel.chatbot,
self._app.settings_state,
self._app.user_id,
+ self.info_panel,
+ self.chat_state,
]
+ self._indices_input,
outputs=None,
)
+ if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
+ self.chat_suggestion.example.select(
+ self.chat_suggestion.select_example,
+ outputs=[self.chat_panel.text_input],
+ show_progress="hidden",
+ )
+
+ def submit_msg(self, chat_input, chat_history, user_id, conv_id, conv_name):
+ """Submit a message to the chatbot"""
+ if not chat_input:
+ raise ValueError("Input is empty")
+
+ if not conv_id:
+ id_, update = self.chat_control.new_conv(user_id)
+ with Session(engine) as session:
+ statement = select(Conversation).where(Conversation.id == id_)
+ name = session.exec(statement).one().name
+ new_conv_id = id_
+ conv_update = update
+ new_conv_name = name
+ else:
+ new_conv_id = conv_id
+ conv_update = gr.update()
+ new_conv_name = conv_name
+
+ return (
+ "",
+ chat_history + [(chat_input, None)],
+ new_conv_id,
+ conv_update,
+ new_conv_name,
+ )
- def update_data_source(self, convo_id, messages, *selecteds):
+ def toggle_delete(self, conv_id):
+ if conv_id:
+ return gr.update(visible=False), gr.update(visible=True)
+ else:
+ return gr.update(visible=True), gr.update(visible=False)
+
+ def on_subscribe_public_events(self):
+ if self._app.f_user_management:
+ self._app.subscribe_event(
+ name="onSignIn",
+ definition={
+ "fn": self.chat_control.reload_conv,
+ "inputs": [self._app.user_id],
+ "outputs": [self.chat_control.conversation],
+ "show_progress": "hidden",
+ },
+ )
+
+ self._app.subscribe_event(
+ name="onSignOut",
+ definition={
+ "fn": lambda: self.chat_control.select_conv(""),
+ "outputs": [
+ self.chat_control.conversation_id,
+ self.chat_control.conversation,
+ self.chat_control.conversation_rn,
+ self.chat_panel.chatbot,
+ ]
+ + self._indices_input,
+ "show_progress": "hidden",
+ },
+ )
+
+ def update_data_source(self, convo_id, messages, state, *selecteds):
"""Update the data source"""
if not convo_id:
gr.Warning("No conversation selected")
@@ -122,8 +318,12 @@ def update_data_source(self, convo_id, messages, *selecteds):
selecteds_ = {}
for index in self._app.index_manager.indices:
- if index.selector != -1:
+ if index.selector is None:
+ continue
+ if isinstance(index.selector, int):
selecteds_[str(index.id)] = selecteds[index.selector]
+ else:
+ selecteds_[str(index.id)] = [selecteds[i] for i in index.selector]
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == convo_id)
@@ -133,6 +333,7 @@ def update_data_source(self, convo_id, messages, *selecteds):
result.data_source = {
"selected": selecteds_,
"messages": messages,
+ "state": state,
"likes": deepcopy(data_source.get("likes", [])),
}
session.add(result)
@@ -152,33 +353,45 @@ def is_liked(self, convo_id, liked: gr.LikeData):
session.add(result)
session.commit()
- def create_pipeline(self, settings: dict, *selecteds):
+ def create_pipeline(self, settings: dict, state: dict, *selecteds):
"""Create the pipeline from settings
Args:
settings: the settings of the app
+ is_regen: whether the regen button is clicked
selected: the list of file ids that will be served as context. If None, then
consider using all files
Returns:
- the pipeline objects
+ - the pipeline objects
"""
+ reasoning_mode = settings["reasoning.use"]
+ reasoning_cls = reasonings[reasoning_mode]
+ reasoning_id = reasoning_cls.get_info()["id"]
+
# get retrievers
retrievers = []
for index in self._app.index_manager.indices:
index_selected = []
- if index.selector != -1:
+ if isinstance(index.selector, int):
index_selected = selecteds[index.selector]
+ if isinstance(index.selector, tuple):
+ for i in index.selector:
+ index_selected.append(selecteds[i])
iretrievers = index.get_retriever_pipelines(settings, index_selected)
retrievers += iretrievers
- reasoning_mode = settings["reasoning.use"]
- reasoning_cls = reasonings[reasoning_mode]
- pipeline = reasoning_cls.get_pipeline(settings, retrievers)
+ # prepare states
+ reasoning_state = {
+ "app": deepcopy(state["app"]),
+ "pipeline": deepcopy(state.get(reasoning_id, {})),
+ }
- return pipeline
+ pipeline = reasoning_cls.get_pipeline(settings, reasoning_state, retrievers)
- async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):
+ return pipeline, reasoning_state
+
+ async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
"""Chat function"""
chat_input = chat_history[-1][0]
chat_history = chat_history[:-1]
@@ -186,7 +399,7 @@ async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
# construct the pipeline
- pipeline = self.create_pipeline(settings, *selecteds)
+ pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
pipeline.set_output_queue(queue)
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
@@ -198,7 +411,8 @@ async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):
try:
response = queue.get_nowait()
except Exception:
- yield "", chat_history + [(chat_input, text or "Thinking ...")], refs
+ state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
+ yield chat_history + [(chat_input, text or "Thinking ...")], refs, state
continue
if response is None:
@@ -207,7 +421,11 @@ async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):
break
if "output" in response:
- text += response["output"]
+ if response["output"] is None:
+ text = ""
+ else:
+ text += response["output"]
+
if "evidence" in response:
if response["evidence"] is None:
refs = ""
@@ -218,4 +436,25 @@ async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):
print(f"Len refs: {len(refs)}")
len_ref = len(refs)
- yield "", chat_history + [(chat_input, text)], refs
+ state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
+ yield chat_history + [(chat_input, text)], refs, state
+
+ async def regen_fn(
+ self, conversation_id, chat_history, settings, state, *selecteds
+ ):
+ """Regen function"""
+ if not chat_history:
+ gr.Warning("Empty chat")
+ yield chat_history, "", state
+ return
+
+ state["app"]["regen"] = True
+ async for chat, refs, state in self.chat_fn(
+ conversation_id, chat_history, settings, state, *selecteds
+ ):
+ new_state = deepcopy(state)
+ new_state["app"]["regen"] = False
+ yield chat, refs, new_state
+ else:
+ state["app"]["regen"] = False
+ yield chat_history, "", state
diff --git a/libs/ktem/ktem/pages/chat/chat_panel.py b/libs/ktem/ktem/pages/chat/chat_panel.py
index f4cfc5bbe..55b9258e9 100644
--- a/libs/ktem/ktem/pages/chat/chat_panel.py
+++ b/libs/ktem/ktem/pages/chat/chat_panel.py
@@ -19,6 +19,7 @@ def on_building_ui(self):
placeholder="Chat input", scale=15, container=False
)
self.submit_btn = gr.Button(value="Send", scale=1, min_width=10)
+ self.regen_btn = gr.Button(value="Regen", scale=1, min_width=10)
def submit_msg(self, chat_input, chat_history):
"""Submit a message to the chatbot"""
diff --git a/libs/ktem/ktem/pages/chat/chat_suggestion.py b/libs/ktem/ktem/pages/chat/chat_suggestion.py
new file mode 100644
index 000000000..23332c034
--- /dev/null
+++ b/libs/ktem/ktem/pages/chat/chat_suggestion.py
@@ -0,0 +1,26 @@
+import gradio as gr
+from ktem.app import BasePage
+from theflow.settings import settings as flowsettings
+
+
+class ChatSuggestion(BasePage):
+ def __init__(self, app):
+ self._app = app
+ self.on_building_ui()
+
+ def on_building_ui(self):
+ chat_samples = getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION_SAMPLES", [])
+ chat_samples = [[each] for each in chat_samples]
+ with gr.Accordion(label="Chat Suggestion", open=False) as self.accordion:
+ self.example = gr.DataFrame(
+ value=chat_samples,
+ headers=["Sample"],
+ interactive=False,
+ wrap=True,
+ )
+
+ def as_gradio_component(self):
+ return self.example
+
+ def select_example(self, ev: gr.SelectData):
+ return ev.value
diff --git a/libs/ktem/ktem/pages/chat/common.py b/libs/ktem/ktem/pages/chat/common.py
new file mode 100644
index 000000000..a2fc0dcee
--- /dev/null
+++ b/libs/ktem/ktem/pages/chat/common.py
@@ -0,0 +1,4 @@
+DEFAULT_APPLICATION_STATE = {"regen": False}
+STATE = {
+ "app": DEFAULT_APPLICATION_STATE,
+}
diff --git a/libs/ktem/ktem/pages/chat/control.py b/libs/ktem/ktem/pages/chat/control.py
index a0b256125..f2ed99bb1 100644
--- a/libs/ktem/ktem/pages/chat/control.py
+++ b/libs/ktem/ktem/pages/chat/control.py
@@ -5,9 +5,22 @@
from ktem.db.models import Conversation, engine
from sqlmodel import Session, select
+from .common import STATE
+
logger = logging.getLogger(__name__)
+def is_conv_name_valid(name):
+ """Check if the conversation name is valid"""
+ errors = []
+ if len(name) == 0:
+ errors.append("Name cannot be empty")
+ elif len(name) > 40:
+ errors.append("Name cannot be longer than 40 characters")
+
+ return "; ".join(errors)
+
+
class ConversationControl(BasePage):
"""Manage conversation"""
@@ -26,9 +39,17 @@ def on_building_ui(self):
interactive=True,
)
- with gr.Row():
- self.conversation_new_btn = gr.Button(value="New", min_width=10)
- self.conversation_del_btn = gr.Button(value="Delete", min_width=10)
+ with gr.Row() as self._new_delete:
+ self.btn_new = gr.Button(value="New", min_width=10)
+ self.btn_del = gr.Button(value="Delete", min_width=10)
+
+ with gr.Row(visible=False) as self._delete_confirm:
+ self.btn_del_conf = gr.Button(
+ value="Delete",
+ variant="primary",
+ min_width=10,
+ )
+ self.btn_del_cnl = gr.Button(value="Cancel", min_width=10)
with gr.Row():
self.conversation_rn = gr.Text(
@@ -50,48 +71,6 @@ def on_building_ui(self):
# outputs=[current_state],
# )
- def on_subscribe_public_events(self):
- if self._app.f_user_management:
- self._app.subscribe_event(
- name="onSignIn",
- definition={
- "fn": self.reload_conv,
- "inputs": [self._app.user_id],
- "outputs": [self.conversation],
- "show_progress": "hidden",
- },
- )
-
- self._app.subscribe_event(
- name="onSignOut",
- definition={
- "fn": self.reload_conv,
- "inputs": [self._app.user_id],
- "outputs": [self.conversation],
- "show_progress": "hidden",
- },
- )
-
- def on_register_events(self):
- self.conversation_new_btn.click(
- self.new_conv,
- inputs=self._app.user_id,
- outputs=[self.conversation_id, self.conversation],
- show_progress="hidden",
- )
- self.conversation_del_btn.click(
- self.delete_conv,
- inputs=[self.conversation_id, self._app.user_id],
- outputs=[self.conversation_id, self.conversation],
- show_progress="hidden",
- )
- self.conversation_rn_btn.click(
- self.rename_conv,
- inputs=[self.conversation_id, self.conversation_rn, self._app.user_id],
- outputs=[self.conversation, self.conversation],
- show_progress="hidden",
- )
-
def load_chat_history(self, user_id):
"""Reload chat history"""
options = []
@@ -110,7 +89,7 @@ def load_chat_history(self, user_id):
def reload_conv(self, user_id):
conv_list = self.load_chat_history(user_id)
if conv_list:
- return gr.update(value=conv_list[0][1], choices=conv_list)
+ return gr.update(value=None, choices=conv_list)
else:
return gr.update(value=None, choices=[])
@@ -131,10 +110,15 @@ def new_conv(self, user_id):
return id_, gr.update(value=id_, choices=history)
def delete_conv(self, conversation_id, user_id):
- """Create new chat"""
+ """Delete the selected conversation"""
+ if not conversation_id:
+ gr.Warning("No conversation selected.")
+ return None, gr.update()
+
if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)")
return None, gr.update()
+
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()
@@ -159,27 +143,44 @@ def select_conv(self, conversation_id):
name = result.name
selected = result.data_source.get("selected", {})
chats = result.data_source.get("messages", [])
+ info_panel = ""
+ state = result.data_source.get("state", STATE)
except Exception as e:
logger.warning(e)
id_ = ""
name = ""
selected = {}
chats = []
+ info_panel = ""
+ state = STATE
indices = []
for index in self._app.index_manager.indices:
# assume that the index has selector
- if index.selector == -1:
+ if index.selector is None:
continue
- indices.append(selected.get(str(index.id), []))
+ if isinstance(index.selector, int):
+ indices.append(selected.get(str(index.id), []))
+ if isinstance(index.selector, tuple):
+ indices.extend(selected.get(str(index.id), [[]] * len(index.selector)))
- return id_, id_, name, chats, *indices
+ return id_, id_, name, chats, info_panel, state, *indices
def rename_conv(self, conversation_id, new_name, user_id):
"""Rename the conversation"""
if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)")
return gr.update(), ""
+
+ if not conversation_id:
+ gr.Warning("No conversation selected.")
+ return gr.update(), ""
+
+ errors = is_conv_name_valid(new_name)
+ if errors:
+ gr.Warning(errors)
+ return gr.update(), conversation_id
+
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()
diff --git a/libs/ktem/ktem/pages/chat/report.py b/libs/ktem/ktem/pages/chat/report.py
index 46d9e3c84..dfe030146 100644
--- a/libs/ktem/ktem/pages/chat/report.py
+++ b/libs/ktem/ktem/pages/chat/report.py
@@ -48,12 +48,19 @@ def report(
chat_history: list,
settings: dict,
user_id: Optional[int],
- *selecteds
+ info_panel: str,
+ chat_state: dict,
+ *selecteds,
):
selecteds_ = {}
for index in self._app.index_manager.indices:
- if index.selector != -1:
- selecteds_[str(index.id)] = selecteds[index.selector]
+ if index.selector is not None:
+ if isinstance(index.selector, int):
+ selecteds_[str(index.id)] = selecteds[index.selector]
+ elif isinstance(index.selector, tuple):
+ selecteds_[str(index.id)] = [selecteds[_] for _ in index.selector]
+ else:
+ print(f"Unknown selector type: {index.selector}")
with Session(engine) as session:
issue = IssueReport(
@@ -65,6 +72,8 @@ def report(
chat={
"conv_id": conv_id,
"chat_history": chat_history,
+ "info_panel": info_panel,
+ "chat_state": chat_state,
"selecteds": selecteds_,
},
settings=settings,
diff --git a/libs/ktem/ktem/pages/login.py b/libs/ktem/ktem/pages/login.py
index 6fe15d0c3..d5c57e5a4 100644
--- a/libs/ktem/ktem/pages/login.py
+++ b/libs/ktem/ktem/pages/login.py
@@ -31,11 +31,10 @@ def __init__(self, app):
self.on_building_ui()
def on_building_ui(self):
- gr.Markdown("Welcome to Kotaemon")
- self.usn = gr.Textbox(label="Username")
- self.pwd = gr.Textbox(label="Password", type="password")
- self.btn_login = gr.Button("Login")
- self._dummy = gr.State()
+ gr.Markdown("# Welcome to Kotaemon")
+ self.usn = gr.Textbox(label="Username", visible=False)
+ self.pwd = gr.Textbox(label="Password", type="password", visible=False)
+ self.btn_login = gr.Button("Login", visible=False)
def on_register_events(self):
onSignIn = gr.on(
@@ -45,24 +44,56 @@ def on_register_events(self):
outputs=[self._app.user_id, self.usn, self.pwd],
show_progress="hidden",
js=signin_js,
+ ).then(
+ self.toggle_login_visibility,
+ inputs=[self._app.user_id],
+ outputs=[self.usn, self.pwd, self.btn_login],
)
for event in self._app.get_event("onSignIn"):
onSignIn = onSignIn.success(**event)
+ def toggle_login_visibility(self, user_id):
+ return (
+ gr.update(visible=user_id is None),
+ gr.update(visible=user_id is None),
+ gr.update(visible=user_id is None),
+ )
+
def _on_app_created(self):
- self._app.app.load(
- None,
- inputs=None,
- outputs=[self.usn, self.pwd],
+ onSignIn = self._app.app.load(
+ self.login,
+ inputs=[self.usn, self.pwd],
+ outputs=[self._app.user_id, self.usn, self.pwd],
+ show_progress="hidden",
js=fetch_creds,
+ ).then(
+ self.toggle_login_visibility,
+ inputs=[self._app.user_id],
+ outputs=[self.usn, self.pwd, self.btn_login],
+ )
+ for event in self._app.get_event("onSignIn"):
+ onSignIn = onSignIn.success(**event)
+
+ def on_subscribe_public_events(self):
+ self._app.subscribe_event(
+ name="onSignOut",
+ definition={
+ "fn": self.toggle_login_visibility,
+ "inputs": [self._app.user_id],
+ "outputs": [self.usn, self.pwd, self.btn_login],
+ "show_progress": "hidden",
+ },
)
def login(self, usn, pwd):
+ if not usn or not pwd:
+ return None, usn, pwd
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
with Session(engine) as session:
stmt = select(User).where(
- User.username_lower == usn.lower(), User.password == hashed_password
+ User.username_lower == usn.lower().strip(),
+ User.password == hashed_password,
)
result = session.exec(stmt).all()
if result:
diff --git a/libs/ktem/ktem/pages/settings.py b/libs/ktem/ktem/pages/settings.py
index 0fce2e8f7..20912cb44 100644
--- a/libs/ktem/ktem/pages/settings.py
+++ b/libs/ktem/ktem/pages/settings.py
@@ -164,9 +164,14 @@ def on_register_events(self):
show_progress="hidden",
)
onSignOutClick = self.signout.click(
- lambda: (None, "Current user: ___"),
+ lambda: (None, "Current user: ___", "", ""),
inputs=None,
- outputs=[self._user_id, self.current_name],
+ outputs=[
+ self._user_id,
+ self.current_name,
+ self.password_change,
+ self.password_change_confirm,
+ ],
show_progress="hidden",
js=signout_js,
).then(
@@ -192,8 +197,12 @@ def user_tab(self):
self.password_change_btn = gr.Button("Change password", interactive=True)
def change_password(self, user_id, password, password_confirm):
- if password != password_confirm:
- gr.Warning("Password does not match")
+ from ktem.pages.admin.user import validate_password
+
+ errors = validate_password(password, password_confirm)
+ if errors:
+ print(errors)
+ gr.Warning(errors)
return password, password_confirm
with Session(engine) as session:
diff --git a/libs/ktem/ktem/reasoning/base.py b/libs/ktem/ktem/reasoning/base.py
index 80cf01698..6d6e48648 100644
--- a/libs/ktem/ktem/reasoning/base.py
+++ b/libs/ktem/ktem/reasoning/base.py
@@ -34,12 +34,16 @@ def get_user_settings(cls) -> dict:
@classmethod
def get_pipeline(
- cls, user_settings: dict, retrievers: Optional[list["BaseComponent"]] = None
+ cls,
+ user_settings: dict,
+ state: dict,
+ retrievers: Optional[list["BaseComponent"]] = None,
) -> "BaseReasoning":
"""Get the reasoning pipeline for the app to execute
Args:
user_setting: user settings
+ state: conversation state
retrievers (list): List of retrievers
"""
return cls()
diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py
index 1970c3996..a5a895eeb 100644
--- a/libs/ktem/ktem/reasoning/simple.py
+++ b/libs/ktem/ktem/reasoning/simple.py
@@ -1,11 +1,12 @@
import asyncio
+import html
import logging
+import re
from collections import defaultdict
from functools import partial
import tiktoken
from ktem.llms.manager import llms
-from ktem.reasoning.base import BaseReasoning
from kotaemon.base import (
BaseComponent,
@@ -18,9 +19,17 @@
from kotaemon.indices.qa.citation import CitationPipeline
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import ChatLLM, PromptTemplate
+from kotaemon.loaders.utils.gpt4v import stream_gpt4v
+
+from .base import BaseReasoning
logger = logging.getLogger(__name__)
+EVIDENCE_MODE_TEXT = 0
+EVIDENCE_MODE_TABLE = 1
+EVIDENCE_MODE_CHATBOT = 2
+EVIDENCE_MODE_FIGURE = 3
+
class PrepareEvidencePipeline(BaseComponent):
"""Prepare the evidence text from the list of retrieved documents
@@ -46,7 +55,7 @@ class PrepareEvidencePipeline(BaseComponent):
def run(self, docs: list[RetrievedDocument]) -> Document:
evidence = ""
table_found = 0
- evidence_mode = 0
+ evidence_mode = EVIDENCE_MODE_TEXT
for _id, retrieved_item in enumerate(docs):
retrieved_content = ""
@@ -55,7 +64,7 @@ def run(self, docs: list[RetrievedDocument]) -> Document:
if page:
source += f" (Page {page})"
if retrieved_item.metadata.get("type", "") == "table":
- evidence_mode = 1 # table
+ evidence_mode = EVIDENCE_MODE_TABLE
if table_found < 5:
retrieved_content = retrieved_item.metadata.get("table_origin", "")
if retrieved_content not in evidence:
@@ -66,13 +75,23 @@ def run(self, docs: list[RetrievedDocument]) -> Document:
+ "\n
"
)
elif retrieved_item.metadata.get("type", "") == "chatbot":
- evidence_mode = 2 # chatbot
+ evidence_mode = EVIDENCE_MODE_CHATBOT
retrieved_content = retrieved_item.metadata["window"]
evidence += (
f"
Chatbot scenario from {filename} (Row {page})\n"
+ retrieved_content
+ "\n
"
)
+ elif retrieved_item.metadata.get("type", "") == "image":
+ evidence_mode = EVIDENCE_MODE_FIGURE
+ retrieved_content = retrieved_item.metadata.get("image_origin", "")
+ retrieved_caption = html.escape(retrieved_item.get_content())
+ evidence += (
+ f"
Figure from {source}\n"
+ + f""
+ + "\n
"
+ )
else:
if "window" in retrieved_item.metadata:
retrieved_content = retrieved_item.metadata["window"]
@@ -90,12 +109,13 @@ def run(self, docs: list[RetrievedDocument]) -> Document:
print(retrieved_item.metadata)
print("Score", retrieved_item.metadata.get("relevance_score", None))
- # trim context by trim_len
- print("len (original)", len(evidence))
- if evidence:
- texts = self.trim_func([Document(text=evidence)])
- evidence = texts[0].text
- print("len (trimmed)", len(evidence))
+ if evidence_mode != EVIDENCE_MODE_FIGURE:
+ # trim context by trim_len
+ print("len (original)", len(evidence))
+ if evidence:
+ texts = self.trim_func([Document(text=evidence)])
+ evidence = texts[0].text
+ print("len (trimmed)", len(evidence))
print(f"PrepareEvidence with input {docs}\nOutput: {evidence}\n")
@@ -134,6 +154,25 @@ def run(self, docs: list[RetrievedDocument]) -> Document:
"Answer:"
)
+DEFAULT_QA_FIGURE_PROMPT = (
+ "Use the given context: texts, tables, and figures below to answer the question. "
+ "If you don't know the answer, just say that you don't know. "
+ "Give answer in {lang}.\n\n"
+ "Context: \n"
+ "{context}\n"
+ "Question: {question}\n"
+ "Answer: "
+)
+
+DEFAULT_REWRITE_PROMPT = (
+ "Given the following question, rephrase and expand it "
+ "to help you do better answering. Maintain all information "
+ "in the original question. Keep the question as concise as possible. "
+ "Give answer in {lang}\n"
+ "Original question: {question}\n"
+ "Rephrased question: "
+)
+
class AnswerWithContextPipeline(BaseComponent):
"""Answer the question based on the evidence
@@ -151,6 +190,7 @@ class AnswerWithContextPipeline(BaseComponent):
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
+ vlm_endpoint: str = ""
citation_pipeline: CitationPipeline = Node(
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
)
@@ -158,13 +198,14 @@ class AnswerWithContextPipeline(BaseComponent):
qa_template: str = DEFAULT_QA_TEXT_PROMPT
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT
+ qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT
enable_citation: bool = False
system_prompt: str = ""
lang: str = "English" # support English and Japanese
async def run( # type: ignore
- self, question: str, evidence: str, evidence_mode: int = 0
+ self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document:
"""Answer the question based on the evidence
@@ -188,18 +229,30 @@ async def run( # type: ignore
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
- if evidence_mode == 0:
+ if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
- elif evidence_mode == 1:
+ elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
+ elif evidence_mode == EVIDENCE_MODE_FIGURE:
+ prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
- prompt = prompt_template.populate(
- context=evidence,
- question=question,
- lang=self.lang,
- )
+ images = []
+ if evidence_mode == EVIDENCE_MODE_FIGURE:
+ # isolate image from evidence
+ evidence, images = self.extract_evidence_images(evidence)
+ prompt = prompt_template.populate(
+ context=evidence,
+ question=question,
+ lang=self.lang,
+ )
+ else:
+ prompt = prompt_template.populate(
+ context=evidence,
+ question=question,
+ lang=self.lang,
+ )
citation_task = None
if evidence and self.enable_citation:
@@ -208,23 +261,29 @@ async def run( # type: ignore
)
print("Citation task created")
- messages = []
- if self.system_prompt:
- messages.append(SystemMessage(content=self.system_prompt))
- messages.append(HumanMessage(content=prompt))
-
output = ""
- try:
- # try streaming first
- print("Trying LLM streaming")
- for text in self.llm.stream(messages):
- output += text.text
- self.report_output({"output": text.text})
+ if evidence_mode == EVIDENCE_MODE_FIGURE:
+ for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
+ output += text
+ self.report_output({"output": text})
await asyncio.sleep(0)
- except NotImplementedError:
- print("Streaming is not supported, falling back to normal processing")
- output = self.llm(messages).text
- self.report_output({"output": output})
+ else:
+ messages = []
+ if self.system_prompt:
+ messages.append(SystemMessage(content=self.system_prompt))
+ messages.append(HumanMessage(content=prompt))
+
+ try:
+ # try streaming first
+ print("Trying LLM streaming")
+ for text in self.llm.stream(messages):
+ output += text.text
+ self.report_output({"output": text.text})
+ await asyncio.sleep(0)
+ except NotImplementedError:
+ print("Streaming is not supported, falling back to normal processing")
+ output = self.llm(messages).text
+ self.report_output({"output": output})
# retrieve the citation
print("Waiting for citation task")
@@ -238,6 +297,46 @@ async def run( # type: ignore
return answer
+def extract_evidence_images(self, evidence: str):
+ """Util function to extract and isolate images from context/evidence"""
+ image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
+ matches = re.findall(image_pattern, evidence)
+ context = re.sub(image_pattern, "", evidence)
+ return context, matches
+
+
+class RewriteQuestionPipeline(BaseComponent):
+ """Rewrite user question
+
+ Args:
+ llm: the language model to rewrite question
+ rewrite_template: the prompt template for llm to paraphrase a text input
+ lang: the language of the answer. Currently support English and Japanese
+ """
+
+ llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
+ rewrite_template: str = DEFAULT_REWRITE_PROMPT
+
+ lang: str = "English"
+
+ async def run(self, question: str) -> Document: # type: ignore
+ prompt_template = PromptTemplate(self.rewrite_template)
+ prompt = prompt_template.populate(question=question, lang=self.lang)
+ messages = [
+ SystemMessage(content="You are a helpful assistant"),
+ HumanMessage(content=prompt),
+ ]
+ output = ""
+ for text in self.llm(messages):
+ if "content" in text:
+ output += text[1]
+ self.report_output({"chat_input": text[1]})
+ break
+ await asyncio.sleep(0)
+
+ return Document(text=output)
+
+
class FullQAPipeline(BaseReasoning):
"""Question answering pipeline. Handle from question to answer"""
@@ -248,24 +347,36 @@ class Config:
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
+ rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
+ use_rewrite: bool = False
async def run( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
+ import markdown
+
docs = []
doc_ids = []
+ if self.use_rewrite:
+ rewrite = await self.rewrite_pipeline(question=message)
+ message = rewrite.text
+
for retriever in self.retrievers:
for doc in retriever(text=message):
if doc.doc_id not in doc_ids:
docs.append(doc)
doc_ids.append(doc.doc_id)
for doc in docs:
+ # TODO: a better approach to show the information
+ text = markdown.markdown(
+ doc.text, extensions=["markdown.extensions.tables"]
+ )
self.report_output(
{
"evidence": (
""
f"{doc.metadata['file_name']}
"
- f"{doc.text}"
+ f"{text}"
"
"
)
}
@@ -274,7 +385,12 @@ async def run( # type: ignore
evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = await self.answering_pipeline(
- question=message, evidence=evidence, evidence_mode=evidence_mode
+ question=message,
+ history=history,
+ evidence=evidence,
+ evidence_mode=evidence_mode,
+ conv_id=conv_id,
+ **kwargs,
)
# prepare citation
@@ -284,14 +400,29 @@ async def run( # type: ignore
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
- if start_idx >= 0:
+ if start_idx == -1:
+ continue
+
+ end_idx = start_idx + len(quote)
+
+ current_idx = start_idx
+ if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
- {
- "start": start_idx,
- "end": start_idx + len(quote),
- }
+ {"start": start_idx, "end": end_idx}
)
- break
+ else:
+ while doc.text[current_idx:end_idx].find("|") != -1:
+ match_idx = doc.text[current_idx:end_idx].find("|")
+ spans[doc.doc_id].append(
+ {
+ "start": current_idx,
+ "end": current_idx + match_idx,
+ }
+ )
+ current_idx += match_idx + 2
+ if current_idx > end_idx:
+ break
+ break
id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True
@@ -310,12 +441,15 @@ async def run( # type: ignore
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
+ text_out = markdown.markdown(
+ text, extensions=["markdown.extensions.tables"]
+ )
self.report_output(
{
"evidence": (
""
f"{id2docs[id].metadata['file_name']}
"
- f"{text}"
+ f"{text_out}"
"
"
)
}
@@ -330,12 +464,15 @@ async def run( # type: ignore
{"evidence": "Retrieved segments without matching evidence:\n"}
)
for id in list(not_detected):
+ text_out = markdown.markdown(
+ id2docs[id].text, extensions=["markdown.extensions.tables"]
+ )
self.report_output(
{
"evidence": (
""
f"{id2docs[id].metadata['file_name']}
"
- f"{id2docs[id].text}"
+ f"{text_out}"
"
"
)
}
@@ -345,7 +482,7 @@ async def run( # type: ignore
return answer
@classmethod
- def get_pipeline(cls, settings, retrievers):
+ def get_pipeline(cls, settings, states, retrievers):
"""Get the reasoning pipeline
Args:
@@ -370,6 +507,11 @@ def get_pipeline(cls, settings, retrievers):
pipeline.answering_pipeline.qa_template = settings[
f"reasoning.options.{_id}.qa_prompt"
]
+ pipeline.use_rewrite = states.get("app", {}).get("regen", False)
+ pipeline.rewrite_pipeline.llm = llms.get_default()
+ pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
+ settings["reasoning.lang"], "English"
+ )
return pipeline
@classmethod