diff --git a/.gitignore b/.gitignore index 2e3fb01ca..711a741e7 100644 --- a/.gitignore +++ b/.gitignore @@ -452,6 +452,7 @@ $RECYCLE.BIN/ .theflow/ # End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm +*.py[coid] logs/ .gitsecret/keys/random_seed diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21356cef6..3f68b56f3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,7 +52,12 @@ repos: hooks: - id: mypy additional_dependencies: - [types-PyYAML==6.0.12.11, "types-requests", "sqlmodel"] + [ + types-PyYAML==6.0.12.11, + "types-requests", + "sqlmodel", + "types-Markdown", + ] args: ["--check-untyped-defs", "--ignore-missing-imports"] exclude: "^templates/" - repo: https://github.com/codespell-project/codespell diff --git a/libs/kotaemon/kotaemon/indices/ingests/files.py b/libs/kotaemon/kotaemon/indices/ingests/files.py index ed00e5cb7..75f944e4d 100644 --- a/libs/kotaemon/kotaemon/indices/ingests/files.py +++ b/libs/kotaemon/kotaemon/indices/ingests/files.py @@ -7,6 +7,7 @@ from kotaemon.indices.extractors import BaseDocParser from kotaemon.indices.splitters import BaseSplitter, TokenSplitter from kotaemon.loaders import ( + AdobeReader, DirectoryReader, MathpixPDFReader, OCRReader, @@ -41,7 +42,7 @@ class DocumentIngestor(BaseComponent): The default file extractors are stored in `KH_DEFAULT_FILE_EXTRACTORS` """ - pdf_mode: str = "normal" # "normal", "mathpix", "ocr" + pdf_mode: str = "normal" # "normal", "mathpix", "ocr", "multimodal" doc_parsers: list[BaseDocParser] = Param(default_callback=lambda _: []) text_splitter: BaseSplitter = TokenSplitter.withx( chunk_size=1024, @@ -61,6 +62,8 @@ def _get_reader(self, input_files: list[str | Path]): pass # use default loader of llama-index which is pypdf elif self.pdf_mode == "ocr": file_extractors[".pdf"] = OCRReader() + elif self.pdf_mode == "multimodal": + file_extractors[".pdf"] = AdobeReader() else: file_extractors[".pdf"] = MathpixPDFReader() diff --git a/libs/kotaemon/kotaemon/indices/qa/citation.py b/libs/kotaemon/kotaemon/indices/qa/citation.py index f1a53c797..3192a07fa 100644 --- a/libs/kotaemon/kotaemon/indices/qa/citation.py +++ b/libs/kotaemon/kotaemon/indices/qa/citation.py @@ -104,18 +104,16 @@ def invoke(self, context: str, question: str): print("CitationPipeline: invoking LLM") llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs) print("CitationPipeline: finish invoking LLM") + if not llm_output.messages: + return None + function_output = llm_output.messages[0].additional_kwargs["function_call"][ + "arguments" + ] + output = QuestionAnswer.parse_raw(function_output) except Exception as e: print(e) return None - if not llm_output.messages: - return None - - function_output = llm_output.messages[0].additional_kwargs["function_call"][ - "arguments" - ] - output = QuestionAnswer.parse_raw(function_output) - return output async def ainvoke(self, context: str, question: str): diff --git a/libs/kotaemon/kotaemon/loaders/__init__.py b/libs/kotaemon/kotaemon/loaders/__init__.py index d742b52f2..a59d71315 100644 --- a/libs/kotaemon/kotaemon/loaders/__init__.py +++ b/libs/kotaemon/kotaemon/loaders/__init__.py @@ -1,10 +1,11 @@ +from .adobe_loader import AdobeReader from .base import AutoReader, BaseReader from .composite_loader import DirectoryReader from .docx_loader import DocxReader from .excel_loader import PandasExcelReader from .html_loader import HtmlReader from .mathpix_loader import MathpixPDFReader -from .ocr_loader import OCRReader +from .ocr_loader import ImageReader, OCRReader from .unstructured_loader import UnstructuredReader __all__ = [ @@ -12,9 +13,11 @@ "BaseReader", "PandasExcelReader", "MathpixPDFReader", + "ImageReader", "OCRReader", "DirectoryReader", "UnstructuredReader", "DocxReader", "HtmlReader", + "AdobeReader", ] diff --git a/libs/kotaemon/kotaemon/loaders/adobe_loader.py b/libs/kotaemon/kotaemon/loaders/adobe_loader.py new file mode 100644 index 000000000..09a802c37 --- /dev/null +++ b/libs/kotaemon/kotaemon/loaders/adobe_loader.py @@ -0,0 +1,186 @@ +import logging +import os +import re +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Optional + +from decouple import config +from llama_index.readers.base import BaseReader + +from kotaemon.base import Document + +logger = logging.getLogger(__name__) + +DEFAULT_VLM_ENDPOINT = ( + "{0}openai/deployments/{1}/chat/completions?api-version={2}".format( + config("AZURE_OPENAI_ENDPOINT", default=""), + "gpt-4-vision", + config("OPENAI_API_VERSION", default=""), + ) +) + + +class AdobeReader(BaseReader): + """Read PDF using the Adobe's PDF Services. + Be able to extract text, table, and figure with high accuracy + + Example: + ```python + >> from kotaemon.loaders import AdobeReader + >> reader = AdobeReader() + >> documents = reader.load_data("path/to/pdf") + ``` + Args: + endpoint: URL to the Vision Language Model endpoint. If not provided, + will use the default `kotaemon.loaders.adobe_loader.DEFAULT_VLM_ENDPOINT` + + max_figures_to_caption: an int decides how many figured will be captioned. + The rest will be ignored (are indexed without captions). + """ + + def __init__( + self, + vlm_endpoint: Optional[str] = None, + max_figures_to_caption: int = 100, + *args: Any, + **kwargs: Any, + ) -> None: + """Init params""" + super().__init__(*args) + self.table_regex = r"/Table(\[\d+\])?$" + self.figure_regex = r"/Figure(\[\d+\])?$" + self.vlm_endpoint = vlm_endpoint or DEFAULT_VLM_ENDPOINT + self.max_figures_to_caption = max_figures_to_caption + + def load_data( + self, file: Path, extra_info: Optional[Dict] = None, **kwargs + ) -> List[Document]: + """Load data by calling to the Adobe's API + + Args: + file (Path): Path to the PDF file + + Returns: + List[Document]: list of documents extracted from the PDF file, + includes 3 types: text, table, and image + + """ + from .utils.adobe import ( + generate_figure_captions, + load_json, + parse_figure_paths, + parse_table_paths, + request_adobe_service, + ) + + filename = file.name + filepath = str(Path(file).resolve()) + output_path = request_adobe_service(file_path=str(file), output_path="") + results_path = os.path.join(output_path, "structuredData.json") + + if not os.path.exists(results_path): + logger.exception("Fail to parse the document.") + return [] + + data = load_json(results_path) + + texts = defaultdict(list) + tables = [] + figures = [] + + elements = data["elements"] + for item_id, item in enumerate(elements): + page_number = item.get("Page", -1) + 1 + item_path = item["Path"] + item_text = item.get("Text", "") + + file_paths = [ + Path(output_path) / path for path in item.get("filePaths", []) + ] + prev_item = elements[item_id - 1] + title = prev_item.get("Text", "") + + if re.search(self.table_regex, item_path): + table_content = parse_table_paths(file_paths) + if not table_content: + continue + table_caption = ( + table_content.replace("|", "").replace("---", "") + + f"\n(Table in Page {page_number}. {title})" + ) + tables.append((page_number, table_content, table_caption)) + + elif re.search(self.figure_regex, item_path): + figure_caption = ( + item_text + f"\n(Figure in Page {page_number}. {title})" + ) + figure_content = parse_figure_paths(file_paths) + if not figure_content: + continue + figures.append([page_number, figure_content, figure_caption]) + + else: + if item_text and "Table" not in item_path and "Figure" not in item_path: + texts[page_number].append(item_text) + + # get figure caption using GPT-4V + figure_captions = generate_figure_captions( + self.vlm_endpoint, + [item[1] for item in figures], + self.max_figures_to_caption, + ) + for item, caption in zip(figures, figure_captions): + # update figure caption + item[2] += " " + caption + + # Wrap elements with Document + documents = [] + + # join plain text elements + for page_number, txts in texts.items(): + documents.append( + Document( + text="\n".join(txts), + metadata={ + "page_label": page_number, + "file_name": filename, + "file_path": filepath, + }, + ) + ) + + # table elements + for page_number, table_content, table_caption in tables: + documents.append( + Document( + text=table_caption, + metadata={ + "table_origin": table_content, + "type": "table", + "page_label": page_number, + "file_name": filename, + "file_path": filepath, + }, + metadata_template="", + metadata_seperator="", + ) + ) + + # figure elements + for page_number, figure_content, figure_caption in figures: + documents.append( + Document( + text=figure_caption, + metadata={ + "image_origin": figure_content, + "type": "image", + "page_label": page_number, + "file_name": filename, + "file_path": filepath, + }, + metadata_template="", + metadata_seperator="", + ) + ) + return documents diff --git a/libs/kotaemon/kotaemon/loaders/ocr_loader.py b/libs/kotaemon/kotaemon/loaders/ocr_loader.py index e68971768..bb1ac5dca 100644 --- a/libs/kotaemon/kotaemon/loaders/ocr_loader.py +++ b/libs/kotaemon/kotaemon/loaders/ocr_loader.py @@ -125,3 +125,70 @@ def load_data( ) return documents + + +class ImageReader(BaseReader): + """Read PDF using OCR, with high focus on table extraction + + Example: + ```python + >> from knowledgehub.loaders import OCRReader + >> reader = OCRReader() + >> documents = reader.load_data("path/to/pdf") + ``` + + Args: + endpoint: URL to FullOCR endpoint. If not provided, will look for + environment variable `OCR_READER_ENDPOINT` or use the default + `knowledgehub.loaders.ocr_loader.DEFAULT_OCR_ENDPOINT` + (http://127.0.0.1:8000/v2/ai/infer/) + use_ocr: whether to use OCR to read text (e.g: from images, tables) in the PDF + If False, only the table and text within table cells will be extracted. + """ + + def __init__(self, endpoint: Optional[str] = None): + """Init the OCR reader with OCR endpoint (FullOCR pipeline)""" + super().__init__() + self.ocr_endpoint = endpoint or os.getenv( + "OCR_READER_ENDPOINT", DEFAULT_OCR_ENDPOINT + ) + + def load_data( + self, file_path: Path, extra_info: Optional[dict] = None, **kwargs + ) -> List[Document]: + """Load data using OCR reader + + Args: + file_path (Path): Path to PDF file + debug_path (Path): Path to store debug image output + artifact_path (Path): Path to OCR endpoints artifacts directory + + Returns: + List[Document]: list of documents extracted from the PDF file + """ + file_path = Path(file_path).resolve() + + with file_path.open("rb") as content: + files = {"input": content} + data = {"job_id": uuid4(), "table_only": False} + + # call the API from FullOCR endpoint + if "response_content" in kwargs: + # overriding response content if specified + ocr_results = kwargs["response_content"] + else: + # call original API + resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data) + ocr_results = resp.json()["result"] + + extra_info = extra_info or {} + result = [] + for ocr_result in ocr_results: + result.append( + Document( + content=ocr_result["csv_string"], + metadata=extra_info, + ) + ) + + return result diff --git a/libs/kotaemon/kotaemon/loaders/utils/adobe.py b/libs/kotaemon/kotaemon/loaders/utils/adobe.py new file mode 100644 index 000000000..f1adcd5a7 --- /dev/null +++ b/libs/kotaemon/kotaemon/loaders/utils/adobe.py @@ -0,0 +1,246 @@ +# need pip install pdfservices-sdk==2.3.0 + +import base64 +import json +import logging +import os +import tempfile +import zipfile +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List, Union + +import pandas as pd +from decouple import config + +from kotaemon.loaders.utils.gpt4v import generate_gpt4v + + +def request_adobe_service(file_path: str, output_path: str = "") -> str: + """Main function to call the adobe service, and unzip the results. + Args: + file_path (str): path to the pdf file + output_path (str): path to store the results + + Returns: + output_path (str): path to the results + + """ + try: + from adobe.pdfservices.operation.auth.credentials import Credentials + from adobe.pdfservices.operation.exception.exceptions import ( + SdkException, + ServiceApiException, + ServiceUsageException, + ) + from adobe.pdfservices.operation.execution_context import ExecutionContext + from adobe.pdfservices.operation.io.file_ref import FileRef + from adobe.pdfservices.operation.pdfops.extract_pdf_operation import ( + ExtractPDFOperation, + ) + from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_element_type import ( # noqa: E501 + ExtractElementType, + ) + from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_pdf_options import ( # noqa: E501 + ExtractPDFOptions, + ) + from adobe.pdfservices.operation.pdfops.options.extractpdf.extract_renditions_element_type import ( # noqa: E501 + ExtractRenditionsElementType, + ) + except ImportError: + raise ImportError( + "pdfservices-sdk is not installed. " + "Please install it by running `pip install pdfservices-sdk" + "@git+https://github.com/niallcm/pdfservices-python-sdk.git" + "@bump-and-unfreeze-requirements`" + ) + + if not output_path: + output_path = tempfile.mkdtemp() + + try: + # Initial setup, create credentials instance. + credentials = ( + Credentials.service_principal_credentials_builder() + .with_client_id(config("PDF_SERVICES_CLIENT_ID", default="")) + .with_client_secret(config("PDF_SERVICES_CLIENT_SECRET", default="")) + .build() + ) + + # Create an ExecutionContext using credentials + # and create a new operation instance. + execution_context = ExecutionContext.create(credentials) + extract_pdf_operation = ExtractPDFOperation.create_new() + + # Set operation input from a source file. + source = FileRef.create_from_local_file(file_path) + extract_pdf_operation.set_input(source) + + # Build ExtractPDF options and set them into the operation + extract_pdf_options: ExtractPDFOptions = ( + ExtractPDFOptions.builder() + .with_elements_to_extract( + [ExtractElementType.TEXT, ExtractElementType.TABLES] + ) + .with_elements_to_extract_renditions( + [ + ExtractRenditionsElementType.TABLES, + ExtractRenditionsElementType.FIGURES, + ] + ) + .build() + ) + extract_pdf_operation.set_options(extract_pdf_options) + + # Execute the operation. + result: FileRef = extract_pdf_operation.execute(execution_context) + + # Save the result to the specified location. + zip_file_path = os.path.join( + output_path, "ExtractTextTableWithFigureTableRendition.zip" + ) + result.save_as(zip_file_path) + # Open the ZIP file + with zipfile.ZipFile(zip_file_path, "r") as zip_ref: + # Extract all contents to the destination folder + zip_ref.extractall(output_path) + except (ServiceApiException, ServiceUsageException, SdkException): + logging.exception("Exception encountered while executing operation") + + return output_path + + +def make_markdown_table(table_as_list: List[str]) -> str: + """ + Convert table from python list representation to markdown format. + The input list consists of rows of tables, the first row is the header. + + Args: + table_as_list: list of table rows + Example: [["Name", "Age", "Height"], + ["Jake", 20, 5'10], + ["Mary", 21, 5'7]] + Returns: + markdown representation of the table + """ + markdown = "\n" + str("| ") + + for e in table_as_list[0]: + to_add = " " + str(e) + str(" |") + markdown += to_add + markdown += "\n" + + markdown += "| " + for i in range(len(table_as_list[0])): + markdown += str("--- | ") + markdown += "\n" + + for entry in table_as_list[1:]: + markdown += str("| ") + for e in entry: + to_add = str(e) + str(" | ") + markdown += to_add + markdown += "\n" + + return markdown + "\n" + + +def load_json(input_path: Union[str | Path]) -> dict: + """Load json file""" + with open(input_path, "r") as fi: + data = json.load(fi) + + return data + + +def load_excel(input_path: Union[str | Path]) -> str: + """Load excel file and convert to markdown""" + + df = pd.read_excel(input_path).fillna("") + # Convert dataframe to a list of rows + row_list = [df.columns.values.tolist()] + df.values.tolist() + + for item_id, item in enumerate(row_list[0]): + if "Unnamed" in item: + row_list[0][item_id] = "" + + for row in row_list: + for item_id, item in enumerate(row): + row[item_id] = str(item).replace("_x000D_", " ").replace("\n", " ").strip() + + markdown_str = make_markdown_table(row_list) + return markdown_str + + +def encode_image_base64(image_path: Union[str | Path]) -> Union[bytes, str]: + """Convert image to base64""" + + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def parse_table_paths(file_paths: List[Path]) -> str: + """Read the table stored in an excel file given the file path""" + + content = "" + for path in file_paths: + if path.suffix == ".xlsx": + content = load_excel(path) + break + return content + + +def parse_figure_paths(file_paths: List[Path]) -> Union[bytes, str]: + """Read and convert an image to base64 given the image path""" + + content = "" + for path in file_paths: + if path.suffix == ".png": + base64_image = encode_image_base64(path) + content = f"data:image/png;base64,{base64_image}" # type: ignore + break + return content + + +def generate_single_figure_caption(vlm_endpoint: str, figure: str) -> str: + """Summarize a single figure using GPT-4V""" + if figure: + output = generate_gpt4v( + endpoint=vlm_endpoint, + prompt="Provide a short 2 sentence summary of this image?", + images=figure, + ) + if "sorry" in output.lower(): + output = "" + else: + output = "" + return output + + +def generate_figure_captions( + vlm_endpoint: str, figures: List, max_figures_to_process: int +) -> List: + """Summarize several figures using GPT-4V. + Args: + vlm_endpoint (str): endpoint to the vision language model service + figures (List): list of base64 images + max_figures_to_process (int): the maximum number of figures will be summarized, + the rest are ignored. + + Returns: + results (List[str]): list of all figure captions and empty strings for + ignored figures. + """ + to_gen_figures = figures[:max_figures_to_process] + other_figures = figures[max_figures_to_process:] + + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit( + lambda: generate_single_figure_caption(vlm_endpoint, figure) + ) + for figure in to_gen_figures + ] + + results = [future.result() for future in futures] + return results + [""] * len(other_figures) diff --git a/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py b/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py new file mode 100644 index 000000000..1e219d660 --- /dev/null +++ b/libs/kotaemon/kotaemon/loaders/utils/gpt4v.py @@ -0,0 +1,96 @@ +import json +from typing import Any, List + +import requests +from decouple import config + + +def generate_gpt4v( + endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512 +) -> str: + # OpenAI API Key + api_key = config("AZURE_OPENAI_API_KEY", default="") + headers = {"Content-Type": "application/json", "api-key": api_key} + + if isinstance(images, str): + images = [images] + + payload = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + ] + + [ + { + "type": "image_url", + "image_url": {"url": image}, + } + for image in images + ], + } + ], + "max_tokens": max_tokens, + } + + try: + response = requests.post(endpoint, headers=headers, json=payload) + output = response.json() + output = output["choices"][0]["message"]["content"] + except Exception: + output = "" + return output + + +def stream_gpt4v( + endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512 +) -> Any: + # OpenAI API Key + api_key = config("AZURE_OPENAI_API_KEY", default="") + headers = {"Content-Type": "application/json", "api-key": api_key} + + if isinstance(images, str): + images = [images] + + payload = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + ] + + [ + { + "type": "image_url", + "image_url": {"url": image}, + } + for image in images + ], + } + ], + "max_tokens": max_tokens, + "stream": True, + } + try: + response = requests.post(endpoint, headers=headers, json=payload, stream=True) + assert response.status_code == 200, str(response.content) + output = "" + for line in response.iter_lines(): + if line: + if line.startswith(b"\xef\xbb\xbf"): + line = line[9:] + else: + line = line[6:] + try: + if line == "[DONE]": + break + line = json.loads(line.decode("utf-8")) + except Exception: + break + if len(line["choices"]): + output += line["choices"][0]["delta"].get("content", "") + yield line["choices"][0]["delta"].get("content", "") + except Exception: + output = "" + return output diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index e1e30280d..73c3e8ab7 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -60,6 +60,7 @@ adv = [ "cohere", "elasticsearch", "llama-cpp-python", + "pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements", ] dev = [ "ipython", @@ -69,6 +70,7 @@ dev = [ "flake8", "sphinx", "coverage", + "python-decouple" ] all = ["kotaemon[adv,dev]"] diff --git a/libs/kotaemon/tests/_test_multimodal_reader.py b/libs/kotaemon/tests/_test_multimodal_reader.py new file mode 100644 index 000000000..b07786f03 --- /dev/null +++ b/libs/kotaemon/tests/_test_multimodal_reader.py @@ -0,0 +1,21 @@ +# TODO: This test is broken and should be rewritten +from pathlib import Path + +from kotaemon.loaders import AdobeReader + +# from dotenv import load_dotenv + + +input_file = Path(__file__).parent / "resources" / "multimodal.pdf" + +# load_dotenv() + + +def test_adobe_reader(): + reader = AdobeReader() + documents = reader.load_data(input_file) + table_docs = [doc for doc in documents if doc.metadata.get("type", "") == "table"] + assert len(table_docs) == 2 + + figure_docs = [doc for doc in documents if doc.metadata.get("type", "") == "image"] + assert len(figure_docs) == 2 diff --git a/libs/kotaemon/tests/resources/multimodal.pdf b/libs/kotaemon/tests/resources/multimodal.pdf new file mode 100644 index 000000000..29c2bdc96 Binary files /dev/null and b/libs/kotaemon/tests/resources/multimodal.pdf differ diff --git a/libs/ktem/flowsettings.py b/libs/ktem/flowsettings.py index 268d63fab..2e26cff1a 100644 --- a/libs/ktem/flowsettings.py +++ b/libs/ktem/flowsettings.py @@ -123,6 +123,11 @@ KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"] +KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format( + config("AZURE_OPENAI_ENDPOINT", default=""), + config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4-vision"), + config("OPENAI_API_VERSION", default=""), +) SETTINGS_APP = { diff --git a/libs/ktem/ktem/app.py b/libs/ktem/ktem/app.py index 0a39fa60b..64e8a9da0 100644 --- a/libs/ktem/ktem/app.py +++ b/libs/ktem/ktem/app.py @@ -17,6 +17,7 @@ class BaseApp: The main application contains app-level information: - setting state + - dynamic conversation state - user id Also contains registering methods for: @@ -228,7 +229,9 @@ def on_register_events(self): def _on_app_created(self): """Called when the app is created""" - def as_gradio_component(self) -> Optional[gr.components.Component]: + def as_gradio_component( + self, + ) -> Optional[gr.components.Component | list[gr.components.Component]]: """Return the gradio components responsible for events Note: in ideal scenario, this method shouldn't be necessary. diff --git a/libs/ktem/ktem/index/base.py b/libs/ktem/ktem/index/base.py index 50bdd9e44..518376264 100644 --- a/libs/ktem/ktem/index/base.py +++ b/libs/ktem/ktem/index/base.py @@ -1,6 +1,6 @@ import abc import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from ktem.app import BasePage @@ -57,7 +57,7 @@ def __init__(self, app, id, name, config): self._app = app self.id = id self.name = name - self._config = config # admin settings + self.config = config # admin settings def on_create(self): """Create the index for the first time""" @@ -121,7 +121,7 @@ def get_indexing_pipeline(self, settings: dict) -> "BaseComponent": ... def get_retriever_pipelines( - self, settings: dict, selected: Optional[list] + self, settings: dict, selected: Any = None ) -> list["BaseComponent"]: """Return the retriever pipelines to retrieve the entity from the index""" return [] diff --git a/libs/ktem/ktem/index/file/base.py b/libs/ktem/ktem/index/file/base.py index 5f8e6f4fa..4f28f51ac 100644 --- a/libs/ktem/ktem/index/file/base.py +++ b/libs/ktem/ktem/index/file/base.py @@ -127,3 +127,11 @@ def get_filestorage_path(self, rel_paths: str | list[str]) -> list[str]: the absolute file storage path to the file """ raise NotImplementedError + + def warning(self, msg): + """Log a warning message + + Args: + msg: the message to log + """ + print(msg) diff --git a/libs/ktem/ktem/index/file/index.py b/libs/ktem/ktem/index/file/index.py index ab1f35a3a..5fe395596 100644 --- a/libs/ktem/ktem/index/file/index.py +++ b/libs/ktem/ktem/index/file/index.py @@ -13,7 +13,6 @@ from kotaemon.storages import BaseDocumentStore, BaseVectorStore from .base import BaseFileIndexIndexing, BaseFileIndexRetriever -from .ui import FileIndexPage, FileSelector class FileIndex(BaseIndex): @@ -77,9 +76,15 @@ def __init__(self, app, id: int, name: str, config: dict): self._indexing_pipeline_cls: Type[BaseFileIndexIndexing] self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]] + self._selector_ui_cls: Type + self._selector_ui: Any = None + self._index_ui_cls: Type + self._index_ui: Any = None self._setup_indexing_cls() self._setup_retriever_cls() + self._setup_file_index_ui_cls() + self._setup_file_selector_ui_cls() self._default_settings: dict[str, dict] = {} self._setting_mappings: dict[str, dict] = {} @@ -91,14 +96,14 @@ def _setup_indexing_cls(self): The indexing class will is retrieved from the following order. Stop at the first order found: - - `FILE_INDEX_PIPELINE` in self._config + - `FILE_INDEX_PIPELINE` in self.config - `FILE_INDEX_{id}_PIPELINE` in the flowsettings - `FILE_INDEX_PIPELINE` in the flowsettings - The default .pipelines.IndexDocumentPipeline """ - if "FILE_INDEX_PIPELINE" in self._config: + if "FILE_INDEX_PIPELINE" in self.config: self._indexing_pipeline_cls = import_dotted_string( - self._config["FILE_INDEX_PIPELINE"], safe=False + self.config["FILE_INDEX_PIPELINE"], safe=False ) return @@ -125,15 +130,15 @@ def _setup_retriever_cls(self): The retriever classes will is retrieved from the following order. Stop at the first order found: - - `FILE_INDEX_RETRIEVER_PIPELINES` in self._config + - `FILE_INDEX_RETRIEVER_PIPELINES` in self.config - `FILE_INDEX_{id}_RETRIEVER_PIPELINES` in the flowsettings - `FILE_INDEX_RETRIEVER_PIPELINES` in the flowsettings - The default .pipelines.DocumentRetrievalPipeline """ - if "FILE_INDEX_RETRIEVER_PIPELINES" in self._config: + if "FILE_INDEX_RETRIEVER_PIPELINES" in self.config: self._retriever_pipeline_cls = [ import_dotted_string(each, safe=False) - for each in self._config["FILE_INDEX_RETRIEVER_PIPELINES"] + for each in self.config["FILE_INDEX_RETRIEVER_PIPELINES"] ] return @@ -157,6 +162,76 @@ def _setup_retriever_cls(self): self._retriever_pipeline_cls = [DocumentRetrievalPipeline] + def _setup_file_selector_ui_cls(self): + """Retrieve the file selector UI for the file index + + There can be multiple retriever classes. + + The retriever classes will is retrieved from the following order. Stop at the + first order found: + - `FILE_INDEX_SELECTOR_UI` in self.config + - `FILE_INDEX_{id}_SELECTOR_UI` in the flowsettings + - `FILE_INDEX_SELECTOR_UI` in the flowsettings + - The default .ui.FileSelector + """ + if "FILE_INDEX_SELECTOR_UI" in self.config: + self._selector_ui_cls = import_dotted_string( + self.config["FILE_INDEX_SELECTOR_UI"], safe=False + ) + return + + if hasattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"): + self._selector_ui_cls = import_dotted_string( + getattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"), + safe=False, + ) + return + + if hasattr(flowsettings, "FILE_INDEX_SELECTOR_UI"): + self._selector_ui_cls = import_dotted_string( + getattr(flowsettings, "FILE_INDEX_SELECTOR_UI"), safe=False + ) + return + + from .ui import FileSelector + + self._selector_ui_cls = FileSelector + + def _setup_file_index_ui_cls(self): + """Retrieve the Index UI class + + There can be multiple retriever classes. + + The retriever classes will is retrieved from the following order. Stop at the + first order found: + - `FILE_INDEX_UI` in self.config + - `FILE_INDEX_{id}_UI` in the flowsettings + - `FILE_INDEX_UI` in the flowsettings + - The default .ui.FileIndexPage + """ + if "FILE_INDEX_UI" in self.config: + self._index_ui_cls = import_dotted_string( + self.config["FILE_INDEX_UI"], safe=False + ) + return + + if hasattr(flowsettings, f"FILE_INDEX_{self.id}_UI"): + self._index_ui_cls = import_dotted_string( + getattr(flowsettings, f"FILE_INDEX_{self.id}_UI"), + safe=False, + ) + return + + if hasattr(flowsettings, "FILE_INDEX_UI"): + self._index_ui_cls = import_dotted_string( + getattr(flowsettings, "FILE_INDEX_UI"), safe=False + ) + return + + from .ui import FileIndexPage + + self._index_ui_cls = FileIndexPage + def on_create(self): """Create the index for the first time @@ -165,6 +240,13 @@ def on_create(self): 2. Create the vectorstore 3. Create the docstore """ + file_types_str = self.config.get( + "supported_file_types", + self.get_admin_settings()["supported_file_types"]["value"], + ) + file_types = [each.strip() for each in file_types_str.split(",")] + self.config["supported_file_types"] = file_types + self._resources["Source"].metadata.create_all(engine) # type: ignore self._resources["Index"].metadata.create_all(engine) # type: ignore self._fs_path.mkdir(parents=True, exist_ok=True) @@ -180,10 +262,14 @@ def on_delete(self): shutil.rmtree(self._fs_path) def get_selector_component_ui(self): - return FileSelector(self._app, self) + if self._selector_ui is None: + self._selector_ui = self._selector_ui_cls(self._app, self) + return self._selector_ui def get_index_page_ui(self): - return FileIndexPage(self._app, self) + if self._index_ui is None: + self._index_ui = self._index_ui_cls(self._app, self) + return self._index_ui def get_user_settings(self): if self._default_settings: @@ -210,7 +296,31 @@ def get_admin_settings(cls): "value": embedding_default, "component": "dropdown", "choices": embedding_choices, - } + }, + "supported_file_types": { + "name": "Supported file types", + "value": ( + "image, .pdf, .txt, .csv, .xlsx, .doc, .docx, .pptx, .html, .zip" + ), + "component": "text", + }, + "max_file_size": { + "name": "Max file size (MB) - set 0 to disable", + "value": 1000, + "component": "number", + }, + "max_number_of_files": { + "name": "Max number of files that can be indexed - set 0 to disable", + "value": 0, + "component": "number", + }, + "max_number_of_text_length": { + "name": ( + "Max amount of characters that can be indexed - set 0 to disable" + ), + "value": 0, + "component": "number", + }, } def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing: @@ -224,14 +334,15 @@ def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing: else: stripped_settings[key] = value - obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self._config) + obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config) obj.set_resources(resources=self._resources) return obj def get_retriever_pipelines( - self, settings: dict, selected: Optional[list] = None + self, settings: dict, selected: Any = None ) -> list["BaseFileIndexRetriever"]: + # retrieval settings prefix = f"index.options.{self.id}." stripped_settings = {} for key, value in settings.items(): @@ -240,9 +351,12 @@ def get_retriever_pipelines( else: stripped_settings[key] = value + # transform selected id + selected_ids: Optional[list[str]] = self._selector_ui.get_selected_ids(selected) + retrievers = [] for cls in self._retriever_pipeline_cls: - obj = cls.get_pipeline(stripped_settings, self._config, selected) + obj = cls.get_pipeline(stripped_settings, self.config, selected_ids) if obj is None: continue obj.set_resources(self._resources) diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index c6c7accd5..13036f33f 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import Optional +import gradio as gr from ktem.components import embeddings, filestorage_path from ktem.db.models import engine from llama_index.vector_stores import ( @@ -18,7 +19,7 @@ MetadataFilters, ) from llama_index.vector_stores.types import VectorStoreQueryMode -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.orm import Session from theflow.settings import settings from theflow.utils.modules import import_dotted_string @@ -279,6 +280,7 @@ def run( to_index: list[str] = [] file_to_hash: dict[str, str] = {} errors = [] + to_update = [] for file_path in file_paths: abs_path = str(Path(file_path).resolve()) @@ -291,16 +293,26 @@ def run( statement = select(Source).where(Source.name == Path(abs_path).name) item = session.execute(statement).first() - if item and not reindex: - errors.append(Path(abs_path).name) - continue + if item: + if not reindex: + errors.append(Path(abs_path).name) + continue + else: + to_update.append(Path(abs_path).name) to_index.append(abs_path) if errors: + error_files = ", ".join(errors) + if len(error_files) > 100: + error_files = error_files[:80] + "..." print( - "Files already exist. Please rename/remove them or enable reindex.\n" - f"{errors}" + "Skip these files already exist. Please rename/remove them or " + f"enable reindex:\n{errors}" + ) + self.warning( + "Skip these files already exist. Please rename/remove them or " + f"enable reindex:\n{error_files}" ) if not to_index: @@ -310,9 +322,19 @@ def run( for path in to_index: shutil.copy(path, filestorage_path / file_to_hash[path]) - # prepare record info + # extract the file & prepare record info file_to_source: dict = {} + extraction_errors = [] + nodes = [] for file_path, file_hash in file_to_hash.items(): + if str(Path(file_path).resolve()) not in to_index: + continue + + extraction_result = self.file_ingestor(file_path) + if not extraction_result: + extraction_errors.append(Path(file_path).name) + continue + nodes.extend(extraction_result) source = Source( name=Path(file_path).name, path=file_hash, @@ -320,9 +342,23 @@ def run( ) file_to_source[file_path] = source - # extract the files - nodes = self.file_ingestor(to_index) - print("Extracted", len(to_index), "files into", len(nodes), "nodes") + if extraction_errors: + msg = "Failed to extract these files: {}".format( + ", ".join(extraction_errors) + ) + print(msg) + self.warning(msg) + + if not nodes: + return [], [] + + print( + "Extracted", + len(to_index) - len(extraction_errors), + "files into", + len(nodes), + "nodes", + ) # index the files print("Indexing the files into vector store") @@ -332,7 +368,11 @@ def run( # persist to the index print("Persisting the vector and the document into index") file_ids = [] + to_update = list(set(to_update)) with Session(engine) as session: + if to_update: + session.execute(delete(Source).where(Source.name.in_(to_update))) + for source in file_to_source.values(): session.add(source) session.commit() @@ -378,6 +418,7 @@ def get_user_settings(cls) -> dict: ("PDF text parser", "normal"), ("Mathpix", "mathpix"), ("Advanced ocr", "ocr"), + ("Multimodal parser", "multimodal"), ], "component": "dropdown", }, @@ -403,3 +444,6 @@ def set_resources(self, resources: dict): super().set_resources(resources) self.indexing_vector_pipeline.vector_store = self._VS self.indexing_vector_pipeline.doc_store = self._DS + + def warning(self, msg): + gr.Warning(msg) diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py index 9da2b4a14..11d491f03 100644 --- a/libs/ktem/ktem/index/file/ui.py +++ b/libs/ktem/ktem/index/file/ui.py @@ -1,29 +1,48 @@ import os import tempfile +from pathlib import Path import gradio as gr import pandas as pd +from gradio.data_classes import FileData +from gradio.utils import NamedString from ktem.app import BasePage from ktem.db.engine import engine from sqlalchemy import select from sqlalchemy.orm import Session +class File(gr.File): + """Subclass from gr.File to maintain the original filename + + The issue happens when user uploads file with name like: !@#$%%^&*().pdf + """ + + def _process_single_file(self, f: FileData) -> NamedString | bytes: + file_name = f.path + if self.type == "filepath": + if f.orig_name and Path(file_name).name != f.orig_name: + file_name = str(Path(file_name).parent / f.orig_name) + os.rename(f.path, file_name) + file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE) + file.name = file_name + return NamedString(file_name) + elif self.type == "binary": + with open(file_name, "rb") as file_data: + return file_data.read() + else: + raise ValueError( + "Unknown type: " + + str(type) + + ". Please choose from: 'filepath', 'binary'." + ) + + class DirectoryUpload(BasePage): - def __init__(self, app): - self._app = app - self._supported_file_types = [ - "image", - ".pdf", - ".txt", - ".csv", - ".xlsx", - ".doc", - ".docx", - ".pptx", - ".html", - ".zip", - ] + def __init__(self, app, index): + super().__init__(app) + self._index = index + self._supported_file_types = self._index.config.get("supported_file_types", []) self.on_building_ui() def on_building_ui(self): @@ -50,18 +69,7 @@ class FileIndexPage(BasePage): def __init__(self, app, index): super().__init__(app) self._index = index - self._supported_file_types = [ - "image", - ".pdf", - ".txt", - ".csv", - ".xlsx", - ".doc", - ".docx", - ".pptx", - ".html", - ".zip", - ] + self._supported_file_types = self._index.config.get("supported_file_types", []) self.selected_panel_false = "Selected file: (please select above)" self.selected_panel_true = "Selected file: {name}" # TODO: on_building_ui is not correctly named if it's always called in @@ -69,13 +77,32 @@ def __init__(self, app, index): self.public_events = [f"onFileIndex{index.id}Changed"] self.on_building_ui() + def upload_instruction(self) -> str: + msgs = [] + if self._supported_file_types: + msgs.append( + f"- Supported file types: {', '.join(self._supported_file_types)}" + ) + + if max_file_size := self._index.config.get("max_file_size", 0): + msgs.append(f"- Maximum file size: {max_file_size} MB") + + if max_number_of_files := self._index.config.get("max_number_of_files", 0): + msgs.append(f"- The index can have maximum {max_number_of_files} files") + + if msgs: + return "\n".join(msgs) + + return "" + def on_building_ui(self): """Build the UI of the app""" - with gr.Accordion(label="File upload", open=False): - gr.Markdown( - f"Supported file types: {', '.join(self._supported_file_types)}", - ) - self.files = gr.File( + with gr.Accordion(label="File upload", open=True) as self.upload: + msg = self.upload_instruction() + if msg: + gr.Markdown(msg) + + self.files = File( file_types=self._supported_file_types, file_count="multiple", container=False, @@ -98,18 +125,20 @@ def on_building_ui(self): interactive=False, ) - with gr.Row(): + with gr.Row() as self.selection_info: self.selected_file_id = gr.State(value=None) self.selected_panel = gr.Markdown(self.selected_panel_false) self.deselect_button = gr.Button("Deselect", visible=False) - with gr.Row(): + with gr.Row() as self.tools: with gr.Column(): self.view_button = gr.Button("View Text (WIP)") with gr.Column(): self.delete_button = gr.Button("Delete") with gr.Row(): - self.delete_yes = gr.Button("Confirm Delete", visible=False) + self.delete_yes = gr.Button( + "Confirm Delete", variant="primary", visible=False + ) self.delete_no = gr.Button("Cancel", visible=False) def on_subscribe_public_events(self): @@ -242,10 +271,12 @@ def on_register_events(self): self._app.settings_state, ], outputs=[self.file_output], + concurrency_limit=20, ).then( fn=self.list_file, inputs=None, outputs=[self.file_list_state, self.file_list], + concurrency_limit=20, ) for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"): onUploaded = onUploaded.then(**event) @@ -274,6 +305,15 @@ def index_fn(self, files, reindex: bool, settings): selected_files: the list of files already selected settings: the settings of the app """ + if not files: + gr.Info("No uploaded file") + return gr.update() + + errors = self.validate(files) + if errors: + gr.Warning(", ".join(errors)) + return gr.update() + gr.Info(f"Start indexing {len(files)} files...") # get the pipeline @@ -409,6 +449,35 @@ def interact_file_list(self, list_files, ev: gr.SelectData): name=list_files["name"][ev.index[0]] ) + def validate(self, files: list[str]): + """Validate if the files are valid""" + paths = [Path(file) for file in files] + errors = [] + if max_file_size := self._index.config.get("max_file_size", 0): + errors_max_size = [] + for path in paths: + if path.stat().st_size > max_file_size * 1e6: + errors_max_size.append(path.name) + if errors_max_size: + str_errors = ", ".join(errors_max_size) + if len(str_errors) > 60: + str_errors = str_errors[:55] + "..." + errors.append( + f"Maximum file size ({max_file_size} MB) exceeded: {str_errors}" + ) + + if max_number_of_files := self._index.config.get("max_number_of_files", 0): + with Session(engine) as session: + current_num_files = session.query( + self._index._db_tables["Source"].id + ).count() + if len(paths) + current_num_files > max_number_of_files: + errors.append( + f"Maximum number of files ({max_number_of_files}) will be exceeded" + ) + + return errors + class FileSelector(BasePage): """File selector UI in the Chat page""" @@ -430,6 +499,9 @@ def on_building_ui(self): def as_gradio_component(self): return self.selector + def get_selected_ids(self, selected): + return selected + def load_files(self, selected_files): options = [] available_ids = [] diff --git a/libs/ktem/ktem/index/manager.py b/libs/ktem/ktem/index/manager.py index 72c4f9948..af1c8d484 100644 --- a/libs/ktem/ktem/index/manager.py +++ b/libs/ktem/ktem/index/manager.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Optional, Type from ktem.db.models import engine from sqlmodel import Session, select @@ -49,15 +49,19 @@ def build_index(self, name: str, config: dict, index_type: str, id=None): Returns: BaseIndex: the index object """ + index_cls = import_dotted_string(index_type, safe=False) + index = index_cls(app=self._app, id=id, name=name, config=config) + index.on_create() + with Session(engine) as session: - index_entry = Index(id=id, name=name, config=config, index_type=index_type) + index_entry = Index( + id=index.id, name=index.name, config=index.config, index_type=index_type + ) session.add(index_entry) session.commit() session.refresh(index_entry) - index_cls = import_dotted_string(index_type, safe=False) - index = index_cls(app=self._app, id=id, name=name, config=config) - index.on_create() + index.id = index_entry.id return index @@ -77,7 +81,7 @@ def start_index(self, id: int, name: str, config: dict, index_type: str): self._indices.append(index) return index - def exists(self, id: int) -> bool: + def exists(self, id: Optional[int] = None, name: Optional[str] = None) -> bool: """Check if the index exists Args: @@ -86,9 +90,19 @@ def exists(self, id: int) -> bool: Returns: bool: True if the index exists, False otherwise """ - with Session(engine) as session: - index = session.get(Index, id) - return index is not None + if id: + with Session(engine) as session: + index = session.get(Index, id) + return index is not None + + if name: + with Session(engine) as session: + index = session.exec( + select(Index).where(Index.name == name) + ).one_or_none() + return index is not None + + return False def on_application_startup(self): """This method is called by the base application when the application starts diff --git a/libs/ktem/ktem/main.py b/libs/ktem/ktem/main.py index c375ed7ed..1d76d0498 100644 --- a/libs/ktem/ktem/main.py +++ b/libs/ktem/ktem/main.py @@ -27,7 +27,7 @@ def ui(self): if self.f_user_management: from ktem.pages.login import LoginPage - with gr.Tab("Login", elem_id="login-tab") as self._tabs["login-tab"]: + with gr.Tab("Welcome", elem_id="login-tab") as self._tabs["login-tab"]: self.login_page = LoginPage(self) with gr.Tab( @@ -62,6 +62,9 @@ def ui(self): def on_subscribe_public_events(self): if self.f_user_management: + from ktem.db.engine import engine + from ktem.db.models import User + from sqlmodel import Session, select def signed_in_out(user_id): if not user_id: @@ -73,14 +76,31 @@ def signed_in_out(user_id): ) for k in self._tabs.keys() ) - return list( - ( - gr.update(visible=True) - if k != "login-tab" - else gr.update(visible=False) - ) - for k in self._tabs.keys() - ) + + with Session(engine) as session: + user = session.exec(select(User).where(User.id == user_id)).first() + if user is None: + return list( + ( + gr.update(visible=True) + if k == "login-tab" + else gr.update(visible=False) + ) + for k in self._tabs.keys() + ) + + is_admin = user.admin + + tabs_update = [] + for k in self._tabs.keys(): + if k == "login-tab": + tabs_update.append(gr.update(visible=False)) + elif k == "admin-tab": + tabs_update.append(gr.update(visible=is_admin)) + else: + tabs_update.append(gr.update(visible=True)) + + return tabs_update self.subscribe_event( name="onSignIn", diff --git a/libs/ktem/ktem/pages/admin/user.py b/libs/ktem/ktem/pages/admin/user.py index 519fb0fd7..6411b30d7 100644 --- a/libs/ktem/ktem/pages/admin/user.py +++ b/libs/ktem/ktem/pages/admin/user.py @@ -40,7 +40,7 @@ def validate_username(usn): if len(usn) > 32: errors.append("Username must be at most 32 characters long") - if not usn.strip("_").isalnum(): + if not usn.replace("_", "").isalnum(): errors.append( "Username must contain only alphanumeric characters and underscores" ) @@ -97,8 +97,6 @@ def validate_password(pwd, pwd_cnf): class UserManagement(BasePage): def __init__(self, app): self._app = app - self.selected_panel_false = "Selected user: (please select above)" - self.selected_panel_true = "Selected user: {name}" self.on_building_ui() if hasattr(flowsettings, "KH_FEATURE_USER_MANAGEMENT_ADMIN") and hasattr( @@ -126,7 +124,38 @@ def __init__(self, app): gr.Info(f'User "{usn}" created successfully') def on_building_ui(self): - with gr.Accordion(label="Create user", open=False): + with gr.Tab(label="User list"): + self.state_user_list = gr.State(value=None) + self.user_list = gr.DataFrame( + headers=["id", "name", "admin"], + interactive=False, + ) + + with gr.Group(visible=False) as self._selected_panel: + self.selected_user_id = gr.Number(value=-1, visible=False) + self.usn_edit = gr.Textbox(label="Username") + with gr.Row(): + self.pwd_edit = gr.Textbox(label="Change password", type="password") + self.pwd_cnf_edit = gr.Textbox( + label="Confirm change password", + type="password", + ) + self.admin_edit = gr.Checkbox(label="Admin") + + with gr.Row() as self._selected_panel_btn: + with gr.Column(): + self.btn_edit_save = gr.Button("Save") + with gr.Column(): + self.btn_delete = gr.Button("Delete") + with gr.Row(): + self.btn_delete_yes = gr.Button( + "Confirm delete", variant="primary", visible=False + ) + self.btn_delete_no = gr.Button("Cancel", visible=False) + with gr.Column(): + self.btn_close = gr.Button("Close") + + with gr.Tab(label="Create user"): self.usn_new = gr.Textbox(label="Username", interactive=True) self.pwd_new = gr.Textbox( label="Password", type="password", interactive=True @@ -139,52 +168,28 @@ def on_building_ui(self): gr.Markdown(PASSWORD_RULE) self.btn_new = gr.Button("Create user") - gr.Markdown("## User list") - self.btn_list_user = gr.Button("Refresh user list") - self.state_user_list = gr.State(value=None) - self.user_list = gr.DataFrame( - headers=["id", "name", "admin"], - interactive=False, - ) - - with gr.Row(): - self.selected_user_id = gr.State(value=None) - self.selected_panel = gr.Markdown(self.selected_panel_false) - self.deselect_button = gr.Button("Deselect", visible=False) - - with gr.Group(): - self.btn_delete = gr.Button("Delete user") - with gr.Row(): - self.btn_delete_yes = gr.Button("Confirm", visible=False) - self.btn_delete_no = gr.Button("Cancel", visible=False) - - gr.Markdown("## User details") - self.usn_edit = gr.Textbox(label="Username") - self.pwd_edit = gr.Textbox(label="Password", type="password") - self.pwd_cnf_edit = gr.Textbox(label="Confirm password", type="password") - self.admin_edit = gr.Checkbox(label="Admin") - self.btn_edit_save = gr.Button("Save") - def on_register_events(self): self.btn_new.click( self.create_user, inputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new], - outputs=None, - ) - self.btn_list_user.click( - self.list_users, inputs=None, outputs=[self.state_user_list, self.user_list] + outputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new], + ).then( + self.list_users, + inputs=self._app.user_id, + outputs=[self.state_user_list, self.user_list], ) self.user_list.select( self.select_user, inputs=self.user_list, - outputs=[self.selected_user_id, self.selected_panel], + outputs=[self.selected_user_id], show_progress="hidden", ) - self.selected_panel.change( + self.selected_user_id.change( self.on_selected_user_change, inputs=[self.selected_user_id], outputs=[ - self.deselect_button, + self._selected_panel, + self._selected_panel_btn, # delete section self.btn_delete, self.btn_delete_yes, @@ -197,12 +202,6 @@ def on_register_events(self): ], show_progress="hidden", ) - self.deselect_button.click( - lambda: (None, self.selected_panel_false), - inputs=None, - outputs=[self.selected_user_id, self.selected_panel], - show_progress="hidden", - ) self.btn_delete.click( self.on_btn_delete_click, inputs=[self.selected_user_id], @@ -211,9 +210,13 @@ def on_register_events(self): ) self.btn_delete_yes.click( self.delete_user, - inputs=[self.selected_user_id], - outputs=[self.selected_user_id, self.selected_panel], + inputs=[self._app.user_id, self.selected_user_id], + outputs=[self.selected_user_id], show_progress="hidden", + ).then( + self.list_users, + inputs=self._app.user_id, + outputs=[self.state_user_list, self.user_list], ) self.btn_delete_no.click( lambda: ( @@ -234,21 +237,53 @@ def on_register_events(self): self.pwd_cnf_edit, self.admin_edit, ], - outputs=None, + outputs=[self.pwd_edit, self.pwd_cnf_edit], show_progress="hidden", + ).then( + self.list_users, + inputs=self._app.user_id, + outputs=[self.state_user_list, self.user_list], + ) + self.btn_close.click( + lambda: -1, + outputs=[self.selected_user_id], + ) + + def on_subscribe_public_events(self): + self._app.subscribe_event( + name="onSignIn", + definition={ + "fn": self.list_users, + "inputs": [self._app.user_id], + "outputs": [self.state_user_list, self.user_list], + }, + ) + self._app.subscribe_event( + name="onSignOut", + definition={ + "fn": lambda: ("", "", "", None, None, -1), + "outputs": [ + self.usn_new, + self.pwd_new, + self.pwd_cnf_new, + self.state_user_list, + self.user_list, + self.selected_user_id, + ], + }, ) def create_user(self, usn, pwd, pwd_cnf): errors = validate_username(usn) if errors: gr.Warning(errors) - return + return usn, pwd, pwd_cnf errors = validate_password(pwd, pwd_cnf) print(errors) if errors: gr.Warning(errors) - return + return usn, pwd, pwd_cnf with Session(engine) as session: statement = select(User).where(User.username_lower == usn.lower()) @@ -265,8 +300,22 @@ def create_user(self, usn, pwd, pwd_cnf): session.commit() gr.Info(f'User "{usn}" created successfully') - def list_users(self): + return "", "", "" + + def list_users(self, user_id): + if user_id is None: + return [], pd.DataFrame.from_records( + [{"id": "-", "username": "-", "admin": "-"}] + ) + with Session(engine) as session: + statement = select(User).where(User.id == user_id) + user = session.exec(statement).one() + if not user.admin: + return [], pd.DataFrame.from_records( + [{"id": "-", "username": "-", "admin": "-"}] + ) + statement = select(User) results = [ {"id": user.id, "username": user.username, "admin": user.admin} @@ -284,18 +333,17 @@ def list_users(self): def select_user(self, user_list, ev: gr.SelectData): if ev.value == "-" and ev.index[0] == 0: gr.Info("No user is loaded. Please refresh the user list") - return None, self.selected_panel_false + return -1 if not ev.selected: - return None, self.selected_panel_false + return -1 - return user_list["id"][ev.index[0]], self.selected_panel_true.format( - name=user_list["username"][ev.index[0]] - ) + return user_list["id"][ev.index[0]] def on_selected_user_change(self, selected_user_id): - if selected_user_id is None: - deselect_button = gr.update(visible=False) + if selected_user_id == -1: + _selected_panel = gr.update(visible=False) + _selected_panel_btn = gr.update(visible=False) btn_delete = gr.update(visible=True) btn_delete_yes = gr.update(visible=False) btn_delete_no = gr.update(visible=False) @@ -304,7 +352,8 @@ def on_selected_user_change(self, selected_user_id): pwd_cnf_edit = gr.update(value="") admin_edit = gr.update(value=False) else: - deselect_button = gr.update(visible=True) + _selected_panel = gr.update(visible=True) + _selected_panel_btn = gr.update(visible=True) btn_delete = gr.update(visible=True) btn_delete_yes = gr.update(visible=False) btn_delete_no = gr.update(visible=False) @@ -319,7 +368,8 @@ def on_selected_user_change(self, selected_user_id): admin_edit = gr.update(value=user.admin) return ( - deselect_button, + _selected_panel, + _selected_panel_btn, btn_delete, btn_delete_yes, btn_delete_no, @@ -344,17 +394,16 @@ def on_btn_delete_click(self, selected_user_id): return btn_delete, btn_delete_yes, btn_delete_no def save_user(self, selected_user_id, usn, pwd, pwd_cnf, admin): - if usn: - errors = validate_username(usn) - if errors: - gr.Warning(errors) - return + errors = validate_username(usn) + if errors: + gr.Warning(errors) + return pwd, pwd_cnf if pwd: errors = validate_password(pwd, pwd_cnf) if errors: gr.Warning(errors) - return + return pwd, pwd_cnf with Session(engine) as session: statement = select(User).where(User.id == int(selected_user_id)) @@ -367,11 +416,17 @@ def save_user(self, selected_user_id, usn, pwd, pwd_cnf, admin): session.commit() gr.Info(f'User "{usn}" updated successfully') - def delete_user(self, selected_user_id): + return "", "" + + def delete_user(self, current_user, selected_user_id): + if current_user == selected_user_id: + gr.Warning("You cannot delete yourself") + return selected_user_id + with Session(engine) as session: statement = select(User).where(User.id == int(selected_user_id)) user = session.exec(statement).one() session.delete(user) session.commit() gr.Info(f'User "{user.username}" deleted successfully') - return None, self.selected_panel_false + return -1 diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index 6648c2fbc..a83bd837d 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -7,8 +7,11 @@ from ktem.components import reasonings from ktem.db.models import Conversation, engine from sqlmodel import Session, select +from theflow.settings import settings as flowsettings from .chat_panel import ChatPanel +from .chat_suggestion import ChatSuggestion +from .common import STATE from .control import ConversationControl from .report import ReportIssue @@ -21,27 +24,43 @@ def __init__(self, app): def on_building_ui(self): with gr.Row(): + self.chat_state = gr.State(STATE) with gr.Column(scale=1): self.chat_control = ConversationControl(self._app) + if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False): + self.chat_suggestion = ChatSuggestion(self._app) + for index in self._app.index_manager.indices: - index.selector = -1 + index.selector = None index_ui = index.get_selector_component_ui() if not index_ui: + # the index doesn't have a selector UI component continue - index_ui.unrender() + index_ui.unrender() # need to rerender later within Accordion with gr.Accordion(label=f"{index.name} Index", open=False): index_ui.render() gr_index = index_ui.as_gradio_component() if gr_index: - index.selector = len(self._indices_input) - self._indices_input.append(gr_index) + if isinstance(gr_index, list): + index.selector = tuple( + range( + len(self._indices_input), + len(self._indices_input) + len(gr_index), + ) + ) + self._indices_input.extend(gr_index) + else: + index.selector = len(self._indices_input) + self._indices_input.append(gr_index) setattr(self, f"_index_{index.id}", index_ui) self.report_issue = ReportIssue(self._app) + with gr.Column(scale=6): self.chat_panel = ChatPanel(self._app) + with gr.Column(scale=3): with gr.Accordion(label="Information panel", open=True): self.info_panel = gr.HTML(elem_id="chat-info-panel") @@ -52,32 +71,77 @@ def on_register_events(self): self.chat_panel.text_input.submit, self.chat_panel.submit_btn.click, ], - fn=self.chat_panel.submit_msg, - inputs=[self.chat_panel.text_input, self.chat_panel.chatbot], - outputs=[self.chat_panel.text_input, self.chat_panel.chatbot], + fn=self.submit_msg, + inputs=[ + self.chat_panel.text_input, + self.chat_panel.chatbot, + self._app.user_id, + self.chat_control.conversation_id, + self.chat_control.conversation_rn, + ], + outputs=[ + self.chat_panel.text_input, + self.chat_panel.chatbot, + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + ], + concurrency_limit=20, show_progress="hidden", - ).then( + ).success( fn=self.chat_fn, inputs=[ self.chat_control.conversation_id, self.chat_panel.chatbot, self._app.settings_state, + self.chat_state, ] + self._indices_input, outputs=[ - self.chat_panel.text_input, self.chat_panel.chatbot, self.info_panel, + self.chat_state, ], + concurrency_limit=20, show_progress="minimal", ).then( fn=self.update_data_source, inputs=[ self.chat_control.conversation_id, self.chat_panel.chatbot, + self.chat_state, ] + self._indices_input, outputs=None, + concurrency_limit=20, + ) + + self.chat_panel.regen_btn.click( + fn=self.regen_fn, + inputs=[ + self.chat_control.conversation_id, + self.chat_panel.chatbot, + self._app.settings_state, + self.chat_state, + ] + + self._indices_input, + outputs=[ + self.chat_panel.chatbot, + self.info_panel, + self.chat_state, + ], + concurrency_limit=20, + show_progress="minimal", + ).then( + fn=self.update_data_source, + inputs=[ + self.chat_control.conversation_id, + self.chat_panel.chatbot, + self.chat_state, + ] + + self._indices_input, + outputs=None, + concurrency_limit=20, ) self.chat_panel.chatbot.like( @@ -86,7 +150,12 @@ def on_register_events(self): outputs=None, ) - self.chat_control.conversation.change( + self.chat_control.btn_new.click( + self.chat_control.new_conv, + inputs=self._app.user_id, + outputs=[self.chat_control.conversation_id, self.chat_control.conversation], + show_progress="hidden", + ).then( self.chat_control.select_conv, inputs=[self.chat_control.conversation], outputs=[ @@ -94,11 +163,71 @@ def on_register_events(self): self.chat_control.conversation, self.chat_control.conversation_rn, self.chat_panel.chatbot, + self.info_panel, + self.chat_state, ] + self._indices_input, show_progress="hidden", ) + self.chat_control.btn_del.click( + lambda id: self.toggle_delete(id), + inputs=[self.chat_control.conversation_id], + outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], + ) + self.chat_control.btn_del_conf.click( + self.chat_control.delete_conv, + inputs=[self.chat_control.conversation_id, self._app.user_id], + outputs=[self.chat_control.conversation_id, self.chat_control.conversation], + show_progress="hidden", + ).then( + self.chat_control.select_conv, + inputs=[self.chat_control.conversation], + outputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + self.chat_panel.chatbot, + self.info_panel, + ] + + self._indices_input, + show_progress="hidden", + ).then( + lambda: self.toggle_delete(""), + outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], + ) + self.chat_control.btn_del_cnl.click( + lambda: self.toggle_delete(""), + outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], + ) + self.chat_control.conversation_rn_btn.click( + self.chat_control.rename_conv, + inputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation_rn, + self._app.user_id, + ], + outputs=[self.chat_control.conversation, self.chat_control.conversation], + show_progress="hidden", + ) + + self.chat_control.conversation.select( + self.chat_control.select_conv, + inputs=[self.chat_control.conversation], + outputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + self.chat_panel.chatbot, + self.info_panel, + ] + + self._indices_input, + show_progress="hidden", + ).then( + lambda: self.toggle_delete(""), + outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], + ) + self.report_issue.report_btn.click( self.report_issue.report, inputs=[ @@ -109,12 +238,79 @@ def on_register_events(self): self.chat_panel.chatbot, self._app.settings_state, self._app.user_id, + self.info_panel, + self.chat_state, ] + self._indices_input, outputs=None, ) + if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False): + self.chat_suggestion.example.select( + self.chat_suggestion.select_example, + outputs=[self.chat_panel.text_input], + show_progress="hidden", + ) + + def submit_msg(self, chat_input, chat_history, user_id, conv_id, conv_name): + """Submit a message to the chatbot""" + if not chat_input: + raise ValueError("Input is empty") + + if not conv_id: + id_, update = self.chat_control.new_conv(user_id) + with Session(engine) as session: + statement = select(Conversation).where(Conversation.id == id_) + name = session.exec(statement).one().name + new_conv_id = id_ + conv_update = update + new_conv_name = name + else: + new_conv_id = conv_id + conv_update = gr.update() + new_conv_name = conv_name + + return ( + "", + chat_history + [(chat_input, None)], + new_conv_id, + conv_update, + new_conv_name, + ) - def update_data_source(self, convo_id, messages, *selecteds): + def toggle_delete(self, conv_id): + if conv_id: + return gr.update(visible=False), gr.update(visible=True) + else: + return gr.update(visible=True), gr.update(visible=False) + + def on_subscribe_public_events(self): + if self._app.f_user_management: + self._app.subscribe_event( + name="onSignIn", + definition={ + "fn": self.chat_control.reload_conv, + "inputs": [self._app.user_id], + "outputs": [self.chat_control.conversation], + "show_progress": "hidden", + }, + ) + + self._app.subscribe_event( + name="onSignOut", + definition={ + "fn": lambda: self.chat_control.select_conv(""), + "outputs": [ + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + self.chat_panel.chatbot, + ] + + self._indices_input, + "show_progress": "hidden", + }, + ) + + def update_data_source(self, convo_id, messages, state, *selecteds): """Update the data source""" if not convo_id: gr.Warning("No conversation selected") @@ -122,8 +318,12 @@ def update_data_source(self, convo_id, messages, *selecteds): selecteds_ = {} for index in self._app.index_manager.indices: - if index.selector != -1: + if index.selector is None: + continue + if isinstance(index.selector, int): selecteds_[str(index.id)] = selecteds[index.selector] + else: + selecteds_[str(index.id)] = [selecteds[i] for i in index.selector] with Session(engine) as session: statement = select(Conversation).where(Conversation.id == convo_id) @@ -133,6 +333,7 @@ def update_data_source(self, convo_id, messages, *selecteds): result.data_source = { "selected": selecteds_, "messages": messages, + "state": state, "likes": deepcopy(data_source.get("likes", [])), } session.add(result) @@ -152,33 +353,45 @@ def is_liked(self, convo_id, liked: gr.LikeData): session.add(result) session.commit() - def create_pipeline(self, settings: dict, *selecteds): + def create_pipeline(self, settings: dict, state: dict, *selecteds): """Create the pipeline from settings Args: settings: the settings of the app + is_regen: whether the regen button is clicked selected: the list of file ids that will be served as context. If None, then consider using all files Returns: - the pipeline objects + - the pipeline objects """ + reasoning_mode = settings["reasoning.use"] + reasoning_cls = reasonings[reasoning_mode] + reasoning_id = reasoning_cls.get_info()["id"] + # get retrievers retrievers = [] for index in self._app.index_manager.indices: index_selected = [] - if index.selector != -1: + if isinstance(index.selector, int): index_selected = selecteds[index.selector] + if isinstance(index.selector, tuple): + for i in index.selector: + index_selected.append(selecteds[i]) iretrievers = index.get_retriever_pipelines(settings, index_selected) retrievers += iretrievers - reasoning_mode = settings["reasoning.use"] - reasoning_cls = reasonings[reasoning_mode] - pipeline = reasoning_cls.get_pipeline(settings, retrievers) + # prepare states + reasoning_state = { + "app": deepcopy(state["app"]), + "pipeline": deepcopy(state.get(reasoning_id, {})), + } - return pipeline + pipeline = reasoning_cls.get_pipeline(settings, reasoning_state, retrievers) - async def chat_fn(self, conversation_id, chat_history, settings, *selecteds): + return pipeline, reasoning_state + + async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds): """Chat function""" chat_input = chat_history[-1][0] chat_history = chat_history[:-1] @@ -186,7 +399,7 @@ async def chat_fn(self, conversation_id, chat_history, settings, *selecteds): queue: asyncio.Queue[Optional[dict]] = asyncio.Queue() # construct the pipeline - pipeline = self.create_pipeline(settings, *selecteds) + pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds) pipeline.set_output_queue(queue) asyncio.create_task(pipeline(chat_input, conversation_id, chat_history)) @@ -198,7 +411,8 @@ async def chat_fn(self, conversation_id, chat_history, settings, *selecteds): try: response = queue.get_nowait() except Exception: - yield "", chat_history + [(chat_input, text or "Thinking ...")], refs + state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] + yield chat_history + [(chat_input, text or "Thinking ...")], refs, state continue if response is None: @@ -207,7 +421,11 @@ async def chat_fn(self, conversation_id, chat_history, settings, *selecteds): break if "output" in response: - text += response["output"] + if response["output"] is None: + text = "" + else: + text += response["output"] + if "evidence" in response: if response["evidence"] is None: refs = "" @@ -218,4 +436,25 @@ async def chat_fn(self, conversation_id, chat_history, settings, *selecteds): print(f"Len refs: {len(refs)}") len_ref = len(refs) - yield "", chat_history + [(chat_input, text)], refs + state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] + yield chat_history + [(chat_input, text)], refs, state + + async def regen_fn( + self, conversation_id, chat_history, settings, state, *selecteds + ): + """Regen function""" + if not chat_history: + gr.Warning("Empty chat") + yield chat_history, "", state + return + + state["app"]["regen"] = True + async for chat, refs, state in self.chat_fn( + conversation_id, chat_history, settings, state, *selecteds + ): + new_state = deepcopy(state) + new_state["app"]["regen"] = False + yield chat, refs, new_state + else: + state["app"]["regen"] = False + yield chat_history, "", state diff --git a/libs/ktem/ktem/pages/chat/chat_panel.py b/libs/ktem/ktem/pages/chat/chat_panel.py index f4cfc5bbe..55b9258e9 100644 --- a/libs/ktem/ktem/pages/chat/chat_panel.py +++ b/libs/ktem/ktem/pages/chat/chat_panel.py @@ -19,6 +19,7 @@ def on_building_ui(self): placeholder="Chat input", scale=15, container=False ) self.submit_btn = gr.Button(value="Send", scale=1, min_width=10) + self.regen_btn = gr.Button(value="Regen", scale=1, min_width=10) def submit_msg(self, chat_input, chat_history): """Submit a message to the chatbot""" diff --git a/libs/ktem/ktem/pages/chat/chat_suggestion.py b/libs/ktem/ktem/pages/chat/chat_suggestion.py new file mode 100644 index 000000000..23332c034 --- /dev/null +++ b/libs/ktem/ktem/pages/chat/chat_suggestion.py @@ -0,0 +1,26 @@ +import gradio as gr +from ktem.app import BasePage +from theflow.settings import settings as flowsettings + + +class ChatSuggestion(BasePage): + def __init__(self, app): + self._app = app + self.on_building_ui() + + def on_building_ui(self): + chat_samples = getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION_SAMPLES", []) + chat_samples = [[each] for each in chat_samples] + with gr.Accordion(label="Chat Suggestion", open=False) as self.accordion: + self.example = gr.DataFrame( + value=chat_samples, + headers=["Sample"], + interactive=False, + wrap=True, + ) + + def as_gradio_component(self): + return self.example + + def select_example(self, ev: gr.SelectData): + return ev.value diff --git a/libs/ktem/ktem/pages/chat/common.py b/libs/ktem/ktem/pages/chat/common.py new file mode 100644 index 000000000..a2fc0dcee --- /dev/null +++ b/libs/ktem/ktem/pages/chat/common.py @@ -0,0 +1,4 @@ +DEFAULT_APPLICATION_STATE = {"regen": False} +STATE = { + "app": DEFAULT_APPLICATION_STATE, +} diff --git a/libs/ktem/ktem/pages/chat/control.py b/libs/ktem/ktem/pages/chat/control.py index a0b256125..f2ed99bb1 100644 --- a/libs/ktem/ktem/pages/chat/control.py +++ b/libs/ktem/ktem/pages/chat/control.py @@ -5,9 +5,22 @@ from ktem.db.models import Conversation, engine from sqlmodel import Session, select +from .common import STATE + logger = logging.getLogger(__name__) +def is_conv_name_valid(name): + """Check if the conversation name is valid""" + errors = [] + if len(name) == 0: + errors.append("Name cannot be empty") + elif len(name) > 40: + errors.append("Name cannot be longer than 40 characters") + + return "; ".join(errors) + + class ConversationControl(BasePage): """Manage conversation""" @@ -26,9 +39,17 @@ def on_building_ui(self): interactive=True, ) - with gr.Row(): - self.conversation_new_btn = gr.Button(value="New", min_width=10) - self.conversation_del_btn = gr.Button(value="Delete", min_width=10) + with gr.Row() as self._new_delete: + self.btn_new = gr.Button(value="New", min_width=10) + self.btn_del = gr.Button(value="Delete", min_width=10) + + with gr.Row(visible=False) as self._delete_confirm: + self.btn_del_conf = gr.Button( + value="Delete", + variant="primary", + min_width=10, + ) + self.btn_del_cnl = gr.Button(value="Cancel", min_width=10) with gr.Row(): self.conversation_rn = gr.Text( @@ -50,48 +71,6 @@ def on_building_ui(self): # outputs=[current_state], # ) - def on_subscribe_public_events(self): - if self._app.f_user_management: - self._app.subscribe_event( - name="onSignIn", - definition={ - "fn": self.reload_conv, - "inputs": [self._app.user_id], - "outputs": [self.conversation], - "show_progress": "hidden", - }, - ) - - self._app.subscribe_event( - name="onSignOut", - definition={ - "fn": self.reload_conv, - "inputs": [self._app.user_id], - "outputs": [self.conversation], - "show_progress": "hidden", - }, - ) - - def on_register_events(self): - self.conversation_new_btn.click( - self.new_conv, - inputs=self._app.user_id, - outputs=[self.conversation_id, self.conversation], - show_progress="hidden", - ) - self.conversation_del_btn.click( - self.delete_conv, - inputs=[self.conversation_id, self._app.user_id], - outputs=[self.conversation_id, self.conversation], - show_progress="hidden", - ) - self.conversation_rn_btn.click( - self.rename_conv, - inputs=[self.conversation_id, self.conversation_rn, self._app.user_id], - outputs=[self.conversation, self.conversation], - show_progress="hidden", - ) - def load_chat_history(self, user_id): """Reload chat history""" options = [] @@ -110,7 +89,7 @@ def load_chat_history(self, user_id): def reload_conv(self, user_id): conv_list = self.load_chat_history(user_id) if conv_list: - return gr.update(value=conv_list[0][1], choices=conv_list) + return gr.update(value=None, choices=conv_list) else: return gr.update(value=None, choices=[]) @@ -131,10 +110,15 @@ def new_conv(self, user_id): return id_, gr.update(value=id_, choices=history) def delete_conv(self, conversation_id, user_id): - """Create new chat""" + """Delete the selected conversation""" + if not conversation_id: + gr.Warning("No conversation selected.") + return None, gr.update() + if user_id is None: gr.Warning("Please sign in first (Settings → User Settings)") return None, gr.update() + with Session(engine) as session: statement = select(Conversation).where(Conversation.id == conversation_id) result = session.exec(statement).one() @@ -159,27 +143,44 @@ def select_conv(self, conversation_id): name = result.name selected = result.data_source.get("selected", {}) chats = result.data_source.get("messages", []) + info_panel = "" + state = result.data_source.get("state", STATE) except Exception as e: logger.warning(e) id_ = "" name = "" selected = {} chats = [] + info_panel = "" + state = STATE indices = [] for index in self._app.index_manager.indices: # assume that the index has selector - if index.selector == -1: + if index.selector is None: continue - indices.append(selected.get(str(index.id), [])) + if isinstance(index.selector, int): + indices.append(selected.get(str(index.id), [])) + if isinstance(index.selector, tuple): + indices.extend(selected.get(str(index.id), [[]] * len(index.selector))) - return id_, id_, name, chats, *indices + return id_, id_, name, chats, info_panel, state, *indices def rename_conv(self, conversation_id, new_name, user_id): """Rename the conversation""" if user_id is None: gr.Warning("Please sign in first (Settings → User Settings)") return gr.update(), "" + + if not conversation_id: + gr.Warning("No conversation selected.") + return gr.update(), "" + + errors = is_conv_name_valid(new_name) + if errors: + gr.Warning(errors) + return gr.update(), conversation_id + with Session(engine) as session: statement = select(Conversation).where(Conversation.id == conversation_id) result = session.exec(statement).one() diff --git a/libs/ktem/ktem/pages/chat/report.py b/libs/ktem/ktem/pages/chat/report.py index 46d9e3c84..dfe030146 100644 --- a/libs/ktem/ktem/pages/chat/report.py +++ b/libs/ktem/ktem/pages/chat/report.py @@ -48,12 +48,19 @@ def report( chat_history: list, settings: dict, user_id: Optional[int], - *selecteds + info_panel: str, + chat_state: dict, + *selecteds, ): selecteds_ = {} for index in self._app.index_manager.indices: - if index.selector != -1: - selecteds_[str(index.id)] = selecteds[index.selector] + if index.selector is not None: + if isinstance(index.selector, int): + selecteds_[str(index.id)] = selecteds[index.selector] + elif isinstance(index.selector, tuple): + selecteds_[str(index.id)] = [selecteds[_] for _ in index.selector] + else: + print(f"Unknown selector type: {index.selector}") with Session(engine) as session: issue = IssueReport( @@ -65,6 +72,8 @@ def report( chat={ "conv_id": conv_id, "chat_history": chat_history, + "info_panel": info_panel, + "chat_state": chat_state, "selecteds": selecteds_, }, settings=settings, diff --git a/libs/ktem/ktem/pages/login.py b/libs/ktem/ktem/pages/login.py index 6fe15d0c3..d5c57e5a4 100644 --- a/libs/ktem/ktem/pages/login.py +++ b/libs/ktem/ktem/pages/login.py @@ -31,11 +31,10 @@ def __init__(self, app): self.on_building_ui() def on_building_ui(self): - gr.Markdown("Welcome to Kotaemon") - self.usn = gr.Textbox(label="Username") - self.pwd = gr.Textbox(label="Password", type="password") - self.btn_login = gr.Button("Login") - self._dummy = gr.State() + gr.Markdown("# Welcome to Kotaemon") + self.usn = gr.Textbox(label="Username", visible=False) + self.pwd = gr.Textbox(label="Password", type="password", visible=False) + self.btn_login = gr.Button("Login", visible=False) def on_register_events(self): onSignIn = gr.on( @@ -45,24 +44,56 @@ def on_register_events(self): outputs=[self._app.user_id, self.usn, self.pwd], show_progress="hidden", js=signin_js, + ).then( + self.toggle_login_visibility, + inputs=[self._app.user_id], + outputs=[self.usn, self.pwd, self.btn_login], ) for event in self._app.get_event("onSignIn"): onSignIn = onSignIn.success(**event) + def toggle_login_visibility(self, user_id): + return ( + gr.update(visible=user_id is None), + gr.update(visible=user_id is None), + gr.update(visible=user_id is None), + ) + def _on_app_created(self): - self._app.app.load( - None, - inputs=None, - outputs=[self.usn, self.pwd], + onSignIn = self._app.app.load( + self.login, + inputs=[self.usn, self.pwd], + outputs=[self._app.user_id, self.usn, self.pwd], + show_progress="hidden", js=fetch_creds, + ).then( + self.toggle_login_visibility, + inputs=[self._app.user_id], + outputs=[self.usn, self.pwd, self.btn_login], + ) + for event in self._app.get_event("onSignIn"): + onSignIn = onSignIn.success(**event) + + def on_subscribe_public_events(self): + self._app.subscribe_event( + name="onSignOut", + definition={ + "fn": self.toggle_login_visibility, + "inputs": [self._app.user_id], + "outputs": [self.usn, self.pwd, self.btn_login], + "show_progress": "hidden", + }, ) def login(self, usn, pwd): + if not usn or not pwd: + return None, usn, pwd hashed_password = hashlib.sha256(pwd.encode()).hexdigest() with Session(engine) as session: stmt = select(User).where( - User.username_lower == usn.lower(), User.password == hashed_password + User.username_lower == usn.lower().strip(), + User.password == hashed_password, ) result = session.exec(stmt).all() if result: diff --git a/libs/ktem/ktem/pages/settings.py b/libs/ktem/ktem/pages/settings.py index 0fce2e8f7..20912cb44 100644 --- a/libs/ktem/ktem/pages/settings.py +++ b/libs/ktem/ktem/pages/settings.py @@ -164,9 +164,14 @@ def on_register_events(self): show_progress="hidden", ) onSignOutClick = self.signout.click( - lambda: (None, "Current user: ___"), + lambda: (None, "Current user: ___", "", ""), inputs=None, - outputs=[self._user_id, self.current_name], + outputs=[ + self._user_id, + self.current_name, + self.password_change, + self.password_change_confirm, + ], show_progress="hidden", js=signout_js, ).then( @@ -192,8 +197,12 @@ def user_tab(self): self.password_change_btn = gr.Button("Change password", interactive=True) def change_password(self, user_id, password, password_confirm): - if password != password_confirm: - gr.Warning("Password does not match") + from ktem.pages.admin.user import validate_password + + errors = validate_password(password, password_confirm) + if errors: + print(errors) + gr.Warning(errors) return password, password_confirm with Session(engine) as session: diff --git a/libs/ktem/ktem/reasoning/base.py b/libs/ktem/ktem/reasoning/base.py index 80cf01698..6d6e48648 100644 --- a/libs/ktem/ktem/reasoning/base.py +++ b/libs/ktem/ktem/reasoning/base.py @@ -34,12 +34,16 @@ def get_user_settings(cls) -> dict: @classmethod def get_pipeline( - cls, user_settings: dict, retrievers: Optional[list["BaseComponent"]] = None + cls, + user_settings: dict, + state: dict, + retrievers: Optional[list["BaseComponent"]] = None, ) -> "BaseReasoning": """Get the reasoning pipeline for the app to execute Args: user_setting: user settings + state: conversation state retrievers (list): List of retrievers """ return cls() diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 1970c3996..a5a895eeb 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -1,11 +1,12 @@ import asyncio +import html import logging +import re from collections import defaultdict from functools import partial import tiktoken from ktem.llms.manager import llms -from ktem.reasoning.base import BaseReasoning from kotaemon.base import ( BaseComponent, @@ -18,9 +19,17 @@ from kotaemon.indices.qa.citation import CitationPipeline from kotaemon.indices.splitters import TokenSplitter from kotaemon.llms import ChatLLM, PromptTemplate +from kotaemon.loaders.utils.gpt4v import stream_gpt4v + +from .base import BaseReasoning logger = logging.getLogger(__name__) +EVIDENCE_MODE_TEXT = 0 +EVIDENCE_MODE_TABLE = 1 +EVIDENCE_MODE_CHATBOT = 2 +EVIDENCE_MODE_FIGURE = 3 + class PrepareEvidencePipeline(BaseComponent): """Prepare the evidence text from the list of retrieved documents @@ -46,7 +55,7 @@ class PrepareEvidencePipeline(BaseComponent): def run(self, docs: list[RetrievedDocument]) -> Document: evidence = "" table_found = 0 - evidence_mode = 0 + evidence_mode = EVIDENCE_MODE_TEXT for _id, retrieved_item in enumerate(docs): retrieved_content = "" @@ -55,7 +64,7 @@ def run(self, docs: list[RetrievedDocument]) -> Document: if page: source += f" (Page {page})" if retrieved_item.metadata.get("type", "") == "table": - evidence_mode = 1 # table + evidence_mode = EVIDENCE_MODE_TABLE if table_found < 5: retrieved_content = retrieved_item.metadata.get("table_origin", "") if retrieved_content not in evidence: @@ -66,13 +75,23 @@ def run(self, docs: list[RetrievedDocument]) -> Document: + "\n
" ) elif retrieved_item.metadata.get("type", "") == "chatbot": - evidence_mode = 2 # chatbot + evidence_mode = EVIDENCE_MODE_CHATBOT retrieved_content = retrieved_item.metadata["window"] evidence += ( f"
Chatbot scenario from {filename} (Row {page})\n" + retrieved_content + "\n
" ) + elif retrieved_item.metadata.get("type", "") == "image": + evidence_mode = EVIDENCE_MODE_FIGURE + retrieved_content = retrieved_item.metadata.get("image_origin", "") + retrieved_caption = html.escape(retrieved_item.get_content()) + evidence += ( + f"
Figure from {source}\n" + + f"" + + "\n
" + ) else: if "window" in retrieved_item.metadata: retrieved_content = retrieved_item.metadata["window"] @@ -90,12 +109,13 @@ def run(self, docs: list[RetrievedDocument]) -> Document: print(retrieved_item.metadata) print("Score", retrieved_item.metadata.get("relevance_score", None)) - # trim context by trim_len - print("len (original)", len(evidence)) - if evidence: - texts = self.trim_func([Document(text=evidence)]) - evidence = texts[0].text - print("len (trimmed)", len(evidence)) + if evidence_mode != EVIDENCE_MODE_FIGURE: + # trim context by trim_len + print("len (original)", len(evidence)) + if evidence: + texts = self.trim_func([Document(text=evidence)]) + evidence = texts[0].text + print("len (trimmed)", len(evidence)) print(f"PrepareEvidence with input {docs}\nOutput: {evidence}\n") @@ -134,6 +154,25 @@ def run(self, docs: list[RetrievedDocument]) -> Document: "Answer:" ) +DEFAULT_QA_FIGURE_PROMPT = ( + "Use the given context: texts, tables, and figures below to answer the question. " + "If you don't know the answer, just say that you don't know. " + "Give answer in {lang}.\n\n" + "Context: \n" + "{context}\n" + "Question: {question}\n" + "Answer: " +) + +DEFAULT_REWRITE_PROMPT = ( + "Given the following question, rephrase and expand it " + "to help you do better answering. Maintain all information " + "in the original question. Keep the question as concise as possible. " + "Give answer in {lang}\n" + "Original question: {question}\n" + "Rephrased question: " +) + class AnswerWithContextPipeline(BaseComponent): """Answer the question based on the evidence @@ -151,6 +190,7 @@ class AnswerWithContextPipeline(BaseComponent): """ llm: ChatLLM = Node(default_callback=lambda _: llms.get_default()) + vlm_endpoint: str = "" citation_pipeline: CitationPipeline = Node( default_callback=lambda _: CitationPipeline(llm=llms.get_default()) ) @@ -158,13 +198,14 @@ class AnswerWithContextPipeline(BaseComponent): qa_template: str = DEFAULT_QA_TEXT_PROMPT qa_table_template: str = DEFAULT_QA_TABLE_PROMPT qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT + qa_figure_template: str = DEFAULT_QA_FIGURE_PROMPT enable_citation: bool = False system_prompt: str = "" lang: str = "English" # support English and Japanese async def run( # type: ignore - self, question: str, evidence: str, evidence_mode: int = 0 + self, question: str, evidence: str, evidence_mode: int = 0, **kwargs ) -> Document: """Answer the question based on the evidence @@ -188,18 +229,30 @@ async def run( # type: ignore (determined by retrieval pipeline) evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot """ - if evidence_mode == 0: + if evidence_mode == EVIDENCE_MODE_TEXT: prompt_template = PromptTemplate(self.qa_template) - elif evidence_mode == 1: + elif evidence_mode == EVIDENCE_MODE_TABLE: prompt_template = PromptTemplate(self.qa_table_template) + elif evidence_mode == EVIDENCE_MODE_FIGURE: + prompt_template = PromptTemplate(self.qa_figure_template) else: prompt_template = PromptTemplate(self.qa_chatbot_template) - prompt = prompt_template.populate( - context=evidence, - question=question, - lang=self.lang, - ) + images = [] + if evidence_mode == EVIDENCE_MODE_FIGURE: + # isolate image from evidence + evidence, images = self.extract_evidence_images(evidence) + prompt = prompt_template.populate( + context=evidence, + question=question, + lang=self.lang, + ) + else: + prompt = prompt_template.populate( + context=evidence, + question=question, + lang=self.lang, + ) citation_task = None if evidence and self.enable_citation: @@ -208,23 +261,29 @@ async def run( # type: ignore ) print("Citation task created") - messages = [] - if self.system_prompt: - messages.append(SystemMessage(content=self.system_prompt)) - messages.append(HumanMessage(content=prompt)) - output = "" - try: - # try streaming first - print("Trying LLM streaming") - for text in self.llm.stream(messages): - output += text.text - self.report_output({"output": text.text}) + if evidence_mode == EVIDENCE_MODE_FIGURE: + for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768): + output += text + self.report_output({"output": text}) await asyncio.sleep(0) - except NotImplementedError: - print("Streaming is not supported, falling back to normal processing") - output = self.llm(messages).text - self.report_output({"output": output}) + else: + messages = [] + if self.system_prompt: + messages.append(SystemMessage(content=self.system_prompt)) + messages.append(HumanMessage(content=prompt)) + + try: + # try streaming first + print("Trying LLM streaming") + for text in self.llm.stream(messages): + output += text.text + self.report_output({"output": text.text}) + await asyncio.sleep(0) + except NotImplementedError: + print("Streaming is not supported, falling back to normal processing") + output = self.llm(messages).text + self.report_output({"output": output}) # retrieve the citation print("Waiting for citation task") @@ -238,6 +297,46 @@ async def run( # type: ignore return answer +def extract_evidence_images(self, evidence: str): + """Util function to extract and isolate images from context/evidence""" + image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'" + matches = re.findall(image_pattern, evidence) + context = re.sub(image_pattern, "", evidence) + return context, matches + + +class RewriteQuestionPipeline(BaseComponent): + """Rewrite user question + + Args: + llm: the language model to rewrite question + rewrite_template: the prompt template for llm to paraphrase a text input + lang: the language of the answer. Currently support English and Japanese + """ + + llm: ChatLLM = Node(default_callback=lambda _: llms.get_default()) + rewrite_template: str = DEFAULT_REWRITE_PROMPT + + lang: str = "English" + + async def run(self, question: str) -> Document: # type: ignore + prompt_template = PromptTemplate(self.rewrite_template) + prompt = prompt_template.populate(question=question, lang=self.lang) + messages = [ + SystemMessage(content="You are a helpful assistant"), + HumanMessage(content=prompt), + ] + output = "" + for text in self.llm(messages): + if "content" in text: + output += text[1] + self.report_output({"chat_input": text[1]}) + break + await asyncio.sleep(0) + + return Document(text=output) + + class FullQAPipeline(BaseReasoning): """Question answering pipeline. Handle from question to answer""" @@ -248,24 +347,36 @@ class Config: evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx() answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx() + rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx() + use_rewrite: bool = False async def run( # type: ignore self, message: str, conv_id: str, history: list, **kwargs # type: ignore ) -> Document: # type: ignore + import markdown + docs = [] doc_ids = [] + if self.use_rewrite: + rewrite = await self.rewrite_pipeline(question=message) + message = rewrite.text + for retriever in self.retrievers: for doc in retriever(text=message): if doc.doc_id not in doc_ids: docs.append(doc) doc_ids.append(doc.doc_id) for doc in docs: + # TODO: a better approach to show the information + text = markdown.markdown( + doc.text, extensions=["markdown.extensions.tables"] + ) self.report_output( { "evidence": ( "
" f"{doc.metadata['file_name']}" - f"{doc.text}" + f"{text}" "

" ) } @@ -274,7 +385,12 @@ async def run( # type: ignore evidence_mode, evidence = self.evidence_pipeline(docs).content answer = await self.answering_pipeline( - question=message, evidence=evidence, evidence_mode=evidence_mode + question=message, + history=history, + evidence=evidence, + evidence_mode=evidence_mode, + conv_id=conv_id, + **kwargs, ) # prepare citation @@ -284,14 +400,29 @@ async def run( # type: ignore for quote in fact_with_evidence.substring_quote: for doc in docs: start_idx = doc.text.find(quote) - if start_idx >= 0: + if start_idx == -1: + continue + + end_idx = start_idx + len(quote) + + current_idx = start_idx + if "|" not in doc.text[start_idx:end_idx]: spans[doc.doc_id].append( - { - "start": start_idx, - "end": start_idx + len(quote), - } + {"start": start_idx, "end": end_idx} ) - break + else: + while doc.text[current_idx:end_idx].find("|") != -1: + match_idx = doc.text[current_idx:end_idx].find("|") + spans[doc.doc_id].append( + { + "start": current_idx, + "end": current_idx + match_idx, + } + ) + current_idx += match_idx + 2 + if current_idx > end_idx: + break + break id2docs = {doc.doc_id: doc for doc in docs} lack_evidence = True @@ -310,12 +441,15 @@ async def run( # type: ignore if idx < len(ss) - 1: text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] text += id2docs[id].text[ss[-1]["end"] :] + text_out = markdown.markdown( + text, extensions=["markdown.extensions.tables"] + ) self.report_output( { "evidence": ( "
" f"{id2docs[id].metadata['file_name']}" - f"{text}" + f"{text_out}" "

" ) } @@ -330,12 +464,15 @@ async def run( # type: ignore {"evidence": "Retrieved segments without matching evidence:\n"} ) for id in list(not_detected): + text_out = markdown.markdown( + id2docs[id].text, extensions=["markdown.extensions.tables"] + ) self.report_output( { "evidence": ( "
" f"{id2docs[id].metadata['file_name']}" - f"{id2docs[id].text}" + f"{text_out}" "

" ) } @@ -345,7 +482,7 @@ async def run( # type: ignore return answer @classmethod - def get_pipeline(cls, settings, retrievers): + def get_pipeline(cls, settings, states, retrievers): """Get the reasoning pipeline Args: @@ -370,6 +507,11 @@ def get_pipeline(cls, settings, retrievers): pipeline.answering_pipeline.qa_template = settings[ f"reasoning.options.{_id}.qa_prompt" ] + pipeline.use_rewrite = states.get("app", {}).get("regen", False) + pipeline.rewrite_pipeline.llm = llms.get_default() + pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get( + settings["reasoning.lang"], "English" + ) return pipeline @classmethod