Skip to content

Commit 4a24b15

Browse files
committed
image obs rendering
1 parent e07118a commit 4a24b15

File tree

10 files changed

+157
-55
lines changed

10 files changed

+157
-55
lines changed

src/xminigrid/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .registration import make, register, registered_environments
33

44
# TODO: add __all__
5-
__version__ = "0.6.0"
5+
__version__ = "0.7.0"
66

77
# ---------- XLand-MiniGrid environments ----------
88

src/xminigrid/core/constants.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,44 @@
11
import jax.numpy as jnp
22
from flax import struct
33

4+
NUM_ACTIONS = 6
5+
46
# GRID: [tile, color]
57
NUM_LAYERS = 2
6-
NUM_TILES = 15
7-
NUM_COLORS = 14
8-
NUM_ACTIONS = 6
8+
NUM_TILES = 13
9+
NUM_COLORS = 12
910

1011

11-
# TODO: do we really need END_OF_MAP? seem like unseen can be used instead...
1212
# enums, kinda...
1313
class Tiles(struct.PyTreeNode):
1414
EMPTY: int = struct.field(pytree_node=False, default=0)
15-
END_OF_MAP: int = struct.field(pytree_node=False, default=1)
16-
UNSEEN: int = struct.field(pytree_node=False, default=2)
17-
FLOOR: int = struct.field(pytree_node=False, default=3)
18-
WALL: int = struct.field(pytree_node=False, default=4)
19-
BALL: int = struct.field(pytree_node=False, default=5)
20-
SQUARE: int = struct.field(pytree_node=False, default=6)
21-
PYRAMID: int = struct.field(pytree_node=False, default=7)
22-
GOAL: int = struct.field(pytree_node=False, default=8)
23-
KEY: int = struct.field(pytree_node=False, default=9)
24-
DOOR_LOCKED: int = struct.field(pytree_node=False, default=10)
25-
DOOR_CLOSED: int = struct.field(pytree_node=False, default=11)
26-
DOOR_OPEN: int = struct.field(pytree_node=False, default=12)
27-
HEX: int = struct.field(pytree_node=False, default=13)
28-
STAR: int = struct.field(pytree_node=False, default=14)
15+
FLOOR: int = struct.field(pytree_node=False, default=1)
16+
WALL: int = struct.field(pytree_node=False, default=2)
17+
BALL: int = struct.field(pytree_node=False, default=3)
18+
SQUARE: int = struct.field(pytree_node=False, default=4)
19+
PYRAMID: int = struct.field(pytree_node=False, default=5)
20+
GOAL: int = struct.field(pytree_node=False, default=6)
21+
KEY: int = struct.field(pytree_node=False, default=7)
22+
DOOR_LOCKED: int = struct.field(pytree_node=False, default=8)
23+
DOOR_CLOSED: int = struct.field(pytree_node=False, default=9)
24+
DOOR_OPEN: int = struct.field(pytree_node=False, default=10)
25+
HEX: int = struct.field(pytree_node=False, default=11)
26+
STAR: int = struct.field(pytree_node=False, default=12)
2927

3028

3129
class Colors(struct.PyTreeNode):
3230
EMPTY: int = struct.field(pytree_node=False, default=0)
33-
END_OF_MAP: int = struct.field(pytree_node=False, default=1)
34-
UNSEEN: int = struct.field(pytree_node=False, default=2)
35-
RED: int = struct.field(pytree_node=False, default=3)
36-
GREEN: int = struct.field(pytree_node=False, default=4)
37-
BLUE: int = struct.field(pytree_node=False, default=5)
38-
PURPLE: int = struct.field(pytree_node=False, default=6)
39-
YELLOW: int = struct.field(pytree_node=False, default=7)
40-
GREY: int = struct.field(pytree_node=False, default=8)
41-
BLACK: int = struct.field(pytree_node=False, default=9)
42-
ORANGE: int = struct.field(pytree_node=False, default=10)
43-
WHITE: int = struct.field(pytree_node=False, default=11)
44-
BROWN: int = struct.field(pytree_node=False, default=12)
45-
PINK: int = struct.field(pytree_node=False, default=13)
31+
RED: int = struct.field(pytree_node=False, default=1)
32+
GREEN: int = struct.field(pytree_node=False, default=2)
33+
BLUE: int = struct.field(pytree_node=False, default=3)
34+
PURPLE: int = struct.field(pytree_node=False, default=4)
35+
YELLOW: int = struct.field(pytree_node=False, default=5)
36+
GREY: int = struct.field(pytree_node=False, default=6)
37+
BLACK: int = struct.field(pytree_node=False, default=7)
38+
ORANGE: int = struct.field(pytree_node=False, default=8)
39+
WHITE: int = struct.field(pytree_node=False, default=9)
40+
BROWN: int = struct.field(pytree_node=False, default=10)
41+
PINK: int = struct.field(pytree_node=False, default=11)
4642

4743

4844
# Only ~100 combinations so far, better to preallocate them
@@ -65,7 +61,6 @@ class Colors(struct.PyTreeNode):
6561

6662
WALKABLE = jnp.array(
6763
(
68-
Tiles.EMPTY,
6964
Tiles.FLOOR,
7065
Tiles.GOAL,
7166
Tiles.DOOR_OPEN,
@@ -83,12 +78,7 @@ class Colors(struct.PyTreeNode):
8378
)
8479
)
8580

86-
FREE_TO_PUT_DOWN = jnp.array(
87-
(
88-
Tiles.EMPTY,
89-
Tiles.FLOOR,
90-
)
91-
)
81+
FREE_TO_PUT_DOWN = jnp.array((Tiles.FLOOR,))
9282

9383
LOS_BLOCKING = jnp.array(
9484
(

src/xminigrid/core/grid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Callable, Union
3+
from typing import Callable
44

55
import jax
66
import jax.numpy as jnp
@@ -22,7 +22,7 @@ def equal(tile1: Tile, tile2: Tile) -> Tile:
2222

2323
def get_neighbouring_tiles(grid: GridState, y: IntOrArray, x: IntOrArray) -> tuple[Tile, Tile, Tile, Tile]:
2424
# end_of_map = TILES_REGISTRY[Tiles.END_OF_MAP, Colors.END_OF_MAP]
25-
end_of_map = Tiles.END_OF_MAP
25+
end_of_map = Tiles.EMPTY
2626

2727
up_tile = grid.at[y - 1, x].get(mode="fill", fill_value=end_of_map)
2828
right_tile = grid.at[y, x + 1].get(mode="fill", fill_value=end_of_map)

src/xminigrid/core/observation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def crop_field_of_view(grid: GridState, agent: AgentState, height: int, width: i
1212
grid = jnp.pad(
1313
grid,
1414
pad_width=((height, height), (width, width), (0, 0)),
15-
constant_values=Tiles.END_OF_MAP,
15+
constant_values=Tiles.EMPTY,
1616
)
1717
# account for padding
1818
y = agent.position[0] + height
@@ -110,8 +110,8 @@ def minigrid_field_of_view(grid: GridState, agent: AgentState, height: int, widt
110110
fov_grid = crop_field_of_view(grid, agent, height, width)
111111
fov_grid = align_with_up(fov_grid, agent.direction)
112112
mask = generate_viz_mask_minigrid(fov_grid)
113-
# set UNSEEN value for all layers (including colors, as UNSEEN color has same id value)
114-
fov_grid = jnp.where(mask[..., None], fov_grid, Tiles.UNSEEN)
113+
# set EMPTY as unseen value for all layers (including colors, as EMPTY color has same id value)
114+
fov_grid = jnp.where(mask[..., None], fov_grid, Tiles.EMPTY)
115115

116116
# TODO: should we even do this? Agent with good memory can remember what he picked up.
117117
# WARN: this can overwrite tile the agent is on, GOAL for example.

src/xminigrid/envs/xland.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
init_tiles=jnp.array(((TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY],))),
2828
)
2929

30-
_empty_tile = TILES_REGISTRY[Tiles.EMPTY, Colors.EMPTY]
3130
_wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]
3231
# colors for doors between rooms
3332
_allowed_colors = jnp.array(

src/xminigrid/experimental/img_obs.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# jit-compatible RGB observations. Currently experimental!
2+
# if it proves useful and necessary in the future, I will consider rewriting env.render in such style also
3+
from __future__ import annotations
4+
5+
import os
6+
7+
import jax
8+
import jax.numpy as jnp
9+
import numpy as np
10+
11+
from ..benchmarks import load_bz2_pickle, save_bz2_pickle
12+
from ..core.constants import NUM_COLORS, NUM_LAYERS, TILES_REGISTRY
13+
from ..rendering.rgb_render import render_tile
14+
from ..wrappers import Wrapper
15+
16+
CACHE_PATH = os.environ.get("XLAND_MINIGRID_CACHE", os.path.expanduser("~/.xland_minigrid"))
17+
18+
19+
def build_cache(tiles: np.ndarray, tile_size: float = 32) -> tuple[np.ndarray, np.ndarray]:
20+
cache = np.full((tiles.shape[0], tiles.shape[1], tile_size, tile_size, 3), dtype=np.uint8, fill_value=-1)
21+
agent_cache = np.full((tiles.shape[0], tiles.shape[1], tile_size, tile_size, 3), dtype=np.uint8, fill_value=-1)
22+
23+
for y in range(tiles.shape[0]):
24+
for x in range(tiles.shape[1]):
25+
# rendering tile
26+
tile_img = render_tile(
27+
tile=tuple(tiles[y, x]),
28+
agent_direction=None,
29+
highlight=False,
30+
tile_size=int(tile_size),
31+
)
32+
cache[y, x] = tile_img
33+
34+
# rendering agent on top
35+
tile_w_agent_img = render_tile(
36+
tile=tuple(tiles[y, x]),
37+
agent_direction=0,
38+
highlight=False,
39+
tile_size=int(tile_size),
40+
)
41+
agent_cache[y, x] = tile_w_agent_img
42+
43+
return cache, agent_cache
44+
45+
46+
# building cache of pre-rendered tiles
47+
TILE_SIZE = 32
48+
49+
cache_path = os.path.join(CACHE_PATH, "render_cache")
50+
if not os.path.exists(cache_path):
51+
TILE_CACHE, TILE_W_AGENT_CACHE = build_cache(np.asarray(TILES_REGISTRY), tile_size=TILE_SIZE)
52+
TILE_CACHE = jnp.asarray(TILE_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3)
53+
TILE_W_AGENT_CACHE = jnp.asarray(TILE_W_AGENT_CACHE).reshape(-1, TILE_SIZE, TILE_SIZE, 3)
54+
55+
save_bz2_pickle({"tile_cache": TILE_CACHE, "tile_agent_cache": TILE_W_AGENT_CACHE}, cache_path)
56+
57+
TILE_CACHE = load_bz2_pickle(cache_path)["tile_cache"]
58+
TILE_W_AGENT_CACHE = load_bz2_pickle(cache_path)["tile_agent_cache"]
59+
60+
61+
# rendering with cached tiles
62+
def _render_obs(obs: jax.Array) -> jax.Array:
63+
view_size = obs.shape[0]
64+
65+
obs_flat_idxs = obs[:, :, 0] * NUM_COLORS + obs[:, :, 1]
66+
# render all tiles
67+
rendered_obs = jnp.take(TILE_CACHE, obs_flat_idxs, axis=0)
68+
69+
# add agent tile
70+
agent_tile = TILE_W_AGENT_CACHE[obs_flat_idxs[view_size - 1, view_size // 2]]
71+
rendered_obs = rendered_obs.at[view_size - 1, view_size // 2].set(agent_tile)
72+
# [view_size, view_size, tile_size, tile_size, 3] -> [view_size * tile_size, view_size * tile_size, 3]
73+
rendered_obs = rendered_obs.transpose((0, 2, 1, 3, 4)).reshape(view_size * TILE_SIZE, view_size * TILE_SIZE, 3)
74+
75+
return rendered_obs
76+
77+
78+
class RGBImgObservationWrapper(Wrapper):
79+
def observation_shape(self, params):
80+
return params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, NUM_LAYERS
81+
82+
def reset(self, params, key):
83+
timestep = self._env.reset(params, key)
84+
timestep = timestep.replace(observation=_render_obs(timestep.observation))
85+
return timestep
86+
87+
def step(self, params, timestep, action):
88+
timestep = self._env.step(params, timestep, action)
89+
timestep = timestep.replace(observation=_render_obs(timestep.observation))
90+
return timestep

src/xminigrid/manual_control.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515

1616
class ManualControl:
17-
def __init__(self, env: Environment[EnvParamsT, EnvCarryT], env_params: EnvParamsT):
17+
def __init__(self, env: Environment[EnvParamsT, EnvCarryT], env_params: EnvParamsT, agent_view: bool = False):
1818
self.env = env
1919
self.env_params = env_params
20+
self.agent_view = agent_view
2021

2122
self._reset = jax.jit(self.env.reset)
2223
self._step = jax.jit(self.env.step)
@@ -33,7 +34,10 @@ def __init__(self, env: Environment[EnvParamsT, EnvCarryT], env_params: EnvParam
3334
def render(self) -> None:
3435
assert self.timestep is not None
3536

36-
img = self.env.render(self.env_params, self.timestep)
37+
if self.agent_view:
38+
img = self.timestep.observation
39+
else:
40+
img = self.env.render(self.env_params, self.timestep)
3741
# [h, w, c] -> [w, h, c]
3842
img = np.transpose(img, axes=(1, 0, 2))
3943

@@ -144,14 +148,20 @@ def close(self) -> None:
144148
parser.add_argument("--env-id", type=str, default="MiniGrid-Empty-5x5", choices=xminigrid.registered_environments())
145149
parser.add_argument("--benchmark-id", type=str, default="trivial-1m", choices=xminigrid.registered_benchmarks())
146150
parser.add_argument("--ruleset-id", type=int, default=0)
151+
parser.add_argument("--agent-view", action="store_true")
147152

148153
args = parser.parse_args()
149154
env, env_params = xminigrid.make(args.env_id)
150155
env = GymAutoResetWrapper(env)
151156

157+
if args.agent_view:
158+
from xminigrid.experimental.img_obs import RGBImgObservationWrapper
159+
160+
env = RGBImgObservationWrapper(env)
161+
152162
if "XLand" in args.env_id:
153163
bench = xminigrid.load_benchmark(args.benchmark_id)
154164
env_params = env_params.replace(ruleset=bench.get_ruleset(args.ruleset_id))
155165

156-
control = ManualControl(env=env, env_params=env_params)
166+
control = ManualControl(env=env, env_params=env_params, agent_view=args.agent_view)
157167
control.start()

src/xminigrid/rendering/rgb_render.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from ..core.constants import Colors, Tiles
8-
from ..types import AgentState, GridState, IntOrArray
8+
from ..types import AgentState, IntOrArray
99
from .utils import (
1010
downsample,
1111
fill_coords,
@@ -18,6 +18,7 @@
1818
)
1919

2020
COLORS_MAP = {
21+
Colors.EMPTY: np.array((255, 255, 255)), # just a placeholder
2122
Colors.RED: np.array((255, 0, 0)),
2223
Colors.GREEN: np.array((0, 255, 0)),
2324
Colors.BLUE: np.array((0, 0, 255)),
@@ -32,6 +33,16 @@
3233
}
3334

3435

36+
def _render_empty(img: np.ndarray, color: int):
37+
fill_coords(img, point_in_rect(0.45, 0.55, 0.2, 0.65), COLORS_MAP[Colors.RED])
38+
fill_coords(img, point_in_rect(0.45, 0.55, 0.7, 0.85), COLORS_MAP[Colors.RED])
39+
40+
fill_coords(img, point_in_rect(0, 0.031, 0, 1), COLORS_MAP[Colors.RED])
41+
fill_coords(img, point_in_rect(0, 1, 0, 0.031), COLORS_MAP[Colors.RED])
42+
fill_coords(img, point_in_rect(1 - 0.031, 1, 0, 1), COLORS_MAP[Colors.RED])
43+
fill_coords(img, point_in_rect(0, 1, 1 - 0.031, 1), COLORS_MAP[Colors.RED])
44+
45+
3546
def _render_floor(img: np.ndarray, color: int):
3647
# draw the grid lines (top and left edges)
3748
fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
@@ -165,7 +176,7 @@ def _render_player(img: np.ndarray, direction: int):
165176
Tiles.DOOR_LOCKED: _render_door_locked,
166177
Tiles.DOOR_CLOSED: _render_door_closed,
167178
Tiles.DOOR_OPEN: _render_door_open,
168-
Tiles.EMPTY: lambda img, color: img,
179+
Tiles.EMPTY: _render_empty,
169180
}
170181

171182

@@ -196,7 +207,7 @@ def get_highlight_mask(grid: np.ndarray, agent: AgentState | None, view_size: in
196207

197208
@functools.cache
198209
def render_tile(
199-
tile: np.ndarray, agent_direction: int | None = None, highlight: bool = False, tile_size: int = 32, subdivs: int = 3
210+
tile: tuple, agent_direction: int | None = None, highlight: bool = False, tile_size: int = 32, subdivs: int = 3
200211
) -> np.ndarray:
201212
img = np.full((tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8, fill_value=255)
202213
# draw tile

src/xminigrid/rendering/text_render.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from ..types import AgentState, RuleSet
66

77
COLOR_NAMES = {
8-
Colors.END_OF_MAP: "red",
9-
Colors.UNSEEN: "white",
108
Colors.EMPTY: "white",
119
Colors.RED: "red",
1210
Colors.GREEN: "green",
@@ -22,8 +20,6 @@
2220
}
2321

2422
TILE_STR = {
25-
Tiles.END_OF_MAP: "!",
26-
Tiles.UNSEEN: "?",
2723
Tiles.EMPTY: " ",
2824
Tiles.FLOOR: ".",
2925
Tiles.WALL: "☰",

src/xminigrid/wrappers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ def __init__(self, env: Environment[EnvParamsT, EnvCarryT]):
1616
def default_params(self, **kwargs) -> EnvParamsT:
1717
return self._env.default_params(**kwargs)
1818

19+
def num_actions(self, params: EnvParamsT) -> int:
20+
return self._env.num_actions(params)
21+
22+
def observation_shape(self, params: EnvParamsT) -> tuple[int, int, int]:
23+
return self._env.observation_shape(params)
24+
1925
def time_limit(self, params: EnvParamsT) -> int:
2026
return self._env.time_limit(params)
2127

0 commit comments

Comments
 (0)