-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
2,920 additions
and
2 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import gymnasium as gym | ||
from langchain_core.agents import AgentAction | ||
from langchain_core.tools import BaseTool | ||
import alfworld.agents.environment as environment # type: ignore[import-untyped] | ||
import yaml # type: ignore[import-untyped] | ||
from typing import Dict, Any, Tuple, Optional, Sequence | ||
from gymnasium.core import SupportsFloat | ||
from .tools import get_alfworld_tools | ||
from planning_library.action_executors import DefaultActionExecutor | ||
from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped] | ||
|
||
|
||
class ALFWorldEnv(gym.Env[str, AgentAction]): | ||
def __init__( | ||
self, | ||
config_path: str, | ||
): | ||
with open(config_path) as reader: | ||
config = yaml.safe_load(reader) | ||
|
||
env = getattr(environment, config["env"]["type"])(config, train_eval="train") | ||
self.env: TextworldBatchGymEnv = env.init_env(batch_size=1) | ||
self._action_executor = DefaultActionExecutor( | ||
tools=get_alfworld_tools(env=self.env) | ||
) | ||
|
||
@property | ||
def tools(self) -> Sequence[BaseTool]: | ||
return self._action_executor.tools | ||
|
||
def seed(self, seed: Optional[int] = None): | ||
self.env.seed(seed) | ||
|
||
def step( | ||
self, action: AgentAction | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
result = self._action_executor.execute(action) | ||
observation, reward, terminated, truncated, info = result.observation | ||
return observation, reward, terminated, truncated, {} | ||
|
||
def reset( | ||
self, | ||
*, | ||
seed: int | None = None, | ||
options: Dict[str, Any] | None = None, | ||
) -> Tuple[str, Dict[str, Any]]: | ||
obs, infos = self.env.reset() | ||
observation = obs[0] | ||
info = {key: infos[key][0] for key in infos} | ||
|
||
if options is not None and "trajectory" in options: | ||
for action in options["trajectory"]: | ||
assert isinstance(action, AgentAction) | ||
observation, reward, terminated, truncated, info = self.step(action) | ||
return observation, info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import re | ||
|
||
from langchain_core.exceptions import OutputParserException | ||
from langchain_core.output_parsers import BaseOutputParser | ||
|
||
|
||
class ALFWorldEvaluateOutputParser(BaseOutputParser[float]): | ||
def parse(self, text: str) -> float: | ||
try: | ||
match = re.search(r"\[\[(.*?)\]\]", text.strip()) | ||
if not match: | ||
raise ValueError("Pattern [[number]] not found.") | ||
result = float(match.groups()[0]) | ||
if result < 0.0 or result > 1.0: | ||
raise ValueError("The given number is out of (0.0, 1.0) range.") | ||
return result | ||
except ValueError: | ||
raise OutputParserException( | ||
f"Couldn't convert {text} to float between 0 and 1." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
from langchain.pydantic_v1 import BaseModel | ||
from langchain.tools import BaseTool | ||
from typing import Type, Any, Tuple, Dict, List | ||
|
||
from gymnasium.core import SupportsFloat | ||
from .tools_utils import ( | ||
BaseALFWorldTool, | ||
ReceptableInput, | ||
ObjectOrReceptableInput, | ||
ObjectAndReceptableInput, | ||
EmptyInput, | ||
) | ||
from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped] | ||
|
||
|
||
def get_alfworld_tools(env: TextworldBatchGymEnv) -> List[BaseTool]: | ||
return [ | ||
GoToTool(env=env), # type: ignore[call-arg] | ||
OpenTool(env=env), # type: ignore[call-arg] | ||
CloseTool(env=env), # type: ignore[call-arg] | ||
TakeTool(env=env), # type: ignore[call-arg] | ||
PutTool(env=env), # type: ignore[call-arg] | ||
ToggleTool(env=env), # type: ignore[call-arg] | ||
HeatTool(env=env), # type: ignore[call-arg] | ||
CoolTool(env=env), # type: ignore[call-arg] | ||
CleanTool(env=env), # type: ignore[call-arg] | ||
ExamineTool(env=env), # type: ignore[call-arg] | ||
InventoryTool(env=env), # type: ignore[call-arg] | ||
LookTool(env=env), # type: ignore[call-arg] | ||
] | ||
|
||
|
||
class GoToTool(BaseALFWorldTool, BaseTool): | ||
name = "goto" | ||
description = """Go to a specified receptable (static object).""" | ||
args_schema: Type[BaseModel] = ReceptableInput | ||
|
||
def _run( | ||
self, | ||
receptable_type: str, | ||
receptable_id: int, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step( | ||
[f"go to {receptable_type} {receptable_id}"] | ||
) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} | ||
|
||
|
||
class OpenTool(BaseALFWorldTool, BaseTool): | ||
name = "open" | ||
description = """Open a specified receptable (static object). Only available when you're already near a receptable.""" | ||
args_schema: Type[BaseModel] = ReceptableInput | ||
|
||
def _run( | ||
self, | ||
receptable_type: str, | ||
receptable_id: int, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step( | ||
[f"open {receptable_type} {receptable_id}"] | ||
) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} | ||
|
||
|
||
class CloseTool(BaseALFWorldTool, BaseTool): | ||
name = "close" | ||
description = """Close a specified receptable (static object). Only available when you're already near a receptable.""" | ||
args_schema: Type[BaseModel] = ReceptableInput | ||
|
||
def _run( | ||
self, | ||
receptable_type: str, | ||
receptable_id: int, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step( | ||
[f"close {receptable_type} {receptable_id}"] | ||
) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} | ||
|
||
|
||
class TakeTool(BaseALFWorldTool, BaseTool): | ||
name = "take" | ||
description = """Pick up a specified portable object from a specified receptable (static object). Only available when you're already near a receptable.""" | ||
args_schema: Type[BaseModel] = ObjectAndReceptableInput | ||
|
||
def _run( | ||
self, | ||
object_type: str, | ||
object_id: int, | ||
receptable_type: str, | ||
receptable_id: int, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step( | ||
[f"take {object_type} {object_id} from {receptable_type} {receptable_id}"] | ||
) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} | ||
|
||
|
||
class PutTool(BaseALFWorldTool, BaseTool): | ||
name = "put" | ||
description = """Put a specified portable object in/щт a specified receptable (static object). Only available when you're already near a receptable and carry a portable object in your inventory.""" | ||
args_schema: Type[BaseModel] = ObjectAndReceptableInput | ||
|
||
def _run( | ||
self, | ||
object_type: str, | ||
object_id: int, | ||
receptable_type: str, | ||
receptable_id: int, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step( | ||
[f"put {object_type} {object_id} in/on {receptable_type} {receptable_id}"] | ||
) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} | ||
|
||
|
||
class ToggleTool(BaseALFWorldTool, BaseTool): | ||
name = "toggle" | ||
description = """Toggle a specified object on/off (can be either a portable object or a static receptable). Only available when you're already near a receptable/a portable object or carry a portable object.""" | ||
args_schema: Type[BaseModel] = ObjectOrReceptableInput | ||
|
||
def _run( | ||
self, | ||
type: str, | ||
id: int, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step([f"toggle {type} {id}"]) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} | ||
|
||
|
||
class HeatTool(BaseALFWorldTool, BaseTool): | ||
name = "heat" | ||
description = """Heat a portable object via a receptable (static object). Only available when you're already near a receptable and carry a portable object.""" | ||
args_schema: Type[BaseModel] = ObjectAndReceptableInput | ||
|
||
def _run( | ||
self, | ||
object_type: str, | ||
object_id: int, | ||
receptable_type: str, | ||
receptable_id: int, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step( | ||
[f"heat {object_type} {object_id} with {receptable_type} {receptable_id}"] | ||
) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} | ||
|
||
|
||
class CoolTool(BaseALFWorldTool, BaseTool): | ||
name = "cool" | ||
description = """Cool a portable object via a receptable (static object). Only available when you're already near a receptable and carry a portable object.""" | ||
args_schema: Type[BaseModel] = ObjectAndReceptableInput | ||
|
||
def _run( | ||
self, | ||
object_type: str, | ||
object_id: int, | ||
receptable_type: str, | ||
receptable_id: int, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step( | ||
[f"cool {object_type} {object_id} with {receptable_type} {receptable_id}"] | ||
) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} | ||
|
||
|
||
class CleanTool(BaseALFWorldTool, BaseTool): | ||
name = "clean" | ||
description = """Clean a portable object via a receptable (static object). Only available when you're already near a receptable and a portable object or carry a portable object.""" | ||
args_schema: Type[BaseModel] = ObjectAndReceptableInput | ||
|
||
def _run( | ||
self, | ||
object_type: str, | ||
object_id: int, | ||
receptable_type: str, | ||
receptable_id: int, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step( | ||
[f"clean {object_type} {object_id} with {receptable_type} {receptable_id}"] | ||
) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} | ||
|
||
|
||
class ExamineTool(BaseALFWorldTool, BaseTool): | ||
name = "examine" | ||
description = """Examine a specified object (can be either a portable object or a static receptable). Only available when you're already near a receptable/a portable object or carry a portable object.""" | ||
args_schema: Type[BaseModel] = ObjectOrReceptableInput | ||
|
||
def _run( | ||
self, | ||
type: str, | ||
id: int, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step([f"examine {type} {id}"]) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} | ||
|
||
|
||
class InventoryTool(BaseALFWorldTool, BaseTool): | ||
name = "inventory" | ||
description = """Check if you are carrying any portable objects.""" | ||
args_schema: Type[BaseModel] = EmptyInput | ||
|
||
def _run( | ||
self, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step(["inventory"]) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} | ||
|
||
|
||
class LookTool(BaseALFWorldTool, BaseTool): | ||
name = "look" | ||
description = """Check your surroundings.""" | ||
args_schema: Type[BaseModel] = EmptyInput | ||
|
||
def _run( | ||
self, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
obs, scores, dones, infos = self.env.step(["look"]) | ||
return obs[0], scores[0], dones[0], False, {key: infos[key][0] for key in infos} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from langchain.pydantic_v1 import BaseModel, Field | ||
from langchain.tools import BaseTool | ||
from textworld.gym.envs.textworld_batch import TextworldBatchGymEnv # type: ignore[import-untyped] | ||
|
||
|
||
class EmptyInput(BaseModel): ... | ||
|
||
|
||
class ObjectInput(BaseModel): | ||
object_type: str = Field( | ||
description="A type of the portable object.", examples=["the apple", "the mug"] | ||
) | ||
object_id: int = Field( | ||
description="A specific number associated with the object (e.g., when there are " | ||
"several mugs in the room, those would be mug 1 and mug 2).", | ||
examples=[1, 2, 3], | ||
) | ||
|
||
|
||
class ReceptableInput(BaseModel): | ||
receptable_type: str = Field( | ||
description="A type of the receptable.", | ||
examples=["the coffee table", "the drawer", "the countertop"], | ||
) | ||
receptable_id: int = Field( | ||
description="A specific number associated with the receptable (e.g., when there are " | ||
"several drawers in the room, those would be drawer 1 and drawer 2).", | ||
examples=[1, 2, 3], | ||
) | ||
|
||
|
||
class ObjectAndReceptableInput(BaseModel): | ||
object_type: str = Field( | ||
description="A type of the portable object.", examples=["the apple", "the mug"] | ||
) | ||
object_id: int = Field( | ||
description="A specific number associated with the object (e.g., when there are " | ||
"several mugs in the room, those would be mug 1 and mug 2).", | ||
examples=[1, 2, 3], | ||
) | ||
receptable_type: str = Field( | ||
description="A type of the receptable.", | ||
examples=["the coffee table", "the drawer", "the countertop"], | ||
) | ||
receptable_id: int = Field( | ||
description="A specific number associated with the receptable (e.g., when there are " | ||
"several drawers in the room, those would be drawer 1 and drawer 2).", | ||
examples=[1, 2, 3], | ||
) | ||
|
||
|
||
class ObjectOrReceptableInput(BaseModel): | ||
type: str = Field( | ||
description="A type of the object (might be either a portable object or a static one).", | ||
examples=["the apple", "the coffee table"], | ||
) | ||
id: int = Field( | ||
description="A specific number associated with the object (e.g., when there are " | ||
"several mugs in the room, those would be mug 1 and mug 2).", | ||
examples=[1, 2, 3], | ||
) | ||
|
||
|
||
class BaseALFWorldTool(BaseModel): | ||
"""Base tool for an ALFWorld environment. | ||
Environment is present as a field, but it won't be shown to models.""" | ||
|
||
env: TextworldBatchGymEnv = Field(exclude=True) | ||
|
||
class Config(BaseTool.Config): | ||
pass |
Oops, something went wrong.