diff --git a/src/xminigrid/experimental/img_obs.py b/src/xminigrid/experimental/img_obs.py index 759bf0a..cc854b4 100644 --- a/src/xminigrid/experimental/img_obs.py +++ b/src/xminigrid/experimental/img_obs.py @@ -14,6 +14,7 @@ from ..wrappers import Wrapper CACHE_PATH = os.environ.get("XLAND_MINIGRID_CACHE", os.path.expanduser("~/.xland_minigrid")) +FORCE_RELOAD = os.environ.get("XLAND_MINIGRID_RELOAD_CACHE", False) def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np.ndarray]: @@ -47,7 +48,8 @@ def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np. TILE_SIZE = 32 cache_path = os.path.join(CACHE_PATH, "render_cache") -if not os.path.exists(cache_path): + +if not os.path.exists(cache_path) or FORCE_RELOAD: os.makedirs(CACHE_PATH, exist_ok=True) print("Building rendering cache, may take a while...") TILE_CACHE, TILE_W_AGENT_CACHE = build_cache(np.asarray(TILES_REGISTRY), tile_size=TILE_SIZE) @@ -80,7 +82,7 @@ def _render_obs(obs: jax.Array) -> jax.Array: class RGBImgObservationWrapper(Wrapper): def observation_shape(self, params): - return params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, NUM_LAYERS + return params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, 3 def reset(self, params, key): timestep = self._env.reset(params, key)