From 9cd976a250be43b01401f1d9ac656118aa38f2e5 Mon Sep 17 00:00:00 2001 From: mskarlin <12701035+mskarlin@users.noreply.github.com> Date: Thu, 29 Aug 2024 11:24:31 -0700 Subject: [PATCH] Agentic workflows, locally indexed search, and CLI (#309) * add agentic workflow module and cli * update pyproject.toml and split google into its own import * remove unused import * update default search chain * Update __init__.py * Update search.py * move reqs into pyproject.toml, addressing PR comments * rename search_chain to openai_get_search_query * remove get_llm_name * move table_formatter to helpers.py * Update paperqa/agents/main.py Co-authored-by: James Braza * Update paperqa/agents/main.py Co-authored-by: James Braza * Update paperqa/agents/main.py Co-authored-by: James Braza * Update paperqa/agents/docs.py Co-authored-by: James Braza * Update paperqa/types.py Co-authored-by: James Braza * rename compute_cost to compute_total_model_token_cost * remove stream_answer * rename to stub_manifest, and use Path for all paths * Update paperqa/llms.py Co-authored-by: James Braza * move SKIP_AGENT_TESTS = False * nix _ = assignments * add test comments * types in conftest.py * split libs into llms * link openai chat timeout to query.timeout * Update paperqa/agents/__init__.py Co-authored-by: James Braza * logging revamp and renaming * Update tests/test_cli.py Co-authored-by: James Braza * Update tests/test_cli.py Co-authored-by: James Braza * move vertex import to func call, add docstring to SupportsPickle * docstring * remove _ = * remove bool return type from set * update gitignore * add config attribute to baase LLMModel class * replace get_current_settings -> get_settings * replace get_current_settings -> get_settings * PR simplifications * remove all stream_* functions * avoid modifying the root logger * re-organize logger import location * move hashlib into utils * refactor strip_answer into Answer object * label circular imports * ensure absolute paths are used in index name * limit select to be used only when DOI is not present in crossref * Update paperqa/agents/search.py Co-authored-by: James Braza * Update paperqa/agents/search.py Co-authored-by: James Braza * Update paperqa/agents/search.py Co-authored-by: James Braza * Update paperqa/agents/search.py Co-authored-by: James Braza * Update paperqa/agents/models.py Co-authored-by: James Braza * reconfigure logging to not prevent propagation * remove newlines in the current year * use required fields as a subset * replace . with Path.cwd() --------- Co-authored-by: James Braza --- .github/workflows/build.yml | 2 +- .github/workflows/tests.yml | 2 +- .gitignore | 176 ++++++++- .pre-commit-config.yaml | 3 +- dev-requirements.txt | 22 -- paperqa/agents/__init__.py | 547 +++++++++++++++++++++++++++ paperqa/agents/helpers.py | 238 ++++++++++++ paperqa/agents/main.py | 336 +++++++++++++++++ paperqa/agents/models.py | 515 +++++++++++++++++++++++++ paperqa/agents/prompts.py | 131 +++++++ paperqa/agents/search.py | 487 ++++++++++++++++++++++++ paperqa/agents/tools.py | 395 +++++++++++++++++++ paperqa/clients/__init__.py | 22 +- paperqa/clients/client_models.py | 18 +- paperqa/clients/crossref.py | 8 +- paperqa/clients/semantic_scholar.py | 3 + paperqa/docs.py | 20 +- paperqa/llms.py | 8 +- paperqa/prompts.py | 3 +- paperqa/types.py | 66 +++- paperqa/utils.py | 33 +- pyproject.toml | 54 ++- tests/conftest.py | 44 +++ tests/stub_manifest.csv | 4 + tests/test_agents.py | 562 ++++++++++++++++++++++++++++ tests/test_cli.py | 151 ++++++++ 26 files changed, 3777 insertions(+), 73 deletions(-) delete mode 100644 dev-requirements.txt create mode 100644 paperqa/agents/__init__.py create mode 100644 paperqa/agents/helpers.py create mode 100644 paperqa/agents/main.py create mode 100644 paperqa/agents/models.py create mode 100644 paperqa/agents/prompts.py create mode 100644 paperqa/agents/search.py create mode 100644 paperqa/agents/tools.py create mode 100644 tests/stub_manifest.csv create mode 100644 tests/test_agents.py create mode 100644 tests/test_cli.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index abcc08bf..fc06a6c8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,7 +16,7 @@ jobs: with: python-version: 3.11 cache: pip - - run: pip install .[dev] + - run: pip install .[agents,google,dev,llms] - name: Build a binary wheel and a source tarball run: | python -m build --sdist --wheel --outdir dist/ . diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 89ff62f6..88553a50 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,7 +20,7 @@ jobs: with: python-version: ${{ matrix.python-version }} cache: pip - - run: pip install .[dev] + - run: pip install .[agents,google,dev,llms] - name: Check pre-commit run: pre-commit run --all-files || ( git status --short ; git diff ; exit 1 ) - name: Run Test diff --git a/.gitignore b/.gitignore index 7edec0e8..963a1992 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,138 @@ +# Swap +[._]*.s[a-v][a-z] +!*.svg # comment out if you don't need vector files +[._]*.sw[a-p] +[._]s[a-rt-v][a-z] +[._]ss[a-gi-z] +[._]sw[a-p] + +# Session +Session.vim +Sessionx.vim + +# Temporary +.netrwhist +*~ +# Auto-generated tag files +tags +# Persistent undo +[._]*.un~ +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon[\r] + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -20,7 +155,6 @@ parts/ sdist/ var/ wheels/ -pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg @@ -50,6 +184,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +cover/ # Translations *.mo @@ -72,6 +207,7 @@ instance/ docs/_build/ # PyBuilder +.pybuilder/ target/ # Jupyter Notebook @@ -82,7 +218,9 @@ profile_default/ ipython_config.py # pyenv -.python-version +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. @@ -91,7 +229,24 @@ ipython_config.py # install all needed dependencies. #Pipfile.lock -# PEP 582; used by e.g. github.com/David-OConnor/pyflow +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff @@ -128,8 +283,18 @@ dmypy.json # Pyre type checker .pyre/ -# testing files generated -*.txt.json +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ *.ipynb env @@ -137,3 +302,4 @@ env # Matching pyproject.toml paperqa/version.py tests/example* +tests/test_index/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bc5f44bb..a4ebe25f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 + rev: v0.6.2 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -41,6 +41,7 @@ repos: - pydantic~=2.0 # Match pyproject.toml - types-requests - types-setuptools + - types-PyYAML - repo: https://github.com/rbubley/mirrors-prettier rev: v3.3.2 hooks: diff --git a/dev-requirements.txt b/dev-requirements.txt deleted file mode 100644 index a2f226af..00000000 --- a/dev-requirements.txt +++ /dev/null @@ -1,22 +0,0 @@ -anthropic -faiss-cpu -langchain-community -langchain-openai -pymupdf -python-dotenv -pyzotero -requests -sentence_transformers -voyageai - -# Code QA dependencies -build -mypy -pre-commit -pytest -pytest-asyncio -pytest-sugar -pytest-vcr -pytest-timer -types-requests -types-setuptools diff --git a/paperqa/agents/__init__.py b/paperqa/agents/__init__.py new file mode 100644 index 00000000..9a73cd02 --- /dev/null +++ b/paperqa/agents/__init__.py @@ -0,0 +1,547 @@ +from __future__ import annotations + +import ast +import logging +import operator +import os +import shutil +from datetime import datetime +from pathlib import Path +from typing import Any + +import yaml +from typing_extensions import Annotated + +from .. import __version__ +from ..utils import get_loop, pqa_directory + +try: + import anyio + import typer + from rich.console import Console + from rich.logging import RichHandler + + from .main import agent_query, search + from .models import AnswerResponse, MismatchedModelsError, QueryRequest + from .search import SearchIndex, get_directory_index + +except ImportError as e: + raise ImportError( + '"agents" module is not installed please install it using "pip install paper-qa[agents]"' + ) from e + +logger = logging.getLogger(__name__) + +app = typer.Typer() + + +def configure_agent_logging( + verbosity: int = 0, default_level: int = logging.INFO +) -> None: + """Default to INFO level, but suppress loquacious loggers.""" + verbosity_map = { + 0: { + "paperqa.agents.helpers": logging.WARNING, + "paperqa.agents.main": logging.WARNING, + "anthropic": logging.WARNING, + "openai": logging.WARNING, + "httpx": logging.WARNING, + "paperqa.agents.models": logging.WARNING, + } + } + + verbosity_map[1] = verbosity_map[0] | { + "paperqa.agents.main": logging.INFO, + "paperqa.models": logging.INFO, + } + + verbosity_map[2] = verbosity_map[1] | { + "paperqa.agents.helpers": logging.DEBUG, + "paperqa.agents.main": logging.DEBUG, + "paperqa.agents.main.agent_callers": logging.DEBUG, + "paperqa.models": logging.DEBUG, + "paperqa.agents.search": logging.DEBUG, + } + + rich_handler = RichHandler( + rich_tracebacks=True, + markup=True, + show_path=False, + show_level=False, + console=Console(force_terminal=True), + ) + + rich_handler.setFormatter(logging.Formatter("%(message)s", datefmt="[%X]")) + + def is_paperqa_related(logger_name: str) -> bool: + return logger_name.startswith("paperqa") or logger_name in { + "anthropic", + "openai", + "httpx", + } + + for logger_name, logger in logging.Logger.manager.loggerDict.items(): + if isinstance(logger, logging.Logger) and is_paperqa_related(logger_name): + logger.setLevel( + verbosity_map.get(min(verbosity, 2), {}).get(logger_name, default_level) + ) + if not any(isinstance(h, RichHandler) for h in logger.handlers): + logger.addHandler(rich_handler) + + +def get_file_timestamps(path: os.PathLike | str) -> dict[str, str]: + # Get the stats for the file/directory + stats = os.stat(path) + + # Get created time (ctime) + created_time = datetime.fromtimestamp(stats.st_ctime) + + # Get modified time (mtime) + modified_time = datetime.fromtimestamp(stats.st_mtime) + + return { + "created_at": created_time.strftime("%Y-%m-%d %H:%M:%S"), + "modified_at": modified_time.strftime("%Y-%m-%d %H:%M:%S"), + } + + +def parse_dot_to_dict(str_w_dots: str, value: str) -> dict[str, Any]: + parsed: dict[str, Any] = {} + for key in str_w_dots.split(".")[::-1]: + if not parsed: + try: + eval_value = ast.literal_eval(value) + if isinstance(eval_value, (set, list)): + parsed[key] = eval_value + else: + parsed[key] = value + except (ValueError, SyntaxError): + parsed[key] = value + else: + parsed = {key: parsed} + return parsed + + +def pop_nested_dict_recursive(d: dict[str, Any], path: str) -> tuple[Any, bool]: + """ + Pop a value from a nested dictionary (in-place) using a period-separated path. + + Recursively remove empty dictionaries after popping. + """ + keys = path.split(".") + + if len(keys) == 1: + if keys[0] not in d: + raise KeyError(f"Key not found: {keys[0]}") + value = d.pop(keys[0]) + return value, len(d) == 0 + + if keys[0] not in d or not isinstance(d[keys[0]], dict): + raise KeyError(f"Invalid path: {path}") + + value, should_remove = pop_nested_dict_recursive(d[keys[0]], ".".join(keys[1:])) + + if should_remove: + d.pop(keys[0]) + + return value, len(d) == 0 + + +def get_settings( + settings_path: str | os.PathLike | None = None, +) -> dict[str, Any]: + + if settings_path is None: + settings_path = pqa_directory("settings") / "settings.yaml" + + if os.path.exists(settings_path): + with open(settings_path) as f: + return yaml.safe_load(f) + + return {} + + +def merge_dicts(dict_a: dict, dict_b: dict) -> dict: + """ + Merge two dictionaries where if dict_a has a key with a subdictionary. + + dict_b only overwrites the keys in dict_a's subdictionary if they are + also specified in dict_b, but otherwise keeps all the subkeys. + """ + result = dict_a.copy() # Start with a shallow copy of dict_a + + for key, value in dict_b.items(): + if isinstance(value, dict) and key in result and isinstance(result[key], dict): + # If both dict_a and dict_b have a dict for this key, recurse + result[key] = merge_dicts(result[key], value) + else: + # Otherwise, just update the value + result[key] = value + + return result + + +def get_merged_settings( + settings: dict[str, Any], settings_path: Path | None = None +) -> dict[str, Any]: + """Merges a new settings with the current settings saved to file.""" + current_settings = get_settings(settings_path) + + # deal with the nested key case + return merge_dicts(current_settings, settings) + + +@app.command("set") +def set_setting( + variable: Annotated[ + str, + typer.Argument( + help=( + "PaperQA variable to set, see agents.models.QueryRequest object for all settings, " + "nested options can be set using periods, ex. agent_tools.paper_directory" + ) + ), + ], + value: Annotated[ + str, + typer.Argument( + help=( + "Value to set to the variable, will be cast to the correct type automatically." + ) + ), + ], +) -> None: + """Set a persistent PaperQA setting.""" + configure_agent_logging(verbosity=0) + + settings_path = pqa_directory("settings") / "settings.yaml" + + current_settings = get_merged_settings( + parse_dot_to_dict(variable, value), settings_path=settings_path + ) + + try: + QueryRequest(**current_settings) + except MismatchedModelsError: + pass + except ValueError as e: + raise ValueError( + f"{variable} (with value {value}) is not a valid setting." + ) from e + + logger.info(f"{variable} set to {str(value)[:100]}!") + + with open(settings_path, "w") as f: + yaml.dump(current_settings, f) + + +@app.command() +def show( + variable: Annotated[ + str, + typer.Argument( + help=( + "PaperQA variable to show, see agents.models.QueryRequest object for all settings, " + "nested options can be set using periods, ex. agent_tools.paper_directory. " + "Can show all indexes with `indexes` input, answers with `answers` input, " + "and `all` for all settings." + ) + ), + ], + limit: Annotated[ + int, typer.Option(help="limit results, only used for 'answers'.") + ] = 5, +) -> Any: + """Show a persistent PaperQA setting, special inputs include `indexes`, `answers` and `all`.""" + configure_agent_logging(verbosity=0) + + # handle special case when user wants to see indexes + if variable == "indexes": + for index in os.listdir(pqa_directory("indexes")): + index_times = get_file_timestamps(pqa_directory("indexes") / index) + logger.info(f"{index}, {index_times}") + return os.listdir(pqa_directory("indexes")) + + if variable == "answers": + all_answers = [] + answer_file_location = pqa_directory("indexes") / "answers" / "docs" + if os.path.exists(answer_file_location): + for answer_file in os.listdir(answer_file_location): + all_answers.append( + get_file_timestamps(os.path.join(answer_file_location, answer_file)) + ) + with open(os.path.join(answer_file_location, answer_file)) as f: + answer = yaml.safe_load(f) + all_answers[-1].update({"answer": answer}) + all_answers = sorted( + all_answers, key=operator.itemgetter("modified_at"), reverse=True + )[:limit] + for answer in all_answers: + logger.info( + f"Q: {answer['answer']['answer']['question']}\n---\nA: {answer['answer']['answer']['answer']}\n\n\n" + ) + return all_answers + + current_settings = get_settings(pqa_directory("settings") / "settings.yaml") + + if variable == "all": + logger.info(current_settings) + return current_settings + + try: + value, _ = pop_nested_dict_recursive(current_settings, variable) + except KeyError: + logger.info(f"{variable} is not set.") + return None + else: + logger.info(f"{variable}: {value}") + return value + + +@app.command() +def clear( + variable: Annotated[ + str, + typer.Argument( + help=( + "PaperQA variable to clear, see agents.models.QueryRequest object for all settings, " + "nested options can be set using periods, ex. agent_tools.paper_directory. " + "Index names can also be used if the --index flag is set." + ) + ), + ], + index: Annotated[ + bool, + typer.Option( + "--index", + is_flag=True, + help="index flag to indicate that this index name should be cleared.", + ), + ] = False, +) -> None: + """Clear a persistent PaperQA setting, include the --index flag to remove an index.""" + configure_agent_logging(verbosity=0) + + settings_path = pqa_directory("settings") / "settings.yaml" + + current_settings = get_settings(settings_path) + + if not index: + _ = pop_nested_dict_recursive(current_settings, variable) + with open(settings_path, "w") as f: + yaml.dump(current_settings, f) + logger.info(f"{variable} cleared!") + + elif variable in os.listdir(pqa_directory("indexes")): + shutil.rmtree(pqa_directory("indexes") / variable) + logger.info(f"Index {variable} cleared!") + else: + logger.info(f"Index {variable} not found!") + + +@app.command() +def ask( + query: Annotated[str, typer.Argument(help=("Question or task ask of PaperQA"))], + agent_type: Annotated[ + str, + typer.Option( + help=( + "Type of agent to use, for now either " + "`OpenAIFunctionsAgent` or `fake`. `fake` uses " + "a hard coded tool path (search->gather evidence->answer)." + ) + ), + ] = "fake", + verbosity: Annotated[ + int, typer.Option(help=("Level of verbosity from 0->2 (inclusive)")) + ] = 0, + directory: Annotated[ + Path | None, + typer.Option(help=("Directory of papers or documents to run PaperQA over.")), + ] = None, + index_directory: Annotated[ + Path | None, + typer.Option( + help=( + "Index directory to store paper index and answers. Default will be `~/.pqa`" + ) + ), + ] = None, + manifest_file: Annotated[ + Path | None, + typer.Option( + help=( + "Optional manifest file (CSV) location to map relative a " + "`file_location` column to `doi` or `title` columns. " + "If not used, then the file will be read by an LLM " + "which attempts to extract the title, authors and DOI." + ) + ), + ] = None, +) -> AnswerResponse: + """Query PaperQA via an agent.""" + configure_agent_logging(verbosity=verbosity) + + loop = get_loop() + + # override settings file if requested directly + to_merge = {} + + if directory is not None: + to_merge = {"agent_tools": {"paper_directory": directory}} + + if index_directory is not None: + if "agent_tools" not in to_merge: + to_merge = {"agent_tools": {"index_directory": index_directory}} + else: + to_merge["agent_tools"].update({"index_directory": index_directory}) + + if manifest_file is not None: + if "agent_tools" not in to_merge: + to_merge = {"agent_tools": {"manifest_file": manifest_file}} + else: + to_merge["agent_tools"].update({"manifest_file": manifest_file}) + + request = QueryRequest( + query=query, + **get_merged_settings( + to_merge, + settings_path=pqa_directory("settings") / "settings.yaml", + ), + ) + + return loop.run_until_complete( + agent_query( + request, + docs=None, + verbosity=verbosity, + agent_type=agent_type, + index_directory=request.agent_tools.index_directory, + ) + ) + + +@app.command("search") +def search_query( + query: Annotated[str, typer.Argument(help=("Query for keyword search"))], + index_name: Annotated[ + str, + typer.Argument( + help=( + "Name of the index to search, or use `answers`" + " to search all indexed answers" + ) + ), + ] = "answers", + index_directory: Annotated[ + Path | None, + typer.Option( + help=( + "Index directory to store paper index and answers. Default will be `~/.pqa`" + ) + ), + ] = None, +) -> list[tuple[AnswerResponse, str] | tuple[Any, str]]: + """Search using a pre-built PaperQA index.""" + configure_agent_logging(verbosity=0) + + loop = get_loop() + return loop.run_until_complete( + search( + query, + index_name=index_name, + index_directory=index_directory or pqa_directory("indexes"), + ) + ) + + +@app.command("index") +def build_index( + directory: Annotated[ + Path | None, + typer.Argument(help=("Directory of papers or documents to run PaperQA over.")), + ] = None, + index_directory: Annotated[ + Path | None, + typer.Option( + help=( + "Index directory to store paper index and answers. Default will be `~/.pqa`" + ) + ), + ] = None, + manifest_file: Annotated[ + Path | None, + typer.Option( + help=( + "Optional manifest file (CSV) location to map relative a " + "`file_location` column to `doi` or `title` columns. " + "If not used, then the file will be read by an LLM " + "which attempts to extract the title, authors and DOI." + ) + ), + ] = None, + verbosity: Annotated[ + int, typer.Option(help=("Level of verbosity from 0->2 (inclusive)")) + ] = 0, +) -> SearchIndex: + """Build a PaperQA search index, this will also happen automatically upon using `ask`.""" + configure_agent_logging(verbosity=verbosity) + + to_merge = {} + + if directory is not None: + to_merge = {"agent_tools": {"paper_directory": directory}} + + if index_directory is not None: + if "agent_tools" not in to_merge: + to_merge = {"agent_tools": {"index_directory": index_directory}} + else: + to_merge["agent_tools"].update({"index_directory": index_directory}) + + if manifest_file is not None: + if "agent_tools" not in to_merge: + to_merge = {"agent_tools": {"manifest_file": manifest_file}} + else: + to_merge["agent_tools"].update({"manifest_file": manifest_file}) + + configure_agent_logging(verbosity) + + request_settings = QueryRequest( + query="", + **get_merged_settings( + to_merge, + settings_path=pqa_directory("settings") / "settings.yaml", + ), + ) + + loop = get_loop() + + return loop.run_until_complete( + get_directory_index( + directory=anyio.Path(request_settings.agent_tools.paper_directory), + index_directory=request_settings.agent_tools.index_directory, + index_name=request_settings.get_index_name( + request_settings.agent_tools.paper_directory, + request_settings.embedding, + request_settings.parsing_configuration, + ), + manifest_file=( + anyio.Path(request_settings.agent_tools.manifest_file) + if request_settings.agent_tools.manifest_file + else None + ), + embedding=request_settings.embedding, + chunk_chars=request_settings.parsing_configuration.chunksize, + overlap=request_settings.parsing_configuration.overlap, + ) + ) + + +@app.command() +def version(): + configure_agent_logging(verbosity=0) + logger.info(f"PaperQA version: {__version__}") + + +if __name__ == "__main__": + app() diff --git a/paperqa/agents/helpers.py b/paperqa/agents/helpers.py new file mode 100644 index 00000000..50387e5b --- /dev/null +++ b/paperqa/agents/helpers.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import logging +import os +import re +from datetime import datetime +from typing import Any, cast + +from anthropic import AsyncAnthropic +from openai import AsyncOpenAI +from rich.table import Table + +from .. import ( + AnthropicLLMModel, + Docs, + OpenAILLMModel, + embedding_model_factory, + llm_model_factory, +) +from ..llms import LangchainLLMModel +from .models import AnswerResponse, QueryRequest + +logger = logging.getLogger(__name__) + + +def get_year(ts: datetime | None = None) -> str: + """Get the year from the input datetime, otherwise using the current datetime.""" + if ts is None: + ts = datetime.now() + return ts.strftime("%Y") + + +async def openai_get_search_query( + question: str, + count: int, + template: str | None = None, + llm: str = "gpt-4o-mini", + temperature: float = 1.0, +) -> list[str]: + if isinstance(template, str): + if not ( + "{count}" in template and "{question}" in template and "{date}" in template + ): + logger.warning( + "Template does not contain {count}, {question} and {date} variables. Ignoring template." + ) + template = None + + else: + # partial formatting + search_prompt = template.replace("{date}", get_year()) + + if template is None: + search_prompt = ( + "We want to answer the following question: {question} \n" + "Provide {count} unique keyword searches (one search per line) and year ranges " + "that will find papers to help answer the question. " + "Do not use boolean operators. " + "Make sure not to repeat searches without changing the keywords or year ranges. " + "Make some searches broad and some narrow. " + "Use this format: [keyword search], [start year]-[end year]. " + "where end year is optional. " + f"The current year is {get_year()}." + ) + + if "gpt" not in llm: + raise ValueError( + f"Invalid llm: {llm}, note a GPT model must be used for the fake agent search." + ) + client = AsyncOpenAI() + model = OpenAILLMModel(config={"model": llm, "temperature": temperature}) + chain = model.make_chain(client, prompt=search_prompt, skip_system=True) + result = await chain({"question": question, "count": count}) # type: ignore[call-arg] + search_query = result.text + queries = [s for s in search_query.split("\n") if len(s) > 3] # noqa: PLR2004 + # remove "2.", "3.", etc. -- https://regex101.com/r/W2f7F1/1 + queries = [re.sub(r"^\d+\.\s*", "", q) for q in queries] + # remove quotes + return [re.sub(r"\"", "", q) for q in queries] + + +def table_formatter( + objects: list[tuple[AnswerResponse | Docs, str]], max_chars_per_column: int = 2000 +) -> Table: + example_object, _ = objects[0] + if isinstance(example_object, AnswerResponse): + table = Table(title="Prior Answers") + table.add_column("Question", style="cyan") + table.add_column("Answer", style="magenta") + for obj, _ in objects: + table.add_row( + cast(AnswerResponse, obj).answer.question[:max_chars_per_column], + cast(AnswerResponse, obj).answer.answer[:max_chars_per_column], + ) + return table + if isinstance(example_object, Docs): + table = Table(title="PDF Search") + table.add_column("Title", style="cyan") + table.add_column("File", style="magenta") + for obj, filename in objects: + table.add_row( + cast(Docs, obj).texts[0].doc.title[:max_chars_per_column], filename # type: ignore[attr-defined] + ) + return table + raise NotImplementedError( + f"Object type {type(example_object)} can not be converted to table." + ) + + +# Index 0 is for prompt tokens, index 1 is for completion tokens +costs: dict[str, tuple[float, float]] = { + "claude-2": (11.02 / 10**6, 32.68 / 10**6), + "claude-instant-1": (1.63 / 10**6, 5.51 / 10**6), + "claude-3-sonnet-20240229": (3 / 10**6, 15 / 10**6), + "claude-3-5-sonnet-20240620": (3 / 10**6, 15 / 10**6), + "claude-3-opus-20240229": (15 / 10**6, 75 / 10**6), + "babbage-002": (0.0004 / 10**3, 0.0004 / 10**3), + "gpt-3.5-turbo": (0.0010 / 10**3, 0.0020 / 10**3), + "gpt-3.5-turbo-1106": (0.0010 / 10**3, 0.0020 / 10**3), + "gpt-3.5-turbo-0613": (0.0010 / 10**3, 0.0020 / 10**3), + "gpt-3.5-turbo-0301": (0.0010 / 10**3, 0.0020 / 10**3), + "gpt-3.5-turbo-0125": (0.0005 / 10**3, 0.0015 / 10**3), + "gpt-4-1106-preview": (0.010 / 10**3, 0.030 / 10**3), + "gpt-4-0125-preview": (0.010 / 10**3, 0.030 / 10**3), + "gpt-4-turbo-2024-04-09": (10 / 10**6, 30 / 10**6), + "gpt-4-turbo": (10 / 10**6, 30 / 10**6), + "gpt-4": (0.03 / 10**3, 0.06 / 10**3), + "gpt-4-0613": (0.03 / 10**3, 0.06 / 10**3), + "gpt-4-0314": (0.03 / 10**3, 0.06 / 10**3), + "gpt-4o": (2.5 / 10**6, 10 / 10**6), + "gpt-4o-2024-05-13": (5 / 10**6, 15 / 10**6), + "gpt-4o-2024-08-06": (2.5 / 10**6, 10 / 10**6), + "gpt-4o-mini": (0.15 / 10**6, 0.60 / 10**6), + "gemini-1.5-flash": (0.35 / 10**6, 0.35 / 10**6), + "gemini-1.5-pro": (3.5 / 10**6, 10.5 / 10**6), + # supported Anyscale models per + # https://docs.anyscale.com/endpoints/text-generation/query-a-model + "meta-llama/Meta-Llama-3-8B-Instruct": (0.15 / 10**6, 0.15 / 10**6), + "meta-llama/Meta-Llama-3-70B-Instruct": (1.0 / 10**6, 1.0 / 10**6), + "mistralai/Mistral-7B-Instruct-v0.1": (0.15 / 10**6, 0.15 / 10**6), + "mistralai/Mixtral-8x7B-Instruct-v0.1": (1.0 / 10**6, 1.0 / 10**6), + "mistralai/Mixtral-8x22B-Instruct-v0.1": (1.0 / 10**6, 1.0 / 10**6), +} + + +def compute_model_token_cost(model: str, tokens: int, is_completion: bool) -> float: + if model in costs: # Prefer our internal costs model + model_costs: tuple[float, float] = costs[model] + else: + logger.warning(f"Model {model} not found in costs.") + return 0.0 + return tokens * model_costs[int(is_completion)] + + +def compute_total_model_token_cost(token_counts: dict[str, list[int]]) -> float: + """Sum the token counts for each model and return the total cost.""" + cost = 0.0 + for model, tokens in token_counts.items(): + if sum(tokens) > 0: + cost += compute_model_token_cost( + model, tokens=tokens[0], is_completion=False + ) + compute_model_token_cost(model, tokens=tokens[1], is_completion=True) + return cost + + +# the defaults here should be (about) the same as in QueryRequest +def update_doc_models(doc: Docs, request: QueryRequest | None = None): + if request is None: + request = QueryRequest() + client: Any = None + + if request.llm.startswith("gemini"): + doc.llm_model = LangchainLLMModel(name=request.llm) + doc.summary_llm_model = LangchainLLMModel(name=request.summary_llm) + else: + doc.llm_model = llm_model_factory(request.llm) + doc.summary_llm_model = llm_model_factory(request.summary_llm) + + # set temperatures + doc.llm_model.config["temperature"] = request.temperature + doc.summary_llm_model.config["temperature"] = request.temperature + + if isinstance(doc.llm_model, OpenAILLMModel): + if request.llm.startswith( + ("meta-llama/Meta-Llama-3-", "mistralai/Mistral-", "mistralai/Mixtral-") + ): + client = AsyncOpenAI( + base_url=os.environ.get("ANYSCALE_BASE_URL"), + api_key=os.environ.get("ANYSCALE_API_KEY"), + ) + logger.info(f"Using Anyscale (via OpenAI client) for {request.llm}") + else: + client = AsyncOpenAI() + elif isinstance(doc.llm_model, AnthropicLLMModel): + client = AsyncAnthropic() + elif isinstance(doc.llm_model, LangchainLLMModel): + from langchain_google_vertexai import ( + ChatVertexAI, + HarmBlockThreshold, + HarmCategory, + ) + + # we have to convert system to human because system is unsupported + # Also we do get blocked content, so adjust thresholds + client = ChatVertexAI( + model=request.llm, + safety_settings={ + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + }, + convert_system_message_to_human=True, + ) + else: + raise TypeError(f"Unsupported LLM model: {doc.llm_model}") + + doc._client = client # set client, since could be just unpickled. + doc._embedding_client = AsyncOpenAI() # hard coded to OpenAI for now + + doc.texts_index.embedding_model = embedding_model_factory( + request.embedding, **(request.texts_index_embedding_config or {}) + ) + doc.docs_index.embedding_model = embedding_model_factory( + request.embedding, **(request.docs_index_embedding_config or {}) + ) + doc.texts_index.mmr_lambda = request.texts_index_mmr_lambda + doc.docs_index.mmr_lambda = request.docs_index_mmr_lambda + doc.embedding = request.embedding + doc.max_concurrent = request.max_concurrent + doc.prompts = request.prompts + Docs.make_llm_names_consistent(doc) + + logger.debug( + f"update_doc_models: {doc.name}" + f" | {(doc.llm_model.config)} | {(doc.summary_llm_model.config)}" + f" | {doc.docs_index.__class__}" + ) diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py new file mode 100644 index 00000000..6acd9485 --- /dev/null +++ b/paperqa/agents/main.py @@ -0,0 +1,336 @@ +from __future__ import annotations + +import logging +import os +from typing import Any, cast +from unittest.mock import patch + +from langchain.agents import AgentExecutor, BaseSingleActionAgent, ZeroShotAgent +from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent +from langchain_community.callbacks import OpenAICallbackHandler +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler, BaseCallbackManager, Callbacks +from langchain_core.messages import SystemMessage +from langchain_openai import ChatOpenAI +from rich.console import Console + +from ..docs import Docs +from ..types import Answer +from ..utils import pqa_directory +from .helpers import openai_get_search_query, table_formatter, update_doc_models +from .models import ( + AgentCallback, + AgentStatus, + AnswerResponse, + QueryRequest, + SimpleProfiler, +) +from .search import SearchDocumentStorage, SearchIndex +from .tools import ( + EmptyDocsError, + GatherEvidenceTool, + GenerateAnswerTool, + PaperSearchTool, + SharedToolState, + query_to_tools, + status, +) + +logger = logging.getLogger(__name__) +agent_logger = logging.getLogger(__name__ + ".agent_callers") + + +async def agent_query( + query: str | QueryRequest, + docs: Docs | None = None, + agent_type: str = "OpenAIFunctionsAgent", + verbosity: int = 0, + index_directory: str | os.PathLike | None = None, +) -> AnswerResponse: + + if isinstance(query, str): + query = QueryRequest(query=query) + + if docs is None: + docs = Docs() + + if index_directory is None: + index_directory = pqa_directory("indexes") + + # in-place modification of the docs object to match query + update_doc_models( + docs, + query, + ) + + search_index = SearchIndex( + fields=SearchIndex.REQUIRED_FIELDS | {"question"}, + index_name="answers", + index_directory=index_directory, + storage=SearchDocumentStorage.JSON_MODEL_DUMP, + ) + + response = await run_agent(docs, query, agent_type) + + agent_logger.debug(f"agent_response: {response}") + truncation_chars = 1_000_000 if verbosity > 1 else 1500 * (verbosity + 1) + agent_logger.info( + f"[bold blue]Answer: {response.answer.answer[:truncation_chars]}" + f'{"...(truncated)" if len(response.answer.answer) > truncation_chars else ""}[/bold blue]' + ) + + await search_index.add_document( + { + "file_location": str(response.answer.id), + "body": response.answer.answer or "", + "question": response.answer.question, + }, + document=response, + ) + + await search_index.save_index() + + return response + + +async def run_agent( + docs: Docs, + query: QueryRequest, + agent_type: str = "OpenAIFunctionsAgent", +) -> AnswerResponse: + """ + Run an agent. + + Args: + docs: Docs to run upon. + query: Query to answer. + websocket: Websocket to send JSON data and receive text. + agent_type: Agent type to pass to AgentType.get_agent, or "fake" to TODOC. + + Returns: + Tuple of resultant answer, token counts, and agent status. + """ + profiler = SimpleProfiler() + outer_profile_name = f"agent-{agent_type}-{query.agent_llm}" + profiler.start(outer_profile_name) + + logger.info( + f"Beginning agent {agent_type!r} run with question {query.query!r} and full" + f" query {query.model_dump()}." + ) + + if agent_type == "fake": + answer, agent_status = await run_fake_agent(query, docs) + else: + answer, agent_status = await run_langchain_agent( + query, docs, agent_type, profiler + ) + + if "cannot answer" in answer.answer.lower() and agent_status != AgentStatus.TIMEOUT: + agent_status = AgentStatus.UNSURE + # stop after, so overall isn't reported as long-running step. + logger.info( + f"Finished agent {agent_type!r} run with question {query.query!r} and status" + f" {agent_status}." + ) + return AnswerResponse( + answer=answer, + usage=answer.token_counts, + status=agent_status, + ) + + +async def run_fake_agent( + query: QueryRequest, + docs: Docs, +) -> tuple[Answer, AgentStatus]: + answer = Answer(question=query.query, dockey_filter=set(), id=query.id) + tools = query_to_tools(query, state=SharedToolState(docs=docs, answer=answer)) + search_tool = cast( + PaperSearchTool, + next( + filter( + lambda x: x.name == PaperSearchTool.__fields__["name"].default, tools + ) + ), + ) + gather_evidence_tool = cast( + GatherEvidenceTool, + next( + filter( + lambda x: x.name == GatherEvidenceTool.__fields__["name"].default, tools + ) + ), + ) + + generate_answer_tool = cast( + GenerateAnswerTool, + next( + filter( + lambda x: x.name == GenerateAnswerTool.__fields__["name"].default, tools + ) + ), + ) + # seed docs with keyword search + for search in await openai_get_search_query( + answer.question, llm=query.llm, count=3 + ): + await search_tool.arun(search) + + await gather_evidence_tool.arun(tool_input=answer.question) + + await generate_answer_tool.arun(tool_input=answer.question) + + return answer, AgentStatus.SUCCESS + + +LANGCHAIN_AGENT_TYPES: dict[str, type[BaseSingleActionAgent]] = { + "ReactAgent": ZeroShotAgent, + "OpenAIFunctionsAgent": OpenAIFunctionsAgent, +} + + +async def run_langchain_agent( + query: QueryRequest, + docs: Docs, + agent_type: str, + profiler: SimpleProfiler, + timeout: float | None = None, # noqa: ASYNC109 +) -> tuple[Answer, AgentStatus]: + answer = Answer(question=query.query, dockey_filter=set(), id=query.id) + shared_callbacks: list[BaseCallbackHandler] = [ + AgentCallback( + profiler, name=f"step-{agent_type}-{query.agent_llm}", answer_id=answer.id + ), + ] + tools = query_to_tools( + query, + state=SharedToolState(docs=docs, answer=answer), + callbacks=shared_callbacks, + ) + try: + search_tool = next( + filter( + lambda x: x.name == PaperSearchTool.__fields__["name"].default, tools + ) + ) + except StopIteration: + search_tool = None + answer_tool = cast( + GenerateAnswerTool, + next( + filter( + lambda x: x.name == GenerateAnswerTool.__fields__["name"].default, tools + ) + ), + ) + + # optionally use the search tool before the agent + if search_tool is not None and query.agent_tools.should_pre_search: + logger.debug("Running search tool before agent choice.") + await search_tool.arun(answer.question) + else: + logger.debug("Skipping search tool before agent choice.") + + llm = ChatOpenAI( + model=query.agent_llm, + request_timeout=timeout or query.agent_tools.timeout / 2.0, + temperature=query.temperature, + ) + agent_status = AgentStatus.SUCCESS + cost_callback = OpenAICallbackHandler() + agent_instance = LANGCHAIN_AGENT_TYPES[agent_type].from_llm_and_tools( + llm, + tools, + system_message=( + SystemMessage(content=query.agent_tools.agent_system_prompt) + if query.agent_tools.agent_system_prompt + else None + ), + ) + orig_aplan = agent_instance.aplan + agent_exec_instance = AgentExecutor.from_agent_and_tools( + tools=tools, + agent=agent_instance, + return_intermediate_steps=True, + handle_parsing_errors=True, + max_execution_time=query.agent_tools.timeout, + callbacks=[*shared_callbacks, cost_callback], + **(query.agent_tools.agent_config or {}), + ) + + async def aplan_with_injected_callbacks( + intermediate_steps: list[tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs, + ) -> AgentAction | AgentFinish: + # Work around https://github.com/langchain-ai/langchain/issues/22703 + for callback in cast(list[BaseCallbackHandler], agent_exec_instance.callbacks): + cast(BaseCallbackManager, callbacks).add_handler(callback, inherit=False) + return await orig_aplan(intermediate_steps, callbacks, **kwargs) + + try: + # Patch at instance (not class) level to avoid concurrency issues, and we have + # to patch the dict to work around Pydantic's BaseModel.__setattr__'s validations + with patch.dict( + agent_instance.__dict__, {"aplan": aplan_with_injected_callbacks} + ): + call_response = await agent_exec_instance.ainvoke( + input={ + # NOTE: str.format still works even if the prompt doesn't have + # template fields like 'status' or 'gen_answer_tool_name' + "input": query.agent_tools.agent_prompt.format( + question=answer.question, + status=await status(docs, answer), + gen_answer_tool_name=answer_tool.name, + ) + } + ) + except TimeoutError: + call_response = {"output": "Agent stopped", "intermediate_steps": []} + except EmptyDocsError: + call_response = { + "output": "Agent failed due to failed search", + "intermediate_steps": [], + } + agent_status = AgentStatus.FAIL + + async with profiler.timer("agent-accounting"): + # TODO: move agent trace to LangChain callback + if "Agent stopped" in call_response["output"]: + # Log that this agent has gone over timeout, and then answer directly + logger.warning( + f"Agent timeout after {query.agent_tools.timeout}-sec, just answering." + ) + await answer_tool.arun(answer.question) + agent_status = AgentStatus.TIMEOUT + + return answer, agent_status + + +async def search( + query: str, + index_name: str = "answers", + index_directory: str | os.PathLike | None = None, +) -> list[tuple[AnswerResponse, str] | tuple[Any, str]]: + + search_index = SearchIndex( + ["file_location", "body", "question"], + index_name=index_name, + index_directory=index_directory or pqa_directory("indexes"), + storage=SearchDocumentStorage.JSON_MODEL_DUMP, + ) + + results = [ + (AnswerResponse(**a[0]) if index_name == "answers" else a[0], a[1]) + for a in await search_index.query(query=query, keep_filenames=True) + ] + + if results: + console = Console(record=True) + # Render the table to a string + console.print(table_formatter(results)) + else: + agent_logger.info("No results found.") + + return results diff --git a/paperqa/agents/models.py b/paperqa/agents/models.py new file mode 100644 index 00000000..c1434966 --- /dev/null +++ b/paperqa/agents/models.py @@ -0,0 +1,515 @@ +from __future__ import annotations + +import asyncio +import logging +import os +import time +from contextlib import asynccontextmanager +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, ClassVar, Collection, assert_never +from uuid import UUID, uuid4 + +from langchain_core.callbacks import AsyncCallbackHandler +from langchain_core.messages import BaseMessage, messages_to_dict +from langchain_core.outputs import ChatGeneration, LLMResult +from openai import AsyncOpenAI +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + ValidationInfo, + computed_field, + field_validator, + model_validator, +) +from typing_extensions import Protocol + +from .. import ( + Answer, + OpenAILLMModel, + PromptCollection, + llm_model_factory, +) +from ..utils import hexdigest +from ..version import __version__ +from .prompts import STATIC_PROMPTS + +logger = logging.getLogger(__name__) + + +class SupportsPickle(Protocol): + """Type protocol for typing any object that supports pickling.""" + + def __reduce__(self) -> str | tuple[Any, ...]: ... + def __getstate__(self) -> object: ... + def __setstate__(self, state: object) -> None: ... + + +class AgentStatus(str, Enum): + # FAIL - no answer could be generated + FAIL = "fail" + # SUCCESS - answer was generated + SUCCESS = "success" + # TIMEOUT - agent took too long, but an answer was generated + TIMEOUT = "timeout" + # UNSURE - the agent was unsure, but an answer is present + UNSURE = "unsure" + + +class AgentPromptCollection(BaseModel): + """Configuration for the agent.""" + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + + agent_system_prompt: str | None = Field( + # Matching https://github.com/langchain-ai/langchain/blob/langchain%3D%3D0.2.3/libs/langchain/langchain/agents/openai_functions_agent/base.py#L213-L215 + default="You are a helpful AI assistant.", + description="Optional system prompt message to precede the below agent_prompt.", + ) + + # TODO: make this prompt more minimalist, instead improving tool descriptions so + # how to use them together can be intuited, and exposing them for configuration + agent_prompt: str = ( + "Answer question: {question}" + "\n\nSearch for papers, gather evidence, collect papers cited in evidence then re-gather evidence, and answer." + " Gathering evidence will do nothing if you have not done a new search or collected new papers." + " If you do not have enough evidence to generate a good answer, you can:" + "\n- Search for more papers (preferred)" + "\n- Collect papers cited by previous evidence (preferred)" + "\n- Gather more evidence using a different phrase" + "\nIf you search for more papers or collect new papers cited by previous evidence," + " remember to gather evidence again." + " Once you have five or more pieces of evidence from multiple sources, or you have tried a few times, " + "call {gen_answer_tool_name} tool. The {gen_answer_tool_name} tool output is visible to the user, " + "so you do not need to restate the answer and can simply terminate if the answer looks sufficient. " + "The current status of evidence/papers/cost is {status}" + ) + paper_directory: str | os.PathLike = Field( + default=Path.cwd(), + description=( + "Local directory which contains the papers to be indexed and searched." + ), + ) + index_directory: str | os.PathLike | None = Field( + default=None, + description=( + "Directory to store the PQA generated search index, configuration, and answer indexes." + ), + ) + manifest_file: str | os.PathLike | None = Field( + default=None, + description=( + "Optional manifest CSV, containing columns which are attributes for a DocDetails object. " + "Only 'file_location','doi', and 'title' will be used when indexing." + ), + ) + search_count: int = 8 + wipe_context_on_answer_failure: bool = True + timeout: float = Field( + default=500.0, + description=( + "Matches LangChain AgentExecutor.max_execution_time (seconds), the timeout" + " on agent execution." + ), + ) + should_pre_search: bool = Field( + default=False, + description="If set to true, run the search tool before invoking agent.", + ) + agent_config: dict[str, Any] | None = Field( + default=None, + description="Optional keyword argument configuration for the agent.", + ) + tool_names: Collection[str] | None = Field( + default=None, + description=( + "Optional override on the tools to provide the agent. Leaving as the" + " default of None will use a minimal toolset of the paper search, gather" + " evidence, collect cited papers from evidence, and gen answer. If passing tool" + " names (non-default route), at least the gen answer tool must be supplied." + ), + ) + + @field_validator("tool_names") + @classmethod + def validate_tool_names(cls, v: set[str] | None) -> set[str] | None: + if v is None: + return None + # imported here to avoid circular imports + from .main import GenerateAnswerTool + + answer_tool_name = GenerateAnswerTool.__fields__["name"].default + if answer_tool_name not in v: + raise ValueError( + f"If using an override, must contain at least the {answer_tool_name}." + ) + return v + + +class ParsingOptions(str, Enum): + PAPERQA_DEFAULT = "paperqa_default" + + def available_for_inference(self) -> list[ParsingOptions]: + return [self.PAPERQA_DEFAULT] # type: ignore[list-item] + + def get_parse_type(self, config: ParsingConfiguration) -> str: + if self == ParsingOptions.PAPERQA_DEFAULT: + return config.parser_version_string + assert_never() + + +class ChunkingOptions(str, Enum): + SIMPLE_OVERLAP = "simple_overlap" + + @property + def valid_parsings(self) -> list[ParsingOptions]: + # Note that SIMPLE_OVERLAP must be valid for all by default + # TODO: implement for future parsing options + valid_parsing_dict: dict[str, list[ParsingOptions]] = {} + return valid_parsing_dict.get(self.value, []) + + +class ImpossibleParsingError(Exception): + """Error to throw when a parsing is impossible.""" + + LOG_METHOD_NAME: ClassVar[str] = "warning" + + +class ParsingConfiguration(BaseModel): + """Holds a superset of params and methods needed for each algorithm.""" + + ordered_parser_preferences: list[ParsingOptions] = [ + ParsingOptions.PAPERQA_DEFAULT, + ] + chunksize: int = 6000 + overlap: int = 100 + chunking_algorithm: ChunkingOptions = ChunkingOptions.SIMPLE_OVERLAP + + def chunk_type(self, chunking_selection: ChunkingOptions | None = None) -> str: + """Future chunking implementations (i.e. by section) will get an elif clause here.""" + if chunking_selection is None: + chunking_selection = self.chunking_algorithm + if chunking_selection == ChunkingOptions.SIMPLE_OVERLAP: + return ( + f"{self.parser_version_string}|{chunking_selection.value}" + f"|tokens={self.chunksize}|overlap={self.overlap}" + ) + assert_never() + + @property + def parser_version_string(self) -> str: + return f"paperqa-{__version__}" + + def is_chunking_valid_for_parsing(self, parsing: str): + # must map the parsings because they won't include versions by default + return ( + self.chunking_algorithm == ChunkingOptions.SIMPLE_OVERLAP + or parsing + in { # type: ignore[unreachable] + p.get_parse_type(self) for p in self.chunking_algorithm.valid_parsings + } + ) + + +class MismatchedModelsError(Exception): + """Error to throw when model clients clash .""" + + LOG_METHOD_NAME: ClassVar[str] = "warning" + + +class QueryRequest(BaseModel): + query: str = "" + id: UUID = Field( + default_factory=uuid4, + description="Identifier which will be propagated to the Answer object.", + ) + llm: str = "gpt-4o-2024-08-06" + agent_llm: str = Field( + default="gpt-4o-2024-08-06", + description="Chat model to use for agent planning", + ) + summary_llm: str = "gpt-4o-2024-08-06" + length: str = "about 200 words, but can be longer if necessary" + summary_length: str = "about 100 words" + max_sources: int = 10 + consider_sources: int = 16 + named_prompt: str | None = None + # if you change this to something other than default + # modify code below in update_prompts + prompts: PromptCollection = Field( + default=STATIC_PROMPTS["default"], validate_default=True + ) + agent_tools: AgentPromptCollection = Field(default_factory=AgentPromptCollection) + texts_index_mmr_lambda: float = 1.0 + texts_index_embedding_config: dict[str, Any] | None = None + docs_index_mmr_lambda: float = 1.0 + docs_index_embedding_config: dict[str, Any] | None = None + parsing_configuration: ParsingConfiguration = ParsingConfiguration() + embedding: str = "text-embedding-3-small" + # concurrent number of summary calls to use inside Doc object + max_concurrent: int = 20 + temperature: float = 0.0 + summary_temperature: float = 0.0 + # at what size should we start using adoc_match? + adoc_match_threshold: int = 500 + # Should we filter out "Extra Background Information" citations + # which come from pre-step in paper-qa algorithm + filter_extra_background: bool = True + # provides post-hoc linkage of request to a docs object + # NOTE: this isn't a unique field, on the user to keep straight + _docs_name: str | None = PrivateAttr(default=None) + + # strict validation for now + model_config = ConfigDict(extra="forbid") + + @computed_field # type: ignore[misc] + @property + def docs_name(self) -> str | None: + return self._docs_name + + @model_validator(mode="after") + def llm_models_must_match(self) -> QueryRequest: + llm = llm_model_factory(self.llm) + summary_llm = llm_model_factory(self.summary_llm) + if type(llm) is not type(summary_llm): + raise MismatchedModelsError( + f"Answer LLM and summary LLM types must match: {type(llm)} != {type(summary_llm)}" + ) + return self + + @field_validator("prompts") + def update_prompts( + cls, # noqa: N805 + v: PromptCollection, + info: ValidationInfo, + ) -> PromptCollection: + values = info.data + if values["named_prompt"] is not None: + if values["named_prompt"] not in STATIC_PROMPTS: + raise ValueError( + f"Named prompt {values['named_prompt']} not in {list(STATIC_PROMPTS.keys())}" + ) + v = STATIC_PROMPTS[values["named_prompt"]] + if values["summary_llm"] == "none": + v.skip_summary = True + # for simplicity (it is not used anywhere) + # so that Docs doesn't break when we don't have a summary_llm + values["summary_llm"] = "gpt-4o-mini" + return v + + def set_docs_name(self, docs_name: str) -> None: + """Set the internal docs name for tracking.""" + self._docs_name = docs_name + + @staticmethod + def get_index_name( + paper_directory: str | os.PathLike, + embedding: str, + parsing_configuration: ParsingConfiguration, + ) -> str: + + # index name should use an absolute path + # this way two different folders where the + # user locally uses '.' will make different indexes + if isinstance(paper_directory, Path): + paper_directory = str(paper_directory.absolute()) + + index_fields = "|".join( + [ + str(paper_directory), # cast for typing + embedding, + str(parsing_configuration.chunksize), + str(parsing_configuration.overlap), + parsing_configuration.chunking_algorithm, + ] + ) + + return f"pqa_index_{hexdigest(index_fields)}" + + +class AnswerResponse(BaseModel): + answer: Answer + usage: dict[str, list[int]] + bibtex: dict[str, str] | None = None + status: AgentStatus + timing_info: dict[str, dict[str, float]] | None = None + duration: float = 0.0 + # A placeholder for interesting statistics we can show users + # about the answer, such as the number of sources used, etc. + stats: dict[str, str] | None = None + + @field_validator("answer") + def strip_answer( + cls, v: Answer, info: ValidationInfo # noqa: ARG002, N805 + ) -> Answer: + # This modifies in place, this is fine + # because when a response is being constructed, + # we should be done with the Answer object + v.filter_content_for_user() + return v + + async def get_summary(self, llm_model="gpt-4-turbo") -> str: + sys_prompt = ( + "Revise the answer to a question to be a concise SMS message. " + "Use abbreviations or emojis if necessary." + ) + model = OpenAILLMModel(config={"model": llm_model, "temperature": 0.1}) + chain = model.make_chain( + AsyncOpenAI(), prompt="{question}\n\n{answer}", system_prompt=sys_prompt + ) + result = await chain({"question": self.answer.question, "answer": self.answer.answer}) # type: ignore[call-arg] + return result.text.strip() + + +class TimerData(BaseModel): + start_time: float = Field(default_factory=time.time) # noqa: FURB111 + durations: list[float] = Field(default_factory=list) + + +class SimpleProfiler(BaseModel): + """Basic profiler with start/stop and named timers. + + The format for this logger needs to be strictly followed, as downstream google + cloud monitoring is based on the following + # [Profiling] {**name** of timer} | {**elapsed** time of function} | {**__version__** of PaperQA} + """ + + timers: dict[str, list[float]] = {} + running_timers: dict[str, TimerData] = {} + uid: UUID = Field(default_factory=uuid4) + + @asynccontextmanager + async def timer(self, name: str): + start_time = asyncio.get_running_loop().time() + try: + yield + finally: + end_time = asyncio.get_running_loop().time() + elapsed = end_time - start_time + self.timers.setdefault(name, []).append(elapsed) + logger.info( + f"[Profiling] | UUID: {self.uid} | NAME: {name} | TIME: {elapsed:.3f}s | VERSION: {__version__}" + ) + + def start(self, name: str) -> None: + try: + self.running_timers[name] = TimerData() + except RuntimeError: # No running event loop (not in async) + self.running_timers[name] = TimerData(start_time=time.time()) + + def stop(self, name: str): + timer_data = self.running_timers.pop(name, None) + if timer_data: + try: + t_stop: float = asyncio.get_running_loop().time() + except RuntimeError: # No running event loop (not in async) + t_stop = time.time() + elapsed = t_stop - timer_data.start_time + self.timers.setdefault(name, []).append(elapsed) + logger.info( + f"[Profiling] | UUID: {self.uid} | NAME: {name} | TIME: {elapsed:.3f}s | VERSION: {__version__}" + ) + else: + logger.warning(f"Timer {name} not running") + + def results(self) -> dict[str, dict[str, float]]: + result = {} + for name, durations in self.timers.items(): + mean = sum(durations) / len(durations) + result[name] = { + "low": min(durations), + "mean": mean, + "max": max(durations), + "total": sum(durations), + } + return result + + +class AgentCallback(AsyncCallbackHandler): + """ + Callback handler used to monitor the agent, for debugging. + + Its various capabilities include: + - Chain start --> error/stop: profile runtime + - Tool start: count tool invocations + - LLM start --> error/stop: insert into LLMResultDB + + NOTE: this is not a thread safe implementation since start(s)/end(s) mutate self. + """ + + def __init__( + self, profiler: SimpleProfiler, name: str, answer_id: UUID, **kwargs + ) -> None: + super().__init__(**kwargs) + self.profiler = profiler + self.name = name + self._tool_starts: list[str] = [] + self._answer_id = answer_id + # This will be None before/after a completion, and a dict during one + self._llm_result_db_kwargs: dict[str, Any] | None = None + + @property + def tool_invocations(self) -> list[str]: + return self._tool_starts + + async def on_chain_start(self, *args, **kwargs) -> None: + await super().on_chain_start(*args, **kwargs) + self.profiler.start(self.name) + + async def on_chain_end(self, *args, **kwargs) -> None: + await super().on_chain_end(*args, **kwargs) + self.profiler.stop(self.name) + + async def on_chain_error(self, *args, **kwargs) -> None: + await super().on_chain_error(*args, **kwargs) + self.profiler.stop(self.name) + + async def on_tool_start( + self, serialized: dict[str, Any], input_str: str, **kwargs + ) -> None: + await super().on_tool_start(serialized, input_str, **kwargs) + self._tool_starts.append(serialized["name"]) + + async def on_chat_model_start( + self, + serialized: dict[str, Any], # noqa: ARG002 + messages: list[list[BaseMessage]], + **kwargs, + ) -> None: + # NOTE: don't call super(), as it changes semantics + if len(messages) != 1: + raise NotImplementedError(f"Didn't handle shape of messages {messages}.") + self._llm_result_db_kwargs = { + "answer_id": self._answer_id, + "name": f"tool_selection:{len(messages[0])}", + "prompt": { + "messages": messages_to_dict(messages[0]), + # SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-functions + "functions": kwargs["invocation_params"]["functions"], + "tool_history": self.tool_invocations, + }, + "model": kwargs["invocation_params"]["model"], + "date": datetime.now().isoformat(), + } + + async def on_llm_end(self, response: LLMResult, **kwargs) -> None: + await super().on_llm_end(response, **kwargs) + if ( + len(response.generations) != 1 + or len(response.generations[0]) != 1 + or not isinstance(response.generations[0][0], ChatGeneration) + ): + raise NotImplementedError( + f"Didn't handle shape of generations {response.generations}." + ) + if self._llm_result_db_kwargs is None: + raise NotImplementedError( + "There should have been an LLM result populated here by now." + ) + if not isinstance(response.llm_output, dict): + raise NotImplementedError( + f"Expected llm_output to be a dict, but got {response.llm_output}." + ) diff --git a/paperqa/agents/prompts.py b/paperqa/agents/prompts.py new file mode 100644 index 00000000..11477fc6 --- /dev/null +++ b/paperqa/agents/prompts.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from paperqa import PromptCollection + +# I wanted to to make this an Enum +# but there is a metaclass conflict +# so we must instead have some logic here +# and some logic on named_prompt in the QueryRequest model +# https://github.com/pydantic/pydantic/issues/2173 +STATIC_PROMPTS: dict[str, PromptCollection] = { + "default": PromptCollection( + qa=( + "Answer the question below with the context.\n\n" + "Context:\n\n{context}\n\n----\n\n" + "Question: {question}\n\n" + "Write an answer based on the context. " + "If the context provides insufficient information and " + "the question cannot be directly answered, reply " + '"I cannot answer." ' + "For each part of your answer, indicate which sources most support " + "it via citation keys at the end of sentences, " + "like (Example2012Example pages 3-4). Only cite from the context " + "below and only use the valid keys. " + "Write in the style of a direct email containing only key details, equations, and quantities. " + 'Avoid using adverb phrases like "furthermore", "additionally", and "moreover." ' + "This will go directly onto a website for public viewing, so do not include any " + "process details about following these instructions.\n\n" + "Answer ({answer_length}):\n" + ), + select=( + "Select papers that may help answer the question below. " + "Papers are listed as $KEY: $PAPER_INFO. " + "Return a list of keys, separated by commas. " + 'Return "None", if no papers are applicable. ' + "Choose papers that are relevant, from reputable sources, and timely " + "(if the question requires timely information). \n\n" + "Question: {question}\n\n" + "Papers: {papers}\n\n" + "Selected keys:" + ), + pre=( + "We are collecting background information for the question/task below. " + "Provide a brief summary of definitions, acronyms, or background information (about 50 words) that " + "could help answer the question. Do not answer it directly. Ignore formatting instructions. " + "Do not answer if there is nothing to contribute. " + "\n\nQuestion:\n{question}\n\n" + ), + post=None, + system=( + "Answer in a direct and concise tone. " + "Your audience is an expert, so be highly specific. " + "If there are ambiguous terms or acronyms, be explicit." + ), + skip_summary=False, + json_summary=True, + summary_json=( + "Excerpt from {citation}\n\n----\n\n{text}\n\n----\n\n" + "Query: {question}\n\n" + ), + summary_json_system="Provide a summary of the excerpt that could help answer the question based on the excerpt. " # noqa: E501 + "The excerpt may be irrelevant. Do not directly answer the question - only summarize relevant information. " + "Respond with the following JSON format:\n\n" + '{{\n"summary": "...",\n"relevance_score": "..."}}\n\n' + "where `summary` is relevant information from text ({summary_length}), " + "and `relevance_score` is " + "the relevance of `summary` to answer the question (integer out of 10).", + ), + "wikicrow": PromptCollection( + qa=( + "Answer the question below with the context.\n\n" + "Context:\n\n{context}\n\n----\n\n" + "Question: {question}\n\n" + "Write an answer based on the context. " + "If the context provides insufficient information and " + "the question cannot be directly answered, reply " + '"I cannot answer." ' + "For each part of your answer, indicate which sources most support " + "it via citation keys at the end of sentences, " + "like (Example2012Example pages 3-4). Only cite from the context " + "below and only use the valid keys. Write in the style of a " + "Wikipedia article, with concise sentences and coherent paragraphs. " + "The context comes from a variety of sources and is only a summary, " + "so there may inaccuracies or ambiguities. Make sure the gene_names exactly match " + "the gene name in the question before using a context. " + "This answer will go directly onto " + "Wikipedia, so do not add any extraneous information.\n\n" + "Answer ({answer_length}):" + ), + select=( + "Select papers that may help answer the question below. " + "Papers are listed as $KEY: $PAPER_INFO. " + "Return a list of keys, separated by commas. " + 'Return "None", if no papers are applicable. ' + "Choose papers that are relevant, from reputable sources, and timely " + "(if the question requires timely information). \n\n" + "Question: {question}\n\n" + "Papers: {papers}\n\n" + "Selected keys:" + ), + pre=( + "We are collecting background information for the question/task below. " + "Provide a brief summary of definitions, acronyms, or background information (about 50 words) that " + "could help answer the question. Do not answer it directly. Ignore formatting instructions. " + "Do not answer if there is nothing to contribute. " + "\n\nQuestion:\n{question}\n\n" + ), + post=None, + system=( + "Answer in a direct and concise tone. " + "Your audience is an expert, so be highly specific. " + "If there are ambiguous terms or acronyms, be explicit." + ), + skip_summary=False, + json_summary=True, + summary_json=( + "Excerpt from {citation}\n\n----\n\n{text}\n\n----\n\n" + "Query: {question}\n\n" + ), + summary_json_system="Provide a summary of the excerpt that could help answer the question based on the excerpt. " # noqa: E501 + "The excerpt may be irrelevant. Do not directly answer the question - only summarize relevant information. " + "Respond with the following JSON format:\n\n" + '{{\n"summary": "...",\n"gene_name: "...",\n"relevance_score": "..."}}\n\n' + "where `summary` is relevant information from text ({summary_length}), " + "`gene_name` is the gene discussed in the excerpt (may be different than query), " + "and `relevance_score` is " + "the relevance of `summary` to answer the question (integer out of 10).", + ), +} + +# for backwards compatibility +STATIC_PROMPTS["json"] = STATIC_PROMPTS["default"] diff --git a/paperqa/agents/search.py b/paperqa/agents/search.py new file mode 100644 index 00000000..59ac1d59 --- /dev/null +++ b/paperqa/agents/search.py @@ -0,0 +1,487 @@ +from __future__ import annotations + +import csv +import json +import logging +import os +import pathlib +import pickle +import zlib +from enum import Enum, auto +from io import StringIO +from typing import Any, ClassVar, Collection +from uuid import UUID + +import anyio +from pydantic import BaseModel +from tantivy import Document, Index, Schema, SchemaBuilder, Searcher +from tenacity import ( + RetryError, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) + +from ..docs import Docs +from ..types import DocDetails +from ..utils import hexdigest, pqa_directory +from .models import SupportsPickle + +logger = logging.getLogger(__name__) + +PQA_INDEX_ABSOLUTE_PATHS = ( + os.environ.get("PQA_INDEX_ABSOLUTE_PATHS", "true").lower() == "true" +) + + +class AsyncRetryError(Exception): + """Flags a retry for another tenacity attempt.""" + + +class RobustEncoder(json.JSONEncoder): + """JSON encoder that can handle UUID and set objects.""" + + def default(self, obj): + if isinstance(obj, UUID): + # if the obj is uuid, we simply return the value of uuid + return str(obj) + if isinstance(obj, set): + return list(obj) + return json.JSONEncoder.default(self, obj) + + +class SearchDocumentStorage(str, Enum): + JSON_MODEL_DUMP = auto() + PICKLE_COMPRESSED = auto() + PICKLE_UNCOMPRESSED = auto() + + def extension(self) -> str: + if self == SearchDocumentStorage.JSON_MODEL_DUMP: + return "json" + if self == SearchDocumentStorage.PICKLE_COMPRESSED: + return "zip" + return "pkl" + + def write_to_string(self, data: BaseModel | SupportsPickle) -> bytes: + if self == SearchDocumentStorage.JSON_MODEL_DUMP: + if isinstance(data, BaseModel): + return json.dumps(data.model_dump(), cls=RobustEncoder).encode("utf-8") + raise ValueError("JSON_MODEL_DUMP requires a BaseModel object.") + if self == SearchDocumentStorage.PICKLE_COMPRESSED: + return zlib.compress(pickle.dumps(data)) + return pickle.dumps(data) + + def read_from_string(self, data: str | bytes) -> BaseModel | SupportsPickle: + if self == SearchDocumentStorage.JSON_MODEL_DUMP: + return json.loads(data) + if self == SearchDocumentStorage.PICKLE_COMPRESSED: + return pickle.loads(zlib.decompress(data)) # type: ignore[arg-type] # noqa: S301 + return pickle.loads(data) # type: ignore[arg-type] # noqa: S301 + + +class SearchIndex: + + REQUIRED_FIELDS: ClassVar[set[str]] = {"file_location", "body"} + + def __init__( + self, + fields: Collection[str] | None = None, + index_name: str = "pqa_index", + index_directory: str | os.PathLike | None = None, + storage: SearchDocumentStorage = SearchDocumentStorage.PICKLE_COMPRESSED, + ): + if fields is None: + fields = self.REQUIRED_FIELDS + self.fields = fields + if not all(f in self.fields for f in self.REQUIRED_FIELDS): + raise ValueError( + f"{self.REQUIRED_FIELDS} must be included in search index fields." + ) + if index_directory is None: + index_directory = pqa_directory("indexes") + self.index_name = index_name + self._index_directory = index_directory + self._schema = None + self._index = None + self._searcher = None + self._index_files: dict[str, str] = {} + self.changed = False + self.storage = storage + + async def init_directory(self): + await anyio.Path(await self.index_directory).mkdir(parents=True, exist_ok=True) + + @staticmethod + async def extend_and_make_directory(base: anyio.Path, *dirs: str) -> anyio.Path: + directory = base.joinpath(*dirs) + await directory.mkdir(parents=True, exist_ok=True) + return directory + + @property + async def index_directory(self) -> anyio.Path: + return await self.extend_and_make_directory( + anyio.Path(self._index_directory), self.index_name + ) + + @property + async def index_filename(self) -> anyio.Path: + return await self.extend_and_make_directory(await self.index_directory, "index") + + @property + async def docs_index_directory(self) -> anyio.Path: + return await self.extend_and_make_directory(await self.index_directory, "docs") + + @property + async def file_index_filename(self) -> anyio.Path: + return (await self.index_directory) / "files.zip" + + @property + def schema(self) -> Schema: + if not self._schema: + schema_builder = SchemaBuilder() + for field in self.fields: + schema_builder.add_text_field(field, stored=True) + self._schema = schema_builder.build() + return self._schema + + @property + async def index(self) -> Index: + if not self._index: + index_path = await self.index_filename + if await (index_path / "meta.json").exists(): + self._index = Index.open(str(index_path)) + else: + self._index = Index(self.schema, str(index_path)) + return self._index + + @property + async def searcher(self) -> Searcher: + if not self._searcher: + index = await self.index + index.reload() + self._searcher = index.searcher() + return self._searcher + + @property + async def index_files(self) -> dict[str, str]: + if not self._index_files: + file_index_path = await self.file_index_filename + if await file_index_path.exists(): + async with await anyio.open_file(file_index_path, "rb") as f: + content = await f.read() + self._index_files = pickle.loads( # noqa: S301 + zlib.decompress(content) + ) + return self._index_files + + @staticmethod + def filehash(body: str) -> str: + return hexdigest(body) + + async def filecheck(self, filename: str, body: str | None = None): + filehash = None + if body: + filehash = self.filehash(body) + index_files = await self.index_files + return bool( + index_files.get(filename) + and (filehash is None or index_files[filename] == filehash) + ) + + async def add_document( + self, index_doc: dict, document: Any | None = None, max_retries: int = 1000 + ): + @retry( + stop=stop_after_attempt(max_retries), + wait=wait_random_exponential(multiplier=0.25, max=60), + retry=retry_if_exception_type(AsyncRetryError), + reraise=True, + ) + async def _add_document_with_retry(): + if not await self.filecheck(index_doc["file_location"], index_doc["body"]): + try: + index = await self.index + writer = index.writer() + writer.add_document(Document.from_dict(index_doc)) + writer.commit() + + filehash = self.filehash(index_doc["body"]) + (await self.index_files)[index_doc["file_location"]] = filehash + + if document: + docs_index_dir = await self.docs_index_directory + async with await anyio.open_file( + docs_index_dir / f"{filehash}.{self.storage.extension()}", + "wb", + ) as f: + await f.write(self.storage.write_to_string(document)) + + self.changed = True + except ValueError as e: + if "Failed to acquire Lockfile: LockBusy." in str(e): + raise AsyncRetryError("Failed to acquire lock") from e + raise + + try: + await _add_document_with_retry() + except RetryError: + logger.exception( + f"Failed to add document after {max_retries} attempts: {index_doc['file_location']}" + ) + raise + + # Success + + @staticmethod + @retry( + stop=stop_after_attempt(1000), + wait=wait_random_exponential(multiplier=0.25, max=60), + retry=retry_if_exception_type(AsyncRetryError), + reraise=True, + ) + def delete_document(index: Index, file_location: str) -> None: + try: + writer = index.writer() + writer.delete_documents("file_location", file_location) + writer.commit() + except ValueError as e: + if "Failed to acquire Lockfile: LockBusy." in str(e): + raise AsyncRetryError("Failed to acquire lock") from e + raise + + async def remove_from_index(self, file_location: str) -> None: + index_files = await self.index_files + if index_files.get(file_location): + index = await self.index + self.delete_document(index, file_location) + filehash = index_files.pop(file_location) + docs_index_dir = await self.docs_index_directory + # TODO: since the directory is part of the filehash these + # are always missing. Unsure of how to get around this. + await (docs_index_dir / f"{filehash}.{self.storage.extension()}").unlink( + missing_ok=True + ) + + self.changed = True + + async def save_index(self) -> None: + file_index_path = await self.file_index_filename + async with await anyio.open_file(file_index_path, "wb") as f: + await f.write(zlib.compress(pickle.dumps(await self.index_files))) + + async def get_saved_object( + self, file_location: str, keep_filenames: bool = False + ) -> Any | None | tuple[Any, str]: + index_files = await self.index_files + filehash = index_files.get(file_location) + if filehash: + docs_index_dir = await self.docs_index_directory + async with await anyio.open_file( + docs_index_dir / f"{filehash}.{self.storage.extension()}", "rb" + ) as f: + content = await f.read() + if keep_filenames: + return self.storage.read_from_string(content), file_location + return self.storage.read_from_string(content) + return None + + def clean_query(self, query: str) -> str: + for replace in {"*", "[", "]"}: + query = query.replace(replace, "") + return query + + async def query( + self, + query: str, + top_n: int = 10, + offset: int = 0, + min_score: float = 0.0, + keep_filenames: bool = False, + field_subset: list[str] | None = None, + ) -> list[Any]: + query_fields = field_subset or self.fields + searcher = await self.searcher + index = await self.index + results = [ + s[1] + for s in searcher.search( + index.parse_query(self.clean_query(query), query_fields), top_n + ).hits + if s[0] > min_score + ][offset : offset + top_n] + search_index_docs = [searcher.doc(result) for result in results] + return [ + result + for result in [ + await self.get_saved_object( + doc["file_location"][0], keep_filenames=keep_filenames + ) + for doc in search_index_docs + ] + if result is not None + ] + + +async def maybe_get_manifest(filename: anyio.Path | None) -> dict[str, DocDetails]: + if not filename: + return {} + if filename.suffix == ".csv": + try: + async with await anyio.open_file(filename, mode="r") as file: + content = await file.read() + reader = csv.DictReader(StringIO(content)) + records = [DocDetails(**row) for row in reader] + return {str(r.file_location): r for r in records if r.file_location} + except FileNotFoundError: + logging.warning(f"Manifest file at {filename} could not be found.") + except Exception: + logging.exception(f"Error reading manifest file {filename}") + else: + logging.error(f"Invalid manifest file type: {filename.suffix}") + + return {} + + +async def process_file( + file_path: anyio.Path, + search_index: SearchIndex, + metadata: dict[str, Any], + semaphore: anyio.Semaphore, + docs_kwargs: dict[str, Any], +) -> None: + + async with semaphore: + + file_name = file_path.name + if not await search_index.filecheck(str(file_path)): + logger.info(f"New file to index: {file_name}...") + + doi, title = None, None + if file_name in metadata: + doi, title = metadata[file_name].doi, metadata[file_name].title + + # note extras are forbidden in docs + chunk_chars, overlap = 3000, 250 + if "chunk_chars" in docs_kwargs: + chunk_chars = int(docs_kwargs.pop("chunk_chars")) + if "overlap" in docs_kwargs: + overlap = int(docs_kwargs.pop("overlap")) + + tmp_docs = Docs(**docs_kwargs) + try: + await tmp_docs.aadd( + path=pathlib.Path(file_path), + title=title, + doi=doi, + chunk_chars=chunk_chars, + overlap=overlap, + fields=["title", "author", "journal", "year"], + use_doc_details=True, + ) + except ValueError: + logger.exception( + f"Error parsing {file_name}, skipping index for this file." + ) + (await search_index.index_files)[str(file_path)] = "ERROR" + await search_index.save_index() + return + + this_doc = next(iter(tmp_docs.docs.values())) + + if isinstance(this_doc, DocDetails): + title = this_doc.title or file_name + year = this_doc.year or "Unknown year" + else: + title, year = file_name, "Unknown year" + + await search_index.add_document( + { + "title": title, + "year": year, + "file_location": str(file_path), + "body": "".join([t.text for t in tmp_docs.texts]), + }, + document=tmp_docs, + ) + await search_index.save_index() + logger.info(f"Complete ({title}).") + + +async def get_directory_index( + directory: anyio.Path, + manifest_file: anyio.Path | None = None, + index_name: str = "pqa_index", + index_directory: str | os.PathLike | None = None, + sync_index_w_directory: bool = True, + use_absolute_directory_path: bool = PQA_INDEX_ABSOLUTE_PATHS, + max_concurrency: int = 30, + **docs_kwargs, +) -> SearchIndex: + """ + Create a Tantivy index from a directory of text files. + + Args: + directory: Directory to index. + manifest_file: File with metadata for each document. + index_name: Name of the index. + index_directory: Directory to store the index. + sync_index_w_directory: Sync the index with the directory. (i.e. delete files not in directory) + use_absolute_directory_path: Use the absolute path for the directory. + docs_kwargs: Keyword arguments for the Docs object. + max_concurrency: maximum number of files to be simultaneously indexed. + """ + semaphore = anyio.Semaphore(max_concurrency) + + if use_absolute_directory_path: + directory = await directory.absolute() + + search_index = SearchIndex( + fields=SearchIndex.REQUIRED_FIELDS | {"title", "year"}, + index_name=index_name, + index_directory=index_directory or pqa_directory("indexes"), + ) + + metadata = await maybe_get_manifest(manifest_file) + + valid_files = [ + file + async for file in directory.iterdir() + if file.suffix in {".txt", ".pdf", ".html"} + ] + index_files = await search_index.index_files + + if missing := (set(index_files.keys()) - {str(f) for f in valid_files}): + if sync_index_w_directory: + for missing_file in missing: + logger.warning( + f"[bold red]Removing {missing_file} from index.[/bold red]" + ) + await search_index.remove_from_index(missing_file) + logger.warning("[bold red]Files removed![/bold red]") + else: + logger.warning( + f"[bold red]Indexed files are missing from index folder ({directory}).[/bold red]" + ) + logger.warning(f"[bold red]files: {missing}[/bold red]") + + async with anyio.create_task_group() as tg: + for file_path in valid_files: + if sync_index_w_directory: + tg.start_soon( + process_file, + file_path, + search_index, + metadata, + semaphore, + docs_kwargs, + ) + else: + logger.debug(f"New file {file_path.name} found in directory.") + + if search_index.changed: + await search_index.save_index() + else: + logger.debug("No changes to index.") + + return search_index diff --git a/paperqa/agents/tools.py b/paperqa/agents/tools.py new file mode 100644 index 00000000..df8ec9a6 --- /dev/null +++ b/paperqa/agents/tools.py @@ -0,0 +1,395 @@ +from __future__ import annotations +import inspect +import logging +import os +import re +import sys +from typing import ClassVar +import anyio +from langchain_core.callbacks import BaseCallbackHandler + +from langchain.tools import BaseTool +from paperqa import Answer, Docs +from ..utils import pqa_directory +from pydantic import BaseModel, ConfigDict, Field + +# ruff: noqa: I001 +from pydantic.v1 import ( # TODO: move to Pydantic v2 after LangChain moves to Pydantic v2, SEE: https://github.com/langchain-ai/langchain/issues/16306 + BaseModel as BaseModelV1, + Field as FieldV1, +) + +from .helpers import compute_total_model_token_cost, get_year +from .search import get_directory_index +from .models import ParsingConfiguration, QueryRequest, SimpleProfiler + +logger = logging.getLogger(__name__) + + +async def status(docs: Docs, answer: Answer, relevant_score_cutoff: int = 5) -> str: + """Create a string that provides a summary of the input doc/answer.""" + answer.cost = compute_total_model_token_cost(answer.token_counts) + return ( + f"Status: Paper Count={len(docs.docs)}" + f" | Relevant Papers={len({c.text.doc.dockey for c in answer.contexts if c.score > relevant_score_cutoff})}" + f" | Current Evidence={len([c for c in answer.contexts if c.score > relevant_score_cutoff])}" + f" | Current Cost=${answer.cost:.2f}" + ) + + +class SharedToolState(BaseModel): + """Shared collection of variables for collection of tools. We use this to avoid + the fact that pydantic treats dictionaries as values, instead of references. + """ # noqa: D205 + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + answer: Answer + docs: Docs + profiler: SimpleProfiler = Field(default_factory=SimpleProfiler) + + # SEE: https://regex101.com/r/RmuVdC/1 + STATUS_SEARCH_REGEX_PATTERN: ClassVar[str] = ( + r"Status: Paper Count=(\d+) \| Relevant Papers=(\d+) \| Current Evidence=(\d+)" + ) + + async def get_status(self) -> str: + return await status(self.docs, self.answer) + + +def _time_tool(func): + """Decorator to time the execution of a tool method. + Assumes that the tool has a shared state. + """ # noqa: D205 + + async def wrapper(self, *args, **kwargs): + async with self.shared_state.profiler.timer(self.name): + return await func(self, *args, **kwargs) + + return wrapper + + +class PaperSearchTool(BaseTool): + class InputSchema( + BaseModelV1 # TODO: move to Pydantic v2 after LangChain moves to Pydantic v2, SEE: https://github.com/langchain-ai/langchain/issues/16306 + ): + query: str = FieldV1( + description=( + "A search query in this format: [query], [start year]-[end year]. " + "You may include years as the last word in the query, " + "e.g. 'machine learning 2020' or 'machine learning 2010-2020'. " + f"The current year is {get_year()}. " + "The query portion can be a specific phrase, complete sentence, " + "or general keywords, e.g. 'machine learning for immunology'." + ) + ) + + paper_directory: str | os.PathLike = "." + index_directory: str | os.PathLike | None = None + manifest_file: str | os.PathLike | None = None + name: str = "paper_search" + args_schema: type[BaseModelV1] | None = InputSchema + description: str = ( + "Search for papers to increase the paper count. You can call this a second " + "time with an different search to gather more papers." + ) + + shared_state: SharedToolState + return_paper_metadata: bool = False + # Second item being True means specify a year range in the search + search_type: tuple[str, bool] = ("google", False) + search_count: int = 8 + previous_searches: dict[str, int] = FieldV1(default_factory=dict) + embedding: str = "text-embedding-3-small" + parsing_configuration: ParsingConfiguration = FieldV1( + default_factory=ParsingConfiguration + ) + + def _run(self, query: str) -> str: + raise NotImplementedError + + @_time_tool + async def _arun(self, query: str) -> str: + """ + Run asynchronously, in-place mutating `self.shared_state.docs`. + + Args: + query: Search keywords followed by optional year or year range + (e.g. COVID-19 vaccines, 2022). + + Returns: + String describing searched papers and the current status. + """ + # get offset if we've done this search before (continuation of search) + # or mark this search as new (so offset 0) + logger.info(f"Starting paper search for '{query}'.") + search_key = query + if search_key in self.previous_searches: + offset = self.previous_searches[search_key] + else: + offset = self.previous_searches[search_key] = 0 + + # Preprocess inputs to make ScrapeRequest + keywords = query.replace('"', "") # Remove quotes + year: str | None = None + last_word = keywords.split(" ")[-1] + if re.match(r"\d{4}(-\d{4})?", last_word): + keywords = keywords[: -len(last_word)].removesuffix(",").strip() + if self.search_type[1]: + year = last_word + if "-" not in year: + year = year + "-" + year # Convert to date range (e.g. 2022-2022) + index = await get_directory_index( + directory=anyio.Path(self.paper_directory), + index_name=QueryRequest.get_index_name( + self.paper_directory, self.embedding, self.parsing_configuration + ), + index_directory=self.index_directory, + manifest_file=( + anyio.Path(self.manifest_file) if self.manifest_file else None + ), + embedding=self.embedding, + chunk_chars=self.parsing_configuration.chunksize, + overlap=self.parsing_configuration.overlap, + ) + + results = await index.query( + keywords, + top_n=self.search_count, + offset=offset, + field_subset=[f for f in index.fields if f != "year"], + ) + + logger.info(f"Search for '{keywords}' returned {len(results)} papers.") + # combine all the resulting doc objects into one and update the state + # there's only one doc per result, so we can just take the first one + all_docs = [] + for r in results: + this_doc = next(iter(r.docs.values())) + all_docs.append(this_doc) + await self.shared_state.docs.aadd_texts(texts=r.texts, doc=this_doc) + + status = await self.shared_state.get_status() + + logger.info(status) + + # mark how far we've searched so that continuation will start at the right place + self.previous_searches[search_key] += self.search_count + + if self.return_paper_metadata: + retrieved_papers = "\n".join([f"{x.title} ({x.year})" for x in all_docs]) + return f"Retrieved Papers:\n{retrieved_papers}\n\n{status}" + return status + + +class EmptyDocsError(RuntimeError): + """Error to throw when we needed docs to be present.""" + + +class GatherEvidenceTool(BaseTool): + class InputSchema( + BaseModelV1 # TODO: move to Pydantic v2 after LangChain moves to Pydantic v2, SEE: https://github.com/langchain-ai/langchain/issues/16306 + ): + question: str = FieldV1(description="Specific question to gather evidence for.") + + name: str = "gather_evidence" + args_schema: type[BaseModelV1] | None = InputSchema + description: str = ( + "Gather evidence from previous papers given a specific question. " + "This will increase evidence and relevant paper counts. " + "Only invoke when paper count is above zero." + ) + + shared_state: SharedToolState + query: QueryRequest + + def _run(self, query: str) -> str: + raise NotImplementedError + + @_time_tool + async def _arun(self, question: str) -> str: + if not self.shared_state.docs.docs: + raise EmptyDocsError("Not gathering evidence due to having no papers.") + + logger.info(f"Gathering and ranking evidence for '{question}'.") + + # first we see if we'd like to filter any docs for relevance + # at the citation level + if len(self.shared_state.docs.docs) >= self.query.adoc_match_threshold: + doc_keys_to_keep = await self.shared_state.docs.adoc_match( + question, + rerank=True, # want to set it explicitly + answer=self.shared_state.answer, + ) + else: + doc_keys_to_keep = set(self.shared_state.docs.docs.keys()) + + self.shared_state.answer.dockey_filter = doc_keys_to_keep + + # swap out the question + # TODO: evaluate how often does the agent changes the question + old = self.shared_state.answer.question + self.shared_state.answer.question = question + + # generator, so run it + l0 = len(self.shared_state.answer.contexts) + + # set jit so that the index is rebuilt; helps if the texts have changed + self.shared_state.docs.jit_texts_index = True + # ensure length is set correctly + self.shared_state.answer.summary_length = self.query.summary_length + # TODO: refactor answer out of this... + self.shared_state.answer = await self.shared_state.docs.aget_evidence( + answer=self.shared_state.answer, + max_sources=self.query.max_sources, + k=self.query.consider_sources, + detailed_citations=True, + ) + l1 = len(self.shared_state.answer.contexts) + self.shared_state.answer.question = old + sorted_contexts = sorted( + self.shared_state.answer.contexts, key=lambda x: x.score, reverse=True + ) + best_evidence = "" + if len(sorted_contexts) > 0: + best_evidence = f" Best evidence:\n\n{sorted_contexts[0].context}" + status = await self.shared_state.get_status() + logger.info(status) + return f"Added {l1 - l0} pieces of evidence.{best_evidence}\n\n" + status + + +class GenerateAnswerTool(BaseTool): + class InputSchema( + BaseModelV1 # TODO: move to Pydantic v2 after LangChain moves to Pydantic v2, SEE: https://github.com/langchain-ai/langchain/issues/16306 + ): + question: str = FieldV1(description="Question to be answered.") + + name: str = "gen_answer" + args_schema: type[BaseModelV1] | None = InputSchema + description: str = ( + "Ask a model to propose an answer answer using current evidence. " + "The tool may fail, " + "indicating that better or different evidence should be found. " + "Having more than one piece of evidence or relevant papers is best." + ) + shared_state: SharedToolState + wipe_context_on_answer_failure: bool = True + query: QueryRequest + + FAILED_TO_ANSWER: ClassVar[str] = "Failed to answer question." + + @classmethod + def did_not_fail_to_answer(cls, message: str) -> bool: + return not message.startswith(cls.FAILED_TO_ANSWER) + + def _run(self, query: str) -> str: + raise NotImplementedError + + @_time_tool + async def _arun(self, question: str) -> str: + logger.info(f"Generating answer for '{question}'.") + # TODO: Should we allow the agent to change the question? + # self.answer.question = query + self.shared_state.answer.answer_length = self.query.length + self.shared_state.answer = await self.shared_state.docs.aquery( + self.query.query, + k=self.query.consider_sources, + max_sources=self.query.max_sources, + answer=self.shared_state.answer, + ) + + if self.query.filter_extra_background: + # filter out "(Extra Background Information)" from the answer + self.shared_state.answer.answer = re.sub( + r"\([Ee]xtra [Bb]ackground [Ii]nformation\)", + "", + self.shared_state.answer.answer, + ) + + if "cannot answer" in self.shared_state.answer.answer.lower(): + if self.wipe_context_on_answer_failure: + self.shared_state.answer.contexts = [] + self.shared_state.answer.dockey_filter = None + self.shared_state.answer.context = "" + status = await self.shared_state.get_status() + logger.info(status) + return f"{self.FAILED_TO_ANSWER} | " + status + status = await self.shared_state.get_status() + logger.info(status) + return f"{self.shared_state.answer.answer} | {status}" + + # NOTE: can match failure to answer or an actual answer + ANSWER_SPLIT_REGEX_PATTERN: ClassVar[str] = ( + r" \| " + SharedToolState.STATUS_SEARCH_REGEX_PATTERN + ) + + @classmethod + def extract_answer_from_message(cls, content: str) -> str: + """Extract the answer from a message content.""" + answer, *rest = re.split( + pattern=cls.ANSWER_SPLIT_REGEX_PATTERN, string=content, maxsplit=1 + ) + if len(rest) != 4 or not cls.did_not_fail_to_answer(answer): # noqa: PLR2004 + return "" + return answer + + +AVAILABLE_TOOL_NAME_TO_CLASS: dict[str, type[BaseTool]] = { + cls.__fields__["name"].default: cls + for _, cls in inspect.getmembers( + sys.modules[__name__], + predicate=lambda v: inspect.isclass(v) + and issubclass(v, BaseTool) + and v is not BaseTool, + ) +} + + +def query_to_tools( + query: QueryRequest, + state: SharedToolState, + callbacks: list[BaseCallbackHandler] | None = None, +) -> list[BaseTool]: + if query.agent_tools.tool_names is None: + tool_types: list[type[BaseTool]] = [ + PaperSearchTool, + GatherEvidenceTool, + GenerateAnswerTool, + ] + else: + tool_types = [ + AVAILABLE_TOOL_NAME_TO_CLASS[name] + for name in set(query.agent_tools.tool_names) + ] + tools: list[BaseTool] = [] + for tool_type in tool_types: + if issubclass(tool_type, PaperSearchTool): + tools.append( + PaperSearchTool( + shared_state=state, + search_count=query.agent_tools.search_count, + embedding=query.embedding, + parsing_configuration=query.parsing_configuration, + paper_directory=query.agent_tools.paper_directory, + index_directory=query.agent_tools.index_directory + or pqa_directory("indexes"), + manifest_file=query.agent_tools.manifest_file, + callbacks=callbacks, + ) + ) + elif issubclass(tool_type, GatherEvidenceTool): + tools.append( + GatherEvidenceTool(shared_state=state, query=query, callbacks=callbacks) + ) + elif issubclass(tool_type, GenerateAnswerTool): + tools.append( + GenerateAnswerTool( + shared_state=state, + wipe_context_on_answer_failure=query.agent_tools.wipe_context_on_answer_failure, + query=query, + callbacks=callbacks, + ) + ) + else: + tools.append(tool_type(shared_state=state)) + return tools diff --git a/paperqa/clients/__init__.py b/paperqa/clients/__init__.py index e3956f98..654715dc 100644 --- a/paperqa/clients/__init__.py +++ b/paperqa/clients/__init__.py @@ -150,6 +150,9 @@ async def query(self, **kwargs) -> DocDetails | None: ) break + if self._session is None: + await session.close() + return doc_details async def bulk_query( @@ -160,18 +163,31 @@ async def bulk_query( ) async def upgrade_doc_to_doc_details(self, doc: Doc, **kwargs) -> DocDetails: + + # note we have some extra fields which may have come from reading the doc text, + # but aren't in the doc object, we add them here too. + extra_fields = { + k: v for k, v in kwargs.items() if k in {"title", "authors", "doi"} + } + # abuse our doc_details object to be an int if it's empty + # our __add__ operation supports int by doing nothing + extra_doc: int | DocDetails = ( + 0 if not extra_fields else DocDetails(**extra_fields) + ) + if doc_details := await self.query(**kwargs): if doc.overwrite_fields_from_metadata: - return doc_details + return extra_doc + doc_details + # hard overwrite the details from the prior object doc_details.dockey = doc.dockey doc_details.doc_id = doc.dockey doc_details.docname = doc.docname doc_details.key = doc.docname doc_details.citation = doc.citation - return doc_details + return extra_doc + doc_details # if we can't get metadata, just return the doc, but don't overwrite any fields prior_doc = doc.model_dump() prior_doc["overwrite_fields_from_metadata"] = False - return DocDetails(**prior_doc) + return DocDetails(**(prior_doc | extra_fields)) diff --git a/paperqa/clients/client_models.py b/paperqa/clients/client_models.py index bdb8b254..36e0765b 100644 --- a/paperqa/clients/client_models.py +++ b/paperqa/clients/client_models.py @@ -63,9 +63,17 @@ class DOIQuery(ClientQuery): @model_validator(mode="before") @classmethod - def ensure_fields_are_present(cls, data: dict[str, Any]) -> dict[str, Any]: + def add_doi_to_fields_and_validate(cls, data: dict[str, Any]) -> dict[str, Any]: + if (fields := data.get("fields")) and "doi" not in fields: fields.append("doi") + + # sometimes the DOI has a URL prefix, remove it + remove_urls = ["https://doi.org/", "http://dx.doi.org/"] + for url in remove_urls: + if data["doi"].startswith(url): + data["doi"] = data["doi"].replace(url, "") + return data @@ -101,14 +109,14 @@ async def query(self, query: dict) -> DocDetails | None: # DOINotFoundError means the paper doesn't exist in the source, the timeout is to prevent # this service from failing us when it's down or slow. except DOINotFoundError: - logger.exception( + logger.warning( f"Metadata not found for " f"{client_query.doi if isinstance(client_query, DOIQuery) else client_query.title}" - " in Crossref." + f" in {self.__class__.__name__}." ) except TimeoutError: - logger.exception( - f"Request to Crossref for " + logger.warning( + f"Request to {self.__class__.__name__} for " f"{client_query.doi if isinstance(client_query, DOIQuery) else client_query.title}" " timed out." ) diff --git a/paperqa/clients/crossref.py b/paperqa/clients/crossref.py index 7e587514..e85770ef 100644 --- a/paperqa/clients/crossref.py +++ b/paperqa/clients/crossref.py @@ -14,6 +14,7 @@ from ..types import CITATION_FALLBACK_DATA, DocDetails from ..utils import ( bibtex_field_extract, + create_bibtex_key, remove_substrings, strings_similarity, union_collections_to_ordered_list, @@ -136,7 +137,9 @@ async def doi_to_bibtex( ] # replace the key if all the fragments are present if all(fragments): - new_key = remove_substrings(("".join(fragments)), FORBIDDEN_KEY_CHARACTERS) + new_key = create_bibtex_key( + author=fragments[0].split(), year=fragments[1], title=fragments[2] + ) # we use the count parameter below to ensure only the 1st entry is replaced return data.replace(key, new_key, 1) @@ -265,7 +268,8 @@ async def get_doc_details_from_crossref( # noqa: C901, PLR0912 query_bibtex = True - if fields: + # note we only do field selection if querying on title + if fields and title: # crossref has a special endpoint for bibtex, so we don't need to request it here if "bibtex" not in fields: query_bibtex = False diff --git a/paperqa/clients/semantic_scholar.py b/paperqa/clients/semantic_scholar.py index d11fc9cb..95e283a4 100644 --- a/paperqa/clients/semantic_scholar.py +++ b/paperqa/clients/semantic_scholar.py @@ -264,6 +264,9 @@ async def get_s2_doc_details_from_doi( session=session, headers=semantic_scholar_headers(), timeout=SEMANTIC_SCHOLAR_API_REQUEST_TIMEOUT, + http_exception_mappings={ + HTTPStatus.NOT_FOUND: DOINotFoundError(f"Could not find DOI for {doi}.") + }, ) return await parse_s2_to_doc_details(details, session) diff --git a/paperqa/docs.py b/paperqa/docs.py index 214653f0..9d49aee7 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -13,6 +13,7 @@ from openai import AsyncOpenAI from pydantic import BaseModel, ConfigDict, Field, model_validator +from tenacity import retry, stop_after_attempt try: import voyageai @@ -62,11 +63,11 @@ # this is just to reduce None checks/type checks -async def empty_callback(result: LLMResult): # noqa: ARG001 +async def empty_callback(result: LLMResult): pass -async def print_callback(result: LLMResult): # noqa: ARG001 +async def print_callback(result: LLMResult): pass @@ -390,6 +391,7 @@ async def aadd( # noqa: C901, PLR0912, PLR0915 disable_check: bool = False, dockey: DocKey | None = None, chunk_chars: int = 3000, + overlap: int = 250, title: str | None = None, doi: str | None = None, authors: list[str] | None = None, @@ -408,7 +410,7 @@ async def aadd( # noqa: C901, PLR0912, PLR0915 ) # peak first chunk fake_doc = Doc(docname="", citation="", dockey=dockey) - texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100) + texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=overlap) if len(texts) == 0: raise ValueError(f"Could not read document {path}. Is it empty?") chain_result = await cite_chain({"text": texts[0].text}, None) @@ -450,7 +452,10 @@ async def aadd( # noqa: C901, PLR0912, PLR0915 ) chain_result = await structured_cite_chain({"citation": citation}, None) with contextlib.suppress(json.JSONDecodeError): - citation_json = json.loads(chain_result.text) + clean_text = chain_result.text.strip("`") + if clean_text.startswith("json"): + clean_text = clean_text.replace("json", "", 1) + citation_json = json.loads(clean_text) if citation_title := citation_json.get("title"): title = citation_title if citation_doi := citation_json.get("doi"): @@ -492,7 +497,7 @@ async def aadd( # noqa: C901, PLR0912, PLR0915 or (not disable_check and not maybe_is_text(texts[0].text)) ): raise ValueError( - f"This does not look like a text document: {path}. Path disable_check to ignore this error." + f"This does not look like a text document: {path}. Pass disable_check to ignore this error." ) if await self.aadd_texts(texts, doc): return docname @@ -575,6 +580,11 @@ def delete( self.deleted_dockeys.add(dockey) self.texts = list(filter(lambda x: x.doc.dockey != dockey, self.texts)) + # no state modifications in adoc_match--only answer is changed + @retry( + stop=stop_after_attempt(3), + reraise=True, + ) async def adoc_match( self, query: str, diff --git a/paperqa/llms.py b/paperqa/llms.py index 15f9a319..3172b0f3 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -221,6 +221,7 @@ class LLMModel(ABC, BaseModel): llm_type: str | None = None name: str + config: dict = Field(default={}) async def acomplete(self, client: Any, prompt: str) -> str: raise NotImplementedError @@ -887,12 +888,11 @@ def llm_model_factory(llm: str) -> LLMModel: if llm != "default": if is_openai_model(llm): return OpenAILLMModel(config={"model": llm}) - elif llm.startswith("langchain"): # noqa: RET505 + if llm.startswith("langchain") or "gemini" in llm: return LangchainLLMModel() - elif "claude" in llm: + if "claude" in llm: return AnthropicLLMModel(config={"model": llm}) - else: - raise ValueError(f"Could not guess model type for {llm}. ") + raise ValueError(f"Could not guess model type for {llm}. ") return OpenAILLMModel() diff --git a/paperqa/prompts.py b/paperqa/prompts.py index 3782ce17..ea4b3de0 100644 --- a/paperqa/prompts.py +++ b/paperqa/prompts.py @@ -19,8 +19,7 @@ "Context (with relevance scores):\n\n{context}\n\n----\n\n" "Question: {question}\n\n" "Write an answer based on the context. " - "If the context provides insufficient information and " - "the question cannot be directly answered, reply " + "If the context provides insufficient information reply " '"I cannot answer."' "For each part of your answer, indicate which sources most support " "it via citation keys at the end of sentences, " diff --git a/paperqa/types.py b/paperqa/types.py index b0ef0489..3df2a43d 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os import re from datetime import datetime from typing import Any, Callable, ClassVar, Collection @@ -9,6 +10,7 @@ import tiktoken from pybtex.database import BibliographyData, Entry, Person from pybtex.database.input.bibtex import Parser +from pybtex.scanner import PybtexSyntaxError from pydantic import ( BaseModel, ConfigDict, @@ -266,6 +268,21 @@ def get_unique_docs_from_contexts(self, score_threshold: int = 0) -> set[Doc]: for c in filter(lambda x: x.score >= score_threshold, self.contexts) } + def filter_content_for_user(self) -> None: + """Filter out extra items (inplace) that do not need to be returned to the user.""" + self.contexts = [ + Context( + context=c.context, + score=c.score, + text=Text( + text="", + **c.text.model_dump(exclude={"text", "embedding", "doc"}), + doc=Doc(**c.text.doc.model_dump(exclude={"embedding"})), + ), + ) + for c in self.contexts + ] + class ChunkMetadata(BaseModel): """Metadata for chunking algorithm.""" @@ -360,11 +377,17 @@ class DocDetails(Doc): doi: str | None = None doi_url: str | None = None doc_id: str | None = None + file_location: str | os.PathLike | None = None other: dict[str, Any] = Field( default_factory=dict, description="Other metadata besides the above standardized fields.", ) UNDEFINED_JOURNAL_QUALITY: ClassVar[int] = -1 + DOI_URL_FORMATS: ClassVar[Collection[str]] = { + "https://doi.org/", + "http://dx.doi.org/", + } + AUTHOR_NAMES_TO_REMOVE: ClassVar[Collection[str]] = {"et al", "et al."} @field_validator("key") @classmethod @@ -372,9 +395,13 @@ def clean_key(cls, value: str) -> str: # Replace HTML tags with empty string return re.sub(pattern=r"<\/?\w{1,10}>", repl="", string=value) - @staticmethod - def lowercase_doi_and_populate_doc_id(data: dict[str, Any]) -> dict[str, Any]: + @classmethod + def lowercase_doi_and_populate_doc_id(cls, data: dict[str, Any]) -> dict[str, Any]: if doi := data.get("doi"): + remove_urls = cls.DOI_URL_FORMATS + for url in remove_urls: + if doi.startswith(url): + doi = doi.replace(url, "") data["doi"] = doi.lower() data["doc_id"] = encode_id(doi.lower()) else: @@ -419,6 +446,7 @@ def inject_clean_doi_url_into_data(data: dict[str, Any]) -> dict[str, Any]: if doi and not doi_url: doi_url = "https://doi.org/" + doi + # ensure the modern doi url is used if doi_url: data["doi_url"] = doi_url.replace( "http://dx.doi.org/", "https://doi.org/" @@ -426,6 +454,16 @@ def inject_clean_doi_url_into_data(data: dict[str, Any]) -> dict[str, Any]: return data + @classmethod + def remove_invalid_authors(cls, data: dict[str, Any]) -> dict[str, Any]: + """Capture and cull strange author names.""" + if authors := data.get("authors"): + data["authors"] = [ + a for a in authors if a.lower() not in cls.AUTHOR_NAMES_TO_REMOVE + ] + + return data + @staticmethod def overwrite_docname_dockey_for_compatibility_w_doc( data: dict[str, Any] @@ -477,10 +515,13 @@ def populate_bibtex_key_citation( # noqa: C901, PLR0912 data["other"]["bibtex_source"] = ["self_generated"] else: data["other"] = {"bibtex_source": ["self_generated"]} - - existing_entry = next( - iter(Parser().parse_string(data["bibtex"]).entries.values()) - ) + try: + existing_entry = next( + iter(Parser().parse_string(data["bibtex"]).entries.values()) + ) + except PybtexSyntaxError: + logger.warning(f"Failed to parse bibtex for {data['bibtex']}.") + existing_entry = None entry_data = { "title": data.get("title") or CITATION_FALLBACK_DATA["title"], @@ -505,7 +546,9 @@ def populate_bibtex_key_citation( # noqa: C901, PLR0912 } entry_data = {k: v for k, v in entry_data.items() if v} try: - new_entry = Entry(data.get("bibtex_type", "article"), fields=entry_data) + new_entry = Entry( + data.get("bibtex_type", "article") or "article", fields=entry_data + ) if existing_entry: new_entry = cls.merge_bibtex_entries(existing_entry, new_entry) # add in authors manually into the entry @@ -519,7 +562,9 @@ def populate_bibtex_key_citation( # noqa: C901, PLR0912 if data.get("overwrite_fields_from_metadata", True): data["citation"] = None except Exception: - logger.exception(f"Failed to generate bibtex for {data}") + logger.warning( + f"Failed to generate bibtex for {data.get('docname') or data.get('citation')}" + ) if not data.get("citation"): data["citation"] = format_bibtex( data["bibtex"], clean=True, missing_replacements=CITATION_FALLBACK_DATA # type: ignore[arg-type] @@ -530,6 +575,7 @@ def populate_bibtex_key_citation( # noqa: C901, PLR0912 @classmethod def validate_all_fields(cls, data: dict[str, Any]) -> dict[str, Any]: data = cls.lowercase_doi_and_populate_doc_id(data) + data = cls.remove_invalid_authors(data) data = cls.misc_string_cleaning(data) data = cls.inject_clean_doi_url_into_data(data) data = cls.populate_bibtex_key_citation(data) @@ -646,7 +692,9 @@ def __add__(self, other: DocDetails | int) -> DocDetails: # noqa: C901 # pre-prints / arXiv versions of papers that are not as up-to-date merged_data[field] = ( other_value - if (other_value is not None and PREFER_OTHER) + if ( + (other_value is not None and other_value != []) and PREFER_OTHER + ) else self_value ) diff --git a/paperqa/utils.py b/paperqa/utils.py index 19f521e0..d92a9b45 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -6,6 +6,7 @@ import json import logging import math +import os import re import string from collections.abc import Iterable @@ -92,11 +93,15 @@ def count_pdf_pages(file_path: StrPath) -> int: return num_pages -def md5sum(file_path: StrPath) -> str: - import hashlib +def hexdigest(data: str | bytes) -> str: + if isinstance(data, str): + return hashlib.md5(data.encode("utf-8")).hexdigest() # noqa: S324 + return hashlib.md5(data).hexdigest() # noqa: S324 + +def md5sum(file_path: StrPath) -> str: with open(file_path, "rb") as f: - return hashlib.md5(f.read()).hexdigest() # noqa: S324 + return hexdigest(f.read()) async def gather_with_concurrency(n: int, coros: list[Coroutine]) -> list[Any]: @@ -367,15 +372,17 @@ def bibtex_field_extract( def create_bibtex_key(author: list[str], year: str, title: str) -> str: - FORBIDDEN_KEY_CHARACTERS = {"_", " ", "-", "/", "'", "`", ":"} + FORBIDDEN_KEY_CHARACTERS = {"_", " ", "-", "/", "'", "`", ":", ",", "\n"} author_rep = ( author[0].split()[-1].casefold() if "Unknown" not in author[0] else "unknownauthors" ) - key = ( - f"{author_rep}{year}{''.join([t.casefold() for t in title.split()[:3]])[:100]}" - ) + # we don't want a bibtex-parsing induced line break in the key + # so we cap it to 100+50+4 = 154 characters max + # 50 for the author, 100 for the first three title words, 4 for the year + # the first three title words are just emulating the s2 convention + key = f"{author_rep[:50]}{year}{''.join([t.casefold() for t in title.split()[:3]])[:100]}" return remove_substrings(key, FORBIDDEN_KEY_CHARACTERS) @@ -398,7 +405,7 @@ async def _get_with_retrying( params: dict[str, Any], session: aiohttp.ClientSession, headers: dict[str, str] | None = None, - timeout: float = 10.0, + timeout: float = 10.0, # noqa: ASYNC109 http_exception_mappings: dict[HTTPStatus | int, Exception] | None = None, ) -> dict[str, Any]: """Get from a URL with retrying protection.""" @@ -419,3 +426,13 @@ async def _get_with_retrying( def union_collections_to_ordered_list(collections: Iterable) -> list: return sorted(reduce(lambda x, y: set(x) | set(y), collections)) + + +def pqa_directory(name: str) -> Path: + if pqa_home := os.environ.get("PQA_HOME"): + directory = Path(pqa_home) / ".pqa" / name + else: + directory = Path.home() / ".pqa" / name + + directory.mkdir(parents=True, exist_ok=True) + return directory diff --git a/pyproject.toml b/pyproject.toml index b8e5bbd6..82f84a7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,14 +26,16 @@ dependencies = [ "pybtex", "pydantic~=2.0", "pypdf", + "tenacity", "tiktoken>=0.4.0", ] description = "LLM Chain for answering questions from docs" -dynamic = ["optional-dependencies", "version"] +dynamic = ["version"] keywords = ["question answering"] license = {file = "LICENSE"} maintainers = [ {email = "jamesbraza@gmail.com", name = "James Braza"}, + {email = "michael.skarlinski@gmail.com", name = "Michael Skarlinski"}, {email = "white.d.andrew@gmail.com", name = "Andrew White"}, ] name = "paper-qa" @@ -41,6 +43,51 @@ readme = "README.md" requires-python = ">=3.8" urls = {repository = "https://github.com/whitead/paper-qa"} +[project.optional-dependencies] +agents = [ + "anyio", + "langchain-community", + "langchain-openai", + "pymupdf", + "requests", + "tantivy", + "typer", + "typing_extensions", +] +dev = [ + "build", + "mypy", + "pre-commit", + "pytest", + "pytest-asyncio", + "pytest-subtests", + "pytest-sugar", + "pytest-timer", + "pytest-vcr", + "python-dotenv", + "pyzotero", + "requests", + "types-PyYAML", + "types-requests", + "types-setuptools", +] +google = [ + "langchain-google-vertexai", + "vertexai", +] +llms = [ + "anthropic", + "faiss-cpu", + "langchain-community", + "langchain-openai", + "pymupdf", + "sentence_transformers", + "voyageai", +] + +[project.scripts] +pqa = "paperqa.agents:app" + [tool.codespell] check-filenames = true check-hidden = true @@ -115,6 +162,7 @@ unsafe-fixes = true [tool.ruff.lint] explicit-preview-rules = true extend-select = [ + "C420", "FURB110", "FURB113", "FURB116", @@ -137,7 +185,6 @@ extend-select = [ "PLR6201", "PLW0108", "RUF022", - "RUF025", ] # List of rule codes that are unsupported by Ruff, but should be preserved when # (e.g.) validating # noqa directives. Useful for retaining # noqa directives @@ -215,9 +262,6 @@ max-doc-length = 120 # Match line-length # defaults when analyzing docstring sections. convention = "google" -[tool.setuptools.dynamic.optional-dependencies.dev] -file = ["dev-requirements.txt"] - [tool.setuptools.packages.find] include = ["paperqa*"] diff --git a/tests/conftest.py b/tests/conftest.py index f52d7170..2c171f5e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,17 @@ +from __future__ import annotations + +import os +import shutil +from pathlib import Path +from typing import Generator +from unittest.mock import patch + import pytest from dotenv import load_dotenv from paperqa.clients.crossref import CROSSREF_HEADER_KEY from paperqa.clients.semantic_scholar import SEMANTIC_SCHOLAR_HEADER_KEY +from paperqa.types import Answer @pytest.fixture(autouse=True, scope="session") @@ -15,3 +24,38 @@ def vcr_config(): return { "filter_headers": [CROSSREF_HEADER_KEY, SEMANTIC_SCHOLAR_HEADER_KEY], } + + +@pytest.fixture +def tmp_path_cleanup( + tmp_path: Path, +) -> Generator[Path, None, None]: + yield tmp_path + # Cleanup after the test + if tmp_path.exists(): + shutil.rmtree(tmp_path, ignore_errors=True) + + +@pytest.fixture +def agent_home_dir( + tmp_path_cleanup: str | os.PathLike, +) -> Generator[str | os.PathLike, None, None]: + """Set up a unique temporary folder for the agent module.""" + with patch.dict("os.environ", {"PQA_HOME": str(tmp_path_cleanup)}): + yield tmp_path_cleanup + + +@pytest.fixture +def agent_index_dir(agent_home_dir: Path) -> Path: + return agent_home_dir / ".pqa" / "indexes" + + +@pytest.fixture(name="agent_test_kit") +def fixture_stub_answer() -> Answer: + return Answer(question="What is is a self-explanatory model?") + + +@pytest.fixture(name="stub_paper_path", scope="session") +def fixture_stub_paper_path() -> Path: + # Corresponds with https://www.semanticscholar.org/paper/A-Perspective-on-Explanations-of-Molecular-Models-Wellawatte-Gandhi/1db1bde653658ec9b30858ae14650b8f9c9d438b + return Path(__file__).parent / "paper.pdf" diff --git a/tests/stub_manifest.csv b/tests/stub_manifest.csv new file mode 100644 index 00000000..d6cf6f34 --- /dev/null +++ b/tests/stub_manifest.csv @@ -0,0 +1,4 @@ +file_location,doi,title +"paper.pdf","10.1021/acs.jctc.2c01235","A Perspective on Explanations of Molecular Prediction Models" +"example.txt",,"Frederick Bates (Wikipedia article)" +"example2.txt",,"Barack Obama (Wikipedia article)" diff --git a/tests/test_agents.py b/tests/test_agents.py new file mode 100644 index 00000000..9879d15f --- /dev/null +++ b/tests/test_agents.py @@ -0,0 +1,562 @@ +from __future__ import annotations + +import itertools +import json +import logging +import re +from pathlib import Path +from typing import Any, cast +from unittest.mock import patch + +import anyio +import pytest +from pydantic import ValidationError +from pytest_subtests import SubTests + +from paperqa.docs import Docs +from paperqa.llms import LangchainLLMModel +from paperqa.types import Answer, Context, Doc, PromptCollection, Text +from paperqa.utils import get_year + +try: + from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent + from langchain_openai import ChatOpenAI + from tenacity import Retrying, retry_if_exception_type, stop_after_attempt + + from paperqa.agents import agent_query + from paperqa.agents.helpers import ( + compute_total_model_token_cost, + update_doc_models, + ) + from paperqa.agents.models import ( + AgentPromptCollection, + AgentStatus, + AnswerResponse, + MismatchedModelsError, + QueryRequest, + ) + from paperqa.agents.prompts import STATIC_PROMPTS + from paperqa.agents.search import get_directory_index + from paperqa.agents.tools import ( + GatherEvidenceTool, + GenerateAnswerTool, + PaperSearchTool, + SharedToolState, + ) +except ImportError: + pytest.skip("agents module is not installed", allow_module_level=True) + + +PAPER_DIRECTORY = Path(__file__).parent + + +@pytest.mark.asyncio +async def test_get_directory_index(agent_index_dir): + index = await get_directory_index( + directory=anyio.Path(PAPER_DIRECTORY), + index_name="pqa_index_0", + index_directory=agent_index_dir, + ) + assert index.fields == [ + "title", + "file_location", + "body", + "year", + ], "Incorrect fields in index" + assert len(await index.index_files) == 4, "Incorrect number of index files" + results = await index.query(query="who is Frederick Bates?") + # added docs.keys come from md5 hash of the file location + assert results[0].docs.keys() == {"dab5b86dea3bd4c7ffe05a9f33ae95f7"} + + +@pytest.mark.asyncio +async def test_get_directory_index_w_manifest(agent_index_dir): + index = await get_directory_index( + directory=anyio.Path(PAPER_DIRECTORY), + index_name="pqa_index_0", + index_directory=agent_index_dir, + manifest_file=anyio.Path(PAPER_DIRECTORY) / "stub_manifest.csv", + ) + assert index.fields == [ + "title", + "file_location", + "body", + "year", + ], "Incorrect fields in index" + # 4 = example.txt + example2.txt + paper.pdf + example.html + assert len(await index.index_files) == 4, "Incorrect number of index files" + results = await index.query(query="who is Barack Obama?") + top_result = next(iter(results[0].docs.values())) + assert top_result.dockey == "af2c9acf6018e62398fc6efc4f0a04b4" + # note: this title comes from the manifest, so we know it worked + assert top_result.title == "Barack Obama (Wikipedia article)" + + +@pytest.mark.flaky(reruns=3, only_rerun=["AssertionError", "httpx.RemoteProtocolError"]) +@pytest.mark.parametrize("agent_type", ["OpenAIFunctionsAgent", "fake"]) +@pytest.mark.asyncio +async def test_agent_types(agent_index_dir, agent_type): + + question = "How can you use XAI for chemical property prediction?" + + request = QueryRequest( + query=question, + consider_sources=10, + max_sources=2, + embedding="sparse", + agent_tools=AgentPromptCollection( + search_count=2, + paper_directory=PAPER_DIRECTORY, + index_directory=agent_index_dir, + ), + ) + response = await agent_query( + request, agent_type=agent_type, index_directory=agent_index_dir + ) + assert response.answer.answer != "I cannot answer", "Answer not generated" + assert len(response.answer.context) >= 1, "No contexts were found" + assert response.answer.question == question + + +@pytest.mark.asyncio +async def test_timeout(agent_index_dir): + response = await agent_query( + QueryRequest( + query="Are COVID-19 vaccines effective?", + llm="gpt-4o-mini", + prompts=PromptCollection(pre=None), + # We just need one tool to test the timeout, gen_answer is not that fast + agent_tools=AgentPromptCollection( + timeout=0.001, + tool_names={"gen_answer"}, + paper_directory=PAPER_DIRECTORY, + index_directory=agent_index_dir, + ), + ), + Docs(), + ) + # ensure that GenerateAnswerTool was called + assert response.status == AgentStatus.TIMEOUT, "Agent did not timeout" + assert "I cannot answer" in response.answer.answer + + +@pytest.mark.asyncio +async def test_propagate_options(agent_index_dir) -> None: + llm_name = "gpt-4o-mini" + default_llm_names = { + cls.model_fields[name].default # type: ignore[attr-defined] + for name, cls in itertools.product(("llm", "summary_llm"), (QueryRequest, Docs)) + } + assert ( + llm_name not in default_llm_names + ), f"Assertions require not matching a default LLM name in {default_llm_names}." + query = QueryRequest( + query="What is is a self-explanatory model?", + summary_llm=llm_name, + llm=llm_name, + max_sources=5, + consider_sources=6, + length="400 words", + prompts=PromptCollection( + pre=None, system="End all responses with ###", skip_summary=True + ), + # NOTE: this is testing that if our prompt forgets template fields (e.g. status), + # the code still runs, despite the presence of extra keyword arguments to format + agent_tools=AgentPromptCollection( + paper_directory=PAPER_DIRECTORY, + index_directory=agent_index_dir, + agent_prompt=( + "Answer question: {question}. Search for papers, gather evidence, and" + " answer. If you do not have enough evidence, you can search for more" + " papers (preferred) or gather more evidence with a different phrase." + " You may rephrase or break-up the question in those steps. Once you" + " have five or more pieces of evidence from multiple sources, or you" + " have tried a few times, call {gen_answer_tool_name} tool. The" + " {gen_answer_tool_name} tool output is visible to the user, so you do" + " not need to restate the answer and can simply terminate if the answer" + " looks sufficient." + ), + tool_names={"paper_search", "gather_evidence", "gen_answer"}, + ), + ) + for attempt in Retrying( + stop=stop_after_attempt(3), retry=retry_if_exception_type(AssertionError) + ): + with attempt: + docs = Docs(llm=llm_name, summary_llm=llm_name) + docs.prompts = query.prompts # this line happens in main + response = await agent_query(query, docs, agent_type="fake") + assert response.status == AgentStatus.SUCCESS, "Agent did not succeed" + result = response.answer + assert len(result.answer) > 200, "Answer did not return any results" + assert result.answer_length == query.length, "Answer length did not propagate" + assert "###" in result.answer, "Answer did not propagate system prompt" + + +@pytest.mark.flaky(reruns=3, only_rerun=["AssertionError"]) +@pytest.mark.asyncio +async def test_mixing_langchain_clients(caplog, agent_index_dir) -> None: + docs = Docs() + query = QueryRequest( + query="What is is a self-explanatory model?", + max_sources=2, + consider_sources=3, + llm="gemini-1.5-flash", + summary_llm="gemini-1.5-flash", + agent_tools=AgentPromptCollection( + paper_directory=PAPER_DIRECTORY, index_directory=agent_index_dir + ), + ) + update_doc_models(docs, query) + with caplog.at_level(logging.WARNING): + response = await agent_query(query, docs) + assert response.status == AgentStatus.SUCCESS, "Agent did not succeed" + assert not [ + msg for (*_, msg) in caplog.record_tuples if "error" in msg.lower() + ], "Expected clean logs" + + +@pytest.mark.asyncio +async def test_gather_evidence_rejects_empty_docs() -> None: + # Patch GenerateAnswerTool._arun so that if this tool is chosen first, we + # don't give a 'cannot answer' response. A 'cannot answer' response can + # lead to an unsure status, which will break this test's assertions. Since + # this test is about a GatherEvidenceTool edge case, defeating + # GenerateAnswerTool is fine + with patch.object( + GenerateAnswerTool, "_arun", return_value="Failed to answer question." + ): + response = await agent_query( + query=QueryRequest( + query="Are COVID-19 vaccines effective?", + agent_tools=AgentPromptCollection( + tool_names={"gather_evidence", "gen_answer"} + ), + ), + docs=Docs(), + ) + assert response.status == AgentStatus.FAIL, "Agent should have registered a failure" + + +@pytest.mark.flaky(reruns=3, only_rerun=["AssertionError"]) +@pytest.mark.asyncio +async def test_agent_sharing_state( + fixture_stub_answer, subtests: SubTests, agent_index_dir +) -> None: + tool_state = SharedToolState(docs=Docs(), answer=fixture_stub_answer) + search_count = 3 # Keep low for speed + query = QueryRequest( + query=fixture_stub_answer.question, + consider_sources=2, + max_sources=1, + agent_tools=AgentPromptCollection( + search_count=search_count, + index_directory=agent_index_dir, + paper_directory=PAPER_DIRECTORY, + ), + ) + + with subtests.test(msg=PaperSearchTool.__name__): + tool = PaperSearchTool( + shared_state=tool_state, + search_count=search_count, + index_directory=agent_index_dir, + paper_directory=PAPER_DIRECTORY, + ) + await tool.arun("XAI self explanatory model") + assert tool_state.docs.docs, "Search did not save any papers" + assert all( + (isinstance(d, Doc) or issubclass(d, Doc)) # type: ignore[unreachable] + for d in tool_state.docs.docs.values() + ), "Document type or DOI propagation failure" + + with subtests.test(msg=GatherEvidenceTool.__name__): + assert ( + not fixture_stub_answer.contexts + ), "No contexts is required for a later assertion" + + tool = GatherEvidenceTool(shared_state=tool_state, query=query) + await tool.arun(fixture_stub_answer.question) + assert ( + len(fixture_stub_answer.dockey_filter) > 0 + ), "Filter did not preserve reference" + assert fixture_stub_answer.contexts, "Evidence did not return any results" + + with subtests.test(msg=f"{GenerateAnswerTool.__name__} working"): + tool = GenerateAnswerTool(shared_state=tool_state, query=query) + result = await tool.arun(fixture_stub_answer.question) + assert re.search( + pattern=SharedToolState.STATUS_SEARCH_REGEX_PATTERN, string=result + ) + assert ( + len(fixture_stub_answer.answer) > 200 + ), "Answer did not return any results" + assert ( + GenerateAnswerTool.extract_answer_from_message(result) + == fixture_stub_answer.answer + ), "Failed to regex extract answer from result" + assert ( + len(fixture_stub_answer.contexts) <= query.max_sources + ), "Answer has more sources than expected" + + with subtests.test(msg=f"{GenerateAnswerTool.__name__} misconfigured query"): + query.consider_sources = 1 # k + query.max_sources = 5 + tool = GenerateAnswerTool(shared_state=tool_state, query=query) + with pytest.raises(ValueError, match="k should be greater than max_sources"): + await tool.arun("Are COVID-19 vaccines effective?") + + +def test_functions() -> None: + """Check the functions schema passed to OpenAI.""" + shared_tool_state = SharedToolState( + answer=Answer(question="stub"), + docs=Docs(), + ) + agent = OpenAIFunctionsAgent.from_llm_and_tools( + llm=ChatOpenAI(model="gpt-4-turbo-2024-04-09"), + tools=[ + PaperSearchTool(shared_state=shared_tool_state), + GatherEvidenceTool(shared_state=shared_tool_state, query=QueryRequest()), + GenerateAnswerTool(shared_state=shared_tool_state, query=QueryRequest()), + ], + ) + assert cast(OpenAIFunctionsAgent, agent).functions == [ + { + "name": "paper_search", + "description": ( + "Search for papers to increase the paper count. You can call this a second " + "time with an different search to gather more papers." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "description": ( + "A search query in this format: [query], [start year]-[end year]. " + "You may include years as the last word in the query, " + "e.g. 'machine learning 2020' or 'machine learning 2010-2020'. " + f"The current year is {get_year()}. " + "The query portion can be a specific phrase, complete sentence, " + "or general keywords, e.g. 'machine learning for immunology'." + ), + "type": "string", + } + }, + "required": ["query"], + }, + }, + { + "name": "gather_evidence", + "description": ( + "Gather evidence from previous papers given a specific question. " + "This will increase evidence and relevant paper counts. " + "Only invoke when paper count is above zero." + ), + "parameters": { + "type": "object", + "properties": { + "question": { + "description": "Specific question to gather evidence for.", + "type": "string", + } + }, + "required": ["question"], + }, + }, + { + "name": "gen_answer", + "description": ( + "Ask a model to propose an answer answer using current evidence. " + "The tool may fail, " + "indicating that better or different evidence should be found. " + "Having more than one piece of evidence or relevant papers is best." + ), + "parameters": { + "type": "object", + "properties": { + "question": { + "description": "Question to be answered.", + "type": "string", + } + }, + "required": ["question"], + }, + }, + ] + + +def test_instruct_model(): + docs = Docs(name="tmp") + query = QueryRequest( + summary_llm="gpt-3.5-turbo-instruct", + query="Are COVID-19 vaccines effective?", + llm="gpt-3.5-turbo-instruct", + ) + update_doc_models(docs, query) + docs.query("Are COVID-19 vaccines effective?") + query.llm = "gpt-3.5-turbo-instruct" + query.summary_llm = "gpt-3.5-turbo-instruct" + update_doc_models(docs, query) + docs.query("Are COVID-19 vaccines effective?") + + +def test_anthropic_model(): + docs = Docs(name="tmp") + query = QueryRequest( + summary_llm="claude-3-sonnet-20240229", + query="Are COVID-19 vaccines effective?", + llm="claude-3-sonnet-20240229", + ) + update_doc_models(docs, query) + answer = docs.query("Are COVID-19 vaccines effective?") + + # make sure we can compute cost with this model. + compute_total_model_token_cost(answer.token_counts) + + +def test_embeddings_anthropic(): + docs = Docs(name="tmp") + query = QueryRequest( + summary_llm="claude-3-sonnet-20240229", + query="Are COVID-19 vaccines effective?", + llm="claude-3-sonnet-20240229", + embedding="sparse", + ) + update_doc_models(docs, query) + _ = docs.query("Are COVID-19 vaccines effective?") + + query = QueryRequest( + summary_llm="claude-3-sonnet-20240229", + query="Are COVID-19 vaccines effective?", + llm="claude-3-sonnet-20240229", + embedding="hybrid-text-embedding-3-small", + ) + update_doc_models(docs, query) + _ = docs.query("Are COVID-19 vaccines effective?") + + +@pytest.mark.asyncio +async def test_gemini_model_construction( + stub_paper_path: Path, +) -> None: + docs = Docs(name="tmp") + query = QueryRequest( + summary_llm="gemini-1.5-pro", + llm="gemini-1.5-pro", + embedding="sparse", + ) + update_doc_models(docs, query) + assert isinstance(docs.llm_model, LangchainLLMModel) # We use LangChain for Gemini + assert docs.llm_model.name == "gemini-1.5-pro", "Gemini Model: model not created" + assert "model" not in docs.llm_model.config, "model should not be in config" + + # now try using it + await docs.aadd(stub_paper_path) + answer = await docs.aget_evidence( + Answer(question="Are COVID-19 vaccines effective?") + ) + assert answer.contexts, "Gemini Model: no contexts returned" + + +def test_query_request_summary(): + """Test that we can set summary llm to none and it will skip summary.""" + request = QueryRequest(query="Are COVID-19 vaccines effective?", summary_llm="none") + assert request.summary_llm == "gpt-4o-mini" + assert request.prompts.skip_summary, "Summary should be skipped with none llm" + + +def test_query_request_preset_prompts(): + """Test that we can set the prompt using our preset prompts.""" + request = QueryRequest( + query="Are COVID-19 vaccines effective?", + prompts=PromptCollection( + summary_json_system=r"{gene_name} {summary} {relevance_score}" + ), + ) + assert "gene_name" in request.prompts.summary_json_system + + +def test_query_request_docs_name_serialized(): + """Test that the query request has a docs_name property.""" + request = QueryRequest(query="Are COVID-19 vaccines effective?") + request_data = json.loads(request.model_dump_json()) + assert "docs_name" in request_data + assert request_data["docs_name"] is None + request.set_docs_name("my_doc") + request_data = json.loads(request.model_dump_json()) + assert request_data["docs_name"] == "my_doc" + + +def test_query_request_model_mismatch(): + with pytest.raises( + MismatchedModelsError, + match=( + "Answer LLM and summary LLM types must match: " + " != " + ), + ): + _ = QueryRequest( + summary_llm="gpt-4o-mini", + query="Are COVID-19 vaccines effective?", + llm="claude-3-sonnet-20240229", + embedding="sparse", + ) + + +def test_answers_are_striped(): + """Test that answers are striped.""" + answer = Answer( + question="What is the meaning of life?", + contexts=[ + Context( + context="bla", + text=Text( + name="text", + text="The meaning of life is 42.", + embedding=[43.3, 34.2], + doc=Doc( + docname="foo", + citation="bar", + dockey="baz", + embedding=[43.1, 65.2], + ), + ), + score=3, + ) + ], + ) + response = AnswerResponse(answer=answer, usage={}, bibtex={}, status="success") + + assert response.answer.contexts[0].text.embedding is None + assert response.answer.contexts[0].text.text == "" # type: ignore[unreachable,unused-ignore] + assert response.answer.contexts[0].text.doc is not None + assert response.answer.contexts[0].text.doc.embedding is None + # make sure it serializes + response.model_dump_json() + + +def test_prompts_are_set(): + assert ( + STATIC_PROMPTS["json"].summary_json_system + != PromptCollection().summary_json_system + ) + + +@pytest.mark.parametrize( + ("kwargs", "result"), + [ + ({}, None), + ({"tool_names": {GenerateAnswerTool.__fields__["name"].default}}, None), + ({"tool_names": set()}, ValidationError), + ({"tool_names": {PaperSearchTool.__fields__["name"].default}}, ValidationError), + ], +) +def test_agent_prompt_collection_validations( + kwargs: dict[str, Any], result: type[Exception] | None +) -> None: + if result is None: + AgentPromptCollection(**kwargs) + else: + with pytest.raises(result): + AgentPromptCollection(**kwargs) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..60ff0681 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,151 @@ +from pathlib import Path + +import pytest + +try: + from paperqa.agents import ask, build_index, clear, search_query, set_setting, show + from paperqa.agents.models import AnswerResponse + from paperqa.agents.search import SearchIndex +except ImportError: + pytest.skip("agents module is not installed", allow_module_level=True) + + +def test_cli_set(agent_index_dir: Path): # noqa: ARG001 + set_setting("temperature", "0.5") + assert show("temperature") == "0.5", "Temperature not properly set" + + with pytest.raises(ValueError) as excinfo: # noqa: PT011 + set_setting("temperature", "abc") + assert "temperature (with value abc) is not a valid setting." in str(excinfo.value) + + # ensure we can do nested settings + set_setting("agent_tools.paper_directory", "my_directory") + assert ( + show("agent_tools.paper_directory") == "my_directory" + ), "Nested setting not properly set" + + # ensure we can set settings which fail specific validations + # normally we'd get a failure for model mixing, but this is reserved for runtime + set_setting("llm", "claude-3-5-sonnet-20240620") + assert show("llm") == "claude-3-5-sonnet-20240620", "Setting not properly set" + + # test that we're able to collection structures like lists + set_setting( + "parsing_configuration.ordered_parser_preferences", '["paperqa_default"]' + ) + assert show("parsing_configuration.ordered_parser_preferences") == [ + "paperqa_default" + ], "List setting not properly set" + + +@pytest.mark.asyncio +async def test_cli_show(agent_index_dir: Path): + + # make empty index + assert not show("indexes"), "No indexes should be present" + + # creates a new index/file + si = SearchIndex(index_directory=agent_index_dir) + await si.init_directory() + + set_setting("temperature", "0.5") + set_setting("agent_tools.paper_directory", "my_directory") + + assert show("temperature") == "0.5", "Temperature not properly set" + + assert not show("fake_variable"), "Fake variable should not be set" + + assert show("all") == { + "temperature": "0.5", + "agent_tools": {"paper_directory": "my_directory"}, + }, "All settings not properly set" + + assert show("indexes") == ["pqa_index"], "Index not properly set" + + assert show("answers") == [], "Answers should be empty" + + +@pytest.mark.asyncio +async def test_cli_clear(agent_index_dir: Path): + + set_setting("temperature", "0.5") + assert show("temperature") == "0.5", "Temperature not properly set" + + clear("temperature") + assert show("temperature") is None, "Temperature not properly cleared" + + # set a nested variable + set_setting("prompts.qa", "Answer my question!") + assert show("prompts.qa") == "Answer my question!", "Prompt not properly set" + clear("prompts.qa") + assert show("prompts.qa") is None, "Prompt not properly cleared" + + # creates a new index/file + si = SearchIndex(index_directory=agent_index_dir) + await si.init_directory() + + clear("pqa_index", index=True) + + assert show("indexes") == [], "Index not properly cleared" + + +def test_cli_ask(agent_index_dir: Path): + set_setting("consider_sources", "10") + set_setting("max_sources", "2") + set_setting("embedding", "sparse") + set_setting("agent_tools.search_count", "1") + answer = ask( + "How can you use XAI for chemical property prediction?", + directory=Path(__file__).parent, + index_directory=agent_index_dir, + ) + assert isinstance(answer, AnswerResponse), "Answer not properly returned" + assert ( + "I cannot answer" not in answer.answer.answer + ), "An answer should be generated." + assert len(answer.answer.context) >= 1, "No contexts were found." + answers = search_query( + "How can you use XAI for chemical property prediction?", + index_directory=agent_index_dir, + ) + answer = answers[0][0] + assert isinstance(answer, AnswerResponse), "Answer not properly returned" + assert ( + "I cannot answer" not in answer.answer.answer + ), "An answer should be generated." + assert len(answer.answer.context) >= 1, "No contexts were found." + assert len(show("answers")) == 1, "An answer should be returned" + + +def test_cli_index(agent_index_dir: Path, caplog): + + build_index(directory=Path(__file__).parent, index_directory=agent_index_dir) + + caplog.clear() + + ask( + "How can you use XAI for chemical property prediction?", + directory=Path(__file__).parent, + index_directory=agent_index_dir, + ) + + # ensure we have no indexing logs after starting the search + for record in caplog.records: + if "Metadata not found" in record.msg: + raise AssertionError( + "Indexing logs should not be present after search starts" + ) + + caplog.clear() + + # running again should not trigger any indexing + build_index(directory=Path(__file__).parent, index_directory=agent_index_dir) + assert not caplog.records, "Indexing should not be triggered again" + + # now we want to change the settings + set_setting("embedding", "sparse") + + # running again should now re-trigger an indexing + build_index(directory=Path(__file__).parent, index_directory=agent_index_dir) + + assert caplog.records, "Indexing should be triggered again"