Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional AnswerSetting.max_answer_attempts to allow a new unsure branch #673

Merged
merged 3 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,22 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
def export_frame(self) -> Frame:
return Frame(state=self.state, info={"query": self._query})

def _has_excess_answer_failures(self) -> bool:
if self._query.settings.answer.max_answer_attempts is None:
return False
return (
sum(
tn == GenerateAnswer.gen_answer.__name__
for s in self.state.tool_history
for tn in s
)
> self._query.settings.answer.max_answer_attempts
)

async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
self.state.session.add_tokens(action) # Add usage for action if present
self.state.record_action(action)

# If the action has empty tool_calls, the agent can later take that into account
msgs = cast(
Expand All @@ -175,7 +187,8 @@ async def step(
and msg.name == GenerateAnswer.gen_answer.__name__
and GenerateAnswer.did_not_fail_to_answer(msg.content)
for msg in msgs
),
)
or self._has_excess_answer_failures(),
False,
)

Expand Down
36 changes: 21 additions & 15 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,30 @@ async def step(
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
return messages, reward, done, truncated
# Filter out non-answer messages (in case parallel tool calls)
answer_tool_messages = [
m
for m in messages
if isinstance(m, ToolResponseMessage)
and m.name == GenerateAnswer.gen_answer.__name__
]
if not answer_tool_messages: # No answer, so no positive reward
valid_answers, failed_answer_messages = [], []
for m in messages:
if (
not isinstance(m, ToolResponseMessage)
or m.name != GenerateAnswer.gen_answer.__name__
):
continue # Filter out non-answer messages (in case parallel tool calls)
if answer := GenerateAnswer.extract_answer_from_message(content=m.content):
valid_answers.append(answer)
else:
failed_answer_messages.append(m)
if not valid_answers: # No answer, so no positive reward
return messages, reward, done, truncated
if len(answer_tool_messages) != 1:
if len(valid_answers) != 1:
raise NotImplementedError(
f"Expected just one answer message, got {messages}."
f"Expected just one answer message, got more than one in {messages}."
)
answer = GenerateAnswer.extract_answer_from_message(
content=answer_tool_messages[0].content
)
if not answer:
return messages, reward, done, truncated
answer = valid_answers[0]
if failed_answer_messages:
logger.warning(
"More than one answer detected, discarding failed answer messages"
f" {failed_answer_messages}, continuing with answer {answer}."
)
# Okay, so we have one answer that was not a failed answer. Let's evaluate it
evaluation = await self._evaluation_from_answer(answer)
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)
Expand Down
13 changes: 13 additions & 0 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
from typing import ClassVar, cast

from aviary.core import ToolRequestMessage
from pydantic import BaseModel, ConfigDict, Field, computed_field

from paperqa.docs import Docs
Expand Down Expand Up @@ -36,6 +37,14 @@ class EnvironmentState(BaseModel):

docs: Docs
session: PQASession = Field(..., alias="answer")
tool_history: list[list[str]] = Field(
default_factory=list,
description=(
"History of tool names input to each Environment.step (regardless of being"
" a typo or not), where the outer list is steps, and the inner list matches"
" the order of tool calls at each step."
),
)

# SEE: https://regex101.com/r/RmuVdC/1
STATUS_SEARCH_REGEX_PATTERN: ClassVar[str] = (
Expand Down Expand Up @@ -65,6 +74,10 @@ def status(self) -> str:
cost=self.session.cost,
)

def record_action(self, action: ToolRequestMessage) -> None:
self.session.add_tokens(action)
self.tool_history.append([tc.function.name for tc in action.tool_calls])


class NamedTool(BaseModel):
"""Base class to make looking up tools easier."""
Expand Down
7 changes: 7 additions & 0 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ class AnswerSettings(BaseModel):
answer_max_sources: int = Field(
default=5, description="Max number of sources to use for an answer"
)
max_answer_attempts: int | None = Field(
default=None,
description=(
"Optional (exclusive) max number (default is no max) of attempts to"
" generate an answer before declaring a failure."
),
)
answer_length: str = Field(
"about 200 words, but can be longer", description="Length of final answer"
)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,21 @@ async def test_evaluation(
base_query_request.settings.agent.tool_names = {
GenerateAnswer.gen_answer.__name__
}
base_query_request.settings.answer.max_answer_attempts = 2
base_query_request.settings.answer.get_evidence_if_no_contexts = False
dataset = LitQAv2TaskDataset(base_query=base_query_request)
dataset.data = dataset.data[:2] # Save the world: just use two questions
storage_callback = StoreTrajectoriesCallback()
evaluator = Evaluator(
config=EvaluatorConfig(batch_size=len(dataset), max_rollout_steps=2),
config=EvaluatorConfig(batch_size=len(dataset), max_rollout_steps=4),
agent=SimpleAgent(),
dataset=dataset,
callbacks=[storage_callback],
)
await evaluator.evaluate()
for traj in storage_callback.eval_trajectories:
assert not traj.failed
assert traj.done
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This specific assertion will not pass before this PR without this new setting

for step in traj.steps:
assert all(
tc.function.name == GenerateAnswer.gen_answer.__name__
Expand Down