-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example for FrozenLake and switch to ruff
- Loading branch information
Showing
41 changed files
with
1,572 additions
and
208 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import gymnasium as gym | ||
from gymnasium.core import ObsType, SupportsFloat | ||
from gymnasium.envs.toy_text.frozen_lake import FrozenLakeEnv | ||
from langchain_core.agents import AgentAction | ||
from langchain_core.callbacks import CallbackManagerForChainRun | ||
|
||
from environments.frozen_lake.common.tools import ( | ||
MoveTool, | ||
) | ||
from planning_library.utils import get_tools_maps, perform_agent_action | ||
|
||
|
||
class FrozenLakeEnvWrapper(gym.Wrapper): | ||
def __init__(self, env: FrozenLakeEnv): | ||
super().__init__(env) | ||
self.name_to_tool_map, self.color_mapping = get_tools_maps( | ||
[ | ||
MoveTool(env=self.env.unwrapped), | ||
] | ||
) | ||
# CheckPositionTool(env=self.env.unwrapped),]) | ||
# CheckMapTool(env=self.env.unwrapped)]) | ||
|
||
def step( | ||
self, | ||
cur_input: Tuple[ | ||
AgentAction, Optional[CallbackManagerForChainRun], Dict[str, Any] | ||
], | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
action, run_manager, tool_run_logging_kwargs = cur_input | ||
result = perform_agent_action( | ||
agent_action=action, | ||
name_to_tool_map=self.name_to_tool_map, | ||
color_mapping=self.color_mapping, | ||
run_manager=run_manager, | ||
tool_run_kwargs=tool_run_logging_kwargs, | ||
) | ||
return result.observation | ||
|
||
def reset( | ||
self, | ||
*, | ||
seed: int | None = None, | ||
options: Dict[str, Any] | None = None, | ||
) -> Tuple[ObsType, Dict[str, Any]]: | ||
result = self.env.reset(seed=seed, options=options) | ||
|
||
if options is not None and "trajectory" in options: | ||
for action in options["trajectory"]: | ||
assert isinstance(action, AgentAction) | ||
result = self.step(action) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import re | ||
|
||
from langchain_core.exceptions import OutputParserException | ||
from langchain_core.output_parsers import BaseOutputParser | ||
|
||
|
||
class FrozenMapEvaluateOutputParser(BaseOutputParser[float]): | ||
def parse(self, text: str) -> float: | ||
try: | ||
match = re.search(r"\[\[(.*?)\]\]", text.strip()) | ||
result = float(match.groups()[0]) | ||
if result < 0.0 or result > 1.0: | ||
raise ValueError("The given number is out of (0.0, 1.0) range.") | ||
return result | ||
except ValueError: | ||
raise OutputParserException( | ||
f"Couldn't convert {text} to float between 0 and 1." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
from textwrap import dedent | ||
from typing import Any, Literal, Tuple, Type | ||
|
||
import gymnasium as gym | ||
from langchain.pydantic_v1 import BaseModel, Field | ||
from langchain.tools import BaseTool | ||
|
||
|
||
class BaseFrozenLakeTool(BaseModel): | ||
"""Base tool for a FrozenLake environment. | ||
Environment is present as a field, but it won't be shown to models.""" | ||
|
||
env: gym.Env = Field(exclude=True) | ||
|
||
class Config(BaseTool.Config): | ||
pass | ||
|
||
|
||
class MoveInput(BaseModel): | ||
direction: Literal["left", "right", "down", "up"] = Field( | ||
description="Which direction to move." | ||
) | ||
|
||
|
||
class MoveTool(BaseFrozenLakeTool, BaseTool): | ||
name = "move" | ||
description = dedent( | ||
""" | ||
Moves one step in given direction. Returns the following: | ||
* observation: current position on the board; | ||
* reward: 1 when the goal is reached, 0 otherwise; | ||
* terminated: if True, the game has ended: there's no opportunity to move anymore (either the goal was found or the player has fallen into a hole); | ||
* truncated: if True, the time limit has been exceeded; | ||
* info: probability of moving in the wrong direction for the current cell (ice is slippery!)""" | ||
) | ||
args_schema: Type[BaseModel] = MoveInput | ||
|
||
@staticmethod | ||
def _convert_frozenlake_observation_to_position( | ||
observation: int, nrow: int | ||
) -> Tuple[int, int]: | ||
# FrozenLake: observation = current_row * nrow + current_col | ||
current_row, current_col = observation // nrow, observation % nrow | ||
return (current_row, current_col) | ||
|
||
@staticmethod | ||
def _convert_direction_to_frozenlake(direction: str) -> int: | ||
match direction: | ||
case "left": | ||
return 0 | ||
case "down": | ||
return 1 | ||
case "right": | ||
return 2 | ||
case "up": | ||
return 3 | ||
case _: | ||
raise ValueError(f"Wrong tool input {direction}.") | ||
|
||
def _run( | ||
self, | ||
direction: str, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Any: | ||
_observation, reward, terminated, truncated, info = self.env.step( | ||
MoveTool._convert_direction_to_frozenlake(direction) | ||
) | ||
nrow = self.env.get_wrapper_attr("nrow") | ||
observation = MoveTool._convert_frozenlake_observation_to_position( | ||
observation=_observation, nrow=nrow | ||
) | ||
return observation, reward, terminated, truncated, info | ||
|
||
|
||
class CheckMapInput(BaseModel): ... | ||
|
||
|
||
class CheckMapTool(BaseFrozenLakeTool, BaseTool): | ||
name = "check_map" | ||
description = """Peeks at current map without changing its state. | ||
The map is an n x n grid where different types of cells are denoted by different letters: | ||
* S - start cell | ||
* G - goal cell | ||
* F - frozen cell | ||
* H - hole cell | ||
Example for 2 x 2 case: | ||
SH | ||
FG | ||
""" | ||
args_schema: Type[BaseModel] = CheckMapInput | ||
|
||
def _run( | ||
self, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Any: | ||
observation, reward, terminated, truncated, info = ( | ||
"\n".join( | ||
"".join(x.decode() for x in y) | ||
for y in self.env.get_wrapper_attr("desc") | ||
), | ||
0, | ||
False, | ||
False, | ||
{}, | ||
) | ||
return observation, reward, terminated, truncated, info | ||
|
||
|
||
class CheckPositionInput(BaseModel): ... | ||
|
||
|
||
class CheckPositionTool(BaseFrozenLakeTool, BaseTool): | ||
name = "check_position" | ||
description = """Peeks at current position map without changing its state.""" | ||
args_schema: Type[BaseModel] = CheckMapInput | ||
|
||
def _run( | ||
self, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Any: | ||
observation, reward, terminated, truncated, info = ( | ||
MoveTool._convert_frozenlake_observation_to_position( | ||
self.env.get_wrapper_attr("s"), nrow=self.env.get_wrapper_attr("nrow") | ||
), | ||
0, | ||
False, | ||
False, | ||
{}, | ||
) | ||
return observation, reward, terminated, truncated, info |
Oops, something went wrong.