Skip to content

Commit

Permalink
Fix ADaPT logic and extend action executor interface with reset method
Browse files Browse the repository at this point in the history
  • Loading branch information
saridormi committed Mar 25, 2024
1 parent 9c025a0 commit 7a0b1c8
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 30 deletions.
12 changes: 8 additions & 4 deletions planning_library/action_executors/base_action_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, overload, Sequence
from typing import List, overload, Sequence, Optional

from langchain_core.agents import AgentAction, AgentStep
from langchain_core.tools import BaseTool
Expand All @@ -10,6 +10,11 @@ class BaseActionExecutor(ABC):
@abstractmethod
def tools(self) -> Sequence[BaseTool]: ...

@abstractmethod
def reset(self, actions: Optional[List[AgentAction]] = None, **kwargs) -> None:
"""Resets the current state. If actions are passed, will also execute them."""
...

@overload
def execute(
self,
Expand All @@ -33,8 +38,7 @@ def execute(
"""Performs actions.
Args:
actions: Currently proposed actions. Can be: multi-action, single action, finishing.
run_manager: Callback for the current run.
actions: Currently proposed actions. Can be: multi-action, single action.
Returns:
* List[AgentStep] - for multi-action thoughts (List[AgentAction])
Expand Down Expand Up @@ -65,7 +69,7 @@ async def aexecute(
"""Performs actions asynchronously.
Args:
actions: Currently proposed actions. Can be: multi-action, single action, finishing.
actions: Currently proposed actions. Can be: multi-action, single action.
Returns:
* List[AgentStep] - for multi-action thoughts (List[AgentAction])
Expand Down
9 changes: 8 additions & 1 deletion planning_library/action_executors/default_action_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, overload, Sequence
from typing import List, overload, Sequence, Optional

from langchain_core.agents import AgentAction, AgentStep
from langchain_core.tools import BaseTool
Expand All @@ -10,6 +10,13 @@ class DefaultActionExecutor(BaseActionExecutor):
def __init__(self, tools: Sequence[BaseTool]):
self._tool_executor = ToolExecutor(tools)

def reset(self, actions: Optional[List[AgentAction]] = None, **kwargs) -> None:
"""Resets the current state. If actions are passed, will also execute them.
This action executor doesn't have a state by default, so this method doesn't do anything.
"""
...

@property
def tools(self) -> Sequence[BaseTool]:
return self._tool_executor.tools
Expand Down
11 changes: 10 additions & 1 deletion planning_library/action_executors/gymnasium_action_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ def __init__(
def tools(self) -> Sequence[BaseTool]:
return self._env.get_wrapper_attr("tools")

def reset(self, actions: Optional[List[AgentAction]] = None, **kwargs) -> None:
"""Resets the environment. If actions are passed, will also execute them."""

options = kwargs
if actions:
options["actions"] = actions

self._env.reset(seed=self._seed, options=options)

@overload
def execute(
self,
Expand All @@ -47,7 +56,7 @@ def execute(
**reset_kwargs,
) -> List[AgentStep] | AgentStep:
if reset_env_before_action:
self._env.reset(seed=self._seed, options=reset_kwargs)
self.reset(**reset_kwargs)

if isinstance(actions, AgentAction):
observation, reward, terminated, truncated, info = self._env.step(actions)
Expand Down
79 changes: 55 additions & 24 deletions planning_library/strategies/adapt/adapt_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ def create(
def _adapt_step(
self,
current_task: ADaPTTask,
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[bool, AgentFinish, List[Tuple[AgentAction, str]]]:
"""Performs an iteration of ADaPT strategy.
Args:
current_task: The input for the current step.
current_task: The input for the current iteration. It can either be the original input or a subtask of a plan generated on a previous step.
intermediate_steps: A list of actions taken before the current iteration.
run_manager: Callback for the current run.
"""
# 1: if we're too deep in task decomposition, finish early
Expand All @@ -60,11 +62,11 @@ def _adapt_step(
AgentFinish(
return_values={}, log="Maximum decomposition depth reached."
),
[],
intermediate_steps,
)

# 2: run task through executor
is_completed, agent_outcome, intermediate_steps = self.executor.execute(
is_completed, cur_agent_outcome, cur_intermediate_steps = self.executor.execute(
inputs=current_task["inputs"],
run_manager=run_manager.get_child(
tag=f"executor:depth_{current_task['depth']}"
Expand All @@ -73,35 +75,46 @@ def _adapt_step(
else None,
)

# if executor estimated successful completion of a task, wrap up
# 3.1: if executor estimated successful completion of a task, wrap up
if is_completed:
return True, agent_outcome, intermediate_steps
intermediate_steps.extend(cur_intermediate_steps)
return True, cur_agent_outcome, intermediate_steps
else:
# otherwise, call planner to further decompose a current task
# 3.2: otherwise:
# clean up the environment
self.action_executor.reset(actions=[step[0] for step in intermediate_steps])

# call a planner to further decompose a current task
plan = self.planner.plan(
inputs=current_task["inputs"],
current_depth=current_task["depth"],
agent_outcome=agent_outcome,
intermediate_steps=intermediate_steps,
agent_outcome=cur_agent_outcome,
intermediate_steps=cur_intermediate_steps,
run_manager=run_manager.get_child(
tag=f"executor:depth_{current_task['depth']}"
)
if run_manager
else None,
)
# when AND logic is given, execute tasks sequentially
if plan["logic"] == "and":
intermediate_steps = []
for task in plan["subtasks"]:
cur_is_completed, cur_agent_outcome, cur_intermediate_steps = (
self._adapt_step(current_task=task, run_manager=run_manager)
self._adapt_step(
current_task=task,
run_manager=run_manager,
intermediate_steps=intermediate_steps,
)
)

if not cur_is_completed:
agent_outcome = AgentFinish(
return_values=cur_agent_outcome.return_values,
log=f"Couldn't solve the task. Last log: {cur_agent_outcome.log}",
)
intermediate_steps.extend(cur_intermediate_steps)
return False, agent_outcome, intermediate_steps
else:
intermediate_steps.extend(cur_intermediate_steps)

agent_outcome = AgentFinish(
return_values={}, log="Task solved successfully!"
Expand All @@ -118,19 +131,23 @@ def _run_strategy(
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Iterator[Tuple[AgentFinish, List[Tuple[AgentAction, str]]]]:
_, agent_outcome, intermediate_steps = self._adapt_step(
current_task={"inputs": inputs, "depth": 0}, run_manager=run_manager
current_task={"inputs": inputs, "depth": 0},
run_manager=run_manager,
intermediate_steps=[],
)
yield agent_outcome, intermediate_steps

async def _adapt_astep(
self,
current_task: ADaPTTask,
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Tuple[bool, AgentFinish, List[Tuple[AgentAction, str]]]:
"""Performs an iteration of ADaPT strategy asynchronously.
Args:
current_task: The input on the current step.
current_task: The input for the current iteration. It can either be the original input or a subtask of a plan generated on a previous step.
intermediate_steps: A list of actions taken before the current iteration.
run_manager: Callback for the current run.
"""
# 1: if we're too deep in task decomposition, finish early
Expand All @@ -140,11 +157,15 @@ async def _adapt_astep(
AgentFinish(
return_values={}, log="Maximum decomposition depth reached."
),
[],
intermediate_steps,
)

# 2: run task through executor
is_completed, agent_outcome, intermediate_steps = await self.executor.aexecute(
(
is_completed,
cur_agent_outcome,
cur_intermediate_steps,
) = await self.executor.aexecute(
inputs=current_task["inputs"],
run_manager=run_manager.get_child(
tag=f"executor:depth_{current_task['depth']}"
Expand All @@ -153,39 +174,47 @@ async def _adapt_astep(
else None,
)

# if executor estimated successful completion of a task, wrap up
# 3.1: if executor estimated successful completion of a task, wrap up
if is_completed:
return True, agent_outcome, intermediate_steps
intermediate_steps.extend(cur_intermediate_steps)
return True, cur_agent_outcome, intermediate_steps
else:
# otherwise, call planner to further decompose a current task
# 3.2: otherwise:
# clean up the environment
self.action_executor.reset(actions=[step[0] for step in intermediate_steps])

plan = await self.planner.aplan(
inputs=current_task["inputs"],
current_depth=current_task["depth"],
agent_outcome=agent_outcome,
intermediate_steps=intermediate_steps,
agent_outcome=cur_agent_outcome,
intermediate_steps=cur_intermediate_steps,
run_manager=run_manager.get_child(
tag=f"executor:depth_{current_task['depth']}"
)
if run_manager
else None,
)
# when AND logic is given, execute tasks sequentially
if plan["logic"] == "and":
intermediate_steps = []
for task in plan["subtasks"]:
(
cur_is_completed,
cur_agent_outcome,
cur_intermediate_steps,
) = await self._adapt_astep(
current_task=task, run_manager=run_manager
current_task=task,
run_manager=run_manager,
intermediate_steps=intermediate_steps,
)

if not cur_is_completed:
agent_outcome = AgentFinish(
return_values=cur_agent_outcome.return_values,
log=f"Couldn't solve the task. Last log: {cur_agent_outcome.log}",
)
intermediate_steps.extend(cur_intermediate_steps)
return False, agent_outcome, intermediate_steps
else:
intermediate_steps.extend(cur_intermediate_steps)

agent_outcome = AgentFinish(
return_values={}, log="Task solved successfully!"
Expand All @@ -202,6 +231,8 @@ async def _arun_strategy(
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> AsyncIterator[Tuple[AgentFinish, List[Tuple[AgentAction, str]]]]:
_, agent_outcome, intermediate_steps = await self._adapt_astep(
current_task={"inputs": inputs, "depth": 0}, run_manager=run_manager
current_task={"inputs": inputs, "depth": 0},
run_manager=run_manager,
intermediate_steps=[],
)
yield agent_outcome, intermediate_steps

0 comments on commit 7a0b1c8

Please sign in to comment.