Skip to content

Commit

Permalink
updates for typings + paper
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Jan 30, 2024
1 parent 870cd5f commit abf181c
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 69 deletions.
10 changes: 1 addition & 9 deletions scripts/ruleset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,6 @@ def sample_ruleset(
# you can add other field if needed, just copy-paste this file!
# saving counts, as later they will be padded to the same size
"num_rules": len([r for r in rules if not isinstance(r, EmptyRule)]),
"num_init_tiles": len(init_tiles),
"max_chain_depth": num_levels,
"num_distractor_rules": num_distractor_rules,
}


Expand Down Expand Up @@ -281,13 +278,11 @@ def sample_ruleset(
"rules": jnp.vstack([r.encode() for r in ruleset["rules"]]),
"init_tiles": jnp.array(ruleset["init_tiles"], dtype=jnp.uint8),
"num_rules": jnp.asarray(ruleset["num_rules"], dtype=jnp.uint8),
"num_init_tiles": jnp.asarray(ruleset["num_init_tiles"], dtype=jnp.uint8),
"max_chain_depth": jnp.asarray(ruleset["max_chain_depth"], dtype=jnp.uint8),
"num_distractor_rules": jnp.asarray(ruleset["num_distractor_rules"], dtype=jnp.uint8),
}
)
unique_rulesets_encodings.add(encode(ruleset))

del unique_rulesets_encodings
# concatenating padded rulesets, for convenient sampling in jax
# as in jax we can not retrieve single item from the list/pytree under jit
# also all rulesets in one benchmark should have same shapes to work under jit
Expand All @@ -306,9 +301,6 @@ def sample_ruleset(
"rules": jnp.vstack([pad_along_axis(r["rules"], pad_to=max_rules)[None, ...] for r in rulesets]),
"init_tiles": jnp.vstack([pad_along_axis(r["init_tiles"], pad_to=max_tiles)[None, ...] for r in rulesets]),
"num_rules": jnp.vstack([r["num_rules"] for r in rulesets]),
"num_init_tiles": jnp.vstack([r["num_init_tiles"] for r in rulesets]),
"max_chain_depth": jnp.vstack([r["max_chain_depth"] for r in rulesets]),
"num_distractor_rules": jnp.vstack([r["num_distractor_rules"] for r in rulesets]),
}
print("Saving...")
save_bz2_pickle(concat_rulesets, args.save_path, protocol=-1)
Expand Down
78 changes: 76 additions & 2 deletions src/xminigrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,84 @@
from .registration import make, register, registered_environments

# TODO: add __all__
__version__ = "0.5.0"
__version__ = "0.5.1"

# ---------- XLand-MiniGrid environments ----------

# WARN: TMP, only for FPS measurements
# register(
# id="MiniGrid-1Rules",
# entry_point="xminigrid.envs.xland_tmp:XLandMiniGrid",
# num_rules=1,
# height=16,
# width=16,
# )
#
# register(
# id="MiniGrid-3Rules",
# entry_point="xminigrid.envs.xland_tmp:XLandMiniGrid",
# num_rules=2,
# height=16,
# width=16,
# )
#
# register(
# id="MiniGrid-6Rules",
# entry_point="xminigrid.envs.xland_tmp:XLandMiniGrid",
# num_rules=6,
# height=16,
# width=16,
# )
#
# register(
# id="MiniGrid-12Rules",
# entry_point="xminigrid.envs.xland_tmp:XLandMiniGrid",
# num_rules=12,
# height=16,
# width=16,
# )
#
# register(
# id="MiniGrid-24Rules",
# entry_point="xminigrid.envs.xland_tmp:XLandMiniGrid",
# num_rules=24,
# height=16,
# width=16,
# )

# register(
# id="XLand-MiniGrid-R1-8x8",
# entry_point="xminigrid.envs.xland:XLandMiniGrid",
# grid_type="R1",
# height=8,
# width=8,
# )
#
# register(
# id="XLand-MiniGrid-R1-16x16",
# entry_point="xminigrid.envs.xland:XLandMiniGrid",
# grid_type="R1",
# height=16,
# width=16,
# )
#
# register(
# id="XLand-MiniGrid-R1-32x32",
# entry_point="xminigrid.envs.xland:XLandMiniGrid",
# grid_type="R1",
# height=32,
# width=32,
# )
#
# register(
# id="XLand-MiniGrid-R1-64x64",
# entry_point="xminigrid.envs.xland:XLandMiniGrid",
# grid_type="R1",
# height=64,
# width=64,
# )


# TODO: reconsider grid sizes and time limits after the benchmarks are generated.
# Should be enough space for initial tiles even in the hardest setting
register(
Expand Down Expand Up @@ -48,7 +123,6 @@
width=13,
)


register(
id="XLand-MiniGrid-R2-17x17",
entry_point="xminigrid.envs.xland:XLandMiniGrid",
Expand Down
53 changes: 27 additions & 26 deletions src/xminigrid/benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
from __future__ import annotations

import bz2
import os
import pickle
import urllib.request
from typing import Callable, Dict
from typing import Callable

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from flax import struct
from jax.random import KeyArray
from tqdm.auto import tqdm

from .types import RuleSet

HF_REPO_ID = os.environ.get("XLAND_MINIGRID_HF_REPO_ID", "Howuhh/xland_minigrid")
DATA_PATH = os.environ.get("XLAND_MINIGRID_DATA", os.path.expanduser("~/.xland_minigrid"))


NAME2HFFILENAME = {
# 1M pre-sampled tasks
"trivial-1m": "trivial_1m",
"small-1m": "small_1m",
"small-dist-1m": "small_dist_1m",
"medium-1m": "medium_1m",
"medium-3m": "medium_3m",
"high-1m": "high_1m",
# 5M pre-sampled tasks (TODO)
"trivial-5m": "",
"small-5m": "",
"small-dist-5m": "",
"medium-5m": "",
"high-5m": "",
"high-3m": "high_3m",
}


Expand All @@ -46,28 +43,39 @@ def num_rulesets(self) -> int:
def get_ruleset(self, ruleset_id: int | jax.Array) -> RuleSet:
return get_ruleset(self.goals, self.rules, self.init_tiles, ruleset_id)

def sample_ruleset(self, key: jax.Array) -> RuleSet:
def sample_ruleset(self, key: KeyArray) -> RuleSet:
ruleset_id = jax.random.randint(key, shape=(), minval=0, maxval=self.num_rulesets())
return self.get_ruleset(ruleset_id)

def shuffle(self, key) -> "Benchmark":
def shuffle(self, key: KeyArray) -> Benchmark:
idxs = jax.random.permutation(key, jnp.arange(len(self.num_rules)))
return jtu.tree_map(lambda a: a[idxs], self)

def split(self, prop: float) -> tuple["Benchmark", "Benchmark"]:
def split(self, prop: float) -> tuple[Benchmark, Benchmark]:
idx = round(len(self.num_rules) * prop)
bench1 = jtu.tree_map(lambda a: a[:idx], self)
bench2 = jtu.tree_map(lambda a: a[idx:], self)
return bench1, bench2

def filter_split(self, fn: Callable[[jax.Array, jax.Array], bool]) -> tuple["Benchmark", "Benchmark"]:
def filter_split(self, fn: Callable[[jax.Array, jax.Array], bool]) -> tuple[Benchmark, Benchmark]:
# fn(single_goal, single_rules) -> bool
mask = jax.vmap(fn)(self.goals, self.rules)
bench1 = jtu.tree_map(lambda a: a[mask], self)
bench2 = jtu.tree_map(lambda a: a[~mask], self)
return bench1, bench2


def load_benchmark_from_path(path: str) -> Benchmark:
benchmark_dict = load_bz2_pickle(path)
benchmark = Benchmark(
goals=benchmark_dict["goals"],
rules=benchmark_dict["rules"],
init_tiles=benchmark_dict["init_tiles"],
num_rules=benchmark_dict["num_rules"],
)
return benchmark


def load_benchmark(name: str) -> Benchmark:
if name not in NAME2HFFILENAME:
raise RuntimeError(f"Unknown benchmark. Registered: {registered_benchmarks()}")
Expand All @@ -78,21 +86,14 @@ def load_benchmark(name: str) -> Benchmark:
if not os.path.exists(path):
_download_from_hf(HF_REPO_ID, NAME2HFFILENAME[name])

benchmark_dict = load_bz2_pickle(path)
benchmark = Benchmark(
goals=benchmark_dict["goals"],
rules=benchmark_dict["rules"],
init_tiles=benchmark_dict["init_tiles"],
num_rules=benchmark_dict["num_rules"],
)
return benchmark
return load_benchmark_from_path(path)


def registered_benchmarks():
def registered_benchmarks() -> tuple[str, ...]:
return tuple(NAME2HFFILENAME.keys())


def _download_from_hf(repo_id: str, filename: str):
def _download_from_hf(repo_id: str, filename: str) -> None:
dataset_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{filename}"

save_path = os.path.join(DATA_PATH, filename)
Expand All @@ -115,7 +116,7 @@ def get_ruleset(
goals: jax.Array,
rules: jax.Array,
init_tiles: jax.Array,
ruleset_id: int,
ruleset_id: int | jax.Array,
) -> RuleSet:
goal = jax.lax.dynamic_index_in_dim(goals, ruleset_id, keepdims=False)
rules = jax.lax.dynamic_index_in_dim(rules, ruleset_id, keepdims=False)
Expand All @@ -124,12 +125,12 @@ def get_ruleset(
return RuleSet(goal=goal, rules=rules, init_tiles=init_tiles)


def save_bz2_pickle(ruleset, path, protocol=-1) -> None:
def save_bz2_pickle(ruleset: dict[str, jax.Array], path: str, protocol: int = -1) -> None:
with bz2.open(path, "wb") as f:
pickle.dump(ruleset, f, protocol=protocol)


def load_bz2_pickle(path) -> Dict[str, jax.Array]:
def load_bz2_pickle(path: str) -> dict[str, jax.Array]:
with bz2.open(path, "rb") as f:
ruleset = pickle.load(f)
return ruleset
1 change: 1 addition & 0 deletions src/xminigrid/core/goals.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def check_goal(
return check


# should I inherit from abc.ABC?
class BaseGoal(struct.PyTreeNode):
@abc.abstractmethod
def __call__(self, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array) -> jax.Array:
Expand Down
5 changes: 3 additions & 2 deletions src/xminigrid/core/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax
import jax.numpy as jnp
from jax.random import KeyArray

from ..types import GridState, Tile
from .constants import FREE_TO_PUT_DOWN, LOS_BLOCKING, PICKABLE, TILES_REGISTRY, WALKABLE, Colors, Tiles
Expand Down Expand Up @@ -158,7 +159,7 @@ def coordinates_mask(grid: GridState, address: tuple[int, int], comparison_fn: C
return mask


def sample_coordinates(key: jax.Array, grid: GridState, num: int, mask: jax.Array | None = None) -> jax.Array:
def sample_coordinates(key: KeyArray, grid: GridState, num: int, mask: jax.Array | None = None) -> jax.Array:
if mask is None:
mask = jnp.ones((grid.shape[0], grid.shape[1]), dtype=jnp.bool_)

Expand All @@ -174,7 +175,7 @@ def sample_coordinates(key: jax.Array, grid: GridState, num: int, mask: jax.Arra
return coords


def sample_direction(key: jax.Array) -> jax.Array:
def sample_direction(key: KeyArray) -> jax.Array:
return jax.random.randint(key, shape=(), minval=0, maxval=4)


Expand Down
17 changes: 12 additions & 5 deletions src/xminigrid/environment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from __future__ import annotations

from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
from jax.random import KeyArray

from .core.actions import take_action
from .core.constants import NUM_ACTIONS, NUM_LAYERS
Expand All @@ -23,24 +29,25 @@ class EnvParams(struct.PyTreeNode):
render_mode: str = struct.field(pytree_node=False, default="rgb_array")


# TODO: add generic type hints (on env params)
class Environment:
def default_params(self, **kwargs) -> EnvParams:
def default_params(self, **kwargs: Any) -> EnvParams:
return EnvParams().replace(**kwargs)

def num_actions(self, params: EnvParams) -> int:
return int(NUM_ACTIONS)

def observation_shape(self, params: EnvParams) -> tuple[int, int, int]:
return (params.view_size, params.view_size, NUM_LAYERS)
return params.view_size, params.view_size, NUM_LAYERS

# TODO: NOT sure that this should be hardcoded like that...
def time_limit(self, params: EnvParams) -> int:
return 3 * params.height * params.width

def _generate_problem(self, params: EnvParams, key: jax.Array) -> State:
def _generate_problem(self, params: EnvParams, key: KeyArray) -> State:
return NotImplemented

def reset(self, params: EnvParams, key: jax.Array) -> TimeStep:
def reset(self, params: EnvParams, key: KeyArray) -> TimeStep:
state = self._generate_problem(params, key)
timestep = TimeStep(
state=state,
Expand Down Expand Up @@ -81,7 +88,7 @@ def step(self, params: EnvParams, timestep: TimeStep, action: int) -> TimeStep:
)
return timestep

def render(self, params: EnvParams, timestep: TimeStep):
def render(self, params: EnvParams, timestep: TimeStep) -> np.ndarray | str:
if params.render_mode == "rgb_array":
return rgb_render(timestep.state.grid, timestep.state.agent, params.view_size)
elif params.render_mode == "rich_text":
Expand Down
Loading

0 comments on commit abf181c

Please sign in to comment.