Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fake agent allowing timeouts or exceptions, #672

Merged
merged 4 commits into from
Nov 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 108 additions & 103 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,29 @@ async def run_agent(
return AnswerResponse(session=session, status=agent_status)


async def _run_with_timeout_failure(
rollout: Callable[[], Awaitable[AgentStatus]],
query: QueryRequest,
env: PaperQAEnvironment,
) -> tuple[PQASession, AgentStatus]:
try:
async with asyncio.timeout(query.settings.agent.timeout):
status = await rollout()
except TimeoutError:
logger.warning(
f"Agent timeout after {query.settings.agent.timeout}-sec, just answering."
)
status = AgentStatus.TRUNCATED
generate_answer_tool = next(
filter(lambda x: x.info.name == GenerateAnswer.TOOL_FN_NAME, env.tools)
)
await generate_answer_tool._tool_fn(question=query.query, state=env.state)
except Exception:
logger.exception("Trajectory failed.")
status = AgentStatus.FAIL
return env.state.session, status


async def run_fake_agent(
query: QueryRequest,
docs: Docs,
Expand Down Expand Up @@ -184,14 +207,17 @@ async def step(tool: Tool, **call_kwargs) -> None:
if on_env_step_callback:
await on_env_step_callback(obs, reward, done, truncated)

# Seed docs with a few keyword searches
for search in await litellm_get_search_query(
question, llm=query.settings.get_llm(), count=3
):
await step(search_tool, query=search, min_year=None, max_year=None)
await step(gather_evidence_tool, question=question)
await step(generate_answer_tool, question=question)
return env.state.session, AgentStatus.SUCCESS
async def rollout() -> AgentStatus:
# Seed docs with a few keyword searches
for search in await litellm_get_search_query(
question, llm=query.settings.get_llm(), count=3
):
await step(search_tool, query=search, min_year=None, max_year=None)
await step(gather_evidence_tool, question=question)
await step(generate_answer_tool, question=question)
return AgentStatus.SUCCESS

return await _run_with_timeout_failure(rollout, query, env)


async def run_aviary_agent(
Expand All @@ -209,76 +235,64 @@ async def run_aviary_agent(
**env_kwargs,
) -> tuple[PQASession, AgentStatus]:
env = env_class(query, docs, **env_kwargs)
done = False

try:
async with asyncio.timeout(query.settings.agent.timeout):
obs, tools = await env.reset()
if on_env_reset_callback:
await on_env_reset_callback(env.state)

agent_state = ToolSelectorLedger(
messages=(
[
Message(
role="system",
content=query.settings.agent.agent_system_prompt,
)
]
if query.settings.agent.agent_system_prompt
else []
),
tools=tools,
)

timestep, max_timesteps = 0, query.settings.agent.max_timesteps
while not done:
if max_timesteps is not None and timestep >= max_timesteps:
logger.warning(
f"Agent didn't finish within {max_timesteps} timesteps, just"
" answering."
)
generate_answer_tool = next(
filter(
lambda x: x.info.name == GenerateAnswer.TOOL_FN_NAME,
env.tools,
)
async def rollout() -> AgentStatus:
obs, tools = await env.reset()
if on_env_reset_callback:
await on_env_reset_callback(env.state)

agent_state = ToolSelectorLedger(
messages=(
[
Message(
role="system",
content=query.settings.agent.agent_system_prompt,
)
await generate_answer_tool._tool_fn(
question=query.query, state=env.state
)
return env.state.session, AgentStatus.TRUNCATED
agent_state.messages += obs
for attempt in Retrying(
stop=stop_after_attempt(5),
retry=retry_if_exception_type(MalformedMessageError),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
):
with attempt: # Retrying if ToolSelector fails to select a tool
action = await agent(agent_state.messages, tools)
agent_state.messages = [*agent_state.messages, action]
if on_agent_action_callback:
await on_agent_action_callback(action, agent_state)

obs, reward, done, truncated = await env.step(action)
if on_env_step_callback:
await on_env_step_callback(obs, reward, done, truncated)
timestep += 1
status = AgentStatus.SUCCESS
except TimeoutError:
logger.warning(
f"Agent timeout after {query.settings.agent.timeout}-sec, just answering."
]
if query.settings.agent.agent_system_prompt
else []
),
tools=tools,
)
status = AgentStatus.TRUNCATED
generate_answer_tool = next(
filter(lambda x: x.info.name == GenerateAnswer.TOOL_FN_NAME, env.tools)
)
await generate_answer_tool._tool_fn(question=query.query, state=env.state)
except Exception:
logger.exception(f"Agent {agent} failed.")
status = AgentStatus.FAIL
return env.state.session, status

timestep, max_timesteps = 0, query.settings.agent.max_timesteps
done = False
while not done:
if max_timesteps is not None and timestep >= max_timesteps:
logger.warning(
f"Agent didn't finish within {max_timesteps} timesteps, just"
" answering."
)
generate_answer_tool = next(
filter(
lambda x: x.info.name == GenerateAnswer.TOOL_FN_NAME,
env.tools,
)
)
await generate_answer_tool._tool_fn(
question=query.query, state=env.state
)
return AgentStatus.TRUNCATED
agent_state.messages += obs
for attempt in Retrying(
stop=stop_after_attempt(5),
retry=retry_if_exception_type(MalformedMessageError),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
):
with attempt: # Retrying if ToolSelector fails to select a tool
action = await agent(agent_state.messages, tools)
agent_state.messages = [*agent_state.messages, action]
if on_agent_action_callback:
await on_agent_action_callback(action, agent_state)

obs, reward, done, truncated = await env.step(action)
if on_env_step_callback:
await on_env_step_callback(obs, reward, done, truncated)
timestep += 1
return AgentStatus.SUCCESS

return await _run_with_timeout_failure(rollout, query, env)


class LDPRolloutCallback(Callback):
Expand Down Expand Up @@ -328,36 +342,27 @@ async def run_ldp_agent(
# NOTE: don't worry about ldp import checks, because we know Settings.make_ldp_agent
# has already taken place, which checks that ldp is installed

try:
async with asyncio.timeout(query.settings.agent.timeout):
rollout_manager = RolloutManager(
agent,
callbacks=[
ldp_callback_type(
env,
on_env_reset_callback,
on_agent_action_callback,
on_env_step_callback,
)
],
)
await rollout_manager.sample_trajectories(
environments=[env], max_steps=query.settings.agent.max_timesteps
)
status = AgentStatus.SUCCESS
except TimeoutError:
logger.warning(
f"Agent timeout after {query.settings.agent.timeout}-sec, just answering."
async def rollout() -> AgentStatus:
rollout_manager = RolloutManager(
agent,
callbacks=[
ldp_callback_type(
env,
on_env_reset_callback,
on_agent_action_callback,
on_env_step_callback,
)
],
)
status = AgentStatus.TRUNCATED
generate_answer_tool = next(
filter(lambda x: x.info.name == GenerateAnswer.TOOL_FN_NAME, env.tools)
trajs = await rollout_manager.sample_trajectories(
environments=[env], max_steps=query.settings.agent.max_timesteps
)
await generate_answer_tool._tool_fn(question=query.query, state=env.state)
except Exception:
logger.exception(f"Agent {agent} failed.")
status = AgentStatus.FAIL
return env.state.session, status
traj = trajs[0]
if traj.steps[-1].truncated:
return AgentStatus.TRUNCATED
return AgentStatus.SUCCESS

return await _run_with_timeout_failure(rollout, query, env)


async def index_search(
Expand Down