-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement a simple strategy and an initial version of ADaPT strategy
- Loading branch information
Showing
10 changed files
with
479 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .adapt_strategy import ADaPTStrategy | ||
|
||
__all__ = ["ADaPTStrategy"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
from typing import Dict, Iterator, List, Optional, Tuple, Union, Sequence, AsyncIterator | ||
|
||
from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent | ||
from langchain_core.agents import AgentAction, AgentFinish | ||
from langchain_core.callbacks import ( | ||
AsyncCallbackManagerForChainRun, | ||
CallbackManagerForChainRun, | ||
) | ||
from langchain_core.runnables import Runnable | ||
from langchain_core.tools import BaseTool | ||
|
||
from ...action_executors import BaseActionExecutor | ||
from ..base_strategy import BaseCustomStrategy | ||
from .components import BaseADaPTExecutor, BaseADaPTPlanner | ||
from planning_library.strategies.adapt.utils import ADaPTTask | ||
|
||
|
||
class ADaPTStrategy(BaseCustomStrategy): | ||
"""ADaPT strategy. | ||
Based on "ADaPT: As-Needed Decomposition and Planning with Language Models" by Prasad et al. | ||
""" | ||
|
||
executor: BaseADaPTExecutor | ||
planner: BaseADaPTPlanner | ||
max_depth: int | ||
|
||
@staticmethod | ||
def create( | ||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent], | ||
tools: Sequence[BaseTool], | ||
action_executor: Optional[BaseActionExecutor] = None, | ||
planner_runnable: Optional[Runnable] = None, | ||
max_depth: int = 20, | ||
**kwargs, | ||
) -> "ADaPTStrategy": | ||
"""Creates an instance of ADaPT strategy. | ||
Args: | ||
agent: The agent to run for proposing thoughts at each DFS step. | ||
tools: The valid tools the agent can call. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def _adapt_step( | ||
self, | ||
current_task: ADaPTTask, | ||
run_manager: Optional[CallbackManagerForChainRun] = None, | ||
) -> Tuple[bool, AgentFinish, List[Tuple[AgentAction, str]]]: | ||
"""Performs an iteration of ADaPT strategy. | ||
Args: | ||
current_task: The input for the current step. | ||
run_manager: Callback for the current run. | ||
""" | ||
# 1: if we're too deep in task decomposition, finish early | ||
if current_task["depth"] > self.max_depth: | ||
return ( | ||
False, | ||
AgentFinish( | ||
return_values={}, log="Maximum decomposition depth reached." | ||
), | ||
[], | ||
) | ||
|
||
# 2: run task through executor | ||
is_completed, agent_outcome, intermediate_steps = self.executor.execute( | ||
inputs=current_task["inputs"], | ||
run_manager=run_manager.get_child( | ||
tag=f"executor:depth_{current_task['depth']}" | ||
) | ||
if run_manager | ||
else None, | ||
) | ||
|
||
# if executor estimated successful completion of a task, wrap up | ||
if is_completed: | ||
return True, agent_outcome, intermediate_steps | ||
else: | ||
# otherwise, call planner to further decompose a current task | ||
plan = self.planner.plan( | ||
inputs=current_task["inputs"], | ||
current_depth=current_task["depth"], | ||
agent_outcome=agent_outcome, | ||
intermediate_steps=intermediate_steps, | ||
run_manager=run_manager.get_child( | ||
tag=f"executor:depth_{current_task['depth']}" | ||
) | ||
if run_manager | ||
else None, | ||
) | ||
if plan["logic"] == "and": | ||
intermediate_steps = [] | ||
for task in plan["subtasks"]: | ||
cur_is_completed, cur_agent_outcome, cur_intermediate_steps = ( | ||
self._adapt_step(current_task=task, run_manager=run_manager) | ||
) | ||
if not cur_is_completed: | ||
agent_outcome = AgentFinish( | ||
return_values=cur_agent_outcome.return_values, | ||
log=f"Couldn't solve the task. Last log: {cur_agent_outcome.log}", | ||
) | ||
intermediate_steps.extend(cur_intermediate_steps) | ||
return False, agent_outcome, intermediate_steps | ||
|
||
agent_outcome = AgentFinish( | ||
return_values={}, log="Task solved successfully!" | ||
) | ||
return True, agent_outcome, intermediate_steps | ||
|
||
raise NotImplementedError("Currently, only `and` logic is supported.") | ||
|
||
def _run_strategy( | ||
self, | ||
inputs: Dict[str, str], | ||
name_to_tool_map: Dict[str, BaseTool], | ||
color_mapping: Dict[str, str], | ||
run_manager: Optional[CallbackManagerForChainRun] = None, | ||
) -> Iterator[Tuple[AgentFinish, List[Tuple[AgentAction, str]]]]: | ||
_, agent_outcome, intermediate_steps = self._adapt_step( | ||
current_task={"inputs": inputs, "depth": 0}, run_manager=run_manager | ||
) | ||
yield agent_outcome, intermediate_steps | ||
|
||
async def _adapt_astep( | ||
self, | ||
current_task: ADaPTTask, | ||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | ||
) -> Tuple[bool, AgentFinish, List[Tuple[AgentAction, str]]]: | ||
"""Performs an iteration of ADaPT strategy asynchronously. | ||
Args: | ||
current_task: The input on the current step. | ||
run_manager: Callback for the current run. | ||
""" | ||
# 1: if we're too deep in task decomposition, finish early | ||
if current_task["depth"] > self.max_depth: | ||
return ( | ||
False, | ||
AgentFinish( | ||
return_values={}, log="Maximum decomposition depth reached." | ||
), | ||
[], | ||
) | ||
|
||
# 2: run task through executor | ||
is_completed, agent_outcome, intermediate_steps = await self.executor.aexecute( | ||
inputs=current_task["inputs"], | ||
run_manager=run_manager.get_child( | ||
tag=f"executor:depth_{current_task['depth']}" | ||
) | ||
if run_manager | ||
else None, | ||
) | ||
|
||
# if executor estimated successful completion of a task, wrap up | ||
if is_completed: | ||
return True, agent_outcome, intermediate_steps | ||
else: | ||
# otherwise, call planner to further decompose a current task | ||
plan = await self.planner.aplan( | ||
inputs=current_task["inputs"], | ||
current_depth=current_task["depth"], | ||
agent_outcome=agent_outcome, | ||
intermediate_steps=intermediate_steps, | ||
run_manager=run_manager.get_child( | ||
tag=f"executor:depth_{current_task['depth']}" | ||
) | ||
if run_manager | ||
else None, | ||
) | ||
if plan["logic"] == "and": | ||
intermediate_steps = [] | ||
for task in plan["subtasks"]: | ||
( | ||
cur_is_completed, | ||
cur_agent_outcome, | ||
cur_intermediate_steps, | ||
) = await self._adapt_astep( | ||
current_task=task, run_manager=run_manager | ||
) | ||
if not cur_is_completed: | ||
agent_outcome = AgentFinish( | ||
return_values=cur_agent_outcome.return_values, | ||
log=f"Couldn't solve the task. Last log: {cur_agent_outcome.log}", | ||
) | ||
intermediate_steps.extend(cur_intermediate_steps) | ||
return False, agent_outcome, intermediate_steps | ||
|
||
agent_outcome = AgentFinish( | ||
return_values={}, log="Task solved successfully!" | ||
) | ||
return True, agent_outcome, intermediate_steps | ||
|
||
raise NotImplementedError("Currently, only `and` logic is supported.") | ||
|
||
async def _arun_strategy( | ||
self, | ||
inputs: Dict[str, str], | ||
name_to_tool_map: Dict[str, BaseTool], | ||
color_mapping: Dict[str, str], | ||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | ||
) -> AsyncIterator[Tuple[AgentFinish, List[Tuple[AgentAction, str]]]]: | ||
_, agent_outcome, intermediate_steps = await self._adapt_astep( | ||
current_task={"inputs": inputs, "depth": 0}, run_manager=run_manager | ||
) | ||
yield agent_outcome, intermediate_steps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .executors import BaseADaPTExecutor | ||
from .planners import BaseADaPTPlanner | ||
|
||
__all__ = ["BaseADaPTExecutor", "BaseADaPTPlanner"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from abc import ABC, abstractmethod | ||
from langchain_core.callbacks import ( | ||
CallbackManager, | ||
AsyncCallbackManager, | ||
) | ||
from typing import Optional, Tuple, List, Dict, Any | ||
from langchain_core.agents import AgentAction, AgentFinish | ||
from planning_library.strategies import BaseCustomStrategy | ||
|
||
|
||
class BaseADaPTExecutor(ABC): | ||
@abstractmethod | ||
def execute( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[CallbackManager] = None, | ||
) -> Tuple[bool, AgentFinish, List[Tuple[AgentAction, str]]]: ... | ||
|
||
@abstractmethod | ||
async def aexecute( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[AsyncCallbackManager] = None, | ||
) -> Tuple[bool, AgentFinish, List[Tuple[AgentAction, str]]]: ... | ||
|
||
|
||
class StrategyADaPTExecutor(BaseADaPTExecutor): | ||
def __init__(self, strategy: BaseCustomStrategy): | ||
self._executor = strategy | ||
|
||
def _is_completed(self, outcome: AgentFinish) -> bool: | ||
return "task completed" in outcome.log.lower() | ||
|
||
def execute( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[CallbackManager] = None, | ||
) -> Tuple[bool, AgentFinish, List[Tuple[AgentAction, str]]]: | ||
outputs = self._executor.invoke(**inputs)["outputs"] | ||
intermediate_steps = outputs.get("intermediate_steps", []) | ||
finish_log = outputs.get("finish_log", "") | ||
del outputs["intermediate_steps"] | ||
del outputs["finish_log"] | ||
outcome = AgentFinish(return_values=outputs, log=finish_log) | ||
is_completed = self._is_completed(outcome) | ||
return is_completed, outcome, intermediate_steps | ||
|
||
async def aexecute( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[AsyncCallbackManager] = None, | ||
) -> Tuple[bool, AgentFinish, List[Tuple[AgentAction, str]]]: | ||
outputs = await self._executor.ainvoke(**inputs) | ||
intermediate_steps = outputs.get("intermediate_steps", []) | ||
finish_log = outputs.get("finish_log", "") | ||
del outputs["intermediate_steps"] | ||
del outputs["finish_log"] | ||
outcome = AgentFinish(return_values=outputs, log=finish_log) | ||
is_completed = self._is_completed(outcome) | ||
return is_completed, outcome, intermediate_steps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from abc import ABC, abstractmethod | ||
from langchain_core.callbacks import ( | ||
CallbackManager, | ||
AsyncCallbackManager, | ||
) | ||
from typing import Optional, Tuple, List, Dict, Any | ||
from langchain_core.agents import AgentAction, AgentFinish | ||
from planning_library.strategies.adapt.utils import ADaPTPlan | ||
from langchain_core.runnables import Runnable | ||
|
||
|
||
class BaseADaPTPlanner(ABC): | ||
@abstractmethod | ||
def plan( | ||
self, | ||
inputs: Dict[str, Any], | ||
current_depth: int, | ||
agent_outcome: AgentFinish, | ||
intermediate_steps: List[Tuple[AgentAction, str]], | ||
run_manager: Optional[CallbackManager] = None, | ||
) -> ADaPTPlan: ... | ||
|
||
@abstractmethod | ||
async def aplan( | ||
self, | ||
inputs: Dict[str, Any], | ||
current_depth: int, | ||
agent_outcome: AgentFinish, | ||
intermediate_steps: List[Tuple[AgentAction, str]], | ||
run_manager: Optional[AsyncCallbackManager] = None, | ||
) -> ADaPTPlan: ... | ||
|
||
|
||
class RunnableADaPTPlanner(BaseADaPTPlanner): | ||
def __init__(self, runnable: Runnable[Dict[str, Any], ADaPTPlan]): | ||
self.runnable = runnable | ||
|
||
def plan( | ||
self, | ||
inputs: Dict[str, Any], | ||
current_depth: int, | ||
agent_outcome: AgentFinish, | ||
intermediate_steps: List[Tuple[AgentAction, str]], | ||
run_manager: Optional[CallbackManager] = None, | ||
) -> ADaPTPlan: | ||
return self.runnable.invoke( | ||
{ | ||
**inputs, | ||
"current_depth": current_depth, | ||
"agent_outcome": agent_outcome, | ||
"intermediate_steps": intermediate_steps, | ||
}, | ||
{"callbacks": run_manager} if run_manager else {}, | ||
) | ||
|
||
async def aplan( | ||
self, | ||
inputs: Dict[str, Any], | ||
current_depth: int, | ||
agent_outcome: AgentFinish, | ||
intermediate_steps: List[Tuple[AgentAction, str]], | ||
run_manager: Optional[AsyncCallbackManager] = None, | ||
) -> ADaPTPlan: | ||
return await self.runnable.ainvoke( | ||
{ | ||
**inputs, | ||
"current_depth": current_depth, | ||
"agent_outcome": agent_outcome, | ||
"intermediate_steps": intermediate_steps, | ||
}, | ||
{"callbacks": run_manager} if run_manager else {}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .typing_utils import ADaPTTask, ADaPTPlan | ||
|
||
__all__ = ["ADaPTTask", "ADaPTPlan"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from typing_extensions import TypedDict | ||
from typing import Dict, Any, List, Literal | ||
|
||
|
||
class ADaPTTask(TypedDict): | ||
inputs: Dict[str, Any] | ||
depth: int | ||
|
||
|
||
class ADaPTPlan(TypedDict): | ||
subtasks: List[ADaPTTask] | ||
logic: Literal["and", "or"] |
Oops, something went wrong.