Skip to content

Commit

Permalink
Removed GenerateAnswer.FAILED_TO_ANSWER as its unnecessary (#691)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Nov 15, 2024
1 parent 417e666 commit ca36785
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 34 deletions.
4 changes: 2 additions & 2 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from paperqa.docs import Docs
from paperqa.llms import EmbeddingModel, LiteLLMModel
from paperqa.settings import Settings
from paperqa.types import PQASession
from paperqa.types import PQASession, check_could_not_answer
from paperqa.utils import get_year

from .models import QueryRequest
Expand Down Expand Up @@ -193,7 +193,7 @@ async def step(
any(
isinstance(msg, ToolResponseMessage)
and msg.name == GenerateAnswer.gen_answer.__name__
and GenerateAnswer.did_not_fail_to_answer(msg.content)
and not check_could_not_answer(msg.content)
for msg in response_messages
)
or self._has_excess_answer_failures(),
Expand Down
26 changes: 9 additions & 17 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from paperqa.docs import Docs
from paperqa.llms import EmbeddingModel, LiteLLMModel
from paperqa.settings import Settings
from paperqa.types import DocDetails, PQASession
from paperqa.types import DocDetails, PQASession, check_could_not_answer

from .search import get_directory_index

Expand Down Expand Up @@ -268,14 +268,6 @@ class GenerateAnswer(NamedTool):
summary_llm_model: LiteLLMModel
embedding_model: EmbeddingModel

# This is not an answer to assign to the current PQASession,
# but a status for the agent message history
FAILED_TO_ANSWER: ClassVar[str] = "Failed to answer question."

@classmethod
def did_not_fail_to_answer(cls, message: str | None) -> bool:
return not (message or "").startswith(cls.FAILED_TO_ANSWER)

async def gen_answer(self, question: str, state: EnvironmentState) -> str:
"""
Ask a model to propose an answer using current evidence.
Expand Down Expand Up @@ -313,13 +305,13 @@ async def gen_answer(self, question: str, state: EnvironmentState) -> str:
),
)

if state.session.could_not_answer:
if self.settings.agent.wipe_context_on_answer_failure:
state.session.contexts = []
state.session.context = ""
answer = self.FAILED_TO_ANSWER
else:
answer = state.session.answer
if (
state.session.could_not_answer
and self.settings.agent.wipe_context_on_answer_failure
):
state.session.contexts = []
state.session.context = ""
answer = state.session.answer
status = state.status
logger.info(status)

Expand All @@ -346,7 +338,7 @@ def extract_answer_from_message(cls, content: str) -> str:
answer, *rest = re.split(
pattern=cls.ANSWER_SPLIT_REGEX_PATTERN, string=content, maxsplit=1
)
if len(rest) != 4 or not cls.did_not_fail_to_answer(answer): # noqa: PLR2004
if len(rest) != 4 or check_could_not_answer(answer): # noqa: PLR2004
return ""
return answer

Expand Down
6 changes: 5 additions & 1 deletion paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def __str__(self) -> str:
return self.context


def check_could_not_answer(answer: str) -> bool:
return "cannot answer" in answer.lower()


class PQASession(BaseModel):
"""A class to hold session about researching/answering."""

Expand Down Expand Up @@ -254,7 +258,7 @@ def filter_content_for_user(self) -> None:

@property
def could_not_answer(self) -> bool:
return "cannot answer" in self.answer.lower()
return check_could_not_answer(self.answer)


# for backwards compatibility
Expand Down
24 changes: 10 additions & 14 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tempfile
import time
from copy import deepcopy
from functools import wraps
from pathlib import Path
from typing import cast
from unittest.mock import AsyncMock, patch
Expand Down Expand Up @@ -44,7 +45,7 @@
from paperqa.docs import Docs
from paperqa.prompts import CONTEXT_INNER_PROMPT_NOT_DETAILED
from paperqa.settings import AgentSettings, IndexSettings, Settings
from paperqa.types import Context, Doc, PQASession, Text
from paperqa.types import Context, Doc, PQASession, Text, check_could_not_answer
from paperqa.utils import extract_thought, get_year, md5sum


Expand Down Expand Up @@ -401,19 +402,14 @@ async def test_propagate_options(agent_test_settings: Settings) -> None:
async def test_gather_evidence_rejects_empty_docs(
agent_test_settings: Settings,
) -> None:

@wraps(GenerateAnswer.gen_answer)
async def gen_answer(self, question: str, state) -> str: # noqa: ARG001
return "I cannot answer."

# Patch GenerateAnswerTool.gen_answer so that if this tool is chosen first,
# we don't give a 'cannot answer' response. A 'cannot answer' response can
# lead to an unsure status, which will break this test's assertions. Since
# this test is about a GatherEvidenceTool edge case, defeating
# GenerateAnswerTool is fine
original_doc = GenerateAnswer.gen_answer.__doc__
with patch.object(
GenerateAnswer,
"gen_answer",
return_value="Failed to answer question.",
autospec=True,
) as mock_gen_answer:
mock_gen_answer.__doc__ = original_doc
# we keep running until we get truncated
with patch.object(GenerateAnswer, "gen_answer", gen_answer):
agent_test_settings.agent = AgentSettings(
tool_names={"gather_evidence", "gen_answer"},
max_timesteps=3,
Expand Down Expand Up @@ -784,7 +780,7 @@ async def test_empty_tool_calls(self, agent_test_settings: Settings) -> None:
obs, _, done, truncated = await env.step(ToolRequestMessage())
assert len(obs) == 1
assert obs[0].content
assert GenerateAnswer.did_not_fail_to_answer(obs[0].content)
assert not check_could_not_answer(obs[0].content)
assert "0 tool calls" in obs[0].content
assert done
assert not truncated

0 comments on commit ca36785

Please sign in to comment.