Skip to content

Commit

Permalink
Add ALFWorld example for Reflexion
Browse files Browse the repository at this point in the history
  • Loading branch information
saridormi committed Mar 26, 2024
1 parent d14b40f commit 869c7b4
Show file tree
Hide file tree
Showing 14 changed files with 2,920 additions and 2 deletions.
Empty file.
Empty file.
55 changes: 55 additions & 0 deletions environments/alfworld/common/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import gymnasium as gym
from langchain_core.agents import AgentAction
from langchain_core.tools import BaseTool
import alfworld.agents.environment as environment # type: ignore[import-untyped]
import yaml # type: ignore[import-untyped]
from typing import Dict, Any, Tuple, Optional, Sequence
from gymnasium.core import SupportsFloat
from .tools import get_alfworld_tools
from planning_library.action_executors import DefaultActionExecutor
from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped]


class ALFWorldEnv(gym.Env[str, AgentAction]):
def __init__(
self,
config_path: str,
):
with open(config_path) as reader:
config = yaml.safe_load(reader)

env = getattr(environment, config["env"]["type"])(config, train_eval="train")
self.env: TextworldBatchGymEnv = env.init_env(batch_size=1)
self._action_executor = DefaultActionExecutor(
tools=get_alfworld_tools(env=self.env)
)

@property
def tools(self) -> Sequence[BaseTool]:
return self._action_executor.tools

def seed(self, seed: Optional[int] = None):
self.env.seed(seed)

def step(
self, action: AgentAction
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
result = self._action_executor.execute(action)
observation, reward, terminated, truncated, info = result.observation
return observation, reward, terminated, truncated, {}

def reset(
self,
*,
seed: int | None = None,
options: Dict[str, Any] | None = None,
) -> Tuple[str, Dict[str, Any]]:
obs, infos = self.env.reset()
observation = obs[0]
info = {key: infos[key][0] for key in infos}

if options is not None and "trajectory" in options:
for action in options["trajectory"]:
assert isinstance(action, AgentAction)
observation, reward, terminated, truncated, info = self.step(action)
return observation, info
20 changes: 20 additions & 0 deletions environments/alfworld/common/evaluate_output_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import re

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser


class ALFWorldEvaluateOutputParser(BaseOutputParser[float]):
def parse(self, text: str) -> float:
try:
match = re.search(r"\[\[(.*?)\]\]", text.strip())
if not match:
raise ValueError("Pattern [[number]] not found.")
result = float(match.groups()[0])
if result < 0.0 or result > 1.0:
raise ValueError("The given number is out of (0.0, 1.0) range.")
return result
except ValueError:
raise OutputParserException(
f"Couldn't convert {text} to float between 0 and 1."
)
244 changes: 244 additions & 0 deletions environments/alfworld/common/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
from langchain.pydantic_v1 import BaseModel
from langchain.tools import BaseTool
from typing import Type, Any, Tuple, Dict, List

from gymnasium.core import SupportsFloat
from .tools_utils import (
BaseALFWorldTool,
ReceptableInput,
ObjectOrReceptableInput,
ObjectAndReceptableInput,
EmptyInput,
)
from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped]


def get_alfworld_tools(env: TextworldBatchGymEnv) -> List[BaseTool]:
return [
GoToTool(env=env), # type: ignore[call-arg]
OpenTool(env=env), # type: ignore[call-arg]
CloseTool(env=env), # type: ignore[call-arg]
TakeTool(env=env), # type: ignore[call-arg]
PutTool(env=env), # type: ignore[call-arg]
ToggleTool(env=env), # type: ignore[call-arg]
HeatTool(env=env), # type: ignore[call-arg]
CoolTool(env=env), # type: ignore[call-arg]
CleanTool(env=env), # type: ignore[call-arg]
ExamineTool(env=env), # type: ignore[call-arg]
InventoryTool(env=env), # type: ignore[call-arg]
LookTool(env=env), # type: ignore[call-arg]
]


class GoToTool(BaseALFWorldTool, BaseTool):
name = "goto"
description = """Go to a specified receptable (static object)."""
args_schema: Type[BaseModel] = ReceptableInput

def _run(
self,
receptable_type: str,
receptable_id: int,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(
[f"go to {receptable_type} {receptable_id}"]
)
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class OpenTool(BaseALFWorldTool, BaseTool):
name = "open"
description = """Open a specified receptable (static object). Only available when you're already near a receptable."""
args_schema: Type[BaseModel] = ReceptableInput

def _run(
self,
receptable_type: str,
receptable_id: int,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(
[f"open {receptable_type} {receptable_id}"]
)
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class CloseTool(BaseALFWorldTool, BaseTool):
name = "close"
description = """Close a specified receptable (static object). Only available when you're already near a receptable."""
args_schema: Type[BaseModel] = ReceptableInput

def _run(
self,
receptable_type: str,
receptable_id: int,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(
[f"close {receptable_type} {receptable_id}"]
)
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class TakeTool(BaseALFWorldTool, BaseTool):
name = "take"
description = """Pick up a specified portable object from a specified receptable (static object). Only available when you're already near a receptable."""
args_schema: Type[BaseModel] = ObjectAndReceptableInput

def _run(
self,
object_type: str,
object_id: int,
receptable_type: str,
receptable_id: int,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(
[f"take {object_type} {object_id} from {receptable_type} {receptable_id}"]
)
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class PutTool(BaseALFWorldTool, BaseTool):
name = "put"
description = """Put a specified portable object in/щт a specified receptable (static object). Only available when you're already near a receptable and carry a portable object in your inventory."""
args_schema: Type[BaseModel] = ObjectAndReceptableInput

def _run(
self,
object_type: str,
object_id: int,
receptable_type: str,
receptable_id: int,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(
[f"put {object_type} {object_id} in/on {receptable_type} {receptable_id}"]
)
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class ToggleTool(BaseALFWorldTool, BaseTool):
name = "toggle"
description = """Toggle a specified object on/off (can be either a portable object or a static receptable). Only available when you're already near a receptable/a portable object or carry a portable object."""
args_schema: Type[BaseModel] = ObjectOrReceptableInput

def _run(
self,
type: str,
id: int,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step([f"toggle {type} {id}"])
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class HeatTool(BaseALFWorldTool, BaseTool):
name = "heat"
description = """Heat a portable object via a receptable (static object). Only available when you're already near a receptable and carry a portable object."""
args_schema: Type[BaseModel] = ObjectAndReceptableInput

def _run(
self,
object_type: str,
object_id: int,
receptable_type: str,
receptable_id: int,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(
[f"heat {object_type} {object_id} with {receptable_type} {receptable_id}"]
)
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class CoolTool(BaseALFWorldTool, BaseTool):
name = "cool"
description = """Cool a portable object via a receptable (static object). Only available when you're already near a receptable and carry a portable object."""
args_schema: Type[BaseModel] = ObjectAndReceptableInput

def _run(
self,
object_type: str,
object_id: int,
receptable_type: str,
receptable_id: int,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(
[f"cool {object_type} {object_id} with {receptable_type} {receptable_id}"]
)
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class CleanTool(BaseALFWorldTool, BaseTool):
name = "clean"
description = """Clean a portable object via a receptable (static object). Only available when you're already near a receptable and a portable object or carry a portable object."""
args_schema: Type[BaseModel] = ObjectAndReceptableInput

def _run(
self,
object_type: str,
object_id: int,
receptable_type: str,
receptable_id: int,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(
[f"clean {object_type} {object_id} with {receptable_type} {receptable_id}"]
)
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class ExamineTool(BaseALFWorldTool, BaseTool):
name = "examine"
description = """Examine a specified object (can be either a portable object or a static receptable). Only available when you're already near a receptable/a portable object or carry a portable object."""
args_schema: Type[BaseModel] = ObjectOrReceptableInput

def _run(
self,
type: str,
id: int,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step([f"examine {type} {id}"])
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class InventoryTool(BaseALFWorldTool, BaseTool):
name = "inventory"
description = """Check if you are carrying any portable objects."""
args_schema: Type[BaseModel] = EmptyInput

def _run(
self,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(["inventory"])
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}


class LookTool(BaseALFWorldTool, BaseTool):
name = "look"
description = """Check your surroundings."""
args_schema: Type[BaseModel] = EmptyInput

def _run(
self,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, scores, dones, infos = self.env.step(["look"])
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos}
72 changes: 72 additions & 0 deletions environments/alfworld/common/tools_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool
from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped]


class EmptyInput(BaseModel): ...


class ObjectInput(BaseModel):
object_type: str = Field(
description="A type of the portable object.", examples=["the apple", "the mug"]
)
object_id: int = Field(
description="A specific number associated with the object (e.g., when there are "
"several mugs in the room, those would be mug 1 and mug 2).",
examples=[1, 2, 3],
)


class ReceptableInput(BaseModel):
receptable_type: str = Field(
description="A type of the receptable.",
examples=["the coffee table", "the drawer", "the countertop"],
)
receptable_id: int = Field(
description="A specific number associated with the receptable (e.g., when there are "
"several drawers in the room, those would be drawer 1 and drawer 2).",
examples=[1, 2, 3],
)


class ObjectAndReceptableInput(BaseModel):
object_type: str = Field(
description="A type of the portable object.", examples=["the apple", "the mug"]
)
object_id: int = Field(
description="A specific number associated with the object (e.g., when there are "
"several mugs in the room, those would be mug 1 and mug 2).",
examples=[1, 2, 3],
)
receptable_type: str = Field(
description="A type of the receptable.",
examples=["the coffee table", "the drawer", "the countertop"],
)
receptable_id: int = Field(
description="A specific number associated with the receptable (e.g., when there are "
"several drawers in the room, those would be drawer 1 and drawer 2).",
examples=[1, 2, 3],
)


class ObjectOrReceptableInput(BaseModel):
type: str = Field(
description="A type of the object (might be either a portable object or a static one).",
examples=["the apple", "the coffee table"],
)
id: int = Field(
description="A specific number associated with the object (e.g., when there are "
"several mugs in the room, those would be mug 1 and mug 2).",
examples=[1, 2, 3],
)


class BaseALFWorldTool(BaseModel):
"""Base tool for an ALFWorld environment.
Environment is present as a field, but it won't be shown to models."""

env: TextworldBatchGymEnv = Field(exclude=True)

class Config(BaseTool.Config):
pass
Loading

0 comments on commit 869c7b4

Please sign in to comment.