From e798a65c8abef3eed682a9a37e31ea6d3dcb3d70 Mon Sep 17 00:00:00 2001 From: Alexandra Eliseeva Date: Mon, 18 Mar 2024 21:54:35 +0100 Subject: [PATCH] Polish Reflexion & tool execution Switch from custom tool calling methods to ToolExecutor from LangGraph (bc I didn't know how to configure callbacks in custom methods to make it render beautifully in LangSmith) Use BaseChatMessageHistory for self-reflections instead of a simple in-memory option --- .github/workflows/workflow.yaml | 4 + .../frozen_lake/common/environment.py | 40 ++---- environments/frozen_lake/common/tools.py | 61 ++++++++- .../frozen_lake/reflexion/__init__.py | 0 .../action_executors/base_action_executor.py | 51 ++------ .../default_action_executor.py | 116 +++++------------- .../gymnasium_action_executor.py | 58 ++------- planning_library/strategies/base_strategy.py | 9 +- .../strategies/reflexion/components/actors.py | 9 +- .../reflexion/components/self_reflections.py | 13 +- .../strategies/reflexion/reflexion_graph.py | 35 +++--- .../reflexion/reflexion_strategy.py | 8 +- .../strategies/tot_dfs/tot_strategy.py | 9 +- .../strategies/tot_dfs/utils/tot_node.py | 2 +- 14 files changed, 164 insertions(+), 251 deletions(-) create mode 100644 environments/frozen_lake/reflexion/__init__.py diff --git a/.github/workflows/workflow.yaml b/.github/workflows/workflow.yaml index 2aa5f3e..38bb16b 100644 --- a/.github/workflows/workflow.yaml +++ b/.github/workflows/workflow.yaml @@ -25,6 +25,10 @@ jobs: run: | poetry run ruff check + - name: Check formatting with ruff + run: | + poetry run ruff format --check + - name: Check types with mypy run: | poetry run mypy . diff --git a/environments/frozen_lake/common/environment.py b/environments/frozen_lake/common/environment.py index 0b35220..df7aac5 100644 --- a/environments/frozen_lake/common/environment.py +++ b/environments/frozen_lake/common/environment.py @@ -1,42 +1,28 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple, Sequence import gymnasium as gym from gymnasium.core import ObsType, SupportsFloat from gymnasium.envs.toy_text.frozen_lake import FrozenLakeEnv from langchain_core.agents import AgentAction -from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.tools import BaseTool -from .tools import ( - MoveTool, -) -from planning_library.utils import get_tools_maps, perform_agent_action +from .tools import MoveTool +from planning_library.action_executors import DefaultActionExecutor class FrozenLakeEnvWrapper(gym.Wrapper): def __init__(self, env: FrozenLakeEnv): super().__init__(env) - self.name_to_tool_map, self.color_mapping = get_tools_maps( - [ - MoveTool(env=self.env.unwrapped), # type: ignore[call-arg] - ] - ) - # CheckPositionTool(env=self.env.unwrapped),]) - # CheckMapTool(env=self.env.unwrapped)]) + self._action_executor = DefaultActionExecutor(tools=[MoveTool(env=self)]) # type: ignore[call-arg] + + @property + def tools(self) -> Sequence[BaseTool]: + return self._action_executor.tools def step( - self, - cur_input: Tuple[ - AgentAction, Optional[CallbackManagerForChainRun], Dict[str, Any] - ], + self, action: AgentAction ) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: - action, run_manager, tool_run_logging_kwargs = cur_input - result = perform_agent_action( - agent_action=action, - name_to_tool_map=self.name_to_tool_map, - color_mapping=self.color_mapping, - run_manager=run_manager, - tool_run_kwargs=tool_run_logging_kwargs, - ) + result = self._action_executor.execute(action) return result.observation def reset( @@ -50,7 +36,5 @@ def reset( if options is not None and "trajectory" in options: for action in options["trajectory"]: assert isinstance(action, AgentAction) - observation, reward, terminated, truncated, info = self.step( - (action, None, {}) - ) + observation, reward, terminated, truncated, info = self.step(action) return observation, info diff --git a/environments/frozen_lake/common/tools.py b/environments/frozen_lake/common/tools.py index 4ba11a9..e427efa 100644 --- a/environments/frozen_lake/common/tools.py +++ b/environments/frozen_lake/common/tools.py @@ -65,7 +65,7 @@ def _run( *args: Any, **kwargs: Any, ) -> Tuple[Tuple[int, int], SupportsFloat, bool, bool, Dict[str, Any]]: - _observation, reward, terminated, truncated, info = self.env.step( + _observation, reward, terminated, truncated, info = self.env.unwrapped.step( MoveTool._convert_direction_to_frozenlake(direction) ) nrow = self.env.get_wrapper_attr("nrow") @@ -75,6 +75,65 @@ def _run( return observation, reward, terminated, truncated, info +class LookInput(BaseModel): + direction: Literal["left", "right", "down", "up"] = Field( + description="Which direction to look at." + ) + + +class LookTool(BaseFrozenLakeTool, BaseTool): + name = "look" + description = dedent(""" + Peeks at the adjacent cell in given direction. The following options are possible: + * out of bounds - it's not possible to move in the given direction from the current cell; + * S - starting cell; + * H - hole; + * F - frozen cell; + * G - goal. + """) + args_schema: Type[BaseModel] = LookInput + + def _run( + self, + direction: str, + *args: Any, + **kwargs: Any, + ) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: + nrow = self.env.get_wrapper_attr("nrow") + board = self.env.get_wrapper_attr("desc") + x, y = MoveTool._convert_frozenlake_observation_to_position( + observation=self.env.get_wrapper_attr("s"), nrow=nrow + ) + + match direction: + case "left": + observation = "out of bounds" if x == 0 else board[x - 1][y].decode() + case "right": + observation = ( + "out of bounds" if x == nrow - 1 else board[x + 1][y].decode() + ) + case "down": + observation = ( + "out of bounds" if y == nrow - 1 else board[x][y + 1].decode() + ) + case "up": + observation = "out of bounds" if y == 0 else board[x][y - 1].decode() + case _: + raise ValueError( + "Wrong direction; expected one of: 'left', 'right', 'down', 'up'." + ) + + info: Dict[str, Any] + reward, terminated, truncated, info = ( + 0, + False, + False, + {}, + ) + + return observation, reward, terminated, truncated, info + + class CheckMapInput(BaseModel): ... diff --git a/environments/frozen_lake/reflexion/__init__.py b/environments/frozen_lake/reflexion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/planning_library/action_executors/base_action_executor.py b/planning_library/action_executors/base_action_executor.py index a65f553..dfc6803 100644 --- a/planning_library/action_executors/base_action_executor.py +++ b/planning_library/action_executors/base_action_executor.py @@ -1,24 +1,19 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, overload +from typing import List, overload, Sequence -from langchain_core.agents import AgentAction, AgentFinish, AgentStep -from langchain_core.callbacks import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) +from langchain_core.agents import AgentAction, AgentStep from langchain_core.tools import BaseTool class BaseActionExecutor(ABC): + @property + @abstractmethod + def tools(self) -> Sequence[BaseTool]: ... + @overload def execute( self, actions: List[AgentAction], - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[CallbackManagerForChainRun] = None, **kwargs, ) -> List[AgentStep]: ... @@ -26,37 +21,24 @@ def execute( def execute( self, actions: AgentAction, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[CallbackManagerForChainRun] = None, **kwargs, ) -> AgentStep: ... @abstractmethod def execute( self, - actions: List[AgentAction] | AgentAction | AgentFinish, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[CallbackManagerForChainRun] = None, + actions: List[AgentAction] | AgentAction, **kwargs, ) -> List[AgentStep] | AgentStep: """Performs actions. Args: actions: Currently proposed actions. Can be: multi-action, single action, finishing. - name_to_tool_map: Mapping from tool names to actual tools, used for calling tools based on agent's output. - color_mapping: Mapping from tool names to colors, used for logging purposes when calling tools. run_manager: Callback for the current run. Returns: * List[AgentStep] - for multi-action thoughts (List[AgentAction]) * AgentStep - for single-action thoughts (AgentAction) - * None - for finishing thoughts (AgentFinish) """ ... @@ -64,11 +46,6 @@ def execute( async def aexecute( self, actions: List[AgentAction], - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **kwargs, ) -> List[AgentStep]: ... @@ -76,11 +53,6 @@ async def aexecute( async def aexecute( self, actions: AgentAction, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **kwargs, ) -> AgentStep: ... @@ -88,24 +60,15 @@ async def aexecute( async def aexecute( self, actions: List[AgentAction] | AgentAction, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **kwargs, ) -> List[AgentStep] | AgentStep: """Performs actions asynchronously. Args: actions: Currently proposed actions. Can be: multi-action, single action, finishing. - name_to_tool_map: Mapping from tool names to actual tools, used for calling tools based on agent's output. - color_mapping: Mapping from tool names to colors, used for logging purposes when calling tools. - run_manager: Callback for the current run. Returns: * List[AgentStep] - for multi-action thoughts (List[AgentAction]) * AgentStep - for single-action thoughts (AgentAction) - * None - for finishing thoughts (AgentFinish) """ ... diff --git a/planning_library/action_executors/default_action_executor.py b/planning_library/action_executors/default_action_executor.py index c1572f5..be6189a 100644 --- a/planning_library/action_executors/default_action_executor.py +++ b/planning_library/action_executors/default_action_executor.py @@ -1,27 +1,23 @@ -import asyncio -from typing import Any, Dict, List, Optional, overload +from typing import List, overload, Sequence from langchain_core.agents import AgentAction, AgentStep -from langchain_core.callbacks import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) from langchain_core.tools import BaseTool - -from ..utils.actions_utils import aperform_agent_action, perform_agent_action +from langgraph.prebuilt.tool_executor import ToolExecutor # type: ignore[import-untyped] from .base_action_executor import BaseActionExecutor class DefaultActionExecutor(BaseActionExecutor): + def __init__(self, tools: Sequence[BaseTool]): + self._tool_executor = ToolExecutor(tools) + + @property + def tools(self) -> Sequence[BaseTool]: + return self._tool_executor.tools + @overload def execute( self, actions: List[AgentAction], - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[CallbackManagerForChainRun] = None, **kwargs, ) -> List[AgentStep]: ... @@ -29,57 +25,29 @@ def execute( def execute( self, actions: AgentAction, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[CallbackManagerForChainRun] = None, **kwargs, ) -> AgentStep: ... def execute( self, actions: List[AgentAction] | AgentAction, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[CallbackManagerForChainRun] = None, **kwargs, ) -> List[AgentStep] | AgentStep: - if isinstance(actions, AgentAction): - tool_result = perform_agent_action( - agent_action=actions, - name_to_tool_map=name_to_tool_map, - color_mapping=color_mapping, - verbose=verbose, - tool_run_kwargs=tool_run_logging_kwargs, - run_manager=run_manager, - ) - return tool_result - elif isinstance(actions, list): - observations = [] - for action in actions: - tool_result = perform_agent_action( - agent_action=action, - name_to_tool_map=name_to_tool_map, - color_mapping=color_mapping, - verbose=verbose, - tool_run_kwargs=tool_run_logging_kwargs, - run_manager=run_manager, - ) - observations.append(tool_result) - return observations + observations = self._tool_executor.invoke(actions) + if isinstance(observations, list): + assert isinstance(actions, list) + return [ + AgentStep(action=action, observation=observation) + for action, observation in zip(actions, observations) + ] + + assert isinstance(actions, AgentAction) + return AgentStep(action=actions, observation=observations) @overload async def aexecute( self, actions: List[AgentAction], - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **kwargs, ) -> List[AgentStep]: ... @@ -87,48 +55,20 @@ async def aexecute( async def aexecute( self, actions: AgentAction, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **kwargs, ) -> AgentStep: ... async def aexecute( self, actions: List[AgentAction] | AgentAction, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **kwargs, ) -> List[AgentStep] | AgentStep: - if isinstance(actions, AgentAction): - tool_result = await aperform_agent_action( - agent_action=actions, - name_to_tool_map=name_to_tool_map, - color_mapping=color_mapping, - verbose=verbose, - tool_run_kwargs=tool_run_logging_kwargs, - run_manager=run_manager, - ) - return tool_result - elif isinstance(actions, list): - # TODO: no idea why mypy complains - with asyncio.TaskGroup() as tg: # type: ignore[attr-defined] - tool_results = [ - tg.create_task( - aperform_agent_action( - agent_action=action, - name_to_tool_map=name_to_tool_map, - color_mapping=color_mapping, - verbose=verbose, - tool_run_kwargs=tool_run_logging_kwargs, - run_manager=run_manager, - ) - ) - for action in actions - ] - return [task.result() for task in tool_results] + observations = await self._tool_executor.ainvoke(actions) + if isinstance(observations, list): + assert isinstance(actions, list) + return [ + AgentStep(action=action, observation=observation) + for action, observation in zip(actions, observations) + ] + assert isinstance(actions, AgentAction) + return AgentStep(action=actions, observation=observations) diff --git a/planning_library/action_executors/gymnasium_action_executor.py b/planning_library/action_executors/gymnasium_action_executor.py index c768cd9..f4557e9 100644 --- a/planning_library/action_executors/gymnasium_action_executor.py +++ b/planning_library/action_executors/gymnasium_action_executor.py @@ -1,15 +1,11 @@ -from typing import Any, Dict, List, Optional, Tuple, overload +from typing import List, Optional, overload, Sequence import gymnasium as gym from gymnasium.core import ObsType from langchain_core.agents import AgentAction, AgentStep -from langchain_core.callbacks import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain_core.tools import BaseTool from .base_action_executor import BaseActionExecutor +from langchain_core.tools import BaseTool class GymnasiumActionExecutor(BaseActionExecutor): @@ -17,22 +13,21 @@ def __init__( self, env: gym.Env[ ObsType, - Tuple[AgentAction, Optional[CallbackManagerForChainRun], Dict[str, Any]], + AgentAction, ], seed: Optional[int] = None, ): self._env = env self._seed = seed + @property + def tools(self) -> Sequence[BaseTool]: + return self._env.get_wrapper_attr("tools") + @overload def execute( self, actions: List[AgentAction], - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[CallbackManagerForChainRun] = None, reset_env_before_action: bool = False, **reset_kwargs, ) -> List[AgentStep]: ... @@ -41,11 +36,6 @@ def execute( def execute( self, actions: AgentAction, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[CallbackManagerForChainRun] = None, reset_env_before_action: bool = False, **reset_kwargs, ) -> AgentStep: ... @@ -53,25 +43,14 @@ def execute( def execute( self, actions: List[AgentAction] | AgentAction, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[CallbackManagerForChainRun] = None, reset_env_before_action: bool = False, **reset_kwargs, ) -> List[AgentStep] | AgentStep: - tool_run_logging_kwargs = ( - {} if tool_run_logging_kwargs is None else tool_run_logging_kwargs - ) - if reset_env_before_action: self._env.reset(seed=self._seed, options=reset_kwargs) if isinstance(actions, AgentAction): - observation, reward, terminated, truncated, info = self._env.step( - (actions, run_manager, tool_run_logging_kwargs) - ) + observation, reward, terminated, truncated, info = self._env.step(actions) return AgentStep( action=actions, @@ -83,14 +62,10 @@ def execute( "info": info, }, ) + return [ self.execute( actions=action, - name_to_tool_map=name_to_tool_map, - color_mapping=color_mapping, - verbose=verbose, - tool_run_logging_kwargs=tool_run_logging_kwargs, - run_manager=run_manager, reset_env_before_action=reset_env_before_action, **reset_kwargs, ) @@ -101,11 +76,6 @@ def execute( async def aexecute( self, actions: List[AgentAction], - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, reset_before_action: bool = False, **reset_kwargs, ) -> List[AgentStep]: ... @@ -114,11 +84,6 @@ async def aexecute( async def aexecute( self, actions: AgentAction, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, reset_before_action: bool = False, **reset_kwargs, ) -> AgentStep: ... @@ -126,11 +91,6 @@ async def aexecute( async def aexecute( self, actions: List[AgentAction] | AgentAction, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], - verbose: bool = True, - tool_run_logging_kwargs: Optional[Dict[str, Any]] = None, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, reset_before_action: bool = False, **reset_kwargs, ) -> List[AgentStep] | AgentStep: diff --git a/planning_library/strategies/base_strategy.py b/planning_library/strategies/base_strategy.py index f32d363..b42ef69 100644 --- a/planning_library/strategies/base_strategy.py +++ b/planning_library/strategies/base_strategy.py @@ -24,18 +24,21 @@ from langchain_core.tools import BaseTool from langgraph.pregel import Pregel # type: ignore[import-untyped] -from planning_library.action_executors import BaseActionExecutor, DefaultActionExecutor +from planning_library.action_executors import BaseActionExecutor from planning_library.utils.actions_utils import get_tools_maps class BaseCustomStrategy(Chain, ABC): agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] - tools: Sequence[BaseTool] - action_executor: BaseActionExecutor = DefaultActionExecutor() + action_executor: BaseActionExecutor return_intermediate_steps: bool = False max_iterations: int = 15 verbose: bool = True + @property + def tools(self) -> Sequence[BaseTool]: + return self.action_executor.tools + @property def input_keys(self) -> List[str]: """Return the input keys.""" diff --git a/planning_library/strategies/reflexion/components/actors.py b/planning_library/strategies/reflexion/components/actors.py index 7d90374..04fed61 100644 --- a/planning_library/strategies/reflexion/components/actors.py +++ b/planning_library/strategies/reflexion/components/actors.py @@ -3,6 +3,7 @@ from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.messages import BaseMessage class BaseActor(ABC): @@ -11,7 +12,7 @@ def act( self, inputs: Dict[str, Any], intermediate_steps: List[Tuple[AgentAction, str]], - self_reflections: Sequence[str], + self_reflections: Sequence[BaseMessage], **kwargs, ) -> Union[List[AgentAction], AgentAction, AgentFinish]: ... @@ -20,7 +21,7 @@ async def aact( self, inputs: Dict[str, Any], intermediate_steps: List[Tuple[AgentAction, str]], - self_reflections: Sequence[str], + self_reflections: Sequence[BaseMessage], **kwargs, ) -> Union[List[AgentAction], AgentAction, AgentFinish]: ... @@ -33,7 +34,7 @@ def act( self, inputs: Dict[str, Any], intermediate_steps: List[Tuple[AgentAction, str]], - self_reflections: Sequence[str], + self_reflections: Sequence[BaseMessage], **kwargs, ) -> Union[List[AgentAction], AgentAction, AgentFinish]: return self.agent.plan( @@ -46,7 +47,7 @@ async def aact( self, inputs: Dict[str, Any], intermediate_steps: List[Tuple[AgentAction, str]], - self_reflections: Sequence[str], + self_reflections: Sequence[BaseMessage], **kwargs, ) -> Union[List[AgentAction], AgentAction, AgentFinish]: return await self.agent.aplan( diff --git a/planning_library/strategies/reflexion/components/self_reflections.py b/planning_library/strategies/reflexion/components/self_reflections.py index 2f511c2..645b2b1 100644 --- a/planning_library/strategies/reflexion/components/self_reflections.py +++ b/planning_library/strategies/reflexion/components/self_reflections.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Sequence from langchain_core.agents import AgentAction, AgentFinish from langchain_core.runnables import Runnable +from langchain_core.messages import BaseMessage class BaseSelfReflection(ABC): @@ -13,7 +14,7 @@ def self_reflect( intermediate_steps: List[Tuple[AgentAction, str]], agent_outcome: AgentFinish, evaluator_score: Any, - ) -> str: ... + ) -> Sequence[BaseMessage]: ... @abstractmethod async def aself_reflect( @@ -22,12 +23,12 @@ async def aself_reflect( intermediate_steps: List[Tuple[AgentAction, str]], agent_outcome: AgentFinish, evaluator_score: Any, - ) -> str: + ) -> Sequence[BaseMessage]: pass class RunnableSelfReflection(BaseSelfReflection): - def __init__(self, llm_chain: Runnable[Dict[str, Any], str]): + def __init__(self, llm_chain: Runnable[Dict[str, Any], Sequence[BaseMessage]]): self.llm_chain = llm_chain def self_reflect( @@ -36,7 +37,7 @@ def self_reflect( intermediate_steps: List[Tuple[AgentAction, str]], agent_outcome: AgentFinish, evaluator_score: Any, - ) -> str: + ) -> Sequence[BaseMessage]: return self.llm_chain.invoke( { "inputs": inputs, @@ -52,7 +53,7 @@ async def aself_reflect( intermediate_steps: List[Tuple[AgentAction, str]], agent_outcome: AgentFinish, evaluator_score: Any, - ) -> str: + ) -> Sequence[BaseMessage]: return await self.llm_chain.ainvoke( { "inputs": inputs, diff --git a/planning_library/strategies/reflexion/reflexion_graph.py b/planning_library/strategies/reflexion/reflexion_graph.py index 02c962a..5f587bb 100644 --- a/planning_library/strategies/reflexion/reflexion_graph.py +++ b/planning_library/strategies/reflexion/reflexion_graph.py @@ -6,20 +6,20 @@ List, Literal, Optional, - Sequence, Tuple, TypedDict, Union, ) +from langchain.memory import ChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage from langchain_core.agents import AgentAction, AgentFinish, AgentStep from langchain_core.runnables import RunnableLambda -from langchain_core.tools import BaseTool from langgraph.graph import END, StateGraph # type: ignore[import] from langgraph.pregel import Pregel # type: ignore[import-untyped] from ...action_executors import BaseActionExecutor -from ...utils import get_tools_maps from .components.actors import BaseActor from .components.evaluators import ReflexionEvaluator from .components.self_reflections import BaseSelfReflection @@ -32,18 +32,24 @@ class ReflexionState(TypedDict): agent_outcome: Optional[Union[List[AgentAction], AgentAction, AgentFinish]] evaluator_score: Any evaluator_should_continue: Optional[bool] - self_reflections: List[str] + self_reflection_memory: BaseChatMessageHistory + self_reflections: List[BaseMessage] intermediate_steps: List[Tuple[AgentAction, str]] iteration: int class ReflexionNodes: @staticmethod - def init(state: ReflexionState) -> ReflexionState: + def init( + state: ReflexionState, memory: Optional[BaseChatMessageHistory] = None + ) -> ReflexionState: """The entry node in the graph. Initializes the state correctly.""" state["agent_outcome"] = None state["evaluator_score"] = None state["evaluator_should_continue"] = None + state["self_reflection_memory"] = ( + ChatMessageHistory() if memory is None else memory + ) state["self_reflections"] = [] state["intermediate_steps"] = [] state["iteration"] = 1 @@ -60,6 +66,7 @@ def re_init( state["evaluator_should_continue"] = None state["intermediate_steps"] = [] state["iteration"] += 1 + state["self_reflections"] = state["self_reflection_memory"].messages if reset_environment: reset_environment(state["inputs"]) @@ -93,8 +100,6 @@ async def aact(state: ReflexionState, actor: BaseActor) -> ReflexionState: def execute_actions( state: ReflexionState, action_executor: BaseActionExecutor, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], ) -> ReflexionState: """Synchronous version of executing actions as previously requested by an agent.""" assert ( @@ -106,8 +111,6 @@ def execute_actions( observation = action_executor.execute( actions=state["agent_outcome"], - name_to_tool_map=name_to_tool_map, - color_mapping=color_mapping, ) if isinstance(observation, AgentStep): @@ -124,8 +127,6 @@ def execute_actions( async def aexecute_actions( state: ReflexionState, action_executor: BaseActionExecutor, - name_to_tool_map: Dict[str, BaseTool], - color_mapping: Dict[str, str], ) -> ReflexionState: """Asynchronous version of executing tools as previously requested by an agent.""" assert ( @@ -137,8 +138,6 @@ async def aexecute_actions( observation = await action_executor.aexecute( actions=state["agent_outcome"], - name_to_tool_map=name_to_tool_map, - color_mapping=color_mapping, ) if isinstance(observation, AgentStep): @@ -201,7 +200,7 @@ def self_reflect( agent_outcome=state["agent_outcome"], evaluator_score=state["evaluator_score"], ) - state["self_reflections"].append(reflection) + state["self_reflection_memory"].add_messages(reflection) return state @staticmethod @@ -219,7 +218,7 @@ async def aself_reflect( agent_outcome=state["agent_outcome"], evaluator_score=state["evaluator_score"], ) - state["self_reflections"].append(reflection) + await state["self_reflection_memory"].aadd_messages(reflection) return state @@ -262,7 +261,6 @@ def create_reflexion_graph( evaluator: ReflexionEvaluator, self_reflection: BaseSelfReflection, action_executor: BaseActionExecutor, - tools: Sequence[BaseTool], max_iterations: Optional[int], reset_environment: Optional[Callable[[Dict[str, Any]], None]], ) -> Pregel: @@ -282,21 +280,16 @@ def create_reflexion_graph( ), ) - name_to_tool_map, color_mapping = get_tools_maps(tools) builder.add_node( "execute_actions", RunnableLambda( partial( ReflexionNodes.execute_actions, action_executor=action_executor, - name_to_tool_map=name_to_tool_map, - color_mapping=color_mapping, ), afunc=partial( ReflexionNodes.aexecute_actions, action_executor=action_executor, - name_to_tool_map=name_to_tool_map, - color_mapping=color_mapping, ), ), ) diff --git a/planning_library/strategies/reflexion/reflexion_strategy.py b/planning_library/strategies/reflexion/reflexion_strategy.py index 95ae102..091e873 100644 --- a/planning_library/strategies/reflexion/reflexion_strategy.py +++ b/planning_library/strategies/reflexion/reflexion_strategy.py @@ -29,7 +29,7 @@ class ReflexionStrategy(BaseLangGraphStrategy): def create( agent: Runnable, tools: Sequence[BaseTool], - action_executor: BaseActionExecutor = DefaultActionExecutor(), + action_executor: Optional[BaseActionExecutor] = None, evaluator_runnable: Optional[Runnable[ReflexionEvaluatorInput, Any]] = None, self_reflection_runnable: Optional[Runnable[Dict[str, Any], Any]] = None, max_iterations: Optional[int] = None, @@ -47,7 +47,7 @@ def create( Args: agent: The agent to run for proposing thoughts at each DFS step. tools: The valid tools the agent can call. - action_executor: The class responsible for actually executing actions. By default, simply calls LangChain tools. + action_executor: The class responsible for actually executing actions. evaluator_runnable: Runnable that powers an evaluator. If None, the default model will be used. self_reflection_runnable: Runnable that powers self-reflection. If None, the default model will be used. max_iterations: Maximum number of iterations. If None, no restrictions on the number of iterations are imposed. @@ -74,12 +74,14 @@ def create( self_reflection = RunnableSelfReflection(self_reflection_runnable) + if action_executor is None: + action_executor = DefaultActionExecutor(tools) + return create_reflexion_graph( actor=actor, evaluator=evaluator, self_reflection=self_reflection, action_executor=action_executor, - tools=tools, max_iterations=max_iterations, reset_environment=reset_environment, ) diff --git a/planning_library/strategies/tot_dfs/tot_strategy.py b/planning_library/strategies/tot_dfs/tot_strategy.py index 188c319..f718196 100644 --- a/planning_library/strategies/tot_dfs/tot_strategy.py +++ b/planning_library/strategies/tot_dfs/tot_strategy.py @@ -6,9 +6,9 @@ Iterator, List, Optional, - Sequence, Tuple, Union, + Sequence, ) from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent @@ -53,7 +53,7 @@ class TreeOfThoughtsDFSStrategy(BaseCustomStrategy): def create( agent: Union[BaseSingleActionAgent, BaseMultiActionAgent], tools: Sequence[BaseTool], - action_executor: BaseActionExecutor = DefaultActionExecutor(), + action_executor: Optional[BaseActionExecutor] = None, evaluator_runnable: Optional[Runnable] = None, value_threshold: float = 0.5, max_thoughts: int = 3, @@ -80,9 +80,11 @@ def create( "Default runnable for thought evaluator is not supported yet." ) + if action_executor is None: + action_executor = DefaultActionExecutor(tools) + strategy = TreeOfThoughtsDFSStrategy( agent=agent, - tools=tools, thought_generator=AgentThoughtGenerator(), thought_evaluator=ThoughtEvaluator( backbone=RunnableThoughtEvaluator(evaluator_runnable), @@ -209,6 +211,7 @@ def _run_strategy( new_node = ToTNode( parent=cur_node, thought=new_thought, observation=observation ) + cur_node.children.append(new_node) if isinstance(new_thought, AgentFinish): self.terminals.append(new_node) diff --git a/planning_library/strategies/tot_dfs/utils/tot_node.py b/planning_library/strategies/tot_dfs/utils/tot_node.py index 8f729f7..0821437 100644 --- a/planning_library/strategies/tot_dfs/utils/tot_node.py +++ b/planning_library/strategies/tot_dfs/utils/tot_node.py @@ -38,7 +38,7 @@ def trajectory(self) -> List[Tuple[AgentAction, str]]: trajectory_actions.append( (node.observation.action, node.observation.observation) ) - elif isinstance(node.thought, AgentFinish): + elif isinstance(node.thought, AgentFinish) and node is not self: raise ValueError("AgentFinish detected as non-terminal node.") node = node.parent