Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,38 @@ You can then use this agent by specifying the path to the file and the class nam
> [!NOTE]
> See the [agents folder](https://github.com/microsoft/tale-suite/tree/main/agents) for more concrete examples.

## 5. Training Your Language Agents on TALES
TALES offers both train splits and test splits, the latter of which make up the games all models in our technical report were evaluated on.

The following is an example of how to import desired environments and allow an agent to play through them.

Note that importing the relevant framework automatically registers all environments in that framework with gym.
You can individually import the frameworks if you want to only evaluate on them one at a time.
For now, we do not include a jericho train split.

```
import gymnasium as gym
from tales import *

# Training splits
train_envs = [env_spec.id for env_spec in gym.envs.registry.values() if "tales/" in env_spec.id and 'train' in env_spec.id]

# Testing splits
envs = [env_spec.id for env_spec in gym.envs.registry.values() if "tales/" in env_spec.id and 'train' not in env_spec.id]

train_env = gym.make(
train_envs[0],
disable_env_checker=True,
admissible_commands=True,
)

test_env = gym.make(
envs[0],
disable_env_checker=True,
admissible_commands=True,
)
```

## Citation
```
@article{cui2025tales,
Expand Down
21 changes: 20 additions & 1 deletion tales/alfworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,36 @@
from .alfworld_env import ALFWorldTask

environments = []
train_environments = []

for split in ["seen", "unseen"]:
for task_type in TASK_TYPES:
gamefiles = sorted(alfworld_data.get_alfworld_game(task_type, split))
train_gamefiles = gamefiles[1:]
test_gamefiles = [gamefiles[0]]

task_name = task_type.replace("_", " ").title().replace(" ", "")
env_name = f"ALFWorld{task_name}{split.title()}"
environments.append([env_name, "v0"])

gym.register(
id=f"tales/{env_name}-v0",
entry_point="tales.alfworld:ALFWorldTask",
kwargs={"task_type": task_type, "split": split},
kwargs={
"all_gamefiles": test_gamefiles,
"start_gamefile": test_gamefiles[0],
},
)

train_env_name = env_name + "_train"
train_environments.append([train_env_name, "v0"])
gym.register(
id=f"tales/{train_env_name}-v0",
entry_point="tales.alfworld:ALFWorldTask",
kwargs={
"all_gamefiles": train_gamefiles,
"start_gamefile": train_gamefiles[0],
},
)


Expand Down
6 changes: 3 additions & 3 deletions tales/alfworld/alfworld_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def step(self, action):

class ALFWorldTask(ALFWorldEnv):

def __init__(self, task_type, split, *args, **kwargs):
self.gamefiles = sorted(alfworld_data.get_alfworld_game(task_type, split))
super().__init__(self.gamefiles[0], *args, **kwargs)
def __init__(self, all_gamefiles, start_gamefile, *args, **kwargs):
self.gamefiles = all_gamefiles
super().__init__(start_gamefile, *args, **kwargs)

def reset(self, *, seed=None, options=None):
if seed is not None:
Expand Down
183 changes: 0 additions & 183 deletions tales/get_env_splits.py

This file was deleted.

11 changes: 10 additions & 1 deletion tales/scienceworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .scienceworld_env import TASK_NAMES, ScienceWorldEnv

environments = []
train_environments = []

for task_name in TASK_NAMES:
env_name = f"ScienceWorld{task_name.title().replace('-', '')}"
Expand All @@ -11,7 +12,15 @@
gym.register(
id=f"tales/{env_name}-v0",
entry_point="tales.scienceworld:ScienceWorldEnv",
kwargs={"task_name": task_name},
kwargs={"task_name": task_name, "split": "test"},
)

train_env_name = env_name + "_train"
train_environments.append([train_env_name, "v0"])
gym.register(
id=f"tales/{train_env_name}-v0",
entry_point="tales.scienceworld:ScienceWorldEnv",
kwargs={"task_name": task_name, "split": "train"},
)


Expand Down
6 changes: 4 additions & 2 deletions tales/scienceworld/scienceworld_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

class ScienceWorldEnv(gym.Env):

def __init__(self, task_name, admissible_commands=False, *args, **kwargs):
def __init__(
self, task_name, admissible_commands=False, split="Test", *args, **kwargs
):
self.task_name = task_name
self.admissible_commands = admissible_commands
self.env = scienceworld.ScienceWorldEnv(self.task_name, envStepLimit=np.inf)
self.variations = scienceworld_data.get_variations(
self.task_name, split="test", env=self.env
self.task_name, split=split, env=self.env
)
self.variation = self.variations[0]

Expand Down
14 changes: 13 additions & 1 deletion tales/textworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,28 @@
from .textworld_env import TextWorldEnv, TWCookingEnv

environments = []
train_environments = []

# TWCookingEnv
for difficulty in range(1, 10 + 1):
gamefiles = sorted(textworld_data.get_cooking_game(difficulty))
train_gamefiles = gamefiles[1:]
test_gamefiles = [gamefiles[0]]
env_name = f"TWCookingLevel{difficulty}"
environments.append([env_name, "v0"])

gym.register(
id=f"tales/{env_name}-v0",
entry_point="tales.textworld:TWCookingEnv",
kwargs={"difficulty": difficulty},
kwargs={"all_gamefiles": test_gamefiles, "start_gamefile": test_gamefiles[0]},
)

train_env_name = env_name + "_train"
train_environments.append([train_env_name, "v0"])
gym.register(
id=f"tales/{train_env_name}-v0",
entry_point="tales.textworld:TWCookingEnv",
kwargs={"all_gamefiles": train_gamefiles, "start_gamefile": train_gamefiles[0]},
)


Expand Down
6 changes: 3 additions & 3 deletions tales/textworld/textworld_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def step(self, action):


class TWCookingEnv(TextWorldEnv):
def __init__(self, difficulty, *args, **kwargs):
self.gamefiles = sorted(textworld_data.get_cooking_game(difficulty))
super().__init__(self.gamefiles[0], *args, **kwargs)
def __init__(self, all_gamefiles, start_gamefile, *args, **kwargs):
self.gamefiles = all_gamefiles
super().__init__(start_gamefile, *args, **kwargs)

def reset(self, *, seed=None, options=None):
if seed is not None:
Expand Down
11 changes: 10 additions & 1 deletion tales/textworld_express/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .twx_env import TASKS, TextWorldExpressEnv

environments = []
train_environments = []

for task_name, game_name, game_params in TASKS:
env_name = f"TWX{task_name}"
Expand All @@ -11,7 +12,15 @@
gym.register(
id=f"tales/{env_name}-v0",
entry_point="tales.textworld_express:TextWorldExpressEnv",
kwargs={"game_name": game_name, "game_params": game_params},
kwargs={"game_name": game_name, "game_params": game_params, "split": "test"},
)

train_env_name = env_name + "_train"
train_environments.append([train_env_name, "v0"])
gym.register(
id=f"tales/{train_env_name}-v0",
entry_point="tales.textworld_express:TextWorldExpressEnv",
kwargs={"game_name": game_name, "game_params": game_params, "split": "train"},
)


Expand Down