Skip to content

Optional AnswerSetting.max_answer_attempts to allow a new unsure branch #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
@@ -157,10 +157,22 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
def export_frame(self) -> Frame:
return Frame(state=self.state, info={"query": self._query})

def _has_excess_answer_failures(self) -> bool:
if self._query.settings.answer.max_answer_attempts is None:
return False
return (
sum(
tn == GenerateAnswer.gen_answer.__name__
for s in self.state.tool_history
for tn in s
)
> self._query.settings.answer.max_answer_attempts
)

async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
self.state.session.add_tokens(action) # Add usage for action if present
self.state.record_action(action)

# If the action has empty tool_calls, the agent can later take that into account
msgs = cast(
@@ -175,7 +187,8 @@ async def step(
and msg.name == GenerateAnswer.gen_answer.__name__
and GenerateAnswer.did_not_fail_to_answer(msg.content)
for msg in msgs
),
)
or self._has_excess_answer_failures(),
False,
)

36 changes: 21 additions & 15 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
@@ -130,24 +130,30 @@ async def step(
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
return messages, reward, done, truncated
# Filter out non-answer messages (in case parallel tool calls)
answer_tool_messages = [
m
for m in messages
if isinstance(m, ToolResponseMessage)
and m.name == GenerateAnswer.gen_answer.__name__
]
if not answer_tool_messages: # No answer, so no positive reward
valid_answers, failed_answer_messages = [], []
for m in messages:
if (
not isinstance(m, ToolResponseMessage)
or m.name != GenerateAnswer.gen_answer.__name__
):
continue # Filter out non-answer messages (in case parallel tool calls)
if answer := GenerateAnswer.extract_answer_from_message(content=m.content):
valid_answers.append(answer)
else:
failed_answer_messages.append(m)
if not valid_answers: # No answer, so no positive reward
return messages, reward, done, truncated
if len(answer_tool_messages) != 1:
if len(valid_answers) != 1:
raise NotImplementedError(
f"Expected just one answer message, got {messages}."
f"Expected just one answer message, got more than one in {messages}."
)
answer = GenerateAnswer.extract_answer_from_message(
content=answer_tool_messages[0].content
)
if not answer:
return messages, reward, done, truncated
answer = valid_answers[0]
if failed_answer_messages:
logger.warning(
"More than one answer detected, discarding failed answer messages"
f" {failed_answer_messages}, continuing with answer {answer}."
)
# Okay, so we have one answer that was not a failed answer. Let's evaluate it
evaluation = await self._evaluation_from_answer(answer)
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)
13 changes: 13 additions & 0 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import sys
from typing import ClassVar, cast

from aviary.core import ToolRequestMessage
from pydantic import BaseModel, ConfigDict, Field, computed_field

from paperqa.docs import Docs
@@ -36,6 +37,14 @@ class EnvironmentState(BaseModel):

docs: Docs
session: PQASession = Field(..., alias="answer")
tool_history: list[list[str]] = Field(
default_factory=list,
description=(
"History of tool names input to each Environment.step (regardless of being"
" a typo or not), where the outer list is steps, and the inner list matches"
" the order of tool calls at each step."
),
)

# SEE: https://regex101.com/r/RmuVdC/1
STATUS_SEARCH_REGEX_PATTERN: ClassVar[str] = (
@@ -65,6 +74,10 @@ def status(self) -> str:
cost=self.session.cost,
)

def record_action(self, action: ToolRequestMessage) -> None:
self.session.add_tokens(action)
self.tool_history.append([tc.function.name for tc in action.tool_calls])


class NamedTool(BaseModel):
"""Base class to make looking up tools easier."""
7 changes: 7 additions & 0 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
@@ -83,6 +83,13 @@ class AnswerSettings(BaseModel):
answer_max_sources: int = Field(
default=5, description="Max number of sources to use for an answer"
)
max_answer_attempts: int | None = Field(
default=None,
description=(
"Optional (exclusive) max number (default is no max) of attempts to"
" generate an answer before declaring a failure."
),
)
answer_length: str = Field(
"about 200 words, but can be longer", description="Length of final answer"
)
5 changes: 4 additions & 1 deletion tests/test_task.py
Original file line number Diff line number Diff line change
@@ -159,18 +159,21 @@ async def test_evaluation(
base_query_request.settings.agent.tool_names = {
GenerateAnswer.gen_answer.__name__
}
base_query_request.settings.answer.max_answer_attempts = 2
base_query_request.settings.answer.get_evidence_if_no_contexts = False
dataset = LitQAv2TaskDataset(base_query=base_query_request)
dataset.data = dataset.data[:2] # Save the world: just use two questions
storage_callback = StoreTrajectoriesCallback()
evaluator = Evaluator(
config=EvaluatorConfig(batch_size=len(dataset), max_rollout_steps=2),
config=EvaluatorConfig(batch_size=len(dataset), max_rollout_steps=4),
agent=SimpleAgent(),
dataset=dataset,
callbacks=[storage_callback],
)
await evaluator.evaluate()
for traj in storage_callback.eval_trajectories:
assert not traj.failed
assert traj.done
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This specific assertion will not pass before this PR without this new setting

for step in traj.steps:
assert all(
tc.function.name == GenerateAnswer.gen_answer.__name__