Skip to content

Commit

Permalink
Add example for FrozenLake and switch to ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
saridormi committed Mar 15, 2024
1 parent 103f155 commit db002a0
Show file tree
Hide file tree
Showing 41 changed files with 1,572 additions and 208 deletions.
8 changes: 2 additions & 6 deletions .github/workflows/workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,9 @@ jobs:
run: |
poetry install --no-interaction
- name: Lint with Black
- name: Lint with ruff
run: |
poetry run black . --check
- name: Check import styling with isort
run: |
poetry run isort . --check
poetry run ruff check
- name: Check types with mypy
run: |
Expand Down
Empty file.
Empty file.
54 changes: 54 additions & 0 deletions environments/frozen_lake/common/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Dict, Optional, Tuple

import gymnasium as gym
from gymnasium.core import ObsType, SupportsFloat
from gymnasium.envs.toy_text.frozen_lake import FrozenLakeEnv
from langchain_core.agents import AgentAction
from langchain_core.callbacks import CallbackManagerForChainRun

from environments.frozen_lake.common.tools import (
MoveTool,
)
from planning_library.utils import get_tools_maps, perform_agent_action


class FrozenLakeEnvWrapper(gym.Wrapper):
def __init__(self, env: FrozenLakeEnv):
super().__init__(env)
self.name_to_tool_map, self.color_mapping = get_tools_maps(
[
MoveTool(env=self.env.unwrapped),
]
)
# CheckPositionTool(env=self.env.unwrapped),])
# CheckMapTool(env=self.env.unwrapped)])

def step(
self,
cur_input: Tuple[
AgentAction, Optional[CallbackManagerForChainRun], Dict[str, Any]
],
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
action, run_manager, tool_run_logging_kwargs = cur_input
result = perform_agent_action(
agent_action=action,
name_to_tool_map=self.name_to_tool_map,
color_mapping=self.color_mapping,
run_manager=run_manager,
tool_run_kwargs=tool_run_logging_kwargs,
)
return result.observation

def reset(
self,
*,
seed: int | None = None,
options: Dict[str, Any] | None = None,
) -> Tuple[ObsType, Dict[str, Any]]:
result = self.env.reset(seed=seed, options=options)

if options is not None and "trajectory" in options:
for action in options["trajectory"]:
assert isinstance(action, AgentAction)
result = self.step(action)
return result
18 changes: 18 additions & 0 deletions environments/frozen_lake/common/evaluate_output_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import re

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser


class FrozenMapEvaluateOutputParser(BaseOutputParser[float]):
def parse(self, text: str) -> float:
try:
match = re.search(r"\[\[(.*?)\]\]", text.strip())
result = float(match.groups()[0])
if result < 0.0 or result > 1.0:
raise ValueError("The given number is out of (0.0, 1.0) range.")
return result
except ValueError:
raise OutputParserException(
f"Couldn't convert {text} to float between 0 and 1."
)
137 changes: 137 additions & 0 deletions environments/frozen_lake/common/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from textwrap import dedent
from typing import Any, Literal, Tuple, Type

import gymnasium as gym
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool


class BaseFrozenLakeTool(BaseModel):
"""Base tool for a FrozenLake environment.
Environment is present as a field, but it won't be shown to models."""

env: gym.Env = Field(exclude=True)

class Config(BaseTool.Config):
pass


class MoveInput(BaseModel):
direction: Literal["left", "right", "down", "up"] = Field(
description="Which direction to move."
)


class MoveTool(BaseFrozenLakeTool, BaseTool):
name = "move"
description = dedent(
"""
Moves one step in given direction. Returns the following:
* observation: current position on the board;
* reward: 1 when the goal is reached, 0 otherwise;
* terminated: if True, the game has ended: there's no opportunity to move anymore (either the goal was found or the player has fallen into a hole);
* truncated: if True, the time limit has been exceeded;
* info: probability of moving in the wrong direction for the current cell (ice is slippery!)"""
)
args_schema: Type[BaseModel] = MoveInput

@staticmethod
def _convert_frozenlake_observation_to_position(
observation: int, nrow: int
) -> Tuple[int, int]:
# FrozenLake: observation = current_row * nrow + current_col
current_row, current_col = observation // nrow, observation % nrow
return (current_row, current_col)

@staticmethod
def _convert_direction_to_frozenlake(direction: str) -> int:
match direction:
case "left":
return 0
case "down":
return 1
case "right":
return 2
case "up":
return 3
case _:
raise ValueError(f"Wrong tool input {direction}.")

def _run(
self,
direction: str,
*args: Any,
**kwargs: Any,
) -> Any:
_observation, reward, terminated, truncated, info = self.env.step(
MoveTool._convert_direction_to_frozenlake(direction)
)
nrow = self.env.get_wrapper_attr("nrow")
observation = MoveTool._convert_frozenlake_observation_to_position(
observation=_observation, nrow=nrow
)
return observation, reward, terminated, truncated, info


class CheckMapInput(BaseModel): ...


class CheckMapTool(BaseFrozenLakeTool, BaseTool):
name = "check_map"
description = """Peeks at current map without changing its state.
The map is an n x n grid where different types of cells are denoted by different letters:
* S - start cell
* G - goal cell
* F - frozen cell
* H - hole cell
Example for 2 x 2 case:
SH
FG
"""
args_schema: Type[BaseModel] = CheckMapInput

def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
observation, reward, terminated, truncated, info = (
"\n".join(
"".join(x.decode() for x in y)
for y in self.env.get_wrapper_attr("desc")
),
0,
False,
False,
{},
)
return observation, reward, terminated, truncated, info


class CheckPositionInput(BaseModel): ...


class CheckPositionTool(BaseFrozenLakeTool, BaseTool):
name = "check_position"
description = """Peeks at current position map without changing its state."""
args_schema: Type[BaseModel] = CheckMapInput

def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
observation, reward, terminated, truncated, info = (
MoveTool._convert_frozenlake_observation_to_position(
self.env.get_wrapper_attr("s"), nrow=self.env.get_wrapper_attr("nrow")
),
0,
False,
False,
{},
)
return observation, reward, terminated, truncated, info
Loading

0 comments on commit db002a0

Please sign in to comment.