Skip to content

Commit

Permalink
Polish Reflexion & tool execution
Browse files Browse the repository at this point in the history
Switch from custom tool calling methods to ToolExecutor from LangGraph (bc I didn't know how to configure callbacks in custom methods to make it render beautifully in LangSmith)

Use BaseChatMessageHistory for self-reflections instead of a simple in-memory option
  • Loading branch information
saridormi committed Mar 18, 2024
1 parent b8ec947 commit e798a65
Show file tree
Hide file tree
Showing 14 changed files with 164 additions and 251 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ jobs:
run: |
poetry run ruff check
- name: Check formatting with ruff
run: |
poetry run ruff format --check
- name: Check types with mypy
run: |
poetry run mypy .
40 changes: 12 additions & 28 deletions environments/frozen_lake/common/environment.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,28 @@
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Tuple, Sequence

import gymnasium as gym
from gymnasium.core import ObsType, SupportsFloat
from gymnasium.envs.toy_text.frozen_lake import FrozenLakeEnv
from langchain_core.agents import AgentAction
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.tools import BaseTool

from .tools import (
MoveTool,
)
from planning_library.utils import get_tools_maps, perform_agent_action
from .tools import MoveTool
from planning_library.action_executors import DefaultActionExecutor


class FrozenLakeEnvWrapper(gym.Wrapper):
def __init__(self, env: FrozenLakeEnv):
super().__init__(env)
self.name_to_tool_map, self.color_mapping = get_tools_maps(
[
MoveTool(env=self.env.unwrapped), # type: ignore[call-arg]
]
)
# CheckPositionTool(env=self.env.unwrapped),])
# CheckMapTool(env=self.env.unwrapped)])
self._action_executor = DefaultActionExecutor(tools=[MoveTool(env=self)]) # type: ignore[call-arg]

@property
def tools(self) -> Sequence[BaseTool]:
return self._action_executor.tools

def step(
self,
cur_input: Tuple[
AgentAction, Optional[CallbackManagerForChainRun], Dict[str, Any]
],
self, action: AgentAction
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
action, run_manager, tool_run_logging_kwargs = cur_input
result = perform_agent_action(
agent_action=action,
name_to_tool_map=self.name_to_tool_map,
color_mapping=self.color_mapping,
run_manager=run_manager,
tool_run_kwargs=tool_run_logging_kwargs,
)
result = self._action_executor.execute(action)
return result.observation

def reset(
Expand All @@ -50,7 +36,5 @@ def reset(
if options is not None and "trajectory" in options:
for action in options["trajectory"]:
assert isinstance(action, AgentAction)
observation, reward, terminated, truncated, info = self.step(
(action, None, {})
)
observation, reward, terminated, truncated, info = self.step(action)
return observation, info
61 changes: 60 additions & 1 deletion environments/frozen_lake/common/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _run(
*args: Any,
**kwargs: Any,
) -> Tuple[Tuple[int, int], SupportsFloat, bool, bool, Dict[str, Any]]:
_observation, reward, terminated, truncated, info = self.env.step(
_observation, reward, terminated, truncated, info = self.env.unwrapped.step(
MoveTool._convert_direction_to_frozenlake(direction)
)
nrow = self.env.get_wrapper_attr("nrow")
Expand All @@ -75,6 +75,65 @@ def _run(
return observation, reward, terminated, truncated, info


class LookInput(BaseModel):
direction: Literal["left", "right", "down", "up"] = Field(
description="Which direction to look at."
)


class LookTool(BaseFrozenLakeTool, BaseTool):
name = "look"
description = dedent("""
Peeks at the adjacent cell in given direction. The following options are possible:
* out of bounds - it's not possible to move in the given direction from the current cell;
* S - starting cell;
* H - hole;
* F - frozen cell;
* G - goal.
""")
args_schema: Type[BaseModel] = LookInput

def _run(
self,
direction: str,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
nrow = self.env.get_wrapper_attr("nrow")
board = self.env.get_wrapper_attr("desc")
x, y = MoveTool._convert_frozenlake_observation_to_position(
observation=self.env.get_wrapper_attr("s"), nrow=nrow
)

match direction:
case "left":
observation = "out of bounds" if x == 0 else board[x - 1][y].decode()
case "right":
observation = (
"out of bounds" if x == nrow - 1 else board[x + 1][y].decode()
)
case "down":
observation = (
"out of bounds" if y == nrow - 1 else board[x][y + 1].decode()
)
case "up":
observation = "out of bounds" if y == 0 else board[x][y - 1].decode()
case _:
raise ValueError(
"Wrong direction; expected one of: 'left', 'right', 'down', 'up'."
)

info: Dict[str, Any]
reward, terminated, truncated, info = (
0,
False,
False,
{},
)

return observation, reward, terminated, truncated, info


class CheckMapInput(BaseModel): ...


Expand Down
Empty file.
51 changes: 7 additions & 44 deletions planning_library/action_executors/base_action_executor.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,74 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, overload
from typing import List, overload, Sequence

from langchain_core.agents import AgentAction, AgentFinish, AgentStep
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.agents import AgentAction, AgentStep
from langchain_core.tools import BaseTool


class BaseActionExecutor(ABC):
@property
@abstractmethod
def tools(self) -> Sequence[BaseTool]: ...

@overload
def execute(
self,
actions: List[AgentAction],
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
verbose: bool = True,
tool_run_logging_kwargs: Optional[Dict[str, Any]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs,
) -> List[AgentStep]: ...

@overload
def execute(
self,
actions: AgentAction,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
verbose: bool = True,
tool_run_logging_kwargs: Optional[Dict[str, Any]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs,
) -> AgentStep: ...

@abstractmethod
def execute(
self,
actions: List[AgentAction] | AgentAction | AgentFinish,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
verbose: bool = True,
tool_run_logging_kwargs: Optional[Dict[str, Any]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
actions: List[AgentAction] | AgentAction,
**kwargs,
) -> List[AgentStep] | AgentStep:
"""Performs actions.
Args:
actions: Currently proposed actions. Can be: multi-action, single action, finishing.
name_to_tool_map: Mapping from tool names to actual tools, used for calling tools based on agent's output.
color_mapping: Mapping from tool names to colors, used for logging purposes when calling tools.
run_manager: Callback for the current run.
Returns:
* List[AgentStep] - for multi-action thoughts (List[AgentAction])
* AgentStep - for single-action thoughts (AgentAction)
* None - for finishing thoughts (AgentFinish)
"""
...

@overload
async def aexecute(
self,
actions: List[AgentAction],
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
verbose: bool = True,
tool_run_logging_kwargs: Optional[Dict[str, Any]] = None,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs,
) -> List[AgentStep]: ...

@overload
async def aexecute(
self,
actions: AgentAction,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
verbose: bool = True,
tool_run_logging_kwargs: Optional[Dict[str, Any]] = None,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs,
) -> AgentStep: ...

@abstractmethod
async def aexecute(
self,
actions: List[AgentAction] | AgentAction,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
verbose: bool = True,
tool_run_logging_kwargs: Optional[Dict[str, Any]] = None,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs,
) -> List[AgentStep] | AgentStep:
"""Performs actions asynchronously.
Args:
actions: Currently proposed actions. Can be: multi-action, single action, finishing.
name_to_tool_map: Mapping from tool names to actual tools, used for calling tools based on agent's output.
color_mapping: Mapping from tool names to colors, used for logging purposes when calling tools.
run_manager: Callback for the current run.
Returns:
* List[AgentStep] - for multi-action thoughts (List[AgentAction])
* AgentStep - for single-action thoughts (AgentAction)
* None - for finishing thoughts (AgentFinish)
"""
...
Loading

0 comments on commit e798a65

Please sign in to comment.