Skip to content

Commit

Permalink
add force cache reload, fix bug in wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Mar 6, 2024
1 parent c783db6 commit 21b486f
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/xminigrid/experimental/img_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..wrappers import Wrapper

CACHE_PATH = os.environ.get("XLAND_MINIGRID_CACHE", os.path.expanduser("~/.xland_minigrid"))
FORCE_RELOAD = os.environ.get("XLAND_MINIGRID_RELOAD_CACHE", False)


def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -47,7 +48,8 @@ def build_cache(tiles: np.ndarray, tile_size: int = 32) -> tuple[np.ndarray, np.
TILE_SIZE = 32

cache_path = os.path.join(CACHE_PATH, "render_cache")
if not os.path.exists(cache_path):

if not os.path.exists(cache_path) or FORCE_RELOAD:
os.makedirs(CACHE_PATH, exist_ok=True)
print("Building rendering cache, may take a while...")
TILE_CACHE, TILE_W_AGENT_CACHE = build_cache(np.asarray(TILES_REGISTRY), tile_size=TILE_SIZE)
Expand Down Expand Up @@ -80,7 +82,7 @@ def _render_obs(obs: jax.Array) -> jax.Array:

class RGBImgObservationWrapper(Wrapper):
def observation_shape(self, params):
return params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, NUM_LAYERS
return params.view_size * TILE_SIZE, params.view_size * TILE_SIZE, 3

def reset(self, params, key):
timestep = self._env.reset(params, key)
Expand Down

0 comments on commit 21b486f

Please sign in to comment.