From ca367859124239293d12c219936285c62f945919 Mon Sep 17 00:00:00 2001 From: James Braza Date: Fri, 15 Nov 2024 10:06:36 -0800 Subject: [PATCH] Removed `GenerateAnswer.FAILED_TO_ANSWER` as its unnecessary (#691) --- paperqa/agents/env.py | 4 ++-- paperqa/agents/tools.py | 26 +++++++++----------------- paperqa/types.py | 6 +++++- tests/test_agents.py | 24 ++++++++++-------------- 4 files changed, 26 insertions(+), 34 deletions(-) diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index 517a0d43..e541fa31 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -15,7 +15,7 @@ from paperqa.docs import Docs from paperqa.llms import EmbeddingModel, LiteLLMModel from paperqa.settings import Settings -from paperqa.types import PQASession +from paperqa.types import PQASession, check_could_not_answer from paperqa.utils import get_year from .models import QueryRequest @@ -193,7 +193,7 @@ async def step( any( isinstance(msg, ToolResponseMessage) and msg.name == GenerateAnswer.gen_answer.__name__ - and GenerateAnswer.did_not_fail_to_answer(msg.content) + and not check_could_not_answer(msg.content) for msg in response_messages ) or self._has_excess_answer_failures(), diff --git a/paperqa/agents/tools.py b/paperqa/agents/tools.py index b5da70e7..ff58cd95 100644 --- a/paperqa/agents/tools.py +++ b/paperqa/agents/tools.py @@ -13,7 +13,7 @@ from paperqa.docs import Docs from paperqa.llms import EmbeddingModel, LiteLLMModel from paperqa.settings import Settings -from paperqa.types import DocDetails, PQASession +from paperqa.types import DocDetails, PQASession, check_could_not_answer from .search import get_directory_index @@ -268,14 +268,6 @@ class GenerateAnswer(NamedTool): summary_llm_model: LiteLLMModel embedding_model: EmbeddingModel - # This is not an answer to assign to the current PQASession, - # but a status for the agent message history - FAILED_TO_ANSWER: ClassVar[str] = "Failed to answer question." - - @classmethod - def did_not_fail_to_answer(cls, message: str | None) -> bool: - return not (message or "").startswith(cls.FAILED_TO_ANSWER) - async def gen_answer(self, question: str, state: EnvironmentState) -> str: """ Ask a model to propose an answer using current evidence. @@ -313,13 +305,13 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str: ), ) - if state.session.could_not_answer: - if self.settings.agent.wipe_context_on_answer_failure: - state.session.contexts = [] - state.session.context = "" - answer = self.FAILED_TO_ANSWER - else: - answer = state.session.answer + if ( + state.session.could_not_answer + and self.settings.agent.wipe_context_on_answer_failure + ): + state.session.contexts = [] + state.session.context = "" + answer = state.session.answer status = state.status logger.info(status) @@ -346,7 +338,7 @@ def extract_answer_from_message(cls, content: str) -> str: answer, *rest = re.split( pattern=cls.ANSWER_SPLIT_REGEX_PATTERN, string=content, maxsplit=1 ) - if len(rest) != 4 or not cls.did_not_fail_to_answer(answer): # noqa: PLR2004 + if len(rest) != 4 or check_could_not_answer(answer): # noqa: PLR2004 return "" return answer diff --git a/paperqa/types.py b/paperqa/types.py index 9af01b7f..3726c192 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -158,6 +158,10 @@ def __str__(self) -> str: return self.context +def check_could_not_answer(answer: str) -> bool: + return "cannot answer" in answer.lower() + + class PQASession(BaseModel): """A class to hold session about researching/answering.""" @@ -254,7 +258,7 @@ def filter_content_for_user(self) -> None: @property def could_not_answer(self) -> bool: - return "cannot answer" in self.answer.lower() + return check_could_not_answer(self.answer) # for backwards compatibility diff --git a/tests/test_agents.py b/tests/test_agents.py index 5636e8cd..13236dcc 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -9,6 +9,7 @@ import tempfile import time from copy import deepcopy +from functools import wraps from pathlib import Path from typing import cast from unittest.mock import AsyncMock, patch @@ -44,7 +45,7 @@ from paperqa.docs import Docs from paperqa.prompts import CONTEXT_INNER_PROMPT_NOT_DETAILED from paperqa.settings import AgentSettings, IndexSettings, Settings -from paperqa.types import Context, Doc, PQASession, Text +from paperqa.types import Context, Doc, PQASession, Text, check_could_not_answer from paperqa.utils import extract_thought, get_year, md5sum @@ -401,19 +402,14 @@ async def test_propagate_options(agent_test_settings: Settings) -> None: async def test_gather_evidence_rejects_empty_docs( agent_test_settings: Settings, ) -> None: + + @wraps(GenerateAnswer.gen_answer) + async def gen_answer(self, question: str, state) -> str: # noqa: ARG001 + return "I cannot answer." + # Patch GenerateAnswerTool.gen_answer so that if this tool is chosen first, - # we don't give a 'cannot answer' response. A 'cannot answer' response can - # lead to an unsure status, which will break this test's assertions. Since - # this test is about a GatherEvidenceTool edge case, defeating - # GenerateAnswerTool is fine - original_doc = GenerateAnswer.gen_answer.__doc__ - with patch.object( - GenerateAnswer, - "gen_answer", - return_value="Failed to answer question.", - autospec=True, - ) as mock_gen_answer: - mock_gen_answer.__doc__ = original_doc + # we keep running until we get truncated + with patch.object(GenerateAnswer, "gen_answer", gen_answer): agent_test_settings.agent = AgentSettings( tool_names={"gather_evidence", "gen_answer"}, max_timesteps=3, @@ -784,7 +780,7 @@ async def test_empty_tool_calls(self, agent_test_settings: Settings) -> None: obs, _, done, truncated = await env.step(ToolRequestMessage()) assert len(obs) == 1 assert obs[0].content - assert GenerateAnswer.did_not_fail_to_answer(obs[0].content) + assert not check_could_not_answer(obs[0].content) assert "0 tool calls" in obs[0].content assert done assert not truncated