Skip to content

Commit

Permalink
inital benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Dec 21, 2023
1 parent a8a35b2 commit aa8eefc
Show file tree
Hide file tree
Showing 9 changed files with 265 additions and 68 deletions.
20 changes: 13 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
<a href="https://github.com/astral-sh/ruff">
<img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json"/>
</a>
<a href="https://twitter.com/vladkurenkov/status/1731709425524543550">
<img src="https://badgen.net/badge/icon/twitter?icon=twitter&label"/>
</a>
<a target="_blank" href="https://colab.research.google.com/github/corl-team/xland-minigrid/blob/main/examples/walkthrough.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
Expand Down Expand Up @@ -88,7 +91,7 @@ key = jax.random.PRNGKey(0)
reset_key, ruleset_key = jax.random.split(key)

# to list available benchmarks: xminigrid.registered_benchmarks()
benchmark = xminigrid.load_benchmark(name="Trivial")
benchmark = xminigrid.load_benchmark(name="trivial-1m")
# choosing ruleset, see section on rules and goals
ruleset = benchmark.sample_ruleset(ruleset_key)

Expand Down Expand Up @@ -150,11 +153,11 @@ While composing rules and goals by hand is flexible, it can quickly become cumbe
Besides, it's hard to express efficiently in a JAX-compatible way due to the high number of heterogeneous computations

To avoid significant overhead during training and facilitate reliable comparisons between agents,
we pre-sampled several benchmarks with up to **one million unique tasks**, following the procedure used to train DeepMind
we pre-sampled several benchmarks with up to **five million unique tasks**, following the procedure used to train DeepMind
AdA agent from the original XLand. These benchmarks differ in the generation configs, producing distributions with
varying levels of diversity and average difficulty of the tasks. They can be used for different purposes, for example
the `Trivial` benchmark can be used to debug your agents, allowing very quick iterations. However, we would caution
against treating benchmarks as a progression from simple to complex. They are just different 🤷.
the `trivial-1m` benchmark can be used to debug your agents, allowing very quick iterations. However, we would caution
against treating benchmarks as a progression from simple to complex. They are just different 🤷.

Pre-sampled benchmarks are hosted on [HuggingFace](https://huggingface.co/datasets/Howuhh/xland_minigrid/tree/main) and will be downloaded and cached on the first use:

Expand All @@ -165,13 +168,16 @@ from xminigrid.benchmarks import Benchmark

# downloading to path specified by XLAND_MINIGRID_DATA,
# ~/.xland_minigrid by default
benchmark: Benchmark = xminigrid.load_benchmark(name="Trivial")
benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m")
# reusing cached on the second use
benchmark: Benchmark = xminigrid.load_benchmark(name="Trivial")
benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m")

# users can sample or get specific rulesets
benchmark.sample_ruleset(jax.random.PRNGKey(0))
benchmark.get_ruleset(ruleset_id=benchmark.num_rulesets() - 1)

# or split them for train & test
train, test = benchmark.shuffle(key=jax.random.PRNGKey(0)).split(prop=0.8)
```

We also provide the [script](scripts/ruleset_generator.py) used to generate these benchmarks. Users can use it for their own purposes:
Expand All @@ -181,7 +187,7 @@ python scripts/ruleset_generator.py --help

In depth description of all available benchmarks is provided [here (soon)]().

**P.S.** Currently only one benchmark is available. We will release more after some testing and configs balancing. Stay tuned!
**P.S.** Be aware, that benchmarks can change, as we are currently testing and balancing them!

## Environments 🌍

Expand Down
36 changes: 33 additions & 3 deletions examples/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@
"source": [
"While composing rules and goals by hand is flexible, it can quickly become cumbersome. Besides, it's hard to express efficiently in a JAX-compatible way due to the high number of heterogeneous computations\n",
"\n",
"To avoid significant overhead during training and facilitate reliable comparisons between agents, we pre-sampled several benchmarks with up to **one million unique** tasks, following the procedure used to train [DeepMind AdA](https://sites.google.com/view/adaptive-agent/) agent from the original XLand. These benchmarks differ in the generation configs, producing distributions with varying levels of diversity and average difficulty of the tasks. They can be used for different purposes, for example the Trivial benchmark can be used to debug your agents, allowing very quick iterations. \n",
"To avoid significant overhead during training and facilitate reliable comparisons between agents, we pre-sampled several benchmarks with up to **five million unique** tasks (apart from the randomization of object positions during reset), following the procedure used to train [DeepMind AdA](https://sites.google.com/view/adaptive-agent/) agent from the original XLand. These benchmarks differ in the generation configs, producing distributions with varying levels of diversity and average difficulty of the tasks. They can be used for different purposes, for example the `trivial-1m` benchmark can be used to debug your agents, allowing quick iterations. \n",
"\n",
"**Generation protocol**:\n",
"\n",
Expand Down Expand Up @@ -921,7 +921,7 @@
"source": [
"print(\"Benchmarks available:\", xminigrid.registered_benchmarks())\n",
"\n",
"benchmark = xminigrid.load_benchmark(name=\"Trivial\")\n",
"benchmark = xminigrid.load_benchmark(name=\"trivial-1m\")\n",
"print(\"Total rulesets:\", benchmark.num_rulesets())\n",
"print(\"Ruleset with id 128: \\n\", benchmark.get_ruleset(ruleset_id=128))\n",
"print(\"Random ruleset: \\n\", benchmark.sample_ruleset(jax.random.PRNGKey(0)))"
Expand Down Expand Up @@ -966,7 +966,7 @@
"outputs": [],
"source": [
"# example path, can be any your valid path\n",
"bechmark_path = os.path.join(DATA_PATH, NAME2HFFILENAME[\"Trivial\"])\n",
"bechmark_path = os.path.join(DATA_PATH, NAME2HFFILENAME[\"trivial-1m\"])\n",
"\n",
"rulesets_clear = load_bz2_pickle(bechmark_path)\n",
"loaded_benchmark = Benchmark(\n",
Expand All @@ -977,6 +977,36 @@
")"
]
},
{
"cell_type": "markdown",
"id": "0a08d28a-60fd-4a09-acb2-7e4d42538b6e",
"metadata": {},
"source": [
"You also my need splitting functionality to test generalization of your agents. For this users can use `split` or `filter_split`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "653b678b-65c1-47a8-9644-4f2d519d9e55",
"metadata": {},
"outputs": [],
"source": [
"train, test = benchmark.shuffle(key=jax.random.PRNGKey(0)).split(prop=0.8)\n",
"\n",
"# or, by some function:\n",
"def cond_fn(goal, rules):\n",
" # 0 index in the encoding is the ID\n",
" return jnp.logical_not(\n",
" jnp.logical_and(\n",
" jnp.greater_equal(goal[0], 7),\n",
" jnp.less_equal(goal[0], 14)\n",
" )\n",
" )\n",
" \n",
"train, test = benchmark.filter_split(fn=cond_fn)"
]
},
{
"cell_type": "markdown",
"id": "4786beac-9bea-4bbc-bcf9-92c21b5770d2",
Expand Down
56 changes: 56 additions & 0 deletions scripts/generate_benchmarks.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# This can take a lot of time. Generate only needed!
# TODO: provide same for 5M benchmarks

# trivial
python scripts/ruleset_generator.py \
--chain_depth=0 \
--num_distractor_objects=3 \
--total_rulesets=1_000_000 \
--save_path="trivial_1m"


# small
python scripts/ruleset_generator.py \
--prune_chain \
--prune_prob=0.3 \
--chain_depth=1 \
--sample_distractor_rules \
--num_distractor_rules=2 \
--num_distractor_objects=2 \
--total_rulesets=1_000_000 \
--save_path="small_1m"

# medium
python scripts/ruleset_generator.py \
--prune_chain \
--prune_prob=0.3 \
--chain_depth=2 \
--sample_distractor_rules \
--num_distractor_rules=3 \
--num_distractor_objects=0 \
--total_rulesets=1_000_000 \
--save_path="medium_1m"


# high
python scripts/ruleset_generator.py \
--prune_chain \
--prune_prob=0.1 \
--chain_depth=3 \
--sample_distractor_rules \
--num_distractor_rules=4 \
--num_distractor_objects=1 \
--total_rulesets=1_000_000 \
--save_path="high_1m"


# medium + distractors
python scripts/ruleset_generator.py \
--prune_chain \
--prune_prob=0.8 \
--chain_depth=2 \
--sample_distractor_rules \
--num_distractor_rules=4 \
--num_distractor_objects=2 \
--total_rulesets=1_000_000 \
--save_path="medium_dist_1m"
122 changes: 94 additions & 28 deletions scripts/ruleset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,62 @@
from tqdm.auto import tqdm, trange
from xminigrid.benchmarks import save_bz2_pickle
from xminigrid.core.constants import Colors, Tiles
from xminigrid.core.goals import AgentHoldGoal, AgentNearGoal, TileNearGoal
from xminigrid.core.goals import (
AgentHoldGoal,
AgentNearDownGoal,
AgentNearGoal,
AgentNearLeftGoal,
AgentNearRightGoal,
AgentNearUpGoal,
TileNearDownGoal,
TileNearGoal,
TileNearLeftGoal,
TileNearRightGoal,
TileNearUpGoal,
)
from xminigrid.core.grid import pad_along_axis
from xminigrid.core.rules import AgentHoldRule, AgentNearRule, EmptyRule, TileNearRule

COLORS = [Colors.RED, Colors.GREEN, Colors.BLUE, Colors.PURPLE, Colors.YELLOW, Colors.GREY, Colors.WHITE]
from xminigrid.core.rules import (
AgentHoldRule,
AgentNearDownRule,
AgentNearLeftRule,
AgentNearRightRule,
AgentNearRule,
AgentNearUpRule,
EmptyRule,
TileNearDownRule,
TileNearLeftRule,
TileNearRightRule,
TileNearRule,
TileNearUpRule,
)

COLORS = [
Colors.RED,
Colors.GREEN,
Colors.BLUE,
Colors.PURPLE,
Colors.YELLOW,
Colors.GREY,
Colors.WHITE,
Colors.BROWN,
Colors.PINK,
Colors.ORANGE,
]

# we need to distinguish between them, to avoid sampling
# near(goal, goal) goal or rule as goal tiles are not pickable
NEAR_TILES_LHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.GOAL], COLORS))
NEAR_TILES_LHS = list(
product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS)
)
# these are pickable!
NEAR_TILES_RHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY], COLORS))
NEAR_TILES_RHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

HOLD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

HOLD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY], COLORS))
PROD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY], COLORS))
# to imitate disappearance production rule
PROD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))
PROD_TILES = PROD_TILES + [(Tiles.FLOOR, Colors.BLACK)]

GOALS = (AgentHoldGoal, AgentNearGoal, TileNearGoal)
RULES = (AgentHoldRule, AgentNearRule, TileNearRule)


def encode(ruleset):
flatten_encoding = jnp.concatenate([ruleset["goal"].encode(), *[r.encode() for r in ruleset["rules"]]]).tolist()
Expand All @@ -41,42 +77,72 @@ def diff(list1, list2):


def sample_goal():
goal_idx = random.randint(0, 2)
goals = (
AgentHoldGoal,
# agent near variations
AgentNearGoal,
AgentNearUpGoal,
AgentNearDownGoal,
AgentNearLeftGoal,
AgentNearRightGoal,
# tile near variations
TileNearGoal,
TileNearUpGoal,
TileNearDownGoal,
TileNearLeftGoal,
TileNearRightGoal,
)
goal_idx = random.randint(0, 10)
if goal_idx == 0:
tile = random.choice(HOLD_TILES)
goal = AgentHoldGoal(tile=jnp.array(tile))
goal = goals[0](tile=jnp.array(tile))
return goal, (tile,)
elif goal_idx == 1:
elif 1 <= goal_idx <= 5:
tile = random.choice(NEAR_TILES_LHS)
goal = AgentNearGoal(tile=jnp.array(tile))
goal = goals[goal_idx](tile=jnp.array(tile))
return goal, (tile,)
elif goal_idx == 2:
elif 6 <= goal_idx <= 10:
tile_a = random.choice(NEAR_TILES_LHS)
tile_b = random.choice(NEAR_TILES_RHS)
goal = TileNearGoal(tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b))
goal = goals[goal_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b))
return goal, (tile_a, tile_b)
else:
raise RuntimeError(f"Unknown goal, should be one of: {GOALS}")
raise RuntimeError("Unknown goal")


def sample_rule(prod_tile, used_tiles):
rule_idx = random.randint(0, 2)
rules = (
AgentHoldRule,
# agent near variations
AgentNearRule,
AgentNearUpRule,
AgentNearDownRule,
AgentNearLeftRule,
AgentNearRightRule,
# tile near variations
TileNearRule,
TileNearUpRule,
TileNearDownRule,
TileNearLeftRule,
TileNearRightRule,
)
rule_idx = random.randint(0, 10)

if rule_idx == 0:
tile = random.choice(diff(HOLD_TILES, used_tiles))
rule = AgentHoldRule(tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
return rule, (tile,)
elif rule_idx == 1:
tile = random.choice(diff(NEAR_TILES_LHS, used_tiles))
rule = AgentNearRule(tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
elif 1 <= rule_idx <= 5:
tile = random.choice(diff(HOLD_TILES, used_tiles))
rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
return rule, (tile,)
elif rule_idx == 2:
elif 6 <= rule_idx <= 10:
tile_a = random.choice(diff(NEAR_TILES_LHS, used_tiles))
tile_b = random.choice(diff(NEAR_TILES_RHS, used_tiles))

rule = TileNearRule(tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b), prod_tile=jnp.array(prod_tile))
rule = rules[rule_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b), prod_tile=jnp.array(prod_tile))
return rule, (tile_a, tile_b)
else:
raise RuntimeError(f"Unknown rule, should be one of: {RULES}")
raise RuntimeError("Unknown rule")


# See Appendix A.2 in "Human-timescale adaptation in an open-ended task space" for sampling procedure.
Expand Down Expand Up @@ -158,7 +224,7 @@ def sample_ruleset(
"init_tiles": init_tiles,
# additional info (for example for biasing sampling by number of rules)
# you can add other field if needed, just copy-paste this file!
"num_rules": num_levels,
"num_rules": len([r for r in rules if not isinstance(r, EmptyRule)]),
}


Expand Down
2 changes: 1 addition & 1 deletion src/xminigrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .registration import make, register, registered_environments

# TODO: add __all__
__version__ = "0.2.0"
__version__ = "0.3.0"

# ---------- XLand-MiniGrid environments ----------
# TODO: reconsider grid sizes and time limits after the benchmarks are generated.
Expand Down
Loading

0 comments on commit aa8eefc

Please sign in to comment.