Skip to content

Commit

Permalink
Fix formatting and typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
saridormi committed Jun 19, 2024
1 parent 260100c commit 0776031
Show file tree
Hide file tree
Showing 56 changed files with 629 additions and 875 deletions.
10 changes: 7 additions & 3 deletions .github/workflows/workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ jobs:
- name: Lint with ruff
run: |
poetry run ruff check
poetry run ruff check --config pyproject.toml
- name: Check formatting with ruff
run: |
poetry run ruff format --check
poetry run ruff format --check --config pyproject.toml
- name: Check types with mypy
run: |
poetry run mypy .
poetry run mypy . --config-file pyproject.toml
- name: Check types with pyright
run: |
poetry run pyright
134 changes: 74 additions & 60 deletions environments/alfworld/adapt.ipynb

Large diffs are not rendered by default.

34 changes: 15 additions & 19 deletions environments/alfworld/common/environment.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from __future__ import annotations

from typing import Any, Dict, Optional, Sequence, SupportsFloat, Tuple

import gymnasium as gym
import yaml # type: ignore[import-untyped]
from langchain_core.agents import AgentAction
from langchain_core.callbacks import CallbackManager
from langchain_core.tools import BaseTool
import alfworld.agents.environment as environment # type: ignore[import-untyped]
import yaml # type: ignore[import-untyped]
from typing import Dict, Any, Tuple, Optional, Sequence
from gymnasium.core import SupportsFloat
from .tools import get_alfworld_tools
from planning_library.action_executors import LangchainActionExecutor
from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped]
from langchain_core.callbacks import CallbackManager

import alfworld.agents.environment as environment # type: ignore[import-untyped]
from alfworld.agents.environment.alfred_tw_env import AlfredTWEnv # type: ignore[import-untyped]
from planning_library.action_executors import LangchainActionExecutor

from .tools import get_alfworld_tools


class ALFWorldEnv(gym.Env[str, Tuple[AgentAction, Optional[CallbackManager]]]):
Expand All @@ -21,13 +23,9 @@ def __init__(
):
with open(config_path) as reader:
config = yaml.safe_load(reader)
self._alfworld_env: AlfredTWEnv = getattr(environment, config["env"]["type"])(
config, train_eval="train"
)
self._alfworld_env: AlfredTWEnv = getattr(environment, config["env"]["type"])(config, train_eval="train")
self.env: TextworldBatchGymEnv = self._alfworld_env.init_env(batch_size=1)
self._action_executor = LangchainActionExecutor(
tools=get_alfworld_tools(env=self.env)
)
self._action_executor = LangchainActionExecutor(tools=get_alfworld_tools(env=self.env))

@property
def tools(self) -> Sequence[BaseTool]:
Expand All @@ -37,10 +35,10 @@ def seed(self, seed: Optional[int] = None):
self.env.seed(seed)

def step(
self, inputs: Tuple[AgentAction, Optional[CallbackManager]]
self, action: Tuple[AgentAction, Optional[CallbackManager]]
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
action, run_manager = inputs
result = self._action_executor.execute(action, run_manager=run_manager)
lc_action, run_manager = action
result = self._action_executor.execute(lc_action, run_manager=run_manager)
try:
observation, reward, terminated, truncated, info = result.observation
except ValueError:
Expand All @@ -59,9 +57,7 @@ def reset(
) -> Tuple[str, Dict[str, Any]]:
if not options or "next_episode" not in options or not options["next_episode"]:
self.env = self._alfworld_env.init_env(batch_size=1)
self._action_executor = LangchainActionExecutor(
tools=get_alfworld_tools(env=self.env)
)
self._action_executor = LangchainActionExecutor(tools=get_alfworld_tools(env=self.env))

obs, infos = self.env.reset()
observation = obs[0]
Expand Down
4 changes: 1 addition & 3 deletions environments/alfworld/common/evaluate_output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,4 @@ def parse(self, text: str) -> float:
raise ValueError("The given number is out of (0.0, 1.0) range.")
return result
except ValueError:
raise OutputParserException(
f"Couldn't convert {text} to float between 0 and 1."
)
raise OutputParserException(f"Couldn't convert {text} to float between 0 and 1.")
48 changes: 21 additions & 27 deletions environments/alfworld/common/tools.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Any, Dict, List, SupportsFloat, Tuple, Type

from langchain.pydantic_v1 import BaseModel
from langchain.tools import BaseTool
from typing import Type, Any, Tuple, Dict, List
from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped]

from gymnasium.core import SupportsFloat
from .tools_utils import (
BaseALFWorldTool,
ReceptableInput,
ObjectOrReceptableInput,
ObjectAndReceptableInput,
EmptyInput,
ObjectAndReceptableInput,
ObjectOrReceptableInput,
ReceptableInput,
)
from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped]


def get_alfworld_tools(env: TextworldBatchGymEnv) -> List[BaseTool]:
Expand All @@ -33,7 +33,7 @@ def get_alfworld_tools(env: TextworldBatchGymEnv) -> List[BaseTool]:
class GoToTool(BaseALFWorldTool, BaseTool):
name = "goto"
description = """Go to the specified receptable (static object)."""
args_schema: Type[BaseModel] = ReceptableInput
args_schema: Type[BaseModel] = ReceptableInput # type: ignore

def _run(
self,
Expand All @@ -42,16 +42,14 @@ def _run(
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(
[f"go to {receptable_type} {receptable_id}"]
)
obs, scores, dones, infos = self.env.step([f"go to {receptable_type} {receptable_id}"])
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class OpenTool(BaseALFWorldTool, BaseTool):
name = "open"
description = """Open a specified receptable (static object). Only works when you're near a receptable and when it is closed."""
args_schema: Type[BaseModel] = ReceptableInput
args_schema: Type[BaseModel] = ReceptableInput # type: ignore

def _run(
self,
Expand All @@ -60,16 +58,14 @@ def _run(
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(
[f"open {receptable_type} {receptable_id}"]
)
obs, scores, dones, infos = self.env.step([f"open {receptable_type} {receptable_id}"])
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class CloseTool(BaseALFWorldTool, BaseTool):
name = "close"
description = """Close a specified receptable (static object). Only available when you're near a receptable and when it is closed."""
args_schema: Type[BaseModel] = ReceptableInput
args_schema: Type[BaseModel] = ReceptableInput # type: ignore

def _run(
self,
Expand All @@ -78,16 +74,14 @@ def _run(
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(
[f"close {receptable_type} {receptable_id}"]
)
obs, scores, dones, infos = self.env.step([f"close {receptable_type} {receptable_id}"])
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class TakeTool(BaseALFWorldTool, BaseTool):
name = "take"
description = """Pick up the specified portable object from the specified receptable (static object). Only works when you're near the specified receptable and the specified object is present in/on the receptable."""
args_schema: Type[BaseModel] = ObjectAndReceptableInput
args_schema: Type[BaseModel] = ObjectAndReceptableInput # type: ignore

def _run(
self,
Expand All @@ -107,7 +101,7 @@ def _run(
class PutTool(BaseALFWorldTool, BaseTool):
name = "put"
description = """Put the specified portable object in/on the specified receptable (static object). Only available when you're near the specified receptable and carry the specified portable object in your inventory."""
args_schema: Type[BaseModel] = ObjectAndReceptableInput
args_schema: Type[BaseModel] = ObjectAndReceptableInput # type: ignore

def _run(
self,
Expand All @@ -127,7 +121,7 @@ def _run(
class ToggleTool(BaseALFWorldTool, BaseTool):
name = "toggle"
description = """Toggle the specified object on/off (can be either a portable object or a static receptable). Only available when you're near the specified receptable/portable object or carry the specified portable object."""
args_schema: Type[BaseModel] = ObjectOrReceptableInput
args_schema: Type[BaseModel] = ObjectOrReceptableInput # type: ignore

def _run(
self,
Expand All @@ -143,7 +137,7 @@ def _run(
class HeatTool(BaseALFWorldTool, BaseTool):
name = "heat"
description = """Heat the portable object via the receptable (static object). Only available when you're already near the receptable and the portable object is in/on the receptable."""
args_schema: Type[BaseModel] = ObjectAndReceptableInput
args_schema: Type[BaseModel] = ObjectAndReceptableInput # type: ignore

def _run(
self,
Expand All @@ -163,7 +157,7 @@ def _run(
class CoolTool(BaseALFWorldTool, BaseTool):
name = "cool"
description = """Cool the portable object via the receptable (static object). Only available when you're already near a receptable and the portable object is in/on the receptable."""
args_schema: Type[BaseModel] = ObjectAndReceptableInput
args_schema: Type[BaseModel] = ObjectAndReceptableInput # type: ignore

def _run(
self,
Expand All @@ -183,7 +177,7 @@ def _run(
class CleanTool(BaseALFWorldTool, BaseTool):
name = "clean"
description = """Clean the portable object via the receptable (static object). Only available when you're already near a receptable and the portable object is in/on the receptable."""
args_schema: Type[BaseModel] = ObjectAndReceptableInput
args_schema: Type[BaseModel] = ObjectAndReceptableInput # type: ignore

def _run(
self,
Expand All @@ -203,7 +197,7 @@ def _run(
class ExamineTool(BaseALFWorldTool, BaseTool):
name = "examine"
description = """Examine the specified object (can be either a portable object or a static receptable). Only available when you're near the receptable/portable object or carry the specified portable object."""
args_schema: Type[BaseModel] = ObjectOrReceptableInput
args_schema: Type[BaseModel] = ObjectOrReceptableInput # type: ignore

def _run(
self,
Expand All @@ -219,7 +213,7 @@ def _run(
class InventoryTool(BaseALFWorldTool, BaseTool):
name = "inventory"
description = """Check if you are carrying any portable objects."""
args_schema: Type[BaseModel] = EmptyInput
args_schema: Type[BaseModel] = EmptyInput # type: ignore

def _run(
self,
Expand All @@ -233,7 +227,7 @@ def _run(
class LookTool(BaseALFWorldTool, BaseTool):
name = "look"
description = """Check your surroundings."""
args_schema: Type[BaseModel] = EmptyInput
args_schema: Type[BaseModel] = EmptyInput # type: ignore

def _run(
self,
Expand Down
8 changes: 2 additions & 6 deletions environments/alfworld/common/tools_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ class EmptyInput(BaseModel): ...


class ObjectInput(BaseModel):
object_type: str = Field(
description="A type of the portable object.", examples=["apple", "mug"]
)
object_type: str = Field(description="A type of the portable object.", examples=["apple", "mug"])
object_id: int = Field(
description="A specific number associated with the object (e.g., when there are "
"several mugs in the room, those would be mug 1 and mug 2).",
Expand All @@ -30,9 +28,7 @@ class ReceptableInput(BaseModel):


class ObjectAndReceptableInput(BaseModel):
object_type: str = Field(
description="A type of the portable object.", examples=["apple", "mug"]
)
object_type: str = Field(description="A type of the portable object.", examples=["apple", "mug"])
object_id: int = Field(
description="A specific number associated with the object (e.g., when there are "
"several mugs in the room, those would be mug 1 and mug 2).",
Expand Down
2 changes: 1 addition & 1 deletion environments/frozen_lake/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .tools import MoveTool, CheckMapTool, CheckPositionTool
from .environment import FrozenLakeEnvWrapper
from .evaluate_output_parser import FrozenMapEvaluateOutputParser
from .tools import CheckMapTool, CheckPositionTool, MoveTool

__all__ = [
"MoveTool",
Expand Down
17 changes: 9 additions & 8 deletions environments/frozen_lake/common/environment.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations
from typing import Any, Dict, Tuple, Sequence, Optional

from typing import Any, Dict, Optional, Sequence, SupportsFloat, Tuple

import gymnasium as gym
from gymnasium.core import ObsType, SupportsFloat
from gymnasium.envs.toy_text.frozen_lake import FrozenLakeEnv
from langchain_core.agents import AgentAction
from langchain_core.tools import BaseTool
from langchain_core.callbacks import CallbackManager
from langchain_core.tools import BaseTool

from .tools import MoveTool
from planning_library.action_executors import LangchainActionExecutor

from .tools import MoveTool


class FrozenLakeEnvWrapper(gym.Wrapper):
def __init__(self, env: FrozenLakeEnv):
Expand All @@ -22,18 +23,18 @@ def tools(self) -> Sequence[BaseTool]:
return self._action_executor.tools

def step(
self, inputs: Tuple[AgentAction, Optional[CallbackManager]]
self, action: Tuple[AgentAction, Optional[CallbackManager]]
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
action, run_manager = inputs
result = self._action_executor.execute(action)
lc_action, run_manager = action
result = self._action_executor.execute(lc_action, run_manager=run_manager)
return result.observation

def reset(
self,
*,
seed: int | None = None,
options: Dict[str, Any] | None = None,
) -> Tuple[ObsType, Dict[str, Any]]:
) -> Tuple[str, Dict[str, Any]]:
observation, info = self.env.reset(seed=seed, options=options)

if options is not None and "trajectory" in options:
Expand Down
4 changes: 1 addition & 3 deletions environments/frozen_lake/common/evaluate_output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,4 @@ def parse(self, text: str) -> float:
raise ValueError("The given number is out of (0.0, 1.0) range.")
return result
except ValueError:
raise OutputParserException(
f"Couldn't convert {text} to float between 0 and 1."
)
raise OutputParserException(f"Couldn't convert {text} to float between 0 and 1.")
Loading

0 comments on commit 0776031

Please sign in to comment.